diff --git a/Rakefile b/Rakefile index 3fe26cb5..de174fae 100644 --- a/Rakefile +++ b/Rakefile @@ -11,7 +11,6 @@ generated_code_files = [ "pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", - "pgtype/range_types.go", "pgtype/zeronull/int.go", "pgtype/zeronull/int_test.go" ] diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ffdb7020..e35299e5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -371,21 +371,21 @@ func NewMap() *Map { registerDefaultPgTypeVariants("box", "_box", Box{}) registerDefaultPgTypeVariants("circle", "_circle", Circle{}) registerDefaultPgTypeVariants("date", "_date", Date{}) - registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{}) + registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{}) registerDefaultPgTypeVariants("float4", "_float4", Float4{}) registerDefaultPgTypeVariants("float8", "_float8", Float8{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Float8range{}) // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange. registerDefaultPgTypeVariants("inet", "_inet", Inet{}) registerDefaultPgTypeVariants("int2", "_int2", Int2{}) registerDefaultPgTypeVariants("int4", "_int4", Int4{}) - registerDefaultPgTypeVariants("int4range", "_int4range", Int4range{}) + registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{}) registerDefaultPgTypeVariants("int8", "_int8", Int8{}) - registerDefaultPgTypeVariants("int8range", "_int8range", Int8range{}) + registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{}) registerDefaultPgTypeVariants("interval", "_interval", Interval{}) registerDefaultPgTypeVariants("line", "_line", Line{}) registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{}) registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Numrange{}) + registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{}) registerDefaultPgTypeVariants("path", "_path", Path{}) registerDefaultPgTypeVariants("point", "_point", Point{}) registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{}) @@ -394,8 +394,8 @@ func NewMap() *Map { registerDefaultPgTypeVariants("time", "_time", Time{}) registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{}) registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{}) - registerDefaultPgTypeVariants("tsrange", "_tsrange", Tsrange{}) - registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Tstzrange{}) + registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{}) + registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{}) registerDefaultPgTypeVariants("uuid", "_uuid", UUID{}) return m diff --git a/pgtype/range.go b/pgtype/range.go index e999f6a9..776bc9eb 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -275,3 +275,47 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { return ubr, nil } + +type Range[T any] struct { + Lower T + Upper T + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Range[T]) IsNull() bool { + return !r.Valid +} + +func (r Range[T]) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Range[T]) Bounds() (lower, upper any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) ScanNull() error { + *r = Range[T]{} + return nil +} + +func (r *Range[T]) ScanBounds() (lowerTarget, upperTarget any) { + return &r.Lower, &r.Upper +} + +func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + var zero T + r.Lower = zero + } + if upper == Unbounded || upper == Empty { + var zero T + r.Upper = zero + } + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index d467b750..ed91d3e8 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -15,27 +15,27 @@ func TestRangeCodecTranscode(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4range", []pgxtest.ValueRoundTripTest{ { - pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - new(pgtype.Int4range), - isExpectedEq(pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), }, { - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ LowerType: pgtype.Inclusive, Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }, - new(pgtype.Int4range), - isExpectedEq(pgtype.Int4range{ + new(pgtype.Range[pgtype.Int4]), + isExpectedEq(pgtype.Range[pgtype.Int4]{ LowerType: pgtype.Inclusive, Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }), }, - {pgtype.Int4range{}, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, - {nil, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, + {pgtype.Range[pgtype.Int4]{}, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, + {nil, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})}, }) } @@ -47,27 +47,27 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "numrange", []pgxtest.ValueRoundTripTest{ { - pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, - new(pgtype.Float8range), - isExpectedEq(pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), }, { - pgtype.Float8range{ + pgtype.Range[pgtype.Float8]{ LowerType: pgtype.Inclusive, Lower: pgtype.Float8{Float64: 1, Valid: true}, Upper: pgtype.Float8{Float64: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }, - new(pgtype.Float8range), - isExpectedEq(pgtype.Float8range{ + new(pgtype.Range[pgtype.Float8]), + isExpectedEq(pgtype.Range[pgtype.Float8]{ LowerType: pgtype.Inclusive, Lower: pgtype.Float8{Float64: 1, Valid: true}, Upper: pgtype.Float8{Float64: 5, Valid: true}, UpperType: pgtype.Exclusive, Valid: true, }), }, - {pgtype.Float8range{}, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})}, - {nil, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})}, + {pgtype.Range[pgtype.Float8]{}, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, + {nil, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})}, }) } @@ -76,14 +76,14 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - var r pgtype.Int4range + var r pgtype.Range[pgtype.Int4] err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) require.NoError(t, err) require.Equal( t, - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, LowerType: pgtype.Inclusive, @@ -98,7 +98,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { require.Equal( t, - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{}, LowerType: pgtype.Inclusive, @@ -113,7 +113,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { require.Equal( t, - pgtype.Int4range{ + pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{}, Upper: pgtype.Int4{}, LowerType: pgtype.Empty, diff --git a/pgtype/range_types.go b/pgtype/range_types.go deleted file mode 100644 index c101fbdc..00000000 --- a/pgtype/range_types.go +++ /dev/null @@ -1,296 +0,0 @@ -// Do not edit. Generated from pgtype/range_types.go.erb -package pgtype - -type Int4range struct { - Lower Int4 - Upper Int4 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Int4range) IsNull() bool { - return !r.Valid -} - -func (r Int4range) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Int4range) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Int4range) ScanNull() error { - *r = Int4range{} - return nil -} - -func (r *Int4range) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Int4range) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Int4{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Int4{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Int8range struct { - Lower Int8 - Upper Int8 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Int8range) IsNull() bool { - return !r.Valid -} - -func (r Int8range) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Int8range) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Int8range) ScanNull() error { - *r = Int8range{} - return nil -} - -func (r *Int8range) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Int8range) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Int8{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Int8{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Numrange struct { - Lower Numeric - Upper Numeric - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Numrange) IsNull() bool { - return !r.Valid -} - -func (r Numrange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Numrange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Numrange) ScanNull() error { - *r = Numrange{} - return nil -} - -func (r *Numrange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Numrange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Numeric{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Numeric{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Tsrange struct { - Lower Timestamp - Upper Timestamp - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Tsrange) IsNull() bool { - return !r.Valid -} - -func (r Tsrange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Tsrange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Tsrange) ScanNull() error { - *r = Tsrange{} - return nil -} - -func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Timestamp{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Timestamp{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Tstzrange struct { - Lower Timestamptz - Upper Timestamptz - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Tstzrange) IsNull() bool { - return !r.Valid -} - -func (r Tstzrange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Tstzrange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Tstzrange) ScanNull() error { - *r = Tstzrange{} - return nil -} - -func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Timestamptz{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Timestamptz{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Daterange struct { - Lower Date - Upper Date - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Daterange) IsNull() bool { - return !r.Valid -} - -func (r Daterange) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Daterange) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Daterange) ScanNull() error { - *r = Daterange{} - return nil -} - -func (r *Daterange) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Date{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Date{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -type Float8range struct { - Lower Float8 - Upper Float8 - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r Float8range) IsNull() bool { - return !r.Valid -} - -func (r Float8range) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r Float8range) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *Float8range) ScanNull() error { - *r = Float8range{} - return nil -} - -func (r *Float8range) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *Float8range) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = Float8{} - } - if upper == Unbounded || upper == Empty { - r.Upper = Float8{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb deleted file mode 100644 index d181548c..00000000 --- a/pgtype/range_types.go.erb +++ /dev/null @@ -1,56 +0,0 @@ -package pgtype - -<% - [ - ["Int4range", "Int4"], - ["Int8range", "Int8"], - ["Numrange", "Numeric"], - ["Tsrange", "Timestamp"], - ["Tstzrange", "Timestamptz"], - ["Daterange", "Date"], - ["Float8range", "Float8"] - ].each do |range_type, element_type| -%> -type <%= range_type %> struct { - Lower <%= element_type %> - Upper <%= element_type %> - LowerType BoundType - UpperType BoundType - Valid bool -} - -func (r <%= range_type %>) IsNull() bool { - return !r.Valid -} - -func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) { - return r.LowerType, r.UpperType -} - -func (r <%= range_type %>) Bounds() (lower, upper any) { - return &r.Lower, &r.Upper -} - -func (r *<%= range_type %>) ScanNull() error { - *r = <%= range_type %>{} - return nil -} - -func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget any) { - return &r.Lower, &r.Upper -} - -func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error { - if lower == Unbounded || lower == Empty { - r.Lower = <%= element_type %>{} - } - if upper == Unbounded || upper == Empty { - r.Upper = <%= element_type %>{} - } - r.LowerType = lower - r.UpperType = upper - r.Valid = true - return nil -} - -<% end %>