From bd85fe870d0ee82e9ee6c57c0e01f1319a397053 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 16:45:52 -0500 Subject: [PATCH] Hard code standard PostgreSQL types Instead of needing to instrospect the database on connection preload the standard OID / type map. Types from extensions (like hstore) and custom types can be registered by the application developer. Otherwise, they will be treated as strings. --- decimal.go | 31 --------------- hstore_array_test.go | 14 +++++++ pgtype.go | 95 +++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 103 insertions(+), 37 deletions(-) delete mode 100644 decimal.go diff --git a/decimal.go b/decimal.go deleted file mode 100644 index 79653cf3..00000000 --- a/decimal.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type Decimal Numeric - -func (dst *Decimal) Set(src interface{}) error { - return (*Numeric)(dst).Set(src) -} - -func (dst *Decimal) Get() interface{} { - return (*Numeric)(dst).Get() -} - -func (src *Decimal) AssignTo(dst interface{}) error { - return (*Numeric)(src).AssignTo(dst) -} - -func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeText(ci, src) -} - -func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeBinary(ci, src) -} - -func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeText(ci, buf) -} - -func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeBinary(ci, buf) -} diff --git a/hstore_array_test.go b/hstore_array_test.go index c8104d28..03dc2ff1 100644 --- a/hstore_array_test.go +++ b/hstore_array_test.go @@ -14,6 +14,20 @@ func TestHstoreArrayTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) + var hstoreOID pgtype.OID + err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) + if err != nil { + t.Fatalf("did not find hstore OID, %v", err) + } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) + + var hstoreArrayOID pgtype.OID + err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) + if err != nil { + t.Fatalf("did not find _hstore OID, %v", err) + } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) + text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} } diff --git a/pgtype.go b/pgtype.go index 8f41d068..4faf23e1 100644 --- a/pgtype.go +++ b/pgtype.go @@ -11,7 +11,7 @@ import ( const ( BoolOID = 16 ByteaOID = 17 - CharOID = 18 + QCharOID = 18 NameOID = 19 Int8OID = 20 Int2OID = 21 @@ -22,11 +22,19 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 CIDROID = 650 CIDRArrayOID = 651 Float4OID = 700 Float8OID = 701 + CircleOID = 718 UnknownOID = 705 + MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 Int2ArrayOID = 1005 @@ -49,11 +57,21 @@ const ( DateArrayOID = 1182 TimestamptzOID = 1184 TimestamptzArrayOID = 1185 + IntervalOID = 1186 + NumericArrayOID = 1231 + BitOID = 1560 + VarbitOID = 1562 NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 + DaterangeOID = 3812 + Int4rangeOID = 3904 + NumrangeOID = 3906 + TsrangeOID = 3908 + TstzrangeOID = 3910 + Int8rangeOID = 3926 ) type Status byte @@ -155,11 +173,77 @@ type ConnInfo struct { } func NewConnInfo() *ConnInfo { - return &ConnInfo{ - oidToDataType: make(map[OID]*DataType, 256), - nameToDataType: make(map[string]*DataType, 256), - reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + ci := &ConnInfo{ + oidToDataType: make(map[OID]*DataType, 128), + nameToDataType: make(map[string]*DataType, 128), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), } + + ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) + ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) + ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) + ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) + ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) + ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) + ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) + ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) + ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) + ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) + ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) + ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) + ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) + ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) + ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) + ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) + ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) + ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) + ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) + ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) + ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) + ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) + ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) + ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) + ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) + ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) + ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) + ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) + ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) + ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) + ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) + ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + + return ci } func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { @@ -295,7 +379,6 @@ func init() { "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, - "decimal": &Decimal{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{},