Hard code standard PostgreSQL types

Instead of needing to instrospect the database on connection preload the
standard OID / type map. Types from extensions (like hstore) and custom
types can be registered by the application developer. Otherwise, they
will be treated as strings.
non-blocking
Jack Christensen 2019-04-13 16:45:52 -05:00
parent a0f487bc09
commit bd85fe870d
3 changed files with 103 additions and 37 deletions

View File

@ -1,31 +0,0 @@
package pgtype
type Decimal Numeric
func (dst *Decimal) Set(src interface{}) error {
return (*Numeric)(dst).Set(src)
}
func (dst *Decimal) Get() interface{} {
return (*Numeric)(dst).Get()
}
func (src *Decimal) AssignTo(dst interface{}) error {
return (*Numeric)(src).AssignTo(dst)
}
func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error {
return (*Numeric)(dst).DecodeText(ci, src)
}
func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*Numeric)(dst).DecodeBinary(ci, src)
}
func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*Numeric)(src).EncodeText(ci, buf)
}
func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*Numeric)(src).EncodeBinary(ci, buf)
}

View File

@ -14,6 +14,20 @@ func TestHstoreArrayTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t) conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn) defer testutil.MustCloseContext(t, conn)
var hstoreOID pgtype.OID
err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID)
if err != nil {
t.Fatalf("did not find hstore OID, %v", err)
}
conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID})
var hstoreArrayOID pgtype.OID
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID)
if err != nil {
t.Fatalf("did not find _hstore OID, %v", err)
}
conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID})
text := func(s string) pgtype.Text { text := func(s string) pgtype.Text {
return pgtype.Text{String: s, Status: pgtype.Present} return pgtype.Text{String: s, Status: pgtype.Present}
} }

View File

@ -11,7 +11,7 @@ import (
const ( const (
BoolOID = 16 BoolOID = 16
ByteaOID = 17 ByteaOID = 17
CharOID = 18 QCharOID = 18
NameOID = 19 NameOID = 19
Int8OID = 20 Int8OID = 20
Int2OID = 21 Int2OID = 21
@ -22,11 +22,19 @@ const (
XIDOID = 28 XIDOID = 28
CIDOID = 29 CIDOID = 29
JSONOID = 114 JSONOID = 114
PointOID = 600
LsegOID = 601
PathOID = 602
BoxOID = 603
PolygonOID = 604
LineOID = 628
CIDROID = 650 CIDROID = 650
CIDRArrayOID = 651 CIDRArrayOID = 651
Float4OID = 700 Float4OID = 700
Float8OID = 701 Float8OID = 701
CircleOID = 718
UnknownOID = 705 UnknownOID = 705
MacaddrOID = 829
InetOID = 869 InetOID = 869
BoolArrayOID = 1000 BoolArrayOID = 1000
Int2ArrayOID = 1005 Int2ArrayOID = 1005
@ -49,11 +57,21 @@ const (
DateArrayOID = 1182 DateArrayOID = 1182
TimestamptzOID = 1184 TimestamptzOID = 1184
TimestamptzArrayOID = 1185 TimestamptzArrayOID = 1185
IntervalOID = 1186
NumericArrayOID = 1231
BitOID = 1560
VarbitOID = 1562
NumericOID = 1700 NumericOID = 1700
RecordOID = 2249 RecordOID = 2249
UUIDOID = 2950 UUIDOID = 2950
UUIDArrayOID = 2951 UUIDArrayOID = 2951
JSONBOID = 3802 JSONBOID = 3802
DaterangeOID = 3812
Int4rangeOID = 3904
NumrangeOID = 3906
TsrangeOID = 3908
TstzrangeOID = 3910
Int8rangeOID = 3926
) )
type Status byte type Status byte
@ -155,11 +173,77 @@ type ConnInfo struct {
} }
func NewConnInfo() *ConnInfo { func NewConnInfo() *ConnInfo {
return &ConnInfo{ ci := &ConnInfo{
oidToDataType: make(map[OID]*DataType, 256), oidToDataType: make(map[OID]*DataType, 128),
nameToDataType: make(map[string]*DataType, 256), nameToDataType: make(map[string]*DataType, 128),
reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), reflectTypeToDataType: make(map[reflect.Type]*DataType, 128),
} }
ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID})
ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID})
ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID})
ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID})
ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID})
ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID})
ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID})
ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID})
ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID})
ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID})
ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID})
ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID})
ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID})
ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID})
ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID})
ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID})
ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID})
ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID})
ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID})
ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID})
ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID})
ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID})
ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID})
ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID})
ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID})
ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID})
ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID})
ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID})
ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID})
ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID})
ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID})
ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID})
ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID})
ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID})
ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID})
ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID})
ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID})
ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID})
ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID})
ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID})
ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID})
ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID})
ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID})
ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID})
ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID})
ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID})
ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID})
ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID})
ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID})
ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID})
ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID})
ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID})
ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID})
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
return ci
} }
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
@ -295,7 +379,6 @@ func init() {
"circle": &Circle{}, "circle": &Circle{},
"date": &Date{}, "date": &Date{},
"daterange": &Daterange{}, "daterange": &Daterange{},
"decimal": &Decimal{},
"float4": &Float4{}, "float4": &Float4{},
"float8": &Float8{}, "float8": &Float8{},
"hstore": &Hstore{}, "hstore": &Hstore{},