mirror of https://github.com/jackc/pgx.git
Automatically register Array and FlatArray
parent
fccaebc93d
commit
a01a9ee6df
155
pgtype/pgtype.go
155
pgtype/pgtype.go
|
@ -316,91 +316,100 @@ func NewMap() *Map {
|
|||
m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}})
|
||||
m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}})
|
||||
|
||||
registerDefaultPgTypeVariants := func(name, arrayName string, value any) {
|
||||
// T
|
||||
m.RegisterDefaultPgType(value, name)
|
||||
|
||||
// *T
|
||||
valueType := reflect.TypeOf(value)
|
||||
m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
|
||||
|
||||
// []T
|
||||
sliceType := reflect.SliceOf(valueType)
|
||||
m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
|
||||
|
||||
// *[]T
|
||||
m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
|
||||
|
||||
// []*T
|
||||
sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface()))
|
||||
m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName)
|
||||
|
||||
// *[]*T
|
||||
m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName)
|
||||
}
|
||||
|
||||
// Integer types that directly map to a PostgreSQL type
|
||||
registerDefaultPgTypeVariants("int2", "_int2", int16(0))
|
||||
registerDefaultPgTypeVariants("int4", "_int4", int32(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", int64(0))
|
||||
registerDefaultPgTypeVariants[int16](m, "int2")
|
||||
registerDefaultPgTypeVariants[int32](m, "int4")
|
||||
registerDefaultPgTypeVariants[int64](m, "int8")
|
||||
|
||||
// Integer types that do not have a direct match to a PostgreSQL type
|
||||
registerDefaultPgTypeVariants("int8", "_int8", int8(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", int(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint8(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
|
||||
registerDefaultPgTypeVariants("int8", "_int8", uint(0))
|
||||
registerDefaultPgTypeVariants[int8](m, "int8")
|
||||
registerDefaultPgTypeVariants[int](m, "int8")
|
||||
registerDefaultPgTypeVariants[uint8](m, "int8")
|
||||
registerDefaultPgTypeVariants[uint16](m, "int8")
|
||||
registerDefaultPgTypeVariants[uint32](m, "int8")
|
||||
registerDefaultPgTypeVariants[uint64](m, "int8")
|
||||
registerDefaultPgTypeVariants[uint](m, "int8")
|
||||
|
||||
registerDefaultPgTypeVariants("float4", "_float4", float32(0))
|
||||
registerDefaultPgTypeVariants("float8", "_float8", float64(0))
|
||||
registerDefaultPgTypeVariants[float32](m, "float4")
|
||||
registerDefaultPgTypeVariants[float64](m, "float8")
|
||||
|
||||
registerDefaultPgTypeVariants("bool", "_bool", false)
|
||||
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
|
||||
registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0))
|
||||
registerDefaultPgTypeVariants("text", "_text", "")
|
||||
registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
|
||||
registerDefaultPgTypeVariants[bool](m, "bool")
|
||||
registerDefaultPgTypeVariants[time.Time](m, "timestamptz")
|
||||
registerDefaultPgTypeVariants[time.Duration](m, "interval")
|
||||
registerDefaultPgTypeVariants[string](m, "text")
|
||||
registerDefaultPgTypeVariants[[]byte](m, "bytea")
|
||||
|
||||
registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
|
||||
registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{})
|
||||
registerDefaultPgTypeVariants[net.IP](m, "inet")
|
||||
registerDefaultPgTypeVariants[net.IPNet](m, "cidr")
|
||||
|
||||
// pgtype provided structs
|
||||
registerDefaultPgTypeVariants("varbit", "_varbit", Bits{})
|
||||
registerDefaultPgTypeVariants("bool", "_bool", Bool{})
|
||||
registerDefaultPgTypeVariants("box", "_box", Box{})
|
||||
registerDefaultPgTypeVariants("circle", "_circle", Circle{})
|
||||
registerDefaultPgTypeVariants("date", "_date", Date{})
|
||||
registerDefaultPgTypeVariants("daterange", "_daterange", Range[Date]{})
|
||||
registerDefaultPgTypeVariants("float4", "_float4", Float4{})
|
||||
registerDefaultPgTypeVariants("float8", "_float8", Float8{})
|
||||
registerDefaultPgTypeVariants("numrange", "_numrange", Range[Float8]{}) // There is no PostgreSQL builtin float8range so map it to numrange.
|
||||
registerDefaultPgTypeVariants("inet", "_inet", Inet{})
|
||||
registerDefaultPgTypeVariants("int2", "_int2", Int2{})
|
||||
registerDefaultPgTypeVariants("int4", "_int4", Int4{})
|
||||
registerDefaultPgTypeVariants("int4range", "_int4range", Range[Int4]{})
|
||||
registerDefaultPgTypeVariants("int8", "_int8", Int8{})
|
||||
registerDefaultPgTypeVariants("int8range", "_int8range", Range[Int8]{})
|
||||
registerDefaultPgTypeVariants("interval", "_interval", Interval{})
|
||||
registerDefaultPgTypeVariants("line", "_line", Line{})
|
||||
registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{})
|
||||
registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{})
|
||||
registerDefaultPgTypeVariants("numrange", "_numrange", Range[Numeric]{})
|
||||
registerDefaultPgTypeVariants("path", "_path", Path{})
|
||||
registerDefaultPgTypeVariants("point", "_point", Point{})
|
||||
registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{})
|
||||
registerDefaultPgTypeVariants("tid", "_tid", TID{})
|
||||
registerDefaultPgTypeVariants("text", "_text", Text{})
|
||||
registerDefaultPgTypeVariants("time", "_time", Time{})
|
||||
registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{})
|
||||
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{})
|
||||
registerDefaultPgTypeVariants("tsrange", "_tsrange", Range[Timestamp]{})
|
||||
registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Range[Timestamptz]{})
|
||||
registerDefaultPgTypeVariants("uuid", "_uuid", UUID{})
|
||||
registerDefaultPgTypeVariants[Bits](m, "varbit")
|
||||
registerDefaultPgTypeVariants[Bool](m, "bool")
|
||||
registerDefaultPgTypeVariants[Box](m, "box")
|
||||
registerDefaultPgTypeVariants[Circle](m, "circle")
|
||||
registerDefaultPgTypeVariants[Date](m, "date")
|
||||
registerDefaultPgTypeVariants[Range[Date]](m, "daterange")
|
||||
registerDefaultPgTypeVariants[Float4](m, "float4")
|
||||
registerDefaultPgTypeVariants[Float8](m, "float8")
|
||||
registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange.
|
||||
registerDefaultPgTypeVariants[Inet](m, "inet")
|
||||
registerDefaultPgTypeVariants[Int2](m, "int2")
|
||||
registerDefaultPgTypeVariants[Int4](m, "int4")
|
||||
registerDefaultPgTypeVariants[Range[Int4]](m, "int4range")
|
||||
registerDefaultPgTypeVariants[Int8](m, "int8")
|
||||
registerDefaultPgTypeVariants[Range[Int8]](m, "int8range")
|
||||
registerDefaultPgTypeVariants[Interval](m, "interval")
|
||||
registerDefaultPgTypeVariants[Line](m, "line")
|
||||
registerDefaultPgTypeVariants[Lseg](m, "lseg")
|
||||
registerDefaultPgTypeVariants[Numeric](m, "numeric")
|
||||
registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange")
|
||||
registerDefaultPgTypeVariants[Path](m, "path")
|
||||
registerDefaultPgTypeVariants[Point](m, "point")
|
||||
registerDefaultPgTypeVariants[Polygon](m, "polygon")
|
||||
registerDefaultPgTypeVariants[TID](m, "tid")
|
||||
registerDefaultPgTypeVariants[Text](m, "text")
|
||||
registerDefaultPgTypeVariants[Time](m, "time")
|
||||
registerDefaultPgTypeVariants[Timestamp](m, "timestamp")
|
||||
registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz")
|
||||
registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange")
|
||||
registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange")
|
||||
registerDefaultPgTypeVariants[UUID](m, "uuid")
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func registerDefaultPgTypeVariants[T any](m *Map, name string) {
|
||||
arrayName := "_" + name
|
||||
|
||||
var value T
|
||||
m.RegisterDefaultPgType(value, name) // T
|
||||
m.RegisterDefaultPgType(&value, name) // *T
|
||||
|
||||
var sliceT []T
|
||||
m.RegisterDefaultPgType(sliceT, arrayName) // []T
|
||||
m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T
|
||||
|
||||
var slicePtrT []*T
|
||||
m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T
|
||||
m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T
|
||||
|
||||
var arrayOfT Array[T]
|
||||
m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T]
|
||||
m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T]
|
||||
|
||||
var arrayOfPtrT Array[*T]
|
||||
m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T]
|
||||
m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T]
|
||||
|
||||
var flatArrayOfT FlatArray[T]
|
||||
m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T]
|
||||
m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T]
|
||||
|
||||
var flatArrayOfPtrT FlatArray[*T]
|
||||
m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T]
|
||||
m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T]
|
||||
}
|
||||
|
||||
func (m *Map) RegisterType(t *Type) {
|
||||
m.oidToType[t.OID] = t
|
||||
m.nameToType[t.Name] = t
|
||||
|
|
|
@ -374,7 +374,7 @@ func TestConnSimpleSlicePassThrough(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestConnQueryScanArray(t *testing.T) {
|
||||
func TestConnQueryScanGoArray(t *testing.T) {
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
||||
m := pgtype.NewMap()
|
||||
|
||||
|
@ -385,6 +385,17 @@ func TestConnQueryScanArray(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestConnQueryScanArray(t *testing.T) {
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
||||
m := pgtype.NewMap()
|
||||
|
||||
var a pgtype.Array[int64]
|
||||
err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnQueryScanRange(t *testing.T) {
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
|
||||
m := pgtype.NewMap()
|
||||
|
|
Loading…
Reference in New Issue