Use generics for Range values

pull/1185/head
Jack Christensen 2022-04-09 09:34:37 -05:00
parent f14fb3d692
commit c8025fd79a
6 changed files with 71 additions and 380 deletions

View File

@ -11,7 +11,6 @@ generated_code_files = [
"pgtype/int.go",
"pgtype/int_test.go",
"pgtype/integration_benchmark_test.go",
"pgtype/range_types.go",
"pgtype/zeronull/int.go",
"pgtype/zeronull/int_test.go"
]

View File

@ -371,21 +371,21 @@ func NewMap() *Map {
registerDefaultPgTypeVariants("box", "_box", Box{})
registerDefaultPgTypeVariants("circle", "_circle", Circle{})
registerDefaultPgTypeVariants("date", "_date", Date{})
registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{})
registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{})
registerDefaultPgTypeVariants("float4", "_float4", Float4{})
registerDefaultPgTypeVariants("float8", "_float8", Float8{})
registerDefaultPgTypeVariants("numrange", "_numrange", Float8range{}) // There is no PostgreSQL builtin float8range so map it to numrange.
registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange.
registerDefaultPgTypeVariants("inet", "_inet", Inet{})
registerDefaultPgTypeVariants("int2", "_int2", Int2{})
registerDefaultPgTypeVariants("int4", "_int4", Int4{})
registerDefaultPgTypeVariants("int4range", "_int4range", Int4range{})
registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{})
registerDefaultPgTypeVariants("int8", "_int8", Int8{})
registerDefaultPgTypeVariants("int8range", "_int8range", Int8range{})
registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{})
registerDefaultPgTypeVariants("interval", "_interval", Interval{})
registerDefaultPgTypeVariants("line", "_line", Line{})
registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{})
registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{})
registerDefaultPgTypeVariants("numrange", "_numrange", Numrange{})
registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{})
registerDefaultPgTypeVariants("path", "_path", Path{})
registerDefaultPgTypeVariants("point", "_point", Point{})
registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{})
@ -394,8 +394,8 @@ func NewMap() *Map {
registerDefaultPgTypeVariants("time", "_time", Time{})
registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{})
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{})
registerDefaultPgTypeVariants("tsrange", "_tsrange", Tsrange{})
registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Tstzrange{})
registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{})
registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{})
registerDefaultPgTypeVariants("uuid", "_uuid", UUID{})
return m

View File

@ -275,3 +275,47 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
return ubr, nil
}
type Range[T any] struct {
Lower T
Upper T
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Range[T]) IsNull() bool {
return !r.Valid
}
func (r Range[T]) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Range[T]) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Range[T]) ScanNull() error {
*r = Range[T]{}
return nil
}
func (r *Range[T]) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
var zero T
r.Lower = zero
}
if upper == Unbounded || upper == Empty {
var zero T
r.Upper = zero
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}

View File

@ -15,27 +15,27 @@ func TestRangeCodecTranscode(t *testing.T) {
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4range", []pgxtest.ValueRoundTripTest{
{
pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true},
new(pgtype.Int4range),
isExpectedEq(pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}),
pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true},
new(pgtype.Range[pgtype.Int4]),
isExpectedEq(pgtype.Range[pgtype.Int4]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}),
},
{
pgtype.Int4range{
pgtype.Range[pgtype.Int4]{
LowerType: pgtype.Inclusive,
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
UpperType: pgtype.Exclusive, Valid: true,
},
new(pgtype.Int4range),
isExpectedEq(pgtype.Int4range{
new(pgtype.Range[pgtype.Int4]),
isExpectedEq(pgtype.Range[pgtype.Int4]{
LowerType: pgtype.Inclusive,
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
UpperType: pgtype.Exclusive, Valid: true,
}),
},
{pgtype.Int4range{}, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})},
{nil, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})},
{pgtype.Range[pgtype.Int4]{}, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})},
{nil, new(pgtype.Range[pgtype.Int4]), isExpectedEq(pgtype.Range[pgtype.Int4]{})},
})
}
@ -47,27 +47,27 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) {
pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "numrange", []pgxtest.ValueRoundTripTest{
{
pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true},
new(pgtype.Float8range),
isExpectedEq(pgtype.Float8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}),
pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true},
new(pgtype.Range[pgtype.Float8]),
isExpectedEq(pgtype.Range[pgtype.Float8]{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}),
},
{
pgtype.Float8range{
pgtype.Range[pgtype.Float8]{
LowerType: pgtype.Inclusive,
Lower: pgtype.Float8{Float64: 1, Valid: true},
Upper: pgtype.Float8{Float64: 5, Valid: true},
UpperType: pgtype.Exclusive, Valid: true,
},
new(pgtype.Float8range),
isExpectedEq(pgtype.Float8range{
new(pgtype.Range[pgtype.Float8]),
isExpectedEq(pgtype.Range[pgtype.Float8]{
LowerType: pgtype.Inclusive,
Lower: pgtype.Float8{Float64: 1, Valid: true},
Upper: pgtype.Float8{Float64: 5, Valid: true},
UpperType: pgtype.Exclusive, Valid: true,
}),
},
{pgtype.Float8range{}, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})},
{nil, new(pgtype.Float8range), isExpectedEq(pgtype.Float8range{})},
{pgtype.Range[pgtype.Float8]{}, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})},
{nil, new(pgtype.Range[pgtype.Float8]), isExpectedEq(pgtype.Range[pgtype.Float8]{})},
})
}
@ -76,14 +76,14 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var r pgtype.Int4range
var r pgtype.Range[pgtype.Int4]
err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r)
require.NoError(t, err)
require.Equal(
t,
pgtype.Int4range{
pgtype.Range[pgtype.Int4]{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive,
@ -98,7 +98,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) {
require.Equal(
t,
pgtype.Int4range{
pgtype.Range[pgtype.Int4]{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{},
LowerType: pgtype.Inclusive,
@ -113,7 +113,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) {
require.Equal(
t,
pgtype.Int4range{
pgtype.Range[pgtype.Int4]{
Lower: pgtype.Int4{},
Upper: pgtype.Int4{},
LowerType: pgtype.Empty,

View File

@ -1,296 +0,0 @@
// Do not edit. Generated from pgtype/range_types.go.erb
package pgtype
type Int4range struct {
Lower Int4
Upper Int4
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Int4range) IsNull() bool {
return !r.Valid
}
func (r Int4range) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Int4range) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Int4range) ScanNull() error {
*r = Int4range{}
return nil
}
func (r *Int4range) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Int4range) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Int4{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Int4{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Int8range struct {
Lower Int8
Upper Int8
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Int8range) IsNull() bool {
return !r.Valid
}
func (r Int8range) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Int8range) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Int8range) ScanNull() error {
*r = Int8range{}
return nil
}
func (r *Int8range) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Int8range) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Int8{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Int8{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Numrange struct {
Lower Numeric
Upper Numeric
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Numrange) IsNull() bool {
return !r.Valid
}
func (r Numrange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Numrange) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Numrange) ScanNull() error {
*r = Numrange{}
return nil
}
func (r *Numrange) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Numrange) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Numeric{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Numeric{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Tsrange struct {
Lower Timestamp
Upper Timestamp
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Tsrange) IsNull() bool {
return !r.Valid
}
func (r Tsrange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Tsrange) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Tsrange) ScanNull() error {
*r = Tsrange{}
return nil
}
func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Timestamp{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Timestamp{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Tstzrange struct {
Lower Timestamptz
Upper Timestamptz
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Tstzrange) IsNull() bool {
return !r.Valid
}
func (r Tstzrange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Tstzrange) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Tstzrange) ScanNull() error {
*r = Tstzrange{}
return nil
}
func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Timestamptz{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Timestamptz{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Daterange struct {
Lower Date
Upper Date
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Daterange) IsNull() bool {
return !r.Valid
}
func (r Daterange) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Daterange) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Daterange) ScanNull() error {
*r = Daterange{}
return nil
}
func (r *Daterange) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Daterange) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Date{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Date{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
type Float8range struct {
Lower Float8
Upper Float8
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r Float8range) IsNull() bool {
return !r.Valid
}
func (r Float8range) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r Float8range) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *Float8range) ScanNull() error {
*r = Float8range{}
return nil
}
func (r *Float8range) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *Float8range) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = Float8{}
}
if upper == Unbounded || upper == Empty {
r.Upper = Float8{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}

View File

@ -1,56 +0,0 @@
package pgtype
<%
[
["Int4range", "Int4"],
["Int8range", "Int8"],
["Numrange", "Numeric"],
["Tsrange", "Timestamp"],
["Tstzrange", "Timestamptz"],
["Daterange", "Date"],
["Float8range", "Float8"]
].each do |range_type, element_type|
%>
type <%= range_type %> struct {
Lower <%= element_type %>
Upper <%= element_type %>
LowerType BoundType
UpperType BoundType
Valid bool
}
func (r <%= range_type %>) IsNull() bool {
return !r.Valid
}
func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) {
return r.LowerType, r.UpperType
}
func (r <%= range_type %>) Bounds() (lower, upper any) {
return &r.Lower, &r.Upper
}
func (r *<%= range_type %>) ScanNull() error {
*r = <%= range_type %>{}
return nil
}
func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget any) {
return &r.Lower, &r.Upper
}
func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error {
if lower == Unbounded || lower == Empty {
r.Lower = <%= element_type %>{}
}
if upper == Unbounded || upper == Empty {
r.Upper = <%= element_type %>{}
}
r.LowerType = lower
r.UpperType = upper
r.Valid = true
return nil
}
<% end %>