diff --git a/conn.go b/conn.go index 6bfe515d..d49eb436 100644 --- a/conn.go +++ b/conn.go @@ -49,8 +49,7 @@ type Conn struct { wbuf []byte preallocatedRows []connRows - int16SlicePool int16SlicePool - paramValues [][]byte + eqb extendedQueryBuilder } // PreparedStatement is a description of a prepared statement @@ -159,7 +158,6 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) - c.paramValues = make([][]byte, 0, 16) // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet @@ -610,6 +608,7 @@ optionLoop: } } + c.eqb.Reset() rows := c.getRows(sql, args) ps, ok := c.preparedStatements[sql] @@ -648,18 +647,8 @@ optionLoop: return rows, rows.err } - paramFormats := c.int16SlicePool.get(len(args)) - - var paramValues [][]byte - if len(args) > cap(c.paramValues) { - paramValues = make([][]byte, len(args)) - } else { - paramValues = c.paramValues[:len(args)] - } - for i := range args { - paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i]) - paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i]) + err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err @@ -674,22 +663,20 @@ optionLoop: } if resultFormats == nil { - resultFormats = c.int16SlicePool.get(len(ps.FieldDescriptions)) - - for i := range resultFormats { + for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - resultFormats[i] = BinaryFormatCode + c.eqb.AppendResultFormat(BinaryFormatCode) } else { - resultFormats[i] = TextFormatCode + c.eqb.AppendResultFormat(TextFormatCode) } } } + + resultFormats = c.eqb.resultFormats } - rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats) - - c.int16SlicePool.reset() + rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) return rows, rows.err } diff --git a/extended_query_builder.go b/extended_query_builder.go new file mode 100644 index 00000000..b41ba6be --- /dev/null +++ b/extended_query_builder.go @@ -0,0 +1,128 @@ +package pgx + +import ( + "database/sql/driver" + "fmt" + "reflect" + + "github.com/jackc/pgtype" +) + +type extendedQueryBuilder struct { + paramValues [][]byte + paramValueBytes []byte + paramFormats []int16 + resultFormats []int16 + + resetCount int +} + +func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) error { + f := chooseParameterFormatCode(ci, oid, arg) + eqb.paramFormats = append(eqb.paramFormats, f) + + v, err := eqb.encodeExtendedParamValue(ci, oid, arg) + if err != nil { + return err + } + eqb.paramValues = append(eqb.paramValues, v) + + return nil +} + +func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { + eqb.resultFormats = append(eqb.resultFormats, f) +} + +func (eqb *extendedQueryBuilder) Reset() { + eqb.paramValues = eqb.paramValues[0:0] + eqb.paramValueBytes = eqb.paramValueBytes[0:0] + eqb.paramFormats = eqb.paramFormats[0:0] + eqb.resultFormats = eqb.resultFormats[0:0] + + eqb.resetCount += 1 + + // Every so often shrink our reserved memory if it is abnormally high + if eqb.resetCount%128 == 0 { + if cap(eqb.paramValues) > 64 { + eqb.paramValues = make([][]byte, 0, cap(eqb.paramValues)/2) + } + + if cap(eqb.paramValueBytes) > 256 { + eqb.paramValueBytes = make([]byte, 0, cap(eqb.paramValueBytes)/2) + } + + if cap(eqb.paramFormats) > 64 { + eqb.paramFormats = make([]int16, 0, cap(eqb.paramFormats)/2) + } + if cap(eqb.resultFormats) > 64 { + eqb.resultFormats = make([]int16, 0, cap(eqb.resultFormats)/2) + } + } + +} + +func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) ([]byte, error) { + if arg == nil { + return nil, nil + } + + if eqb.paramValueBytes == nil { + eqb.paramValueBytes = make([]byte, 0, 128) + } + + var err error + pos := len(eqb.paramValueBytes) + + switch arg := arg.(type) { + case pgtype.BinaryEncoder: + eqb.paramValueBytes, err = arg.EncodeBinary(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + return eqb.paramValueBytes[pos:], nil + case pgtype.TextEncoder: + eqb.paramValueBytes, err = arg.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + return eqb.paramValueBytes[pos:], nil + case string: + return []byte(arg), nil + } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + return nil, nil + } + arg = refVal.Elem().Interface() + return eqb.encodeExtendedParamValue(ci, oid, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return eqb.encodeExtendedParamValue(ci, oid, v) + } + } + + return nil, err + } + + return eqb.encodeExtendedParamValue(ci, oid, value) + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return eqb.encodeExtendedParamValue(ci, oid, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +}