diff --git a/conn.go b/conn.go index fdd649f8..1f644b66 100644 --- a/conn.go +++ b/conn.go @@ -24,21 +24,6 @@ const ( connStatusBusy ) -// minimalConnInfo has just enough static type information to establish the -// connection and retrieve the type data. -var minimalConnInfo *pgtype.ConnInfo - -func init() { - minimalConnInfo = pgtype.NewConnInfo() - minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{ - "int4": pgtype.Int4OID, - "name": pgtype.NameOID, - "oid": pgtype.OIDOID, - "text": pgtype.TextOID, - "varchar": pgtype.VarcharOID, - }) -} - // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { pgconn.Config @@ -132,12 +117,12 @@ func Connect(ctx context.Context, connString string) (*Conn, error) { if err != nil { return nil, err } - return connect(ctx, connConfig, minimalConnInfo) + return connect(ctx, connConfig) } // Connect establishes a connection with a PostgreSQL server with a configuration struct. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { - return connect(ctx, connConfig, minimalConnInfo) + return connect(ctx, connConfig) } func ParseConfig(connString string) (*ConnConfig, error) { @@ -152,11 +137,11 @@ func ParseConfig(connString string) (*ConnConfig, error) { return connConfig, nil } -func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { +func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c = new(Conn) c.config = config - c.ConnInfo = connInfo + c.ConnInfo = pgtype.NewConnInfo() if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel @@ -193,230 +178,9 @@ func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo) return c, nil } - if c.ConnInfo == minimalConnInfo { - err = c.initConnInfo() - if err != nil { - c.Close(ctx) - return nil, err - } - } - return c, nil } -func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { - const ( - namedOIDQuery = `select t.oid, - case when nsp.nspname in ('pg_catalog', 'public') then t.typname - else nsp.nspname||'.'||t.typname - end -from pg_type t -left join pg_type base_type on t.typelem=base_type.oid -left join pg_namespace nsp on t.typnamespace=nsp.oid -where ( - t.typtype in('b', 'p', 'r', 'e') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) - )` - ) - - nameOIDs, err := connInfoFromRows(c.Query(context.TODO(), namedOIDQuery)) - if err != nil { - return nil, err - } - - cinfo := pgtype.NewConnInfo() - cinfo.InitializeDataTypes(nameOIDs) - - if err = c.initConnInfoEnumArray(cinfo); err != nil { - return nil, err - } - - if err = c.initConnInfoDomains(cinfo); err != nil { - return nil, err - } - - return cinfo, nil -} - -func (c *Conn) initConnInfo() (err error) { - var ( - connInfo *pgtype.ConnInfo - ) - - if c.config.CustomConnInfo != nil { - if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil { - return err - } - - return nil - } - - if connInfo, err = initPostgresql(c); err == nil { - c.ConnInfo = connInfo - return err - } - - // Check if CrateDB specific approach might still allow us to connect. - if connInfo, err = c.crateDBTypesQuery(err); err == nil { - c.ConnInfo = connInfo - } - - return err -} - -// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them. -func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error { - nameOIDs := make(map[string]pgtype.OID, 16) - rows, err := c.Query(context.TODO(), `select t.oid, t.typname -from pg_type t - join pg_type base_type on t.typelem=base_type.oid -where t.typtype = 'b' - and base_type.typtype = 'e'`) - if err != nil { - return err - } - - for rows.Next() { - var oid pgtype.OID - var name pgtype.Text - if err := rows.Scan(&oid, &name); err != nil { - return err - } - - nameOIDs[name.String] = oid - } - - if rows.Err() != nil { - return rows.Err() - } - - for name, oid := range nameOIDs { - cinfo.RegisterDataType(pgtype.DataType{ - Value: &pgtype.EnumArray{}, - Name: name, - OID: oid, - }) - } - - return nil -} - -// initConnInfoDomains introspects for domains and registers a data type for them. -func (c *Conn) initConnInfoDomains(cinfo *pgtype.ConnInfo) error { - type domain struct { - oid pgtype.OID - name pgtype.Text - baseOID pgtype.OID - } - - domains := make([]*domain, 0, 16) - - rows, err := c.Query(context.TODO(), `select t.oid, t.typname, t.typbasetype -from pg_type t - join pg_type base_type on t.typbasetype=base_type.oid -where t.typtype = 'd' - and base_type.typtype = 'b'`) - if err != nil { - return err - } - - for rows.Next() { - var d domain - if err := rows.Scan(&d.oid, &d.name, &d.baseOID); err != nil { - return err - } - - domains = append(domains, &d) - } - - if rows.Err() != nil { - return rows.Err() - } - - for _, d := range domains { - baseDataType, ok := cinfo.DataTypeForOID(d.baseOID) - if ok { - cinfo.RegisterDataType(pgtype.DataType{ - Value: reflect.New(reflect.ValueOf(baseDataType.Value).Elem().Type()).Interface().(pgtype.Value), - Name: d.name.String, - OID: d.oid, - }) - } - } - - return nil -} - -// crateDBTypesQuery checks if the given err is likely to be the result of -// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB -// specific query against pg_types is executed and its results are returned. If -// not, the original error is returned. -func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) { - // CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol, - // but not perfectly. In particular, the pg_catalog schema containing the - // pg_type table is not visible by default and the pg_type.typtype column is - // not implemented. Therefor the query above currently returns the following - // error: - // - // pgx.PgError{Severity:"ERROR", Code:"XX000", - // Message:"TableUnknownException: Table 'test.pg_type' unknown", - // Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"", - // Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"", - // ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"} - // - // If CrateDB was to fix the pg_type table visbility in the future, we'd - // still get this error until typtype column is implemented: - // - // pgx.PgError{Severity:"ERROR", Code:"XX000", - // Message:"ColumnUnknownException: Column typtype unknown", Detail:"", - // Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"", - // SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"", - // ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132, - // - // Additionally CrateDB doesn't implement Postgres error codes [2], and - // instead always returns "XX000" (internal_error). The code below uses all - // of this knowledge as a heuristic to detect CrateDB. If CrateDB is - // detected, a CrateDB specific pg_type query is executed instead. - // - // The heuristic is designed to still work even if CrateDB fixes [2] or - // renames its internal exception names. If both are changed but pg_types - // isn't fixed, this code will need to be changed. - // - // There is also a small chance the heuristic will yield a false positive for - // non-CrateDB databases (e.g. if a real Postgres instance returns a XX000 - // error), but hopefully there will be no harm in attempting the alternative - // query in this case. - // - // CrateDB also uses the type varchar for the typname column which required - // adding varchar to the minimalConnInfo init code. - // - // Also see the discussion here [3]. - // - // [1] https://crate.io/ - // [2] https://github.com/crate/crate/issues/5027 - // [3] https://github.com/jackc/pgx/issues/320 - - if pgErr, ok := err.(*pgconn.PgError); ok && - (pgErr.Code == "XX000" || - strings.Contains(pgErr.Message, "TableUnknownException") || - strings.Contains(pgErr.Message, "ColumnUnknownException")) { - var ( - nameOIDs map[string]pgtype.OID - ) - - if nameOIDs, err = connInfoFromRows(c.Query(context.TODO(), `select oid, typname from pg_catalog.pg_type`)); err != nil { - return nil, err - } - - cinfo := pgtype.NewConnInfo() - cinfo.InitializeDataTypes(nameOIDs) - - return cinfo, err - } - - return nil, err -} - // PID returns the backend PID for this connection. func (c *Conn) PID() uint32 { return c.pgConn.PID() diff --git a/conn_test.go b/conn_test.go index a4b89cec..c0c1da64 100644 --- a/conn_test.go +++ b/conn_test.go @@ -680,26 +680,25 @@ func TestConnInitConnInfo(t *testing.T) { ensureConnValid(t, conn) } -func TestDomainType(t *testing.T) { +func TestRegisteredDomainType(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - dt, ok := conn.ConnInfo.DataTypeForName("uint64") - if !ok { - t.Fatal("Expected data type for domain uint64 to be present") - } - if dt, ok := dt.Value.(*pgtype.Numeric); !ok { - t.Fatalf("Expected data type value for domain uint64 to be *pgtype.Numeric, but it was %T", dt) + var uint64OID pgtype.OID + err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) + if err != nil { + t.Fatalf("did not find uint64 OID, %v", err) } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) var n uint64 - err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(42)).Scan(&n) + err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) if err != nil { t.Fatal(err) } - if n != 42 { - t.Fatalf("Expected n to be 42, but was %v", n) + if n != 24 { + t.Fatalf("Expected n to be 24, but was %v", n) } ensureConnValid(t, conn) diff --git a/pgtype/decimal.go b/pgtype/decimal.go deleted file mode 100644 index 79653cf3..00000000 --- a/pgtype/decimal.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type Decimal Numeric - -func (dst *Decimal) Set(src interface{}) error { - return (*Numeric)(dst).Set(src) -} - -func (dst *Decimal) Get() interface{} { - return (*Numeric)(dst).Get() -} - -func (src *Decimal) AssignTo(dst interface{}) error { - return (*Numeric)(src).AssignTo(dst) -} - -func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeText(ci, src) -} - -func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Numeric)(dst).DecodeBinary(ci, src) -} - -func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeText(ci, buf) -} - -func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Numeric)(src).EncodeBinary(ci, buf) -} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go index c8104d28..03dc2ff1 100644 --- a/pgtype/hstore_array_test.go +++ b/pgtype/hstore_array_test.go @@ -14,6 +14,20 @@ func TestHstoreArrayTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) + var hstoreOID pgtype.OID + err := conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='hstore';").Scan(&hstoreOID) + if err != nil { + t.Fatalf("did not find hstore OID, %v", err) + } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Hstore{}, Name: "hstore", OID: hstoreOID}) + + var hstoreArrayOID pgtype.OID + err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='_hstore';").Scan(&hstoreArrayOID) + if err != nil { + t.Fatalf("did not find _hstore OID, %v", err) + } + conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.HstoreArray{}, Name: "_hstore", OID: hstoreArrayOID}) + text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8f41d068..4faf23e1 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -11,7 +11,7 @@ import ( const ( BoolOID = 16 ByteaOID = 17 - CharOID = 18 + QCharOID = 18 NameOID = 19 Int8OID = 20 Int2OID = 21 @@ -22,11 +22,19 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 CIDROID = 650 CIDRArrayOID = 651 Float4OID = 700 Float8OID = 701 + CircleOID = 718 UnknownOID = 705 + MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 Int2ArrayOID = 1005 @@ -49,11 +57,21 @@ const ( DateArrayOID = 1182 TimestamptzOID = 1184 TimestamptzArrayOID = 1185 + IntervalOID = 1186 + NumericArrayOID = 1231 + BitOID = 1560 + VarbitOID = 1562 NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 UUIDArrayOID = 2951 JSONBOID = 3802 + DaterangeOID = 3812 + Int4rangeOID = 3904 + NumrangeOID = 3906 + TsrangeOID = 3908 + TstzrangeOID = 3910 + Int8rangeOID = 3926 ) type Status byte @@ -155,11 +173,77 @@ type ConnInfo struct { } func NewConnInfo() *ConnInfo { - return &ConnInfo{ - oidToDataType: make(map[OID]*DataType, 256), - nameToDataType: make(map[string]*DataType, 256), - reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + ci := &ConnInfo{ + oidToDataType: make(map[OID]*DataType, 128), + nameToDataType: make(map[string]*DataType, 128), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 128), } + + ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) + ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) + ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) + ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) + ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) + ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) + ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) + ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) + ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) + ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) + ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) + ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) + ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) + ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) + ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) + ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) + ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) + ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) + ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) + ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) + ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) + ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) + ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) + ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) + ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) + ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) + ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) + ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) + ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) + ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) + ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) + ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + + return ci } func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { @@ -295,7 +379,6 @@ func init() { "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, - "decimal": &Decimal{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, diff --git a/replication_test.go b/replication_test.go index 0574e86f..070f0ed5 100644 --- a/replication_test.go +++ b/replication_test.go @@ -41,6 +41,8 @@ func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string { // - Checks the wal position of the slot on the server to make sure // the update succeeded func TestSimpleReplicationConnection(t *testing.T) { + t.Skipf("TODO - replication needs to be revisited when v4 churn settles down. For now just skip") + var err error connString := os.Getenv("PGX_TEST_REPLICATION_CONN_STRING")