mirror of https://github.com/jackc/pgx.git
Use generics for RangeCodec
This allows DecodeValue to return a more strongly typed value.pull/1185/head
parent
c8025fd79a
commit
976b1e03a9
|
@ -263,12 +263,12 @@ func NewMap() *Map {
|
|||
m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
||||
m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
||||
|
||||
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}})
|
||||
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}})
|
||||
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}})
|
||||
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}})
|
||||
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}})
|
||||
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}})
|
||||
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}})
|
||||
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}})
|
||||
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}})
|
||||
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}})
|
||||
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}})
|
||||
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}})
|
||||
|
||||
m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}})
|
||||
m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}})
|
||||
|
|
|
@ -276,6 +276,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
|
|||
|
||||
}
|
||||
|
||||
// Range is a generic range type.
|
||||
type Range[T any] struct {
|
||||
Lower T
|
||||
Upper T
|
||||
|
|
|
@ -34,79 +34,43 @@ type RangeScanner interface {
|
|||
SetBoundTypes(lower, upper BoundType) error
|
||||
}
|
||||
|
||||
type GenericRange struct {
|
||||
Lower any
|
||||
Upper any
|
||||
LowerType BoundType
|
||||
UpperType BoundType
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (r GenericRange) IsNull() bool {
|
||||
return !r.Valid
|
||||
}
|
||||
|
||||
func (r GenericRange) BoundTypes() (lower, upper BoundType) {
|
||||
return r.LowerType, r.UpperType
|
||||
}
|
||||
|
||||
func (r GenericRange) Bounds() (lower, upper any) {
|
||||
return &r.Lower, &r.Upper
|
||||
}
|
||||
|
||||
func (r *GenericRange) ScanNull() error {
|
||||
*r = GenericRange{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) {
|
||||
return &r.Lower, &r.Upper
|
||||
}
|
||||
|
||||
func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error {
|
||||
r.LowerType = lower
|
||||
r.UpperType = upper
|
||||
r.Valid = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// RangeCodec is a codec for any range type.
|
||||
type RangeCodec struct {
|
||||
type RangeCodec[T any] struct {
|
||||
ElementType *Type
|
||||
}
|
||||
|
||||
func (c *RangeCodec) FormatSupported(format int16) bool {
|
||||
func (c *RangeCodec[T]) FormatSupported(format int16) bool {
|
||||
return c.ElementType.Codec.FormatSupported(format)
|
||||
}
|
||||
|
||||
func (c *RangeCodec) PreferredFormat() int16 {
|
||||
func (c *RangeCodec[T]) PreferredFormat() int16 {
|
||||
if c.FormatSupported(BinaryFormatCode) {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||
if _, ok := value.(RangeValuer); !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m}
|
||||
return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m}
|
||||
case TextFormatCode:
|
||||
return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m}
|
||||
return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodePlanRangeCodecRangeValuerToBinary struct {
|
||||
rc *RangeCodec
|
||||
type encodePlanRangeCodecRangeValuerToBinary[T any] struct {
|
||||
rc *RangeCodec[T]
|
||||
m *Map
|
||||
}
|
||||
|
||||
func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
getter := value.(RangeValuer)
|
||||
|
||||
if getter.IsNull() {
|
||||
|
@ -192,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
|
|||
return buf, nil
|
||||
}
|
||||
|
||||
type encodePlanRangeCodecRangeValuerToText struct {
|
||||
rc *RangeCodec
|
||||
type encodePlanRangeCodecRangeValuerToText[T any] struct {
|
||||
rc *RangeCodec[T]
|
||||
m *Map
|
||||
}
|
||||
|
||||
func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
getter := value.(RangeValuer)
|
||||
|
||||
if getter.IsNull() {
|
||||
|
@ -270,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
|
|||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
switch target.(type) {
|
||||
case RangeScanner:
|
||||
return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m}
|
||||
return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m}
|
||||
}
|
||||
case TextFormatCode:
|
||||
switch target.(type) {
|
||||
case RangeScanner:
|
||||
return &scanPlanTextRangeToRangeScanner{rc: c, m: m}
|
||||
return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanBinaryRangeToRangeScanner struct {
|
||||
rc *RangeCodec
|
||||
type scanPlanBinaryRangeToRangeScanner[T any] struct {
|
||||
rc *RangeCodec[T]
|
||||
m *Map
|
||||
}
|
||||
|
||||
func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error {
|
||||
func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error {
|
||||
rangeScanner := (target).(RangeScanner)
|
||||
|
||||
if src == nil {
|
||||
|
@ -337,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro
|
|||
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
|
||||
}
|
||||
|
||||
type scanPlanTextRangeToRangeScanner struct {
|
||||
rc *RangeCodec
|
||||
type scanPlanTextRangeToRangeScanner[T any] struct {
|
||||
rc *RangeCodec[T]
|
||||
m *Map
|
||||
}
|
||||
|
||||
func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error {
|
||||
func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error {
|
||||
rangeScanner := (target).(RangeScanner)
|
||||
|
||||
if src == nil {
|
||||
|
@ -387,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error
|
|||
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
|
||||
}
|
||||
|
||||
func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -404,12 +368,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
|
|||
}
|
||||
}
|
||||
|
||||
func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var r GenericRange
|
||||
var r Range[T]
|
||||
err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
|
||||
return r, err
|
||||
}
|
||||
|
|
|
@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
sql: `select '[1,5)'::int4range`,
|
||||
expected: pgtype.GenericRange{
|
||||
Lower: int32(1),
|
||||
Upper: int32(5),
|
||||
expected: pgtype.Range[pgtype.Int4]{
|
||||
Lower: pgtype.Int4{Int32: 1, Valid: true},
|
||||
Upper: pgtype.Int4{Int32: 5, Valid: true},
|
||||
LowerType: pgtype.Inclusive,
|
||||
UpperType: pgtype.Exclusive,
|
||||
Valid: true,
|
||||
|
|
Loading…
Reference in New Issue