Automatically register Array and FlatArray

non-blocking
Jack Christensen 2022-04-16 14:04:25 -05:00
parent fccaebc93d
commit a01a9ee6df
2 changed files with 94 additions and 74 deletions

View File

@ -316,91 +316,100 @@ func NewMap() *Map {
m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}})
m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}})
registerDefaultPgTypeVariants := func(name, arrayName string, value any) {
// T
m.RegisterDefaultPgType(value, name)
// *T
valueType := reflect.TypeOf(value)
m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
// []T
sliceType := reflect.SliceOf(valueType)
m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
// *[]T
m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
// []*T
sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface()))
m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName)
// *[]*T
m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName)
}
// Integer types that directly map to a PostgreSQL type
registerDefaultPgTypeVariants("int2", "_int2", int16(0))
registerDefaultPgTypeVariants("int4", "_int4", int32(0))
registerDefaultPgTypeVariants("int8", "_int8", int64(0))
registerDefaultPgTypeVariants[int16](m, "int2")
registerDefaultPgTypeVariants[int32](m, "int4")
registerDefaultPgTypeVariants[int64](m, "int8")
// Integer types that do not have a direct match to a PostgreSQL type
registerDefaultPgTypeVariants("int8", "_int8", int8(0))
registerDefaultPgTypeVariants("int8", "_int8", int(0))
registerDefaultPgTypeVariants("int8", "_int8", uint8(0))
registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
registerDefaultPgTypeVariants("int8", "_int8", uint(0))
registerDefaultPgTypeVariants[int8](m, "int8")
registerDefaultPgTypeVariants[int](m, "int8")
registerDefaultPgTypeVariants[uint8](m, "int8")
registerDefaultPgTypeVariants[uint16](m, "int8")
registerDefaultPgTypeVariants[uint32](m, "int8")
registerDefaultPgTypeVariants[uint64](m, "int8")
registerDefaultPgTypeVariants[uint](m, "int8")
registerDefaultPgTypeVariants("float4", "_float4", float32(0))
registerDefaultPgTypeVariants("float8", "_float8", float64(0))
registerDefaultPgTypeVariants[float32](m, "float4")
registerDefaultPgTypeVariants[float64](m, "float8")
registerDefaultPgTypeVariants("bool", "_bool", false)
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0))
registerDefaultPgTypeVariants("text", "_text", "")
registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
registerDefaultPgTypeVariants[bool](m, "bool")
registerDefaultPgTypeVariants[time.Time](m, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](m, "interval")
registerDefaultPgTypeVariants[string](m, "text")
registerDefaultPgTypeVariants[[]byte](m, "bytea")
registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{})
registerDefaultPgTypeVariants[net.IP](m, "inet")
registerDefaultPgTypeVariants[net.IPNet](m, "cidr")
// pgtype provided structs
registerDefaultPgTypeVariants("varbit", "_varbit", Bits{})
registerDefaultPgTypeVariants("bool", "_bool", Bool{})
registerDefaultPgTypeVariants("box", "_box", Box{})
registerDefaultPgTypeVariants("circle", "_circle", Circle{})
registerDefaultPgTypeVariants("date", "_date", Date{})
registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{})
registerDefaultPgTypeVariants("float4", "_float4", Float4{})
registerDefaultPgTypeVariants("float8", "_float8", Float8{})
registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange.
registerDefaultPgTypeVariants("inet", "_inet", Inet{})
registerDefaultPgTypeVariants("int2", "_int2", Int2{})
registerDefaultPgTypeVariants("int4", "_int4", Int4{})
registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{})
registerDefaultPgTypeVariants("int8", "_int8", Int8{})
registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{})
registerDefaultPgTypeVariants("interval", "_interval", Interval{})
registerDefaultPgTypeVariants("line", "_line", Line{})
registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{})
registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{})
registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{})
registerDefaultPgTypeVariants("path", "_path", Path{})
registerDefaultPgTypeVariants("point", "_point", Point{})
registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{})
registerDefaultPgTypeVariants("tid", "_tid", TID{})
registerDefaultPgTypeVariants("text", "_text", Text{})
registerDefaultPgTypeVariants("time", "_time", Time{})
registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{})
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{})
registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{})
registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{})
registerDefaultPgTypeVariants("uuid", "_uuid", UUID{})
registerDefaultPgTypeVariants[Bits](m, "varbit")
registerDefaultPgTypeVariants[Bool](m, "bool")
registerDefaultPgTypeVariants[Box](m, "box")
registerDefaultPgTypeVariants[Circle](m, "circle")
registerDefaultPgTypeVariants[Date](m, "date")
registerDefaultPgTypeVariants[Range[Date]](m, "daterange")
registerDefaultPgTypeVariants[Float4](m, "float4")
registerDefaultPgTypeVariants[Float8](m, "float8")
registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange.
registerDefaultPgTypeVariants[Inet](m, "inet")
registerDefaultPgTypeVariants[Int2](m, "int2")
registerDefaultPgTypeVariants[Int4](m, "int4")
registerDefaultPgTypeVariants[Range[Int4]](m, "int4range")
registerDefaultPgTypeVariants[Int8](m, "int8")
registerDefaultPgTypeVariants[Range[Int8]](m, "int8range")
registerDefaultPgTypeVariants[Interval](m, "interval")
registerDefaultPgTypeVariants[Line](m, "line")
registerDefaultPgTypeVariants[Lseg](m, "lseg")
registerDefaultPgTypeVariants[Numeric](m, "numeric")
registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange")
registerDefaultPgTypeVariants[Path](m, "path")
registerDefaultPgTypeVariants[Point](m, "point")
registerDefaultPgTypeVariants[Polygon](m, "polygon")
registerDefaultPgTypeVariants[TID](m, "tid")
registerDefaultPgTypeVariants[Text](m, "text")
registerDefaultPgTypeVariants[Time](m, "time")
registerDefaultPgTypeVariants[Timestamp](m, "timestamp")
registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz")
registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange")
registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange")
registerDefaultPgTypeVariants[UUID](m, "uuid")
return m
}
func registerDefaultPgTypeVariants[T any](m *Map, name string) {
arrayName := "_" + name
var value T
m.RegisterDefaultPgType(value, name) // T
m.RegisterDefaultPgType(&value, name) // *T
var sliceT []T
m.RegisterDefaultPgType(sliceT, arrayName) // []T
m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T
var slicePtrT []*T
m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T
m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T
var arrayOfT Array[T]
m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T]
m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T]
var arrayOfPtrT Array[*T]
m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T]
m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T]
var flatArrayOfT FlatArray[T]
m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T]
m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T]
var flatArrayOfPtrT FlatArray[*T]
m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T]
m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T]
}
func (m *Map) RegisterType(t *Type) {
m.oidToType[t.OID] = t
m.nameToType[t.Name] = t

View File

@ -374,7 +374,7 @@ func TestConnSimpleSlicePassThrough(t *testing.T) {
})
}
func TestConnQueryScanArray(t *testing.T) {
func TestConnQueryScanGoArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
@ -385,6 +385,17 @@ func TestConnQueryScanArray(t *testing.T) {
})
}
func TestConnQueryScanArray(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()
var a pgtype.Array[int64]
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
require.NoError(t, err)
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
})
}
func TestConnQueryScanRange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
m := pgtype.NewMap()