Use generics for RangeCodec

This allows DecodeValue to return a more strongly typed value.
pull/1185/head
Jack Christensen 2022-04-09 10:18:51 -05:00
parent c8025fd79a
commit 976b1e03a9
4 changed files with 34 additions and 69 deletions

View File

@ -263,12 +263,12 @@ func NewMap() *Map {
m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}})
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}})
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}})
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}})
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}})
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}})
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}})
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}})
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}})
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}})
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}})
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}})
m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}})
m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}})

View File

@ -276,6 +276,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
}
// Range is a generic range type.
type Range[T any] struct {
Lower T
Upper T

View File

@ -34,79 +34,43 @@ type RangeScanner interface {
SetBoundTypes(lower, upper BoundType) error
}
type GenericRange struct {
Lower any
Upper any
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r GenericRange) IsNull() bool {
return !r.Valid
}
func (r GenericRange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r GenericRange) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *GenericRange) ScanNull() error {
*r = GenericRange{}
return nil
}
func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
// RangeCodec is a codec for any range type.
type RangeCodec struct {
type RangeCodec[T any] struct {
ElementType *Type
}
func (c *RangeCodec) FormatSupported(format int16) bool {
func (c *RangeCodec[T]) FormatSupported(format int16) bool {
return c.ElementType.Codec.FormatSupported(format)
}
func (c *RangeCodec) PreferredFormat() int16 {
func (c *RangeCodec[T]) PreferredFormat() int16 {
if c.FormatSupported(BinaryFormatCode) {
return BinaryFormatCode
}
return TextFormatCode
}
func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(RangeValuer); !ok {
return nil
}
switch format {
case BinaryFormatCode:
return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m}
return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m}
case TextFormatCode:
return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m}
return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m}
}
return nil
}
type encodePlanRangeCodecRangeValuerToBinary struct {
rc *RangeCodec
type encodePlanRangeCodecRangeValuerToBinary[T any] struct {
rc *RangeCodec[T]
m *Map
}
func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
getter := value.(RangeValuer)
if getter.IsNull() {
@ -192,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
return buf, nil
}
type encodePlanRangeCodecRangeValuerToText struct {
rc *RangeCodec
type encodePlanRangeCodecRangeValuerToText[T any] struct {
rc *RangeCodec[T]
m *Map
}
func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
getter := value.(RangeValuer)
if getter.IsNull() {
@ -270,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
return buf, nil
}
func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case RangeScanner:
return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m}
return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m}
}
case TextFormatCode:
switch target.(type) {
case RangeScanner:
return &scanPlanTextRangeToRangeScanner{rc: c, m: m}
return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m}
}
}
return nil
}
type scanPlanBinaryRangeToRangeScanner struct {
rc *RangeCodec
type scanPlanBinaryRangeToRangeScanner[T any] struct {
rc *RangeCodec[T]
m *Map
}
func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error {
func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error {
rangeScanner := (target).(RangeScanner)
if src == nil {
@ -337,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
}
type scanPlanTextRangeToRangeScanner struct {
rc *RangeCodec
type scanPlanTextRangeToRangeScanner[T any] struct {
rc *RangeCodec[T]
m *Map
}
func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error {
func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error {
rangeScanner := (target).(RangeScanner)
if src == nil {
@ -387,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
}
func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -404,12 +368,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
}
}
func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var r GenericRange
var r Range[T]
err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
return r, err
}

View File

@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) {
}{
{
sql: `select '[1,5)'::int4range`,
expected: pgtype.GenericRange{
Lower: int32(1),
Upper: int32(5),
expected: pgtype.Range[pgtype.Int4]{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,