Remove extra prepare in stdlib

pull/483/head
Jack Christensen 2019-04-27 15:45:30 -05:00
parent 71d8503b81
commit 243f9031b3
2 changed files with 37 additions and 62 deletions

15
conn.go
View File

@ -580,14 +580,19 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
return r
}
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
type QueryResultFormats []int16
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
type QueryResultFormatsByOID map[pgtype.OID]int16
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
// allowed to ignore the error returned from Query and handle it in Rows.
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
// rows = c.getRows(sql, args)
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
optionLoop:
for len(args) > 0 {
@ -595,6 +600,9 @@ optionLoop:
case QueryResultFormats:
resultFormats = arg
args = args[1:]
case QueryResultFormatsByOID:
resultFormatsByOID = arg
args = args[1:]
default:
break optionLoop
}
@ -655,6 +663,13 @@ optionLoop:
}
}
if resultFormatsByOID != nil {
resultFormats = make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[ps.FieldDescriptions[i].DataType]
}
}
if resultFormats == nil {
resultFormats = make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {

View File

@ -87,9 +87,8 @@ import (
"github.com/jackc/pgx/v4"
)
// oids that map to intrinsic database/sql types. These will be allowed to be
// binary, anything else will be forced to text format
var databaseSqlOIDs map[pgtype.OID]bool
// Only intrinsic types should be binary format with database/sql.
var databaseSQLResultFormats pgx.QueryResultFormatsByOID
var pgxDriver *Driver
@ -104,20 +103,21 @@ func init() {
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
databaseSqlOIDs = make(map[pgtype.OID]bool)
databaseSqlOIDs[pgtype.BoolOID] = true
databaseSqlOIDs[pgtype.ByteaOID] = true
databaseSqlOIDs[pgtype.CIDOID] = true
databaseSqlOIDs[pgtype.DateOID] = true
databaseSqlOIDs[pgtype.Float4OID] = true
databaseSqlOIDs[pgtype.Float8OID] = true
databaseSqlOIDs[pgtype.Int2OID] = true
databaseSqlOIDs[pgtype.Int4OID] = true
databaseSqlOIDs[pgtype.Int8OID] = true
databaseSqlOIDs[pgtype.OIDOID] = true
databaseSqlOIDs[pgtype.TimestampOID] = true
databaseSqlOIDs[pgtype.TimestamptzOID] = true
databaseSqlOIDs[pgtype.XIDOID] = true
databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
pgtype.BoolOID: 1,
pgtype.ByteaOID: 1,
pgtype.CIDOID: 1,
pgtype.DateOID: 1,
pgtype.Float4OID: 1,
pgtype.Float8OID: 1,
pgtype.Int2OID: 1,
pgtype.Int4OID: 1,
pgtype.Int8OID: 1,
pgtype.OIDOID: 1,
pgtype.TimestampOID: 1,
pgtype.TimestamptzOID: 1,
pgtype.XIDOID: 1,
}
}
var (
@ -168,8 +168,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return &Stmt{ps: ps, conn: c}, nil
}
@ -241,48 +239,22 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
return nil, driver.ErrBadConn
}
// TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name
psname := fmt.Sprintf("stdlibpx%v", &argsV)
ps, err := c.conn.Prepare(ctx, psname, query)
if err != nil {
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, psname, argsV)
return c.queryPreparedContext(ctx, query, argsV)
}
// func (c *Conn) execParams(ctx context.Context, sql string, argsV []driver.NamedValue) (*pgconn.ResultReader, error) {
// if !c.conn.IsAlive() {
// return nil, driver.ErrBadConn
// }
// paramValues := make([][]byte, len(argsV))
// for i := 0;i< len(paramValues); i++ {
// v := argsV[i].Value
// paramValues
// }
// return c.conn.PgConn().ExecParams(ctx, sql,paramValues, nil, nil, nil)
// }
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
// TODO - don't always use text
args := []interface{}{pgx.QueryResultFormats{0}}
args := []interface{}{databaseSQLResultFormats}
args = append(args, namedValueToInterface(argsV)...)
rows, err := c.conn.Query(ctx, name, args...)
if err != nil {
if errors.Is(err, pgconn.ErrNoBytesSent) {
return nil, driver.ErrBadConn
}
return nil, err
}
@ -299,18 +271,6 @@ func (c *Conn) Ping(ctx context.Context) error {
return c.conn.Ping(ctx)
}
// Anything that isn't a database/sql compatible type needs to be forced to
// text format so that pgx.Rows.Values doesn't decode it into a native type
// (e.g. []int32)
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
for i := range ps.FieldDescriptions {
intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType]
if !intrinsic {
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
}
}
}
type Stmt struct {
ps *pgx.PreparedStatement
conn *Conn