diff --git a/conn.go b/conn.go index cba6bcd3..7121bb27 100644 --- a/conn.go +++ b/conn.go @@ -469,10 +469,12 @@ where t.typtype = 'b' } for name, oid := range nameOIDs { + v := &pgtype.EnumArray{} c.ConnInfo.RegisterDataType(pgtype.DataType{ - &pgtype.EnumArray{}, - name, - oid, + Value: v, + Name: name, + OID: oid, + FormatCode: pgtype.DetermineFormatCode(v), }) } @@ -942,11 +944,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { ps.FieldDescriptions[i].DataTypeName = dt.Name - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - ps.FieldDescriptions[i].FormatCode = BinaryFormatCode - } else { - ps.FieldDescriptions[i].FormatCode = TextFormatCode - } + ps.FieldDescriptions[i].FormatCode = dt.FormatCode } else { return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index f7a1a300..23d672a4 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -53,6 +53,12 @@ const ( JSONBOID = 3802 ) +// PostgreSQL format codes +const ( + textFormatCode int16 = 0 + binaryFormatCode = 1 +) + type Status byte const ( @@ -134,9 +140,10 @@ var errUndefined = errors.New("cannot encode status undefined") var errBadStatus = errors.New("invalid status") type DataType struct { - Value Value - Name string - OID OID + Value Value + Name string + OID OID + FormatCode int16 } type ConnInfo struct { @@ -153,6 +160,16 @@ func NewConnInfo() *ConnInfo { } } +func (ci *ConnInfo) DataTypes() map[OID]DataType { + out := make(map[OID]DataType, len(ci.oidToDataType)) + for _, dt := range ci.oidToDataType { + tmp := *dt + out[dt.OID] = tmp + } + + return out +} + func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { for name, oid := range nameOIDs { var value Value @@ -161,7 +178,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { } else { value = &GenericText{} } - ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) + ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid, FormatCode: DetermineFormatCode(value)}) } } @@ -196,9 +213,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { for _, dt := range ci.oidToDataType { ci2.RegisterDataType(DataType{ - Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), - Name: dt.Name, - OID: dt.OID, + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + OID: dt.OID, + FormatCode: dt.FormatCode, }) } @@ -274,3 +292,13 @@ func init() { "xid": &XID{}, } } + +// DetermineFormatCode determines the default format code to use +// for the given value. +func DetermineFormatCode(v Value) int16 { + if _, ok := v.(BinaryDecoder); ok { + return binaryFormatCode + } + + return textFormatCode +} diff --git a/query.go b/query.go index 3576091f..9bb6bbd4 100644 --- a/query.go +++ b/query.go @@ -134,7 +134,7 @@ func (rows *Rows) Next() bool { for i := range rows.fields { if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok { rows.fields[i].DataTypeName = dt.Name - rows.fields[i].FormatCode = TextFormatCode + rows.fields[i].FormatCode = dt.FormatCode } else { rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) return false @@ -367,6 +367,9 @@ type QueryExOptions struct { } func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { + var ( + fieldDescriptions []FieldDescription + ) c.lastActivityTime = time.Now() rows = c.getRows(sql, args) @@ -376,12 +379,12 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return rows, err } - if err := c.ensureConnectionReadyForQuery(); err != nil { + if err = c.ensureConnectionReadyForQuery(); err != nil { rows.fatal(err) return rows, err } - if err := c.lock(); err != nil { + if err = c.lock(); err != nil { rows.fatal(err) return rows, err } @@ -394,12 +397,18 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { - err = c.sanitizeAndSendSimpleQuery(sql, args...) - if err != nil { + if err = c.sanitizeAndSendSimpleQuery(sql, args...); err != nil { rows.fatal(err) return rows, err } + if fieldDescriptions, err = c.readFieldDescriptions(QueryExOptions{}); err != nil { + rows.fatal(err) + return rows, err + } + + rows.sql = sql + rows.fields = fieldDescriptions return rows, nil } @@ -421,27 +430,11 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } c.pendingReadyForQueryCount++ - fieldDescriptions, err := c.readUntilRowDescription() - if err != nil { + if fieldDescriptions, err = c.readFieldDescriptions(*options); err != nil { rows.fatal(err) return rows, err } - if len(options.ResultFormatCodes) == 0 { - for i := range fieldDescriptions { - fieldDescriptions[i].FormatCode = TextFormatCode - } - } else if len(options.ResultFormatCodes) == 1 { - fc := options.ResultFormatCodes[0] - for i := range fieldDescriptions { - fieldDescriptions[i].FormatCode = fc - } - } else { - for i := range options.ResultFormatCodes { - fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] - } - } - rows.sql = sql rows.fields = fieldDescriptions return rows, nil @@ -449,7 +442,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, ps, ok := c.preparedStatements[sql] if !ok { - var err error ps, err = c.prepareEx("", sql, nil) if err != nil { rows.fatal(err) @@ -543,3 +535,27 @@ func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptio rows, _ := c.QueryEx(ctx, sql, options, args...) return (*Row)(rows) } + +func (c *Conn) readFieldDescriptions(options QueryExOptions) (fieldDescriptions []FieldDescription, err error) { + if fieldDescriptions, err = c.readUntilRowDescription(); err != nil { + return fieldDescriptions, err + } + + switch len(options.ResultFormatCodes) { + case 0: + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = TextFormatCode + } + case 1: + fc := options.ResultFormatCodes[0] + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = fc + } + default: + for i := range options.ResultFormatCodes { + fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] + } + } + + return fieldDescriptions, err +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 04369095..1da4cc1d 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -85,10 +85,6 @@ import ( "github.com/jackc/pgx/pgtype" ) -// oids that map to intrinsic database/sql types. These will be allowed to be -// binary, anything else will be forced to text format -var databaseSqlOIDs map[pgtype.OID]bool - var pgxDriver *Driver type ctxKey int @@ -97,27 +93,30 @@ var ctxKeyFakeTx ctxKey = 0 var ErrNotPgx = errors.New("not pgx *sql.DB") +// oids that map to intrinsic database/sql types. These will be allowed to be +// binary, anything else will be forced to text format +var allowedBinaryOID = []pgtype.OID{ + pgtype.BoolOID, + pgtype.ByteaOID, + pgtype.CIDOID, + pgtype.DateOID, + pgtype.Float4OID, + pgtype.Float8OID, + pgtype.Int2OID, + pgtype.Int4OID, + pgtype.Int8OID, + pgtype.OIDOID, + pgtype.TimestampOID, + pgtype.TimestamptzOID, + pgtype.XIDOID, +} + func init() { pgxDriver = &Driver{ configs: make(map[int64]*DriverConfig), fakeTxConns: make(map[*pgx.Conn]*sql.Tx), } sql.Register("pgx", pgxDriver) - - databaseSqlOIDs = make(map[pgtype.OID]bool) - databaseSqlOIDs[pgtype.BoolOID] = true - databaseSqlOIDs[pgtype.ByteaOID] = true - databaseSqlOIDs[pgtype.CIDOID] = true - databaseSqlOIDs[pgtype.DateOID] = true - databaseSqlOIDs[pgtype.Float4OID] = true - databaseSqlOIDs[pgtype.Float8OID] = true - databaseSqlOIDs[pgtype.Int2OID] = true - databaseSqlOIDs[pgtype.Int4OID] = true - databaseSqlOIDs[pgtype.Int8OID] = true - databaseSqlOIDs[pgtype.OIDOID] = true - databaseSqlOIDs[pgtype.TimestampOID] = true - databaseSqlOIDs[pgtype.TimestamptzOID] = true - databaseSqlOIDs[pgtype.XIDOID] = true } type Driver struct { @@ -130,8 +129,11 @@ type Driver struct { } func (d *Driver) Open(name string) (driver.Conn, error) { - var connConfig pgx.ConnConfig - var afterConnect func(*pgx.Conn) error + var ( + afterConnect func(*pgx.Conn) error + connConfig pgx.ConnConfig + ) + if len(name) >= 9 && name[0] == 0 { idBuf := []byte(name)[1:9] id := int64(binary.BigEndian.Uint64(idBuf)) @@ -151,6 +153,8 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return nil, err } + conn.ConnInfo = restrictBinary(conn.ConnInfo) + if afterConnect != nil { err = afterConnect(conn) if err != nil { @@ -232,8 +236,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e return nil, err } - restrictBinaryToDatabaseSqlTypes(ps) - return &Stmt{ps: ps, conn: c}, nil } @@ -299,33 +301,28 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam } func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { - if !c.conn.IsAlive() { - return nil, driver.ErrBadConn - } - - ps, err := c.conn.Prepare("", query) - if err != nil { - return nil, err - } - - restrictBinaryToDatabaseSqlTypes(ps) - - return c.queryPrepared("", argsV) + return c.query(context.Background(), query, valueToInterface(argsV)) } func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + return c.query(ctx, query, namedValueToInterface(argsV)) +} + +func (c *Conn) query(ctx context.Context, query string, args []interface{}) (driver.Rows, error) { + var ( + err error + rows *pgx.Rows + ) + if !c.conn.IsAlive() { return nil, driver.ErrBadConn } - ps, err := c.conn.PrepareEx(ctx, "", query, nil) - if err != nil { + if rows, err = c.conn.QueryEx(ctx, query, nil, args...); err != nil { return nil, err } - restrictBinaryToDatabaseSqlTypes(ps) - - return c.queryPreparedContext(ctx, "", argsV) + return &Rows{rows: rows}, nil } func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { @@ -369,13 +366,26 @@ func (c *Conn) Ping(ctx context.Context) error { // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) -func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { - for i, _ := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] - if !intrinsic { - ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode +func restrictBinary(in *pgtype.ConnInfo) (out *pgtype.ConnInfo) { + out = in.DeepCopy() + for oid, dt := range out.DataTypes() { + if textOID(oid) { + dt.FormatCode = pgx.TextFormatCode + out.RegisterDataType(dt) } } + + return out +} + +func textOID(oid pgtype.OID) bool { + for _, roid := range allowedBinaryOID { + if roid == oid { + return false + } + } + + return true } type Stmt struct { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 1880429d..c0b30080 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -81,6 +81,63 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) { } } +func TestSimpleQueryLifeCycle(t *testing.T) { + driverConfig := stdlib.DriverConfig{ + ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true}, + } + + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) + + db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + var ( + s string + n int64 + ) + + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + + if err = rows.Err(); err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +} + func TestNormalLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db)