diff --git a/batch.go b/batch.go index 994f4e4b..dbe82a67 100644 --- a/batch.go +++ b/batch.go @@ -144,7 +144,7 @@ func (b *Batch) ExecResults() (pgconn.CommandTag, error) { // QueryResults reads the results from the next query in the batch as if the // query has been sent with Query. -func (b *Batch) QueryResults() (*Rows, error) { +func (b *Batch) QueryResults() (Rows, error) { rows := b.conn.getRows("batch query", nil) if !b.mrr.NextResult() { @@ -162,9 +162,9 @@ func (b *Batch) QueryResults() (*Rows, error) { // QueryRowResults reads the results from the next query in the batch as if the // query has been sent with QueryRow. -func (b *Batch) QueryRowResults() *Row { +func (b *Batch) QueryRowResults() Row { rows, _ := b.QueryResults() - return (*Row)(rows) + return (*connRow)(rows.(*connRows)) } diff --git a/conn.go b/conn.go index aca77dcf..06a3f266 100644 --- a/conn.go +++ b/conn.go @@ -70,7 +70,7 @@ type Conn struct { logLevel LogLevel fp *fastpath poolResetCount int - preallocatedRows []Rows + preallocatedRows []connRows mux sync.Mutex status byte // One of connStatus* constants @@ -681,7 +681,7 @@ func (c *Conn) Ping(ctx context.Context) error { return err } -func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { +func connInfoFromRows(rows Rows, err error) (map[string]pgtype.OID, error) { if err != nil { return nil, err } diff --git a/fastpath.go b/fastpath.go index 6ac81b2c..e3ce3969 100644 --- a/fastpath.go +++ b/fastpath.go @@ -25,7 +25,7 @@ func (f *fastpath) addFunction(name string, oid pgtype.OID) { f.fns[name] = oid } -func (f *fastpath) addFunctions(rows *Rows) error { +func (f *fastpath) addFunctions(rows Rows) error { for rows.Next() { var name string var oid pgtype.OID diff --git a/pool/common_test.go b/pool/common_test.go index 1ba9d0ed..e53bea8b 100644 --- a/pool/common_test.go +++ b/pool/common_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgx/pool" + "github.com/jackc/pgx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -37,7 +37,7 @@ func testExec(t *testing.T, db execer) { } type queryer interface { - Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (*pool.Rows, error) + Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) } func testQuery(t *testing.T, db queryer) { @@ -59,7 +59,7 @@ func testQuery(t *testing.T, db queryer) { } type queryRower interface { - QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) *pool.Row + QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row } func testQueryRow(t *testing.T, db queryRower) { diff --git a/pool/conn.go b/pool/conn.go index 81f5a625..86dc9507 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -50,23 +50,19 @@ func (c *Conn) Release() { } func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - conn := c.res.Value().(*pgx.Conn) - return conn.Exec(ctx, sql, arguments...) + return c.Conn().Exec(ctx, sql, arguments...) } -func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (*Rows, error) { - r, err := c.res.Value().(*pgx.Conn).Query(ctx, sql, optionsAndArgs...) - rows := &Rows{r: r, err: err} - return rows, err +func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) { + return c.Conn().Query(ctx, sql, optionsAndArgs...) } -func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) *Row { - r := c.res.Value().(*pgx.Conn).QueryRow(ctx, sql, optionsAndArgs...) - return &Row{r: r} +func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { + return c.Conn().QueryRow(ctx, sql, optionsAndArgs...) } func (c *Conn) Begin() (*pgx.Tx, error) { - return c.res.Value().(*pgx.Conn).Begin() + return c.Conn().Begin() } func (c *Conn) Conn() *pgx.Conn { diff --git a/pool/pool.go b/pool/pool.go index 157007c7..11401de8 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -68,31 +68,29 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) ( return c.Exec(ctx, sql, arguments...) } -func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (*Rows, error) { +func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) { c, err := p.Acquire(ctx) if err != nil { - return &Rows{err: err}, err + return errRows{err: err}, err } rows, err := c.Query(ctx, sql, optionsAndArgs...) - if err == nil { - rows.c = c - } else { + if err != nil { c.Release() + return errRows{err: err}, err } - return rows, err + return &poolRows{r: rows, c: c}, nil } -func (p *Pool) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) *Row { +func (p *Pool) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { c, err := p.Acquire(ctx) if err != nil { - return &Row{err: err} + return errRow{err: err} } row := c.QueryRow(ctx, sql, optionsAndArgs...) - row.c = c - return row + return &poolRow{r: row, c: c} } func (p *Pool) Begin() (*Tx, error) { diff --git a/pool/rows.go b/pool/rows.go index 340ea54e..43a3192b 100644 --- a/pool/rows.go +++ b/pool/rows.go @@ -4,13 +4,30 @@ import ( "github.com/jackc/pgx" ) -type Rows struct { - r *pgx.Rows +type errRows struct { + err error +} + +func (errRows) Close() {} +func (e errRows) Err() error { return e.err } +func (errRows) FieldDescriptions() []pgx.FieldDescription { return nil } +func (errRows) Next() bool { return false } +func (e errRows) Scan(dest ...interface{}) error { return e.err } +func (e errRows) Values() ([]interface{}, error) { return nil, e.err } + +type errRow struct { + err error +} + +func (e errRow) Scan(dest ...interface{}) error { return e.err } + +type poolRows struct { + r pgx.Rows c *Conn err error } -func (rows *Rows) Close() { +func (rows *poolRows) Close() { rows.r.Close() if rows.c != nil { rows.c.Release() @@ -18,18 +35,18 @@ func (rows *Rows) Close() { } } -func (rows *Rows) Err() error { +func (rows *poolRows) Err() error { if rows.err != nil { return rows.err } return rows.r.Err() } -func (rows *Rows) FieldDescriptions() []pgx.FieldDescription { +func (rows *poolRows) FieldDescriptions() []pgx.FieldDescription { return rows.r.FieldDescriptions() } -func (rows *Rows) Next() bool { +func (rows *poolRows) Next() bool { if rows.err != nil { return false } @@ -41,7 +58,7 @@ func (rows *Rows) Next() bool { return n } -func (rows *Rows) Scan(dest ...interface{}) error { +func (rows *poolRows) Scan(dest ...interface{}) error { err := rows.r.Scan(dest...) if err != nil { rows.Close() @@ -49,7 +66,7 @@ func (rows *Rows) Scan(dest ...interface{}) error { return err } -func (rows *Rows) Values() ([]interface{}, error) { +func (rows *poolRows) Values() ([]interface{}, error) { values, err := rows.r.Values() if err != nil { rows.Close() @@ -57,13 +74,13 @@ func (rows *Rows) Values() ([]interface{}, error) { return values, err } -type Row struct { - r *pgx.Row +type poolRow struct { + r pgx.Row c *Conn err error } -func (row *Row) Scan(dest ...interface{}) error { +func (row *poolRow) Scan(dest ...interface{}) error { if row.err != nil { return row.err } diff --git a/pool/tx.go b/pool/tx.go index 2898a21d..4ab1c2f9 100644 --- a/pool/tx.go +++ b/pool/tx.go @@ -38,10 +38,10 @@ func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (p return tx.c.Exec(ctx, sql, arguments...) } -func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (*Rows, error) { +func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) { return tx.c.Query(ctx, sql, optionsAndArgs...) } -func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) *Row { +func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { return tx.c.QueryRow(ctx, sql, optionsAndArgs...) } diff --git a/query.go b/query.go index 53fb1e72..0bdaa440 100644 --- a/query.go +++ b/query.go @@ -14,14 +14,45 @@ import ( "github.com/jackc/pgx/pgtype" ) -// Row is a convenience wrapper over Rows that is returned by QueryRow. -type Row Rows +// Rows is the result set returned from *Conn.Query. Rows must be closed before +// the *Conn can be used again. Rows are closed by explicitly calling Close(), +// calling Next() until it returns false, or when a fatal error occurs. +type Rows interface { + // Close closes the rows, making the connection ready for use again. It is safe + // to call Close after rows is already closed. + Close() -// Scan works the same as (*Rows Scan) with the following exceptions. If no -// rows were found it returns ErrNoRows. If multiple rows are returned it -// ignores all but the first. -func (r *Row) Scan(dest ...interface{}) (err error) { - rows := (*Rows)(r) + Err() error + FieldDescriptions() []FieldDescription + + // Next prepares the next row for reading. It returns true if there is another + // row and false if no more rows are available. It automatically closes rows + // when all rows are read. + Next() bool + + // Scan reads the values from the current row into dest values positionally. + // dest can include pointers to core types, values implementing the Scanner + // interface, []byte, and nil. []byte will skip the decoding process and directly + // copy the raw bytes received from PostgreSQL. nil will skip the value entirely. + Scan(dest ...interface{}) error + + // Values returns an array of the row values + Values() ([]interface{}, error) +} + +// Row is a convenience wrapper over Rows that is returned by QueryRow. +type Row interface { + // Scan works the same as Rows. with the following exceptions. If no + // rows were found it returns ErrNoRows. If multiple rows are returned it + // ignores all but the first. + Scan(dest ...interface{}) error +} + +// connRow implements the Row interface for Conn.QueryRow. +type connRow connRows + +func (r *connRow) Scan(dest ...interface{}) (err error) { + rows := (*connRows)(r) if rows.Err() != nil { return rows.Err() @@ -39,10 +70,8 @@ func (r *Row) Scan(dest ...interface{}) (err error) { return rows.Err() } -// Rows is the result set returned from *Conn.Query. Rows must be closed before -// the *Conn can be used again. Rows are closed by explicitly calling Close(), -// calling Next() until it returns false, or when a fatal error occurs. -type Rows struct { +// connRows implements the Rows interface for Conn.Query. +type connRows struct { conn *Conn batch *Batch values [][]byte @@ -60,13 +89,11 @@ type Rows struct { multiResultReader *pgconn.MultiResultReader } -func (rows *Rows) FieldDescriptions() []FieldDescription { +func (rows *connRows) FieldDescriptions() []FieldDescription { return rows.fields } -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { +func (rows *connRows) Close() { if rows.closed { return } @@ -106,13 +133,13 @@ func (rows *Rows) Close() { } } -func (rows *Rows) Err() error { +func (rows *connRows) Err() error { return rows.err } // fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. -func (rows *Rows) fatal(err error) { +func (rows *connRows) fatal(err error) { if rows.err != nil { return } @@ -121,10 +148,7 @@ func (rows *Rows) fatal(err error) { rows.Close() } -// Next prepares the next row for reading. It returns true if there is another -// row and false if no more rows are available. It automatically closes rows -// when all rows are read. -func (rows *Rows) Next() bool { +func (rows *connRows) Next() bool { if rows.closed { return false } @@ -147,7 +171,7 @@ func (rows *Rows) Next() bool { } } -func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { +func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) { if rows.closed { return nil, nil, false } @@ -162,11 +186,7 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { return buf, fd, true } -// Scan reads the values from the current row into dest values positionally. -// dest can include pointers to core types, values implementing the Scanner -// interface, []byte, and nil. []byte will skip the decoding process and directly -// copy the raw bytes received from PostgreSQL. nil will skip the value entirely. -func (rows *Rows) Scan(dest ...interface{}) (err error) { +func (rows *connRows) Scan(dest ...interface{}) (err error) { if len(rows.fields) != len(dest) { err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) rows.fatal(err) @@ -243,8 +263,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { return nil } -// Values returns an array of the row values -func (rows *Rows) Values() ([]interface{}, error) { +func (rows *connRows) Values() ([]interface{}, error) { if rows.closed { return nil, errors.New("rows is closed") } @@ -307,9 +326,9 @@ func (e scanArgError) Error() string { return fmt.Sprintf("can't scan into dest[%d]: %v", e.col, e.err) } -func (c *Conn) getRows(sql string, args []interface{}) *Rows { +func (c *Conn) getRows(sql string, args []interface{}) *connRows { if len(c.preallocatedRows) == 0 { - c.preallocatedRows = make([]Rows, 64) + c.preallocatedRows = make([]connRows, 64) } r := &c.preallocatedRows[len(c.preallocatedRows)-1] @@ -333,10 +352,9 @@ type QueryExOptions struct { SimpleProtocol bool } -// Query executes sql with args. If there is an error the returned *Rows will -// be returned in an error state. So it is allowed to ignore the error returned -// from Query and handle it in *Rows. -func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (rows *Rows, err error) { +// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is +// allowed to ignore the error returned from Query and handle it in Rows. +func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) { c.lastStmtSent = false // rows = c.getRows(sql, args) @@ -349,7 +367,7 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac } } - rows = &Rows{ + rows := &connRows{ conn: c, startTime: time.Now(), sql: sql, @@ -368,6 +386,7 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac // return rows, rows.err // } + var err error if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { @@ -519,9 +538,9 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, } // QueryRow is a convenience wrapper over Query. Any error that occurs while -// querying is deferred until calling Scan on the returned *Row. That *Row will +// querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. -func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) *Row { +func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row { rows, _ := c.Query(ctx, sql, optionsAndArgs...) - return (*Row)(rows) + return (*connRow)(rows.(*connRows)) } diff --git a/replication.go b/replication.go index cfe7583a..60f8bcfc 100644 --- a/replication.go +++ b/replication.go @@ -340,7 +340,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl // NOTE: Because this is a replication mode connection, we don't have // type names, so the field descriptions in the result will have only // OIDs and no DataTypeName values -func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { +func (rc *ReplicationConn) IdentifySystem() (r Rows, err error) { return nil, errors.New("TODO") // return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") } @@ -356,7 +356,7 @@ func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { // NOTE: Because this is a replication mode connection, we don't have // type names, so the field descriptions in the result will have only // OIDs and no DataTypeName values -func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { +func (rc *ReplicationConn) TimelineHistory(timeline int) (r Rows, err error) { return nil, errors.New("TODO") // return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) } diff --git a/tx.go b/tx.go index def6dbad..19b9159b 100644 --- a/tx.go +++ b/tx.go @@ -172,20 +172,20 @@ func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOp } // Query delegates to the underlying *Conn -func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (*Rows, error) { +func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) { if tx.status != TxStatusInProgress { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed - return &Rows{closed: true, err: err}, err + return &connRows{closed: true, err: err}, err } return tx.conn.Query(ctx, sql, optionsAndArgs...) } // QueryRow delegates to the underlying *Conn -func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) *Row { +func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row { rows, _ := tx.Query(ctx, sql, optionsAndArgs...) - return (*Row)(rows) + return (*connRow)(rows.(*connRows)) } // CopyFrom delegates to the underlying *Conn