Use extended query builder everywhere

pull/586/head
Jack Christensen 2019-05-17 13:59:41 -05:00
parent c418d45f75
commit b0dac84d77
1 changed files with 22 additions and 85 deletions

107
conn.go
View File

@ -2,9 +2,6 @@ package pgx
import (
"context"
"database/sql/driver"
"fmt"
"reflect"
"strings"
"time"
@ -393,6 +390,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
c.eqb.Reset()
if ps, ok := c.preparedStatements[sql]; ok {
args, err := convertDriverValuers(arguments)
@ -400,28 +398,24 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
return nil, err
}
paramFormats := make([]int16, len(args))
paramValues := make([][]byte, len(args))
for i := range args {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i])
err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i])
if err != nil {
return nil, err
}
}
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
c.eqb.AppendResultFormat(BinaryFormatCode)
} else {
resultFormats[i] = TextFormatCode
c.eqb.AppendResultFormat(TextFormatCode)
}
}
}
result := c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats).Read()
result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
return result.CommandTag, result.Err
}
@ -461,87 +455,29 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
return nil, err
}
paramFormats := make([]int16, len(arguments))
paramValues := make([][]byte, len(arguments))
for i := range arguments {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], arguments[i])
if err != nil {
return nil, err
}
}
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
c.eqb.AppendResultFormat(BinaryFormatCode)
} else {
resultFormats[i] = TextFormatCode
c.eqb.AppendResultFormat(TextFormatCode)
}
}
}
result := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats).Read()
result := c.pgConn.ExecPrepared(ctx, psd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
return result.CommandTag, result.Err
}
}
func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) {
if arg == nil {
return nil, nil
}
// TODO - don't allocate a new buf for each encoded prepared statement. The empty slice is necessary because otherwise empty strings may be encoded as []byte(nil) instead of []byte{}
buf := make([]byte, 0)
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
return arg.EncodeBinary(ci, buf)
case pgtype.TextEncoder:
return arg.EncodeText(ci, buf)
case string:
return []byte(arg), nil
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
return nil, nil
}
arg = refVal.Elem().Interface()
return newencodePreparedStatementArgument(ci, oid, arg)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return newencodePreparedStatementArgument(ci, oid, v)
}
}
return nil, err
}
return value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return newencodePreparedStatementArgument(ci, oid, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
func (c *Conn) getRows(sql string, args []interface{}) *connRows {
if len(c.preallocatedRows) == 0 {
c.preallocatedRows = make([]connRows, 64)
@ -669,6 +605,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
batch := &pgconn.Batch{}
for _, bi := range b.items {
c.eqb.Reset()
var parameterOIDs []pgtype.OID
ps := c.preparedStatements[bi.query]
@ -683,11 +621,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
return &batchResults{err: err}
}
paramFormats := make([]int16, len(args))
paramValues := make([][]byte, len(args))
for i := range args {
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i])
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i])
err = c.eqb.AppendParam(c.ConnInfo, parameterOIDs[i], args[i])
if err != nil {
return &batchResults{err: err}
}
@ -697,25 +632,27 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
if ps != nil {
resultFormats := bi.resultFormatCodes
if resultFormats == nil {
resultFormats = make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
c.eqb.AppendResultFormat(BinaryFormatCode)
} else {
resultFormats[i] = TextFormatCode
c.eqb.AppendResultFormat(TextFormatCode)
}
}
}
resultFormats = c.eqb.resultFormats
}
batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats)
batch.ExecPrepared(ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
} else {
oids := make([]uint32, len(parameterOIDs))
for i := 0; i < len(parameterOIDs); i++ {
oids[i] = uint32(parameterOIDs[i])
}
batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes)
batch.ExecParams(bi.query, c.eqb.paramValues, oids, c.eqb.paramFormats, bi.resultFormatCodes)
}
}