diff --git a/conn.go b/conn.go index d46c3b4c..3c7c413f 100644 --- a/conn.go +++ b/conn.go @@ -2,9 +2,6 @@ package pgx import ( "context" - "database/sql/driver" - "fmt" - "reflect" "strings" "time" @@ -393,6 +390,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + c.eqb.Reset() if ps, ok := c.preparedStatements[sql]; ok { args, err := convertDriverValuers(arguments) @@ -400,28 +398,24 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( return nil, err } - paramFormats := make([]int16, len(args)) - paramValues := make([][]byte, len(args)) for i := range args { - paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i]) - paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i]) + err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) if err != nil { return nil, err } } - resultFormats := make([]int16, len(ps.FieldDescriptions)) - for i := range resultFormats { + for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - resultFormats[i] = BinaryFormatCode + c.eqb.AppendResultFormat(BinaryFormatCode) } else { - resultFormats[i] = TextFormatCode + c.eqb.AppendResultFormat(TextFormatCode) } } } - result := c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats).Read() + result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() return result.CommandTag, result.Err } @@ -461,87 +455,29 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( return nil, err } - paramFormats := make([]int16, len(arguments)) - paramValues := make([][]byte, len(arguments)) for i := range arguments { - paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) - paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) + err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) if err != nil { return nil, err } - } - resultFormats := make([]int16, len(ps.FieldDescriptions)) - for i := range resultFormats { + for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - resultFormats[i] = BinaryFormatCode + c.eqb.AppendResultFormat(BinaryFormatCode) } else { - resultFormats[i] = TextFormatCode + c.eqb.AppendResultFormat(TextFormatCode) } } } - result := c.pgConn.ExecPrepared(ctx, psd.Name, paramValues, paramFormats, resultFormats).Read() + result := c.pgConn.ExecPrepared(ctx, psd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() return result.CommandTag, result.Err } } -func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) { - if arg == nil { - return nil, nil - } - - // TODO - don't allocate a new buf for each encoded prepared statement. The empty slice is necessary because otherwise empty strings may be encoded as []byte(nil) instead of []byte{} - buf := make([]byte, 0) - - switch arg := arg.(type) { - case pgtype.BinaryEncoder: - return arg.EncodeBinary(ci, buf) - case pgtype.TextEncoder: - return arg.EncodeText(ci, buf) - case string: - return []byte(arg), nil - } - - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { - return nil, nil - } - arg = refVal.Elem().Interface() - return newencodePreparedStatementArgument(ci, oid, arg) - } - - if dt, ok := ci.DataTypeForOID(oid); ok { - value := dt.Value - err := value.Set(arg) - if err != nil { - { - if arg, ok := arg.(driver.Valuer); ok { - v, err := callValuerValue(arg) - if err != nil { - return nil, err - } - return newencodePreparedStatementArgument(ci, oid, v) - } - } - - return nil, err - } - - return value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return newencodePreparedStatementArgument(ci, oid, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) -} - func (c *Conn) getRows(sql string, args []interface{}) *connRows { if len(c.preallocatedRows) == 0 { c.preallocatedRows = make([]connRows, 64) @@ -669,6 +605,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { batch := &pgconn.Batch{} for _, bi := range b.items { + c.eqb.Reset() + var parameterOIDs []pgtype.OID ps := c.preparedStatements[bi.query] @@ -683,11 +621,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { return &batchResults{err: err} } - paramFormats := make([]int16, len(args)) - paramValues := make([][]byte, len(args)) for i := range args { - paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i]) - paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i]) + err = c.eqb.AppendParam(c.ConnInfo, parameterOIDs[i], args[i]) if err != nil { return &batchResults{err: err} } @@ -697,25 +632,27 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { if ps != nil { resultFormats := bi.resultFormatCodes if resultFormats == nil { - resultFormats = make([]int16, len(ps.FieldDescriptions)) - for i := range resultFormats { + + for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - resultFormats[i] = BinaryFormatCode + c.eqb.AppendResultFormat(BinaryFormatCode) } else { - resultFormats[i] = TextFormatCode + c.eqb.AppendResultFormat(TextFormatCode) } } } + + resultFormats = c.eqb.resultFormats } - batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats) + batch.ExecPrepared(ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) } else { oids := make([]uint32, len(parameterOIDs)) for i := 0; i < len(parameterOIDs); i++ { oids[i] = uint32(parameterOIDs[i]) } - batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes) + batch.ExecParams(bi.query, c.eqb.paramValues, oids, c.eqb.paramFormats, bi.resultFormatCodes) } }