From a01a9ee6dfde03a8a458ff2872dc31d56e79bd4e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 Apr 2022 14:04:25 -0500 Subject: [PATCH] Automatically register Array and FlatArray --- pgtype/pgtype.go | 155 ++++++++++++++++++++++++--------------------- stdlib/sql_test.go | 13 +++- 2 files changed, 94 insertions(+), 74 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ce06e738..787eaead 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -316,91 +316,100 @@ func NewMap() *Map { m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - registerDefaultPgTypeVariants := func(name, arrayName string, value any) { - // T - m.RegisterDefaultPgType(value, name) - - // *T - valueType := reflect.TypeOf(value) - m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) - - // []T - sliceType := reflect.SliceOf(valueType) - m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) - - // *[]T - m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) - - // []*T - sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface())) - m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName) - - // *[]*T - m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName) - } - // Integer types that directly map to a PostgreSQL type - registerDefaultPgTypeVariants("int2", "_int2", int16(0)) - registerDefaultPgTypeVariants("int4", "_int4", int32(0)) - registerDefaultPgTypeVariants("int8", "_int8", int64(0)) + registerDefaultPgTypeVariants[int16](m, "int2") + registerDefaultPgTypeVariants[int32](m, "int4") + registerDefaultPgTypeVariants[int64](m, "int8") // Integer types that do not have a direct match to a PostgreSQL type - registerDefaultPgTypeVariants("int8", "_int8", int8(0)) - registerDefaultPgTypeVariants("int8", "_int8", int(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint8(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) - registerDefaultPgTypeVariants("int8", "_int8", uint(0)) + registerDefaultPgTypeVariants[int8](m, "int8") + registerDefaultPgTypeVariants[int](m, "int8") + registerDefaultPgTypeVariants[uint8](m, "int8") + registerDefaultPgTypeVariants[uint16](m, "int8") + registerDefaultPgTypeVariants[uint32](m, "int8") + registerDefaultPgTypeVariants[uint64](m, "int8") + registerDefaultPgTypeVariants[uint](m, "int8") - registerDefaultPgTypeVariants("float4", "_float4", float32(0)) - registerDefaultPgTypeVariants("float8", "_float8", float64(0)) + registerDefaultPgTypeVariants[float32](m, "float4") + registerDefaultPgTypeVariants[float64](m, "float8") - registerDefaultPgTypeVariants("bool", "_bool", false) - registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) - registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0)) - registerDefaultPgTypeVariants("text", "_text", "") - registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) + registerDefaultPgTypeVariants[bool](m, "bool") + registerDefaultPgTypeVariants[time.Time](m, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](m, "interval") + registerDefaultPgTypeVariants[string](m, "text") + registerDefaultPgTypeVariants[[]byte](m, "bytea") - registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) - registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) + registerDefaultPgTypeVariants[net.IP](m, "inet") + registerDefaultPgTypeVariants[net.IPNet](m, "cidr") // pgtype provided structs - registerDefaultPgTypeVariants("varbit", "_varbit", Bits{}) - registerDefaultPgTypeVariants("bool", "_bool", Bool{}) - registerDefaultPgTypeVariants("box", "_box", Box{}) - registerDefaultPgTypeVariants("circle", "_circle", Circle{}) - registerDefaultPgTypeVariants("date", "_date", Date{}) - registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{}) - registerDefaultPgTypeVariants("float4", "_float4", Float4{}) - registerDefaultPgTypeVariants("float8", "_float8", Float8{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange. - registerDefaultPgTypeVariants("inet", "_inet", Inet{}) - registerDefaultPgTypeVariants("int2", "_int2", Int2{}) - registerDefaultPgTypeVariants("int4", "_int4", Int4{}) - registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{}) - registerDefaultPgTypeVariants("int8", "_int8", Int8{}) - registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{}) - registerDefaultPgTypeVariants("interval", "_interval", Interval{}) - registerDefaultPgTypeVariants("line", "_line", Line{}) - registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{}) - registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{}) - registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{}) - registerDefaultPgTypeVariants("path", "_path", Path{}) - registerDefaultPgTypeVariants("point", "_point", Point{}) - registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{}) - registerDefaultPgTypeVariants("tid", "_tid", TID{}) - registerDefaultPgTypeVariants("text", "_text", Text{}) - registerDefaultPgTypeVariants("time", "_time", Time{}) - registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{}) - registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{}) - registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{}) - registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{}) - registerDefaultPgTypeVariants("uuid", "_uuid", UUID{}) + registerDefaultPgTypeVariants[Bits](m, "varbit") + registerDefaultPgTypeVariants[Bool](m, "bool") + registerDefaultPgTypeVariants[Box](m, "box") + registerDefaultPgTypeVariants[Circle](m, "circle") + registerDefaultPgTypeVariants[Date](m, "date") + registerDefaultPgTypeVariants[Range[Date]](m, "daterange") + registerDefaultPgTypeVariants[Float4](m, "float4") + registerDefaultPgTypeVariants[Float8](m, "float8") + registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Inet](m, "inet") + registerDefaultPgTypeVariants[Int2](m, "int2") + registerDefaultPgTypeVariants[Int4](m, "int4") + registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") + registerDefaultPgTypeVariants[Int8](m, "int8") + registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") + registerDefaultPgTypeVariants[Interval](m, "interval") + registerDefaultPgTypeVariants[Line](m, "line") + registerDefaultPgTypeVariants[Lseg](m, "lseg") + registerDefaultPgTypeVariants[Numeric](m, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") + registerDefaultPgTypeVariants[Path](m, "path") + registerDefaultPgTypeVariants[Point](m, "point") + registerDefaultPgTypeVariants[Polygon](m, "polygon") + registerDefaultPgTypeVariants[TID](m, "tid") + registerDefaultPgTypeVariants[Text](m, "text") + registerDefaultPgTypeVariants[Time](m, "time") + registerDefaultPgTypeVariants[Timestamp](m, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") + registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") + registerDefaultPgTypeVariants[UUID](m, "uuid") return m } +func registerDefaultPgTypeVariants[T any](m *Map, name string) { + arrayName := "_" + name + + var value T + m.RegisterDefaultPgType(value, name) // T + m.RegisterDefaultPgType(&value, name) // *T + + var sliceT []T + m.RegisterDefaultPgType(sliceT, arrayName) // []T + m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T + + var slicePtrT []*T + m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T + m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T + + var arrayOfT Array[T] + m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] + m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] + + var arrayOfPtrT Array[*T] + m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] + m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] + + var flatArrayOfT FlatArray[T] + m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] + m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] + + var flatArrayOfPtrT FlatArray[*T] + m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] + m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] +} + func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 75f0caf4..30cea7d6 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -374,7 +374,7 @@ func TestConnSimpleSlicePassThrough(t *testing.T) { }) } -func TestConnQueryScanArray(t *testing.T) { +func TestConnQueryScanGoArray(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { m := pgtype.NewMap() @@ -385,6 +385,17 @@ func TestConnQueryScanArray(t *testing.T) { }) } +func TestConnQueryScanArray(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + m := pgtype.NewMap() + + var a pgtype.Array[int64] + err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) + require.NoError(t, err) + assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a) + }) +} + func TestConnQueryScanRange(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { m := pgtype.NewMap()