Use generics for RangeCodec

This allows DecodeValue to return a more strongly typed value.
pull/1185/head
Jack Christensen 2022-04-09 10:18:51 -05:00
parent c8025fd79a
commit 976b1e03a9
4 changed files with 34 additions and 69 deletions

View File

@ -263,12 +263,12 @@ func NewMap() *Map {
m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}})
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}})
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}})
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}})
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}})
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}})
m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}})
m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}})

View File

@ -276,6 +276,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
} }
// Range is a generic range type.
type Range[T any] struct { type Range[T any] struct {
Lower T Lower T
Upper T Upper T

View File

@ -34,79 +34,43 @@ type RangeScanner interface {
SetBoundTypes(lower, upper BoundType) error SetBoundTypes(lower, upper BoundType) error
} }
type GenericRange struct {
Lower any
Upper any
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r GenericRange) IsNull() bool {
return !r.Valid
}
func (r GenericRange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r GenericRange) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *GenericRange) ScanNull() error {
*r = GenericRange{}
return nil
}
func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
// RangeCodec is a codec for any range type. // RangeCodec is a codec for any range type.
type RangeCodec struct { type RangeCodec[T any] struct {
ElementType *Type ElementType *Type
} }
func (c *RangeCodec) FormatSupported(format int16) bool { func (c *RangeCodec[T]) FormatSupported(format int16) bool {
return c.ElementType.Codec.FormatSupported(format) return c.ElementType.Codec.FormatSupported(format)
} }
func (c *RangeCodec) PreferredFormat() int16 { func (c *RangeCodec[T]) PreferredFormat() int16 {
if c.FormatSupported(BinaryFormatCode) { if c.FormatSupported(BinaryFormatCode) {
return BinaryFormatCode return BinaryFormatCode
} }
return TextFormatCode return TextFormatCode
} }
func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(RangeValuer); !ok { if _, ok := value.(RangeValuer); !ok {
return nil return nil
} }
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m} return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m}
case TextFormatCode: case TextFormatCode:
return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m} return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m}
} }
return nil return nil
} }
type encodePlanRangeCodecRangeValuerToBinary struct { type encodePlanRangeCodecRangeValuerToBinary[T any] struct {
rc *RangeCodec rc *RangeCodec[T]
m *Map m *Map
} }
func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
getter := value.(RangeValuer) getter := value.(RangeValuer)
if getter.IsNull() { if getter.IsNull() {
@ -192,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
return buf, nil return buf, nil
} }
type encodePlanRangeCodecRangeValuerToText struct { type encodePlanRangeCodecRangeValuerToText[T any] struct {
rc *RangeCodec rc *RangeCodec[T]
m *Map m *Map
} }
func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
getter := value.(RangeValuer) getter := value.(RangeValuer)
if getter.IsNull() { if getter.IsNull() {
@ -270,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
return buf, nil return buf, nil
} }
func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {
case RangeScanner: case RangeScanner:
return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m} return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m}
} }
case TextFormatCode: case TextFormatCode:
switch target.(type) { switch target.(type) {
case RangeScanner: case RangeScanner:
return &scanPlanTextRangeToRangeScanner{rc: c, m: m} return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m}
} }
} }
return nil return nil
} }
type scanPlanBinaryRangeToRangeScanner struct { type scanPlanBinaryRangeToRangeScanner[T any] struct {
rc *RangeCodec rc *RangeCodec[T]
m *Map m *Map
} }
func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error { func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error {
rangeScanner := (target).(RangeScanner) rangeScanner := (target).(RangeScanner)
if src == nil { if src == nil {
@ -337,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
} }
type scanPlanTextRangeToRangeScanner struct { type scanPlanTextRangeToRangeScanner[T any] struct {
rc *RangeCodec rc *RangeCodec[T]
m *Map m *Map
} }
func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error { func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error {
rangeScanner := (target).(RangeScanner) rangeScanner := (target).(RangeScanner)
if src == nil { if src == nil {
@ -387,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
} }
func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
@ -404,12 +368,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
} }
} }
func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil { if src == nil {
return nil, nil return nil, nil
} }
var r GenericRange var r Range[T]
err := c.PlanScan(m, oid, format, &r).Scan(src, &r) err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
return r, err return r, err
} }

View File

@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) {
}{ }{
{ {
sql: `select '[1,5)'::int4range`, sql: `select '[1,5)'::int4range`,
expected: pgtype.GenericRange{ expected: pgtype.Range[pgtype.Int4]{
Lower: int32(1), Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: int32(5), Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive, LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive, UpperType: pgtype.Exclusive,
Valid: true, Valid: true,