diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e35299e5..78ed341e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -263,12 +263,12 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) + m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}}) + m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}}) + m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}}) + m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}}) + m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}}) + m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}}) m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) diff --git a/pgtype/range.go b/pgtype/range.go index 776bc9eb..c775239d 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -276,6 +276,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } +// Range is a generic range type. type Range[T any] struct { Lower T Upper T diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 6d62e7ff..49a39a47 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -34,79 +34,43 @@ type RangeScanner interface { SetBoundTypes(lower, upper BoundType) error } -type GenericRange struct { - Lower any - Upper any - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r GenericRange) IsNull() bool { - return !r.Valid -} - -func (r GenericRange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r GenericRange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *GenericRange) ScanNull() error { - *r = GenericRange{} - return nil -} - -func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error { - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - // RangeCodec is a codec for any range type. -type RangeCodec struct { +type RangeCodec[T any] struct { ElementType *Type } -func (c *RangeCodec) FormatSupported(format int16) bool { +func (c *RangeCodec[T]) FormatSupported(format int16) bool { return c.ElementType.Codec.FormatSupported(format) } -func (c *RangeCodec) PreferredFormat() int16 { +func (c *RangeCodec[T]) PreferredFormat() int16 { if c.FormatSupported(BinaryFormatCode) { return BinaryFormatCode } return TextFormatCode } -func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(RangeValuer); !ok { return nil } switch format { case BinaryFormatCode: - return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} + return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m} case TextFormatCode: - return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} + return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m} } return nil } -type encodePlanRangeCodecRangeValuerToBinary struct { - rc *RangeCodec +type encodePlanRangeCodecRangeValuerToBinary[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -192,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt return buf, nil } -type encodePlanRangeCodecRangeValuerToText struct { - rc *RangeCodec +type encodePlanRangeCodecRangeValuerToText[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { +func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { getter := value.(RangeValuer) if getter.IsNull() { @@ -270,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) return buf, nil } -func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} + return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m} } case TextFormatCode: switch target.(type) { case RangeScanner: - return &scanPlanTextRangeToRangeScanner{rc: c, m: m} + return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m} } } return nil } -type scanPlanBinaryRangeToRangeScanner struct { - rc *RangeCodec +type scanPlanBinaryRangeToRangeScanner[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { +func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -337,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) } -type scanPlanTextRangeToRangeScanner struct { - rc *RangeCodec +type scanPlanTextRangeToRangeScanner[T any] struct { + rc *RangeCodec[T] m *Map } -func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { +func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error { rangeScanner := (target).(RangeScanner) if src == nil { @@ -387,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) } -func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -404,12 +368,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr } } -func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } - var r GenericRange + var r Range[T] err := c.PlanScan(m, oid, format, &r).Scan(src, &r) return r, err } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index ed91d3e8..23e93105 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) { }{ { sql: `select '[1,5)'::int4range`, - expected: pgtype.GenericRange{ - Lower: int32(1), - Upper: int32(5), + expected: pgtype.Range[pgtype.Int4]{ + Lower: pgtype.Int4{Int32: 1, Valid: true}, + Upper: pgtype.Int4{Int32: 5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true,