diff --git a/pgtype/numrange.go b/pgtype/numrange.go new file mode 100644 index 00000000..cf42dcbd --- /dev/null +++ b/pgtype/numrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Numrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Numrange", src) +} + +func (dst *Numrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numrange) Scan(src interface{}) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go new file mode 100644 index 00000000..81202362 --- /dev/null +++ b/pgtype/numrange_test.go @@ -0,0 +1,33 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNumrangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "numrange", []interface{}{ + pgtype.Numrange{ + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Status: pgtype.Present, + }, + pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Numrange{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 84939b58..d7e28641 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -242,6 +242,7 @@ func init() { "jsonb": &Jsonb{}, "name": &Name{}, "numeric": &Numeric{}, + "numrange": &Numrange{}, "oid": &OidValue{}, "record": &Record{}, "text": &Text{}, diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh index b4220f09..bedda292 100644 --- a/pgtype/typed_range_gen.sh +++ b/pgtype/typed_range_gen.sh @@ -3,4 +3,5 @@ erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go +erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go goimports -w *range.go diff --git a/v3.md b/v3.md index a2384ace..b79ce9cd 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -numrange numeric[] point line