diff --git a/conn.go b/conn.go index 15a19501..dbfc4017 100644 --- a/conn.go +++ b/conn.go @@ -586,3 +586,119 @@ func (c *Conn) pgproto3FieldDescriptionToPgxFieldDescription(src *pgproto3.Field dst.DataTypeName = dt.Name } } + +func (c *Conn) getRows(sql string, args []interface{}) *connRows { + if len(c.preallocatedRows) == 0 { + c.preallocatedRows = make([]connRows, 64) + } + + r := &c.preallocatedRows[len(c.preallocatedRows)-1] + c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] + + r.conn = c + r.startTime = time.Now() + r.sql = sql + r.args = args + + return r +} + +type QueryResultFormats []int16 + +// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is +// allowed to ignore the error returned from Query and handle it in Rows. +func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + // rows = c.getRows(sql, args) + + var resultFormats QueryResultFormats + +optionLoop: + for len(args) > 0 { + switch arg := args[0].(type) { + case QueryResultFormats: + resultFormats = arg + args = args[1:] + default: + break optionLoop + } + } + + rows := &connRows{ + conn: c, + startTime: time.Now(), + sql: sql, + args: args, + } + + ps, ok := c.preparedStatements[sql] + if !ok { + psd, err := c.pgConn.Prepare(ctx, "", sql, nil) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + if len(psd.ParamOIDs) != len(args) { + rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args))) + return rows, rows.err + } + + ps = &PreparedStatement{ + Name: psd.Name, + SQL: psd.SQL, + ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), + FieldDescriptions: make([]FieldDescription, len(psd.Fields)), + } + + for i := range ps.ParameterOIDs { + ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) + } + for i := range ps.FieldDescriptions { + c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) + } + } + rows.sql = ps.SQL + + var err error + args, err = convertDriverValuers(args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + paramFormats := make([]int16, len(args)) + paramValues := make([][]byte, len(args)) + for i := range args { + paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i]) + paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i]) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + + if resultFormats == nil { + resultFormats = make([]int16, len(ps.FieldDescriptions)) + for i := range resultFormats { + if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + resultFormats[i] = BinaryFormatCode + } else { + resultFormats[i] = TextFormatCode + } + } + } + } + + rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats) + + return rows, rows.err +} + +// QueryRow is a convenience wrapper over Query. Any error that occurs while +// querying is deferred until calling Scan on the returned Row. That Row will +// error with ErrNoRows if no rows are returned. +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := c.Query(ctx, sql, args...) + return (*connRow)(rows.(*connRows)) +} diff --git a/query.go b/rows.go similarity index 66% rename from query.go rename to rows.go index 4d372e06..dece31ca 100644 --- a/query.go +++ b/rows.go @@ -1,7 +1,6 @@ package pgx import ( - "context" "fmt" "reflect" "time" @@ -264,119 +263,3 @@ type scanArgError struct { func (e scanArgError) Error() string { return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) } - -func (c *Conn) getRows(sql string, args []interface{}) *connRows { - if len(c.preallocatedRows) == 0 { - c.preallocatedRows = make([]connRows, 64) - } - - r := &c.preallocatedRows[len(c.preallocatedRows)-1] - c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] - - r.conn = c - r.startTime = time.Now() - r.sql = sql - r.args = args - - return r -} - -type QueryResultFormats []int16 - -// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is -// allowed to ignore the error returned from Query and handle it in Rows. -func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { - // rows = c.getRows(sql, args) - - var resultFormats QueryResultFormats - -optionLoop: - for len(args) > 0 { - switch arg := args[0].(type) { - case QueryResultFormats: - resultFormats = arg - args = args[1:] - default: - break optionLoop - } - } - - rows := &connRows{ - conn: c, - startTime: time.Now(), - sql: sql, - args: args, - } - - ps, ok := c.preparedStatements[sql] - if !ok { - psd, err := c.pgConn.Prepare(ctx, "", sql, nil) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - - if len(psd.ParamOIDs) != len(args) { - rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args))) - return rows, rows.err - } - - ps = &PreparedStatement{ - Name: psd.Name, - SQL: psd.SQL, - ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), - FieldDescriptions: make([]FieldDescription, len(psd.Fields)), - } - - for i := range ps.ParameterOIDs { - ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) - } - for i := range ps.FieldDescriptions { - c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) - } - } - rows.sql = ps.SQL - - var err error - args, err = convertDriverValuers(args) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - - paramFormats := make([]int16, len(args)) - paramValues := make([][]byte, len(args)) - for i := range args { - paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, ps.ParameterOIDs[i], args[i]) - paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, ps.ParameterOIDs[i], args[i]) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - } - - if resultFormats == nil { - resultFormats = make([]int16, len(ps.FieldDescriptions)) - for i := range resultFormats { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - resultFormats[i] = BinaryFormatCode - } else { - resultFormats[i] = TextFormatCode - } - } - } - } - - rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, paramValues, paramFormats, resultFormats) - - return rows, rows.err -} - -// QueryRow is a convenience wrapper over Query. Any error that occurs while -// querying is deferred until calling Scan on the returned Row. That Row will -// error with ErrNoRows if no rows are returned. -func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { - rows, _ := c.Query(ctx, sql, args...) - return (*connRow)(rows.(*connRows)) -}