From 7312fb20e8702393e5da6038dc4d87e41921a6be Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Mar 2017 13:36:10 -0500 Subject: [PATCH] Add Int8range Add code generation for ranges --- pgtype/int8range.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/int8range_test.go | 25 ++++ pgtype/typed_range.go.erb | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/typed_range_gen.sh | 3 + 4 files changed, 564 insertions(+) create mode 100644 pgtype/int8range.go create mode 100644 pgtype/int8range_test.go create mode 100644 pgtype/typed_range.go.erb create mode 100644 pgtype/typed_range_gen.sh diff --git a/pgtype/int8range.go b/pgtype/int8range.go new file mode 100644 index 00000000..44946be9 --- /dev/null +++ b/pgtype/int8range.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int8range) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Int8range", src) +} + +func (dst *Int8range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8range) Scan(src interface{}) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8range) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go new file mode 100644 index 00000000..1b3e594c --- /dev/null +++ b/pgtype/int8range_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8rangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "Int8range", []interface{}{ + pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int8range{Status: pgtype.Null}, + }) +} + +func TestInt8rangeNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select Int8range(1, 10, '(]')", + value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb new file mode 100644 index 00000000..922b98b4 --- /dev/null +++ b/pgtype/typed_range.go.erb @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *<%= range_type %>) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to <%= range_type %>", src) +} + +func (dst *<%= range_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= range_type %>) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src <%= range_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= range_type %>) Scan(src interface{}) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= range_type %>) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh new file mode 100644 index 00000000..af3e2cd1 --- /dev/null +++ b/pgtype/typed_range_gen.sh @@ -0,0 +1,3 @@ +erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go +erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +goimports -w *range.go