Restore range support

query-exec-mode
Jack Christensen 2022-02-02 08:40:42 -06:00
parent 11223497b3
commit cebe44ee85
9 changed files with 1228 additions and 10 deletions

View File

@ -6,5 +6,14 @@ rule '.go' => '.go.erb' do |task|
sh "goimports", "-w", task.name
end
generated_code_files = [
"pgtype/int.go",
"pgtype/int_test.go",
"pgtype/integration_benchmark_test.go",
"pgtype/range_types.go",
"pgtype/zeronull/int.go",
"pgtype/zeronull/int_test.go"
]
desc "Generate code"
task generate: ["pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", "pgtype/zeronull/int.go", "pgtype/zeronull/int_test.go"]
task generate: generated_code_files

View File

@ -10,7 +10,7 @@ import (
<% [2, 4, 8].each do |pg_byte_size| %>
<% pg_bit_size = pg_byte_size * 8 %>
func TestInt<%= pg_byte_size %>Codec(t *testing.T) {
testPgxCodec(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{
testutil.RunTranscodeTests(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{
{int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))},
{int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))},
{int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))},

View File

@ -208,14 +208,6 @@ func NewConnInfo() *ConnInfo {
},
}
// ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID})
// ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
// ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
// ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
// ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID})
// ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID})
// ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID})
// ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID})
ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}})
ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}})
@ -257,6 +249,16 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
ci.RegisterDataType(DataType{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[DateOID]}})
ci.RegisterDataType(DataType{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int4OID]}})
ci.RegisterDataType(DataType{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int8OID]}})
ci.RegisterDataType(DataType{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[NumericOID]}})
ci.RegisterDataType(DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}})
ci.RegisterDataType(DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}})
// ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID})
// ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID})
ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}})
ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}})
ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}})

277
pgtype/range.go Normal file
View File

@ -0,0 +1,277 @@
package pgtype
import (
"bytes"
"encoding/binary"
"fmt"
)
type BoundType byte
const (
Inclusive = BoundType('i')
Exclusive = BoundType('e')
Unbounded = BoundType('U')
Empty = BoundType('E')
)
func (bt BoundType) String() string {
return string(bt)
}
type UntypedTextRange struct {
Lower string
Upper string
LowerType BoundType
UpperType BoundType
}
func ParseUntypedTextRange(src string) (*UntypedTextRange, error) {
utr := &UntypedTextRange{}
if src == "empty" {
utr.LowerType = Empty
utr.UpperType = Empty
return utr, nil
}
buf := bytes.NewBufferString(src)
skipWhitespace(buf)
r, _, err := buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid lower bound: %v", err)
}
switch r {
case '(':
utr.LowerType = Exclusive
case '[':
utr.LowerType = Inclusive
default:
return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r))
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid lower value: %v", err)
}
buf.UnreadRune()
if r == ',' {
utr.LowerType = Unbounded
} else {
utr.Lower, err = rangeParseValue(buf)
if err != nil {
return nil, fmt.Errorf("invalid lower value: %v", err)
}
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("missing range separator: %v", err)
}
if r != ',' {
return nil, fmt.Errorf("missing range separator: %v", r)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid upper value: %v", err)
}
if r == ')' || r == ']' {
utr.UpperType = Unbounded
} else {
buf.UnreadRune()
utr.Upper, err = rangeParseValue(buf)
if err != nil {
return nil, fmt.Errorf("invalid upper value: %v", err)
}
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("missing upper bound: %v", err)
}
switch r {
case ')':
utr.UpperType = Exclusive
case ']':
utr.UpperType = Inclusive
default:
return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r))
}
}
skipWhitespace(buf)
if buf.Len() > 0 {
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
}
return utr, nil
}
func rangeParseValue(buf *bytes.Buffer) (string, error) {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
if r == '"' {
return rangeParseQuotedValue(buf)
}
buf.UnreadRune()
s := &bytes.Buffer{}
for {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
switch r {
case '\\':
r, _, err = buf.ReadRune()
if err != nil {
return "", err
}
case ',', '[', ']', '(', ')':
buf.UnreadRune()
return s.String(), nil
}
s.WriteRune(r)
}
}
func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) {
s := &bytes.Buffer{}
for {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
switch r {
case '\\':
r, _, err = buf.ReadRune()
if err != nil {
return "", err
}
case '"':
r, _, err = buf.ReadRune()
if err != nil {
return "", err
}
if r != '"' {
buf.UnreadRune()
return s.String(), nil
}
}
s.WriteRune(r)
}
}
type UntypedBinaryRange struct {
Lower []byte
Upper []byte
LowerType BoundType
UpperType BoundType
}
// 0 = () = 00000
// 1 = empty = 00001
// 2 = [) = 00010
// 4 = (] = 00100
// 6 = [] = 00110
// 8 = ) = 01000
// 12 = ] = 01100
// 16 = ( = 10000
// 18 = [ = 10010
// 24 = = 11000
const emptyMask = 1
const lowerInclusiveMask = 2
const upperInclusiveMask = 4
const lowerUnboundedMask = 8
const upperUnboundedMask = 16
func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
ubr := &UntypedBinaryRange{}
if len(src) == 0 {
return nil, fmt.Errorf("range too short: %v", len(src))
}
rangeType := src[0]
rp := 1
if rangeType&emptyMask > 0 {
if len(src[rp:]) > 0 {
return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:]))
}
ubr.LowerType = Empty
ubr.UpperType = Empty
return ubr, nil
}
if rangeType&lowerInclusiveMask > 0 {
ubr.LowerType = Inclusive
} else if rangeType&lowerUnboundedMask > 0 {
ubr.LowerType = Unbounded
} else {
ubr.LowerType = Exclusive
}
if rangeType&upperInclusiveMask > 0 {
ubr.UpperType = Inclusive
} else if rangeType&upperUnboundedMask > 0 {
ubr.UpperType = Unbounded
} else {
ubr.UpperType = Exclusive
}
if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded {
if len(src[rp:]) > 0 {
return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:]))
}
return ubr, nil
}
if len(src[rp:]) < 4 {
return nil, fmt.Errorf("too few bytes for size: %v", src[rp:])
}
valueLen := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
val := src[rp : rp+valueLen]
rp += valueLen
if ubr.LowerType != Unbounded {
ubr.Lower = val
} else {
ubr.Upper = val
if len(src[rp:]) > 0 {
return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:]))
}
return ubr, nil
}
if ubr.UpperType != Unbounded {
if len(src[rp:]) < 4 {
return nil, fmt.Errorf("too few bytes for size: %v", src[rp:])
}
valueLen := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
ubr.Upper = src[rp : rp+valueLen]
rp += valueLen
}
if len(src[rp:]) > 0 {
return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:]))
}
return ubr, nil
}

414
pgtype/range_codec.go Normal file
View File

@ -0,0 +1,414 @@
package pgtype
import (
"database/sql/driver"
"fmt"
"github.com/jackc/pgio"
)
// RangeValuer is a type that can be converted into a PostgreSQL range.
type RangeValuer interface {
// IsNull returns true if the value is SQL NULL.
IsNull() bool
// BoundTypes returns the lower and upper bound types.
BoundTypes() (lower, upper BoundType)
// Bounds returns the lower and upper range values.
Bounds() (lower, upper interface{})
}
// RangeScanner is a type can be scanned from a PostgreSQL range.
type RangeScanner interface {
// ScanNull sets the value to SQL NULL.
ScanNull() error
// ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or
// the bound type is unbounded.
ScanBounds() (lowerTarget, upperTarget interface{})
// SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned
// (if appropriate) before SetBoundTypes is called.
SetBoundTypes(lower, upper BoundType) error
}
type GenericRange struct {
Lower interface{}
Upper interface{}
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r GenericRange) IsNull() bool {
return !r.Valid
}
func (r GenericRange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r GenericRange) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *GenericRange) ScanNull() error {
*r = GenericRange{}
return nil
}
func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
// RangeCodec is a codec for any range type.
type RangeCodec struct {
ElementDataType *DataType
}
func (c *RangeCodec) FormatSupported(format int16) bool {
return c.ElementDataType.Codec.FormatSupported(format)
}
func (c *RangeCodec) PreferredFormat() int16 {
if c.FormatSupported(BinaryFormatCode) {
return BinaryFormatCode
}
return TextFormatCode
}
func (c *RangeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
if _, ok := value.(RangeValuer); !ok {
return nil
}
switch format {
case BinaryFormatCode:
return &encodePlanRangeCodecRangeValuerToBinary{rc: c, ci: ci}
case TextFormatCode:
return &encodePlanRangeCodecRangeValuerToText{rc: c, ci: ci}
}
return nil
}
type encodePlanRangeCodecRangeValuerToBinary struct {
rc *RangeCodec
ci *ConnInfo
}
func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
getter := value.(RangeValuer)
if getter.IsNull() {
return nil, nil
}
lowerType, upperType := getter.BoundTypes()
lower, upper := getter.Bounds()
var rangeType byte
switch lowerType {
case Inclusive:
rangeType |= lowerInclusiveMask
case Unbounded:
rangeType |= lowerUnboundedMask
case Exclusive:
case Empty:
return append(buf, emptyMask), nil
default:
return nil, fmt.Errorf("unknown LowerType: %v", lowerType)
}
switch upperType {
case Inclusive:
rangeType |= upperInclusiveMask
case Unbounded:
rangeType |= upperUnboundedMask
case Exclusive:
default:
return nil, fmt.Errorf("unknown UpperType: %v", upperType)
}
buf = append(buf, rangeType)
if lowerType != Unbounded {
if lower == nil {
return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded")
}
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, lower)
if lowerPlan == nil {
return nil, fmt.Errorf("cannot encode %v as element of range", lower)
}
buf, err = lowerPlan.Encode(lower, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err)
}
if buf == nil {
return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded")
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
if upperType != Unbounded {
if upper == nil {
return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded")
}
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, upper)
if upperPlan == nil {
return nil, fmt.Errorf("cannot encode %v as element of range", upper)
}
buf, err = upperPlan.Encode(upper, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err)
}
if buf == nil {
return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded")
}
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return buf, nil
}
type encodePlanRangeCodecRangeValuerToText struct {
rc *RangeCodec
ci *ConnInfo
}
func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
getter := value.(RangeValuer)
if getter.IsNull() {
return nil, nil
}
lowerType, upperType := getter.BoundTypes()
lower, upper := getter.Bounds()
switch lowerType {
case Exclusive, Unbounded:
buf = append(buf, '(')
case Inclusive:
buf = append(buf, '[')
case Empty:
return append(buf, "empty"...), nil
default:
return nil, fmt.Errorf("unknown lower bound type %v", lowerType)
}
if lowerType != Unbounded {
if lower == nil {
return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded")
}
lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, lower)
if lowerPlan == nil {
return nil, fmt.Errorf("cannot encode %v as element of range", lower)
}
buf, err = lowerPlan.Encode(lower, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err)
}
if buf == nil {
return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded")
}
}
buf = append(buf, ',')
if upperType != Unbounded {
if upper == nil {
return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded")
}
upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, upper)
if upperPlan == nil {
return nil, fmt.Errorf("cannot encode %v as element of range", upper)
}
buf, err = upperPlan.Encode(upper, buf)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err)
}
if buf == nil {
return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded")
}
}
switch upperType {
case Exclusive, Unbounded:
buf = append(buf, ')')
case Inclusive:
buf = append(buf, ']')
default:
return nil, fmt.Errorf("unknown upper bound type %v", upperType)
}
return buf, nil
}
func (c *RangeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case RangeScanner:
return &scanPlanBinaryRangeToRangeScanner{rc: c, ci: ci}
}
case TextFormatCode:
switch target.(type) {
case RangeScanner:
return &scanPlanTextRangeToRangeScanner{rc: c, ci: ci}
}
}
return nil
}
type scanPlanBinaryRangeToRangeScanner struct {
rc *RangeCodec
ci *ConnInfo
}
func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface{}) error {
rangeScanner := (target).(RangeScanner)
if src == nil {
return rangeScanner.ScanNull()
}
ubr, err := ParseUntypedBinaryRange(src)
if err != nil {
return err
}
if ubr.LowerType == Empty {
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
}
lowerTarget, upperTarget := rangeScanner.ScanBounds()
if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive {
lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, lowerTarget)
if lowerPlan == nil {
return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
}
err = lowerPlan.Scan(ubr.Lower, lowerTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err)
}
}
if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive {
upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, upperTarget)
if upperPlan == nil {
return fmt.Errorf("cannot scan into %v from range element", upperTarget)
}
err = upperPlan.Scan(ubr.Upper, upperTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err)
}
}
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
}
type scanPlanTextRangeToRangeScanner struct {
rc *RangeCodec
ci *ConnInfo
}
func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{}) error {
rangeScanner := (target).(RangeScanner)
if src == nil {
return rangeScanner.ScanNull()
}
utr, err := ParseUntypedTextRange(string(src))
if err != nil {
return err
}
if utr.LowerType == Empty {
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
}
lowerTarget, upperTarget := rangeScanner.ScanBounds()
if utr.LowerType == Inclusive || utr.LowerType == Exclusive {
lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, lowerTarget)
if lowerPlan == nil {
return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
}
err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err)
}
}
if utr.UpperType == Inclusive || utr.UpperType == Exclusive {
upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, upperTarget)
if upperPlan == nil {
return fmt.Errorf("cannot scan into %v from range element", upperTarget)
}
err = upperPlan.Scan([]byte(utr.Upper), upperTarget)
if err != nil {
return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err)
}
}
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
}
func (c *RangeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
switch format {
case TextFormatCode:
return string(src), nil
case BinaryFormatCode:
buf := make([]byte, len(src))
copy(buf, src)
return buf, nil
default:
return nil, fmt.Errorf("unknown format code %d", format)
}
}
func (c *RangeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
if src == nil {
return nil, nil
}
var r GenericRange
err := c.PlanScan(ci, oid, format, &r, true).Scan(src, &r)
return r, err
}

View File

@ -0,0 +1,72 @@
package pgtype_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
"github.com/stretchr/testify/require"
)
func TestRangeCodecTranscode(t *testing.T) {
testutil.RunTranscodeTests(t, "int4range", []testutil.TranscodeTestCase{
{
pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true},
new(pgtype.Int4range),
isExpectedEq(pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}),
},
{
pgtype.Int4range{
LowerType: pgtype.Inclusive,
Lower: pgtype.Int4{Int: 1, Valid: true},
Upper: pgtype.Int4{Int: 5, Valid: true},
UpperType: pgtype.Exclusive, Valid: true,
},
new(pgtype.Int4range),
isExpectedEq(pgtype.Int4range{
LowerType: pgtype.Inclusive,
Lower: pgtype.Int4{Int: 1, Valid: true},
Upper: pgtype.Int4{Int: 5, Valid: true},
UpperType: pgtype.Exclusive, Valid: true,
}),
},
{pgtype.Int4range{}, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})},
{nil, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})},
})
}
func TestRangeCodecDecodeValue(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
for _, tt := range []struct {
sql string
expected interface{}
}{
{
sql: `select '[1,5)'::int4range`,
expected: pgtype.GenericRange{
Lower: int32(1),
Upper: int32(5),
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
},
} {
t.Run(tt.sql, func(t *testing.T) {
rows, err := conn.Query(context.Background(), tt.sql)
require.NoError(t, err)
for rows.Next() {
values, err := rows.Values()
require.NoError(t, err)
require.Len(t, values, 1)
require.Equal(t, tt.expected, values[0])
}
require.NoError(t, rows.Err())
})
}
}

177
pgtype/range_test.go Normal file
View File

@ -0,0 +1,177 @@
package pgtype
import (
"bytes"
"testing"
)
func TestParseUntypedTextRange(t *testing.T) {
tests := []struct {
src string
result UntypedTextRange
err error
}{
{
src: `[1,2)`,
result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `[1,2]`,
result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive},
err: nil,
},
{
src: `(1,3)`,
result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive},
err: nil,
},
{
src: ` [1,2) `,
result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `[ foo , bar )`,
result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `["foo","bar")`,
result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `["f""oo","b""ar")`,
result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `["f""oo","b""ar")`,
result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `["","bar")`,
result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `[f\"oo\,,b\\ar\))`,
result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: `empty`,
result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty},
err: nil,
},
}
for i, tt := range tests {
r, err := ParseUntypedTextRange(tt.src)
if err != tt.err {
t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err)
continue
}
if r.LowerType != tt.result.LowerType {
t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType))
}
if r.UpperType != tt.result.UpperType {
t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType))
}
if r.Lower != tt.result.Lower {
t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower)
}
if r.Upper != tt.result.Upper {
t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper)
}
}
}
func TestParseUntypedBinaryRange(t *testing.T) {
tests := []struct {
src []byte
result UntypedBinaryRange
err error
}{
{
src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive},
err: nil,
},
{
src: []byte{1},
result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty},
err: nil,
},
{
src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive},
err: nil,
},
{
src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive},
err: nil,
},
{
src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive},
err: nil,
},
{
src: []byte{8, 0, 0, 0, 2, 0, 5},
result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive},
err: nil,
},
{
src: []byte{12, 0, 0, 0, 2, 0, 5},
result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive},
err: nil,
},
{
src: []byte{16, 0, 0, 0, 2, 0, 4},
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded},
err: nil,
},
{
src: []byte{18, 0, 0, 0, 2, 0, 4},
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded},
err: nil,
},
{
src: []byte{24},
result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded},
err: nil,
},
}
for i, tt := range tests {
r, err := ParseUntypedBinaryRange(tt.src)
if err != tt.err {
t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err)
continue
}
if r.LowerType != tt.result.LowerType {
t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType))
}
if r.UpperType != tt.result.UpperType {
t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType))
}
if bytes.Compare(r.Lower, tt.result.Lower) != 0 {
t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower)
}
if bytes.Compare(r.Upper, tt.result.Upper) != 0 {
t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper)
}
}
}

218
pgtype/range_types.go Normal file
View File

@ -0,0 +1,218 @@
// Do not edit. Generated from pgtype/range_types.go.erb
package pgtype
type Int4range struct {
Lower Int4
Upper Int4
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Int4range) IsNull() bool {
return !r.Valid
}
func (r Int4range) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Int4range) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *Int4range) ScanNull() error {
*r = Int4range{}
return nil
}
func (r *Int4range) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *Int4range) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Int8range struct {
Lower Int8
Upper Int8
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Int8range) IsNull() bool {
return !r.Valid
}
func (r Int8range) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Int8range) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *Int8range) ScanNull() error {
*r = Int8range{}
return nil
}
func (r *Int8range) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *Int8range) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Numrange struct {
Lower Numeric
Upper Numeric
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Numrange) IsNull() bool {
return !r.Valid
}
func (r Numrange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Numrange) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *Numrange) ScanNull() error {
*r = Numrange{}
return nil
}
func (r *Numrange) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *Numrange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Tsrange struct {
Lower Timestamp
Upper Timestamp
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Tsrange) IsNull() bool {
return !r.Valid
}
func (r Tsrange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Tsrange) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *Tsrange) ScanNull() error {
*r = Tsrange{}
return nil
}
func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Tstzrange struct {
Lower Timestamptz
Upper Timestamptz
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Tstzrange) IsNull() bool {
return !r.Valid
}
func (r Tstzrange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Tstzrange) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *Tstzrange) ScanNull() error {
*r = Tstzrange{}
return nil
}
func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Daterange struct {
Lower Date
Upper Date
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Daterange) IsNull() bool {
return !r.Valid
}
func (r Daterange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Daterange) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *Daterange) ScanNull() error {
*r = Daterange{}
return nil
}
func (r *Daterange) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *Daterange) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}

49
pgtype/range_types.go.erb Normal file
View File

@ -0,0 +1,49 @@
package pgtype
<%
[
["Int4range", "Int4"],
["Int8range", "Int8"],
["Numrange", "Numeric"],
["Tsrange", "Timestamp"],
["Tstzrange", "Timestamptz"],
["Daterange", "Date"]
].each do |range_type, element_type|
%>
type <%= range_type %> struct {
Lower <%= element_type %>
Upper <%= element_type %>
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r <%= range_type %>) IsNull() bool {
return !r.Valid
}
func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r <%= range_type %>) Bounds() (lower, upper interface{}) {
return &r.Lower, &r.Upper
}
func (r *<%= range_type %>) ScanNull() error {
*r = <%= range_type %>{}
return nil
}
func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget interface{}) {
return &r.Lower, &r.Upper
}
func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error {
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
<% end %>