Hard code standard PostgreSQL types

Instead of needing to instrospect the database on connection preload the
standard OID / type map. Types from extensions (like hstore) and custom
types can be registered by the application developer. Otherwise, they
will be treated as strings.
pull/483/head
Jack Christensen 2019-04-13 16:45:52 -05:00
parent 95058dc476
commit a6bdd8fd49
6 changed files with 118 additions and 287 deletions

244
conn.go
View File

@ -24,21 +24,6 @@ const (
connStatusBusy
)
// minimalConnInfo has just enough static type information to establish the
// connection and retrieve the type data.
var minimalConnInfo *pgtype.ConnInfo
func init() {
minimalConnInfo = pgtype.NewConnInfo()
minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{
"int4": pgtype.Int4OID,
"name": pgtype.NameOID,
"oid": pgtype.OIDOID,
"text": pgtype.TextOID,
"varchar": pgtype.VarcharOID,
})
}
// ConnConfig contains all the options used to establish a connection.
type ConnConfig struct {
pgconn.Config
@ -132,12 +117,12 @@ func Connect(ctx context.Context, connString string) (*Conn, error) {
if err != nil {
return nil, err
}
return connect(ctx, connConfig, minimalConnInfo)
return connect(ctx, connConfig)
}
// Connect establishes a connection with a PostgreSQL server with a configuration struct.
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig, minimalConnInfo)
return connect(ctx, connConfig)
}
func ParseConfig(connString string) (*ConnConfig, error) {
@ -152,11 +137,11 @@ func ParseConfig(connString string) (*ConnConfig, error) {
return connConfig, nil
}
func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c = new(Conn)
c.config = config
c.ConnInfo = connInfo
c.ConnInfo = pgtype.NewConnInfo()
if c.config.LogLevel != 0 {
c.logLevel = c.config.LogLevel
@ -193,230 +178,9 @@ func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo)
return c, nil
}
if c.ConnInfo == minimalConnInfo {
err = c.initConnInfo()
if err != nil {
c.Close(ctx)
return nil, err
}
}
return c, nil
}
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
const (
namedOIDQuery = `select t.oid,
case when nsp.nspname in ('pg_catalog', 'public') then t.typname
else nsp.nspname||'.'||t.typname
end
from pg_type t
left join pg_type base_type on t.typelem=base_type.oid
left join pg_namespace nsp on t.typnamespace=nsp.oid
where (
t.typtype in('b', 'p', 'r', 'e')
and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
)`
)
nameOIDs, err := connInfoFromRows(c.Query(context.TODO(), namedOIDQuery))
if err != nil {
return nil, err
}
cinfo := pgtype.NewConnInfo()
cinfo.InitializeDataTypes(nameOIDs)
if err = c.initConnInfoEnumArray(cinfo); err != nil {
return nil, err
}
if err = c.initConnInfoDomains(cinfo); err != nil {
return nil, err
}
return cinfo, nil
}
func (c *Conn) initConnInfo() (err error) {
var (
connInfo *pgtype.ConnInfo
)
if c.config.CustomConnInfo != nil {
if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil {
return err
}
return nil
}
if connInfo, err = initPostgresql(c); err == nil {
c.ConnInfo = connInfo
return err
}
// Check if CrateDB specific approach might still allow us to connect.
if connInfo, err = c.crateDBTypesQuery(err); err == nil {
c.ConnInfo = connInfo
}
return err
}
// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them.
func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error {
nameOIDs := make(map[string]pgtype.OID, 16)
rows, err := c.Query(context.TODO(), `select t.oid, t.typname
from pg_type t
join pg_type base_type on t.typelem=base_type.oid
where t.typtype = 'b'
and base_type.typtype = 'e'`)
if err != nil {
return err
}
for rows.Next() {
var oid pgtype.OID
var name pgtype.Text
if err := rows.Scan(&oid, &name); err != nil {
return err
}
nameOIDs[name.String] = oid
}
if rows.Err() != nil {
return rows.Err()
}
for name, oid := range nameOIDs {
cinfo.RegisterDataType(pgtype.DataType{
Value: &pgtype.EnumArray{},
Name: name,
OID: oid,
})
}
return nil
}
// initConnInfoDomains introspects for domains and registers a data type for them.
func (c *Conn) initConnInfoDomains(cinfo *pgtype.ConnInfo) error {
type domain struct {
oid pgtype.OID
name pgtype.Text
baseOID pgtype.OID
}
domains := make([]*domain, 0, 16)
rows, err := c.Query(context.TODO(), `select t.oid, t.typname, t.typbasetype
from pg_type t
join pg_type base_type on t.typbasetype=base_type.oid
where t.typtype = 'd'
and base_type.typtype = 'b'`)
if err != nil {
return err
}
for rows.Next() {
var d domain
if err := rows.Scan(&d.oid, &d.name, &d.baseOID); err != nil {
return err
}
domains = append(domains, &d)
}
if rows.Err() != nil {
return rows.Err()
}
for _, d := range domains {
baseDataType, ok := cinfo.DataTypeForOID(d.baseOID)
if ok {
cinfo.RegisterDataType(pgtype.DataType{
Value: reflect.New(reflect.ValueOf(baseDataType.Value).Elem().Type()).Interface().(pgtype.Value),
Name: d.name.String,
OID: d.oid,
})
}
}
return nil
}
// crateDBTypesQuery checks if the given err is likely to be the result of
// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB
// specific query against pg_types is executed and its results are returned. If
// not, the original error is returned.
func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) {
// CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol,
// but not perfectly. In particular, the pg_catalog schema containing the
// pg_type table is not visible by default and the pg_type.typtype column is
// not implemented. Therefor the query above currently returns the following
// error:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"TableUnknownException: Table 'test.pg_type' unknown",
// Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"",
// Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"}
//
// If CrateDB was to fix the pg_type table visbility in the future, we'd
// still get this error until typtype column is implemented:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"ColumnUnknownException: Column typtype unknown", Detail:"",
// Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"",
// SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132,
//
// Additionally CrateDB doesn't implement Postgres error codes [2], and
// instead always returns "XX000" (internal_error). The code below uses all
// of this knowledge as a heuristic to detect CrateDB. If CrateDB is
// detected, a CrateDB specific pg_type query is executed instead.
//
// The heuristic is designed to still work even if CrateDB fixes [2] or
// renames its internal exception names. If both are changed but pg_types
// isn't fixed, this code will need to be changed.
//
// There is also a small chance the heuristic will yield a false positive for
// non-CrateDB databases (e.g. if a real Postgres instance returns a XX000
// error), but hopefully there will be no harm in attempting the alternative
// query in this case.
//
// CrateDB also uses the type varchar for the typname column which required
// adding varchar to the minimalConnInfo init code.
//
// Also see the discussion here [3].
//
// [1] https://crate.io/
// [2] https://github.com/crate/crate/issues/5027
// [3] https://github.com/jackc/pgx/issues/320
if pgErr, ok := err.(*pgconn.PgError); ok &&
(pgErr.Code == "XX000" ||
strings.Contains(pgErr.Message, "TableUnknownException") ||
strings.Contains(pgErr.Message, "ColumnUnknownException")) {
var (
nameOIDs map[string]pgtype.OID
)
if nameOIDs, err = connInfoFromRows(c.Query(context.TODO(), `select oid, typname from pg_catalog.pg_type`)); err != nil {
return nil, err
}
cinfo := pgtype.NewConnInfo()
cinfo.InitializeDataTypes(nameOIDs)
return cinfo, err
}
return nil, err
}
// PID returns the backend PID for this connection.
func (c *Conn) PID() uint32 {
return c.pgConn.PID()

View File

@ -680,26 +680,25 @@ func TestConnInitConnInfo(t *testing.T) {
ensureConnValid(t, conn)
}
func TestDomainType(t *testing.T) {
func TestRegisteredDomainType(t *testing.T) {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
dt, ok := conn.ConnInfo.DataTypeForName("uint64")
if !ok {
t.Fatal("Expected data type for domain uint64 to be present")
}
if dt, ok := dt.Value.(*pgtype.Numeric); !ok {
t.Fatalf("Expected data type value for domain uint64 to be *pgtype.Numeric, but it was %T", dt)
var uint64OID pgtype.OID
err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
if err != nil {
t.Fatalf("did not find uint64 OID, %v", err)
}
conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID})
var n uint64
err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(42)).Scan(&n)
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
if err != nil {
t.Fatal(err)
}
if n != 42 {
t.Fatalf("Expected n to be 42, but was %v", n)
if n != 24 {
t.Fatalf("Expected n to be 24, but was %v", n)
}
ensureConnValid(t, conn)

View File

@ -1,31 +0,0 @@
package pgtype
type Decimal Numeric
func (dst *Decimal) Set(src interface{}) error {
return (*Numeric)(dst).Set(src)
}
func (dst *Decimal) Get() interface{} {
return (*Numeric)(dst).Get()
}
func (src *Decimal) AssignTo(dst interface{}) error {
return (*Numeric)(src).AssignTo(dst)
}
func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error {
return (*Numeric)(dst).DecodeText(ci, src)
}
func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error {
return (*Numeric)(dst).DecodeBinary(ci, src)
}
func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*Numeric)(src).EncodeText(ci, buf)
}
func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
return (*Numeric)(src).EncodeBinary(ci, buf)
}

View File

@ -14,6 +14,20 @@ func TestHstoreArrayTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
var hstoreOID pgtype.OID
err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID)
if err != nil {
t.Fatalf("did not find hstore OID, %v", err)
}
conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID})
var hstoreArrayOID pgtype.OID
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID)
if err != nil {
t.Fatalf("did not find _hstore OID, %v", err)
}
conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID})
text := func(s string) pgtype.Text {
return pgtype.Text{String: s, Status: pgtype.Present}
}

View File

@ -11,7 +11,7 @@ import (
const (
BoolOID = 16
ByteaOID = 17
CharOID = 18
QCharOID = 18
NameOID = 19
Int8OID = 20
Int2OID = 21
@ -22,11 +22,19 @@ const (
XIDOID = 28
CIDOID = 29
JSONOID = 114
PointOID = 600
LsegOID = 601
PathOID = 602
BoxOID = 603
PolygonOID = 604
LineOID = 628
CIDROID = 650
CIDRArrayOID = 651
Float4OID = 700
Float8OID = 701
CircleOID = 718
UnknownOID = 705
MacaddrOID = 829
InetOID = 869
BoolArrayOID = 1000
Int2ArrayOID = 1005
@ -49,11 +57,21 @@ const (
DateArrayOID = 1182
TimestamptzOID = 1184
TimestamptzArrayOID = 1185
IntervalOID = 1186
NumericArrayOID = 1231
BitOID = 1560
VarbitOID = 1562
NumericOID = 1700
RecordOID = 2249
UUIDOID = 2950
UUIDArrayOID = 2951
JSONBOID = 3802
DaterangeOID = 3812
Int4rangeOID = 3904
NumrangeOID = 3906
TsrangeOID = 3908
TstzrangeOID = 3910
Int8rangeOID = 3926
)
type Status byte
@ -155,11 +173,77 @@ type ConnInfo struct {
}
func NewConnInfo() *ConnInfo {
return &ConnInfo{
oidToDataType: make(map[OID]*DataType, 256),
nameToDataType: make(map[string]*DataType, 256),
reflectTypeToDataType: make(map[reflect.Type]*DataType, 256),
ci := &ConnInfo{
oidToDataType: make(map[OID]*DataType, 128),
nameToDataType: make(map[string]*DataType, 128),
reflectTypeToDataType: make(map[reflect.Type]*DataType, 128),
}
ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID})
ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID})
ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID})
ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID})
ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID})
ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID})
ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID})
ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID})
ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID})
ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID})
ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID})
ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID})
ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID})
ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID})
ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID})
ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID})
ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID})
ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID})
ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID})
ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID})
ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID})
ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID})
ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID})
ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID})
ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID})
ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID})
ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID})
ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID})
ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID})
ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID})
ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID})
ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID})
ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID})
ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID})
ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID})
ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID})
ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID})
ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID})
ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID})
ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID})
ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID})
ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID})
ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID})
ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID})
ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID})
ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID})
ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID})
ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID})
ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID})
ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID})
ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID})
ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID})
ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID})
ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
return ci
}
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
@ -295,7 +379,6 @@ func init() {
"circle": &Circle{},
"date": &Date{},
"daterange": &Daterange{},
"decimal": &Decimal{},
"float4": &Float4{},
"float8": &Float8{},
"hstore": &Hstore{},

View File

@ -41,6 +41,8 @@ func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string {
// - Checks the wal position of the slot on the server to make sure
// the update succeeded
func TestSimpleReplicationConnection(t *testing.T) {
t.Skipf("TODO - replication needs to be revisited when v4 churn settles down. For now just skip")
var err error
connString := os.Getenv("PGX_TEST_REPLICATION_CONN_STRING")