From 243f9031b3dd25b8ed3f8b28cd47fece644f7e50 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Apr 2019 15:45:30 -0500 Subject: [PATCH] Remove extra prepare in stdlib --- conn.go | 15 +++++++++ stdlib/sql.go | 84 ++++++++++++++------------------------------------- 2 files changed, 37 insertions(+), 62 deletions(-) diff --git a/conn.go b/conn.go index 80d040db..70f608f0 100644 --- a/conn.go +++ b/conn.go @@ -580,14 +580,19 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows { return r } +// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 +// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. +type QueryResultFormatsByOID map[pgtype.OID]int16 + // Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is // allowed to ignore the error returned from Query and handle it in Rows. func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { // rows = c.getRows(sql, args) var resultFormats QueryResultFormats + var resultFormatsByOID QueryResultFormatsByOID optionLoop: for len(args) > 0 { @@ -595,6 +600,9 @@ optionLoop: case QueryResultFormats: resultFormats = arg args = args[1:] + case QueryResultFormatsByOID: + resultFormatsByOID = arg + args = args[1:] default: break optionLoop } @@ -655,6 +663,13 @@ optionLoop: } } + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(ps.FieldDescriptions)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[ps.FieldDescriptions[i].DataType] + } + } + if resultFormats == nil { resultFormats = make([]int16, len(ps.FieldDescriptions)) for i := range resultFormats { diff --git a/stdlib/sql.go b/stdlib/sql.go index 0cfafe2c..df2d0572 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -87,9 +87,8 @@ import ( "github.com/jackc/pgx/v4" ) -// oids that map to intrinsic database/sql types. These will be allowed to be -// binary, anything else will be forced to text format -var databaseSqlOIDs map[pgtype.OID]bool +// Only intrinsic types should be binary format with database/sql. +var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver @@ -104,20 +103,21 @@ func init() { fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) - databaseSqlOIDs = make(map[pgtype.OID]bool) - databaseSqlOIDs[pgtype.BoolOID] = true - databaseSqlOIDs[pgtype.ByteaOID] = true - databaseSqlOIDs[pgtype.CIDOID] = true - databaseSqlOIDs[pgtype.DateOID] = true - databaseSqlOIDs[pgtype.Float4OID] = true - databaseSqlOIDs[pgtype.Float8OID] = true - databaseSqlOIDs[pgtype.Int2OID] = true - databaseSqlOIDs[pgtype.Int4OID] = true - databaseSqlOIDs[pgtype.Int8OID] = true - databaseSqlOIDs[pgtype.OIDOID] = true - databaseSqlOIDs[pgtype.TimestampOID] = true - databaseSqlOIDs[pgtype.TimestamptzOID] = true - databaseSqlOIDs[pgtype.XIDOID] = true + databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ + pgtype.BoolOID: 1, + pgtype.ByteaOID: 1, + pgtype.CIDOID: 1, + pgtype.DateOID: 1, + pgtype.Float4OID: 1, + pgtype.Float8OID: 1, + pgtype.Int2OID: 1, + pgtype.Int4OID: 1, + pgtype.Int8OID: 1, + pgtype.OIDOID: 1, + pgtype.TimestampOID: 1, + pgtype.TimestamptzOID: 1, + pgtype.XIDOID: 1, + } } var ( @@ -168,8 +168,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e return nil, err } - restrictBinaryToDatabaseSqlTypes(ps) - return &Stmt{ps: ps, conn: c}, nil } @@ -241,48 +239,22 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - // TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name - psname := fmt.Sprintf("stdlibpx%v", &argsV) - - ps, err := c.conn.Prepare(ctx, psname, query) - if err != nil { - // since PrepareEx failed, we didn't actually get to send the values, so - // we can safely retry - if _, is := err.(net.Error); is { - return nil, driver.ErrBadConn - } - return nil, err - } - - restrictBinaryToDatabaseSqlTypes(ps) - return c.queryPreparedContext(ctx, psname, argsV) + return c.queryPreparedContext(ctx, query, argsV) } -// func (c *Conn) execParams(ctx context.Context, sql string, argsV []driver.NamedValue) (*pgconn.ResultReader, error) { -// if !c.conn.IsAlive() { -// return nil, driver.ErrBadConn -// } - -// paramValues := make([][]byte, len(argsV)) -// for i := 0;i< len(paramValues); i++ { -// v := argsV[i].Value -// paramValues -// } - -// return c.conn.PgConn().ExecParams(ctx, sql,paramValues, nil, nil, nil) -// } - func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } - // TODO - don't always use text - args := []interface{}{pgx.QueryResultFormats{0}} + args := []interface{}{databaseSQLResultFormats} args = append(args, namedValueToInterface(argsV)...) rows, err := c.conn.Query(ctx, name, args...) if err != nil { + if errors.Is(err, pgconn.ErrNoBytesSent) { + return nil, driver.ErrBadConn + } return nil, err } @@ -299,18 +271,6 @@ func (c *Conn) Ping(ctx context.Context) error { return c.conn.Ping(ctx) } -// Anything that isn't a database/sql compatible type needs to be forced to -// text format so that pgx.Rows.Values doesn't decode it into a native type -// (e.g. []int32) -func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { - for i := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] - if !intrinsic { - ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode - } - } -} - type Stmt struct { ps *pgx.PreparedStatement conn *Conn