mirror of https://github.com/jackc/pgx.git
Use generics for RangeCodec
This allows DecodeValue to return a more strongly typed value.pull/1185/head
parent
c8025fd79a
commit
976b1e03a9
|
@ -263,12 +263,12 @@ func NewMap() *Map {
|
||||||
m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
||||||
m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
||||||
|
|
||||||
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}})
|
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec[Date]{ElementType: m.oidToType[DateOID]}})
|
||||||
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}})
|
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec[Int4]{ElementType: m.oidToType[Int4OID]}})
|
||||||
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}})
|
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec[Int8]{ElementType: m.oidToType[Int8OID]}})
|
||||||
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}})
|
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec[Numeric]{ElementType: m.oidToType[NumericOID]}})
|
||||||
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}})
|
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec[Timestamp]{ElementType: m.oidToType[TimestampOID]}})
|
||||||
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}})
|
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec[Timestamptz]{ElementType: m.oidToType[TimestamptzOID]}})
|
||||||
|
|
||||||
m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}})
|
m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}})
|
||||||
m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}})
|
m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}})
|
||||||
|
|
|
@ -276,6 +276,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Range is a generic range type.
|
||||||
type Range[T any] struct {
|
type Range[T any] struct {
|
||||||
Lower T
|
Lower T
|
||||||
Upper T
|
Upper T
|
||||||
|
|
|
@ -34,79 +34,43 @@ type RangeScanner interface {
|
||||||
SetBoundTypes(lower, upper BoundType) error
|
SetBoundTypes(lower, upper BoundType) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenericRange struct {
|
|
||||||
Lower any
|
|
||||||
Upper any
|
|
||||||
LowerType BoundType
|
|
||||||
UpperType BoundType
|
|
||||||
Valid bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r GenericRange) IsNull() bool {
|
|
||||||
return !r.Valid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r GenericRange) BoundTypes() (lower, upper BoundType) {
|
|
||||||
return r.LowerType, r.UpperType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r GenericRange) Bounds() (lower, upper any) {
|
|
||||||
return &r.Lower, &r.Upper
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *GenericRange) ScanNull() error {
|
|
||||||
*r = GenericRange{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget any) {
|
|
||||||
return &r.Lower, &r.Upper
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error {
|
|
||||||
r.LowerType = lower
|
|
||||||
r.UpperType = upper
|
|
||||||
r.Valid = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RangeCodec is a codec for any range type.
|
// RangeCodec is a codec for any range type.
|
||||||
type RangeCodec struct {
|
type RangeCodec[T any] struct {
|
||||||
ElementType *Type
|
ElementType *Type
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RangeCodec) FormatSupported(format int16) bool {
|
func (c *RangeCodec[T]) FormatSupported(format int16) bool {
|
||||||
return c.ElementType.Codec.FormatSupported(format)
|
return c.ElementType.Codec.FormatSupported(format)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RangeCodec) PreferredFormat() int16 {
|
func (c *RangeCodec[T]) PreferredFormat() int16 {
|
||||||
if c.FormatSupported(BinaryFormatCode) {
|
if c.FormatSupported(BinaryFormatCode) {
|
||||||
return BinaryFormatCode
|
return BinaryFormatCode
|
||||||
}
|
}
|
||||||
return TextFormatCode
|
return TextFormatCode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
func (c *RangeCodec[T]) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
||||||
if _, ok := value.(RangeValuer); !ok {
|
if _, ok := value.(RangeValuer); !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m}
|
return &encodePlanRangeCodecRangeValuerToBinary[T]{rc: c, m: m}
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m}
|
return &encodePlanRangeCodecRangeValuerToText[T]{rc: c, m: m}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type encodePlanRangeCodecRangeValuerToBinary struct {
|
type encodePlanRangeCodecRangeValuerToBinary[T any] struct {
|
||||||
rc *RangeCodec
|
rc *RangeCodec[T]
|
||||||
m *Map
|
m *Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
func (plan *encodePlanRangeCodecRangeValuerToBinary[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
getter := value.(RangeValuer)
|
getter := value.(RangeValuer)
|
||||||
|
|
||||||
if getter.IsNull() {
|
if getter.IsNull() {
|
||||||
|
@ -192,12 +156,12 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type encodePlanRangeCodecRangeValuerToText struct {
|
type encodePlanRangeCodecRangeValuerToText[T any] struct {
|
||||||
rc *RangeCodec
|
rc *RangeCodec[T]
|
||||||
m *Map
|
m *Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
func (plan *encodePlanRangeCodecRangeValuerToText[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
getter := value.(RangeValuer)
|
getter := value.(RangeValuer)
|
||||||
|
|
||||||
if getter.IsNull() {
|
if getter.IsNull() {
|
||||||
|
@ -270,29 +234,29 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (c *RangeCodec[T]) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
case RangeScanner:
|
case RangeScanner:
|
||||||
return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m}
|
return &scanPlanBinaryRangeToRangeScanner[T]{rc: c, m: m}
|
||||||
}
|
}
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
case RangeScanner:
|
case RangeScanner:
|
||||||
return &scanPlanTextRangeToRangeScanner{rc: c, m: m}
|
return &scanPlanTextRangeToRangeScanner[T]{rc: c, m: m}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanBinaryRangeToRangeScanner struct {
|
type scanPlanBinaryRangeToRangeScanner[T any] struct {
|
||||||
rc *RangeCodec
|
rc *RangeCodec[T]
|
||||||
m *Map
|
m *Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error {
|
func (plan *scanPlanBinaryRangeToRangeScanner[T]) Scan(src []byte, target any) error {
|
||||||
rangeScanner := (target).(RangeScanner)
|
rangeScanner := (target).(RangeScanner)
|
||||||
|
|
||||||
if src == nil {
|
if src == nil {
|
||||||
|
@ -337,12 +301,12 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro
|
||||||
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
|
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanTextRangeToRangeScanner struct {
|
type scanPlanTextRangeToRangeScanner[T any] struct {
|
||||||
rc *RangeCodec
|
rc *RangeCodec[T]
|
||||||
m *Map
|
m *Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error {
|
func (plan *scanPlanTextRangeToRangeScanner[T]) Scan(src []byte, target any) error {
|
||||||
rangeScanner := (target).(RangeScanner)
|
rangeScanner := (target).(RangeScanner)
|
||||||
|
|
||||||
if src == nil {
|
if src == nil {
|
||||||
|
@ -387,7 +351,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error
|
||||||
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
|
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
func (c *RangeCodec[T]) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -404,12 +368,12 @@ func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
func (c *RangeCodec[T]) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var r GenericRange
|
var r Range[T]
|
||||||
err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
|
err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
|
||||||
return r, err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,9 +136,9 @@ func TestRangeCodecDecodeValue(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
sql: `select '[1,5)'::int4range`,
|
sql: `select '[1,5)'::int4range`,
|
||||||
expected: pgtype.GenericRange{
|
expected: pgtype.Range[pgtype.Int4]{
|
||||||
Lower: int32(1),
|
Lower: pgtype.Int4{Int32: 1, Valid: true},
|
||||||
Upper: int32(5),
|
Upper: pgtype.Int4{Int32: 5, Valid: true},
|
||||||
LowerType: pgtype.Inclusive,
|
LowerType: pgtype.Inclusive,
|
||||||
UpperType: pgtype.Exclusive,
|
UpperType: pgtype.Exclusive,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
|
|
Loading…
Reference in New Issue