Rename QueryResult to Rows

This helps conform closer to database/sql
This commit is contained in:
Jack Christensen 2014-07-11 08:21:29 -05:00
parent 01f261c71c
commit d7529600e0
6 changed files with 249 additions and 249 deletions

218
conn.go
View File

@ -59,7 +59,7 @@ type Conn struct {
alive bool alive bool
causeOfDeath error causeOfDeath error
logger log.Logger logger log.Logger
qr QueryResult rows Rows
mr MsgReader mr MsgReader
} }
@ -397,29 +397,29 @@ func (c *Conn) CauseOfDeath() error {
return c.causeOfDeath return c.causeOfDeath
} }
type Row QueryResult type Row Rows
func (r *Row) Scan(dest ...interface{}) (err error) { func (r *Row) Scan(dest ...interface{}) (err error) {
qr := (*QueryResult)(r) rows := (*Rows)(r)
if qr.Err() != nil { if rows.Err() != nil {
return qr.Err() return rows.Err()
} }
if !qr.NextRow() { if !rows.NextRow() {
if qr.Err() == nil { if rows.Err() == nil {
return ErrNoRows return ErrNoRows
} else { } else {
return qr.Err() return rows.Err()
} }
} }
qr.Scan(dest...) rows.Scan(dest...)
qr.Close() rows.Close()
return qr.Err() return rows.Err()
} }
type QueryResult struct { type Rows struct {
pool *ConnPool pool *ConnPool
conn *Conn conn *Conn
mr *MsgReader mr *MsgReader
@ -430,174 +430,174 @@ type QueryResult struct {
closed bool closed bool
} }
func (qr *QueryResult) FieldDescriptions() []FieldDescription { func (rows *Rows) FieldDescriptions() []FieldDescription {
return qr.fields return rows.fields
} }
func (qr *QueryResult) MsgReader() *MsgReader { func (rows *Rows) MsgReader() *MsgReader {
return qr.mr return rows.mr
} }
func (qr *QueryResult) close() { func (rows *Rows) close() {
if qr.pool != nil { if rows.pool != nil {
qr.pool.Release(qr.conn) rows.pool.Release(rows.conn)
qr.pool = nil rows.pool = nil
} }
qr.closed = true rows.closed = true
} }
func (qr *QueryResult) readUntilReadyForQuery() { func (rows *Rows) readUntilReadyForQuery() {
for { for {
t, r, err := qr.conn.rxMsg() t, r, err := rows.conn.rxMsg()
if err != nil { if err != nil {
qr.close() rows.close()
return return
} }
switch t { switch t {
case readyForQuery: case readyForQuery:
qr.conn.rxReadyForQuery(r) rows.conn.rxReadyForQuery(r)
qr.close() rows.close()
return return
case rowDescription: case rowDescription:
case dataRow: case dataRow:
case commandComplete: case commandComplete:
case bindComplete: case bindComplete:
default: default:
err = qr.conn.processContextFreeMsg(t, r) err = rows.conn.processContextFreeMsg(t, r)
if err != nil { if err != nil {
qr.close() rows.close()
return return
} }
} }
} }
} }
func (qr *QueryResult) Close() { func (rows *Rows) Close() {
if qr.closed { if rows.closed {
return return
} }
qr.readUntilReadyForQuery() rows.readUntilReadyForQuery()
qr.close() rows.close()
} }
func (qr *QueryResult) Err() error { func (rows *Rows) Err() error {
return qr.err return rows.err
} }
// abort signals that the query was not successfully sent to the server. // abort signals that the query was not successfully sent to the server.
// This differs from Fatal in that it is not necessary to readUntilReadyForQuery // This differs from Fatal in that it is not necessary to readUntilReadyForQuery
func (qr *QueryResult) abort(err error) { func (rows *Rows) abort(err error) {
if qr.err != nil { if rows.err != nil {
return return
} }
qr.err = err rows.err = err
qr.close() rows.close()
} }
// Fatal signals an error occurred after the query was sent to the server // Fatal signals an error occurred after the query was sent to the server
func (qr *QueryResult) Fatal(err error) { func (rows *Rows) Fatal(err error) {
if qr.err != nil { if rows.err != nil {
return return
} }
qr.err = err rows.err = err
qr.Close() rows.Close()
} }
func (qr *QueryResult) NextRow() bool { func (rows *Rows) NextRow() bool {
if qr.closed { if rows.closed {
return false return false
} }
qr.rowCount++ rows.rowCount++
qr.columnIdx = 0 rows.columnIdx = 0
for { for {
t, r, err := qr.conn.rxMsg() t, r, err := rows.conn.rxMsg()
if err != nil { if err != nil {
qr.Fatal(err) rows.Fatal(err)
return false return false
} }
switch t { switch t {
case readyForQuery: case readyForQuery:
qr.conn.rxReadyForQuery(r) rows.conn.rxReadyForQuery(r)
qr.close() rows.close()
return false return false
case dataRow: case dataRow:
fieldCount := r.ReadInt16() fieldCount := r.ReadInt16()
if int(fieldCount) != len(qr.fields) { if int(fieldCount) != len(rows.fields) {
qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount))) rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
return false return false
} }
qr.mr = r rows.mr = r
return true return true
case commandComplete: case commandComplete:
case bindComplete: case bindComplete:
default: default:
err = qr.conn.processContextFreeMsg(t, r) err = rows.conn.processContextFreeMsg(t, r)
if err != nil { if err != nil {
qr.Fatal(err) rows.Fatal(err)
return false return false
} }
} }
} }
} }
func (qr *QueryResult) nextColumn() (*FieldDescription, int32, bool) { func (rows *Rows) nextColumn() (*FieldDescription, int32, bool) {
if qr.closed { if rows.closed {
return nil, 0, false return nil, 0, false
} }
if len(qr.fields) <= qr.columnIdx { if len(rows.fields) <= rows.columnIdx {
qr.Fatal(ProtocolError("No next column available")) rows.Fatal(ProtocolError("No next column available"))
return nil, 0, false return nil, 0, false
} }
fd := &qr.fields[qr.columnIdx] fd := &rows.fields[rows.columnIdx]
qr.columnIdx++ rows.columnIdx++
size := qr.mr.ReadInt32() size := rows.mr.ReadInt32()
return fd, size, true return fd, size, true
} }
func (qr *QueryResult) Scan(dest ...interface{}) (err error) { func (rows *Rows) Scan(dest ...interface{}) (err error) {
if len(qr.fields) != len(dest) { if len(rows.fields) != len(dest) {
err = errors.New("Scan received wrong number of arguments") err = errors.New("Scan received wrong number of arguments")
qr.Fatal(err) rows.Fatal(err)
return err return err
} }
for _, d := range dest { for _, d := range dest {
fd, size, _ := qr.nextColumn() fd, size, _ := rows.nextColumn()
switch d := d.(type) { switch d := d.(type) {
case *bool: case *bool:
*d = decodeBool(qr, fd, size) *d = decodeBool(rows, fd, size)
case *[]byte: case *[]byte:
*d = decodeBytea(qr, fd, size) *d = decodeBytea(rows, fd, size)
case *int64: case *int64:
*d = decodeInt8(qr, fd, size) *d = decodeInt8(rows, fd, size)
case *int16: case *int16:
*d = decodeInt2(qr, fd, size) *d = decodeInt2(rows, fd, size)
case *int32: case *int32:
*d = decodeInt4(qr, fd, size) *d = decodeInt4(rows, fd, size)
case *string: case *string:
*d = decodeText(qr, fd, size) *d = decodeText(rows, fd, size)
case *float32: case *float32:
*d = decodeFloat4(qr, fd, size) *d = decodeFloat4(rows, fd, size)
case *float64: case *float64:
*d = decodeFloat8(qr, fd, size) *d = decodeFloat8(rows, fd, size)
case *time.Time: case *time.Time:
if fd.DataType == DateOid { if fd.DataType == DateOid {
*d = decodeDate(qr, fd, size) *d = decodeDate(rows, fd, size)
} else { } else {
*d = decodeTimestampTz(qr, fd, size) *d = decodeTimestampTz(rows, fd, size)
} }
case Scanner: case Scanner:
err = d.Scan(qr, fd, size) err = d.Scan(rows, fd, size)
if err != nil { if err != nil {
return err return err
} }
@ -609,39 +609,39 @@ func (qr *QueryResult) Scan(dest ...interface{}) (err error) {
return nil return nil
} }
func (qr *QueryResult) ReadValue() (v interface{}, err error) { func (rows *Rows) ReadValue() (v interface{}, err error) {
fd, size, _ := qr.nextColumn() fd, size, _ := rows.nextColumn()
if qr.Err() != nil { if rows.Err() != nil {
return nil, qr.Err() return nil, rows.Err()
} }
switch fd.DataType { switch fd.DataType {
case BoolOid: case BoolOid:
return decodeBool(qr, fd, size), qr.Err() return decodeBool(rows, fd, size), rows.Err()
case ByteaOid: case ByteaOid:
return decodeBytea(qr, fd, size), qr.Err() return decodeBytea(rows, fd, size), rows.Err()
case Int8Oid: case Int8Oid:
return decodeInt8(qr, fd, size), qr.Err() return decodeInt8(rows, fd, size), rows.Err()
case Int2Oid: case Int2Oid:
return decodeInt2(qr, fd, size), qr.Err() return decodeInt2(rows, fd, size), rows.Err()
case Int4Oid: case Int4Oid:
return decodeInt4(qr, fd, size), qr.Err() return decodeInt4(rows, fd, size), rows.Err()
case VarcharOid, TextOid: case VarcharOid, TextOid:
return decodeText(qr, fd, size), qr.Err() return decodeText(rows, fd, size), rows.Err()
case Float4Oid: case Float4Oid:
return decodeFloat4(qr, fd, size), qr.Err() return decodeFloat4(rows, fd, size), rows.Err()
case Float8Oid: case Float8Oid:
return decodeFloat8(qr, fd, size), qr.Err() return decodeFloat8(rows, fd, size), rows.Err()
case DateOid: case DateOid:
return decodeDate(qr, fd, size), qr.Err() return decodeDate(rows, fd, size), rows.Err()
case TimestampTzOid: case TimestampTzOid:
return decodeTimestampTz(qr, fd, size), qr.Err() return decodeTimestampTz(rows, fd, size), rows.Err()
} }
// if it is not an intrinsic type then return the text // if it is not an intrinsic type then return the text
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
return qr.MsgReader().ReadString(size), qr.Err() return rows.MsgReader().ReadString(size), rows.Err()
// TODO // TODO
//case BinaryFormatCode: //case BinaryFormatCode:
default: default:
@ -650,23 +650,23 @@ func (qr *QueryResult) ReadValue() (v interface{}, err error) {
} }
// TODO - document // TODO - document
func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
c.qr = QueryResult{conn: c} c.rows = Rows{conn: c}
qr := &c.qr rows := &c.rows
if ps, present := c.preparedStatements[sql]; present { if ps, present := c.preparedStatements[sql]; present {
qr.fields = ps.FieldDescriptions rows.fields = ps.FieldDescriptions
err := c.sendPreparedQuery(ps, args...) err := c.sendPreparedQuery(ps, args...)
if err != nil { if err != nil {
qr.abort(err) rows.abort(err)
} }
return qr, qr.err return rows, rows.err
} }
err := c.sendSimpleQuery(sql, args...) err := c.sendSimpleQuery(sql, args...)
if err != nil { if err != nil {
qr.abort(err) rows.abort(err)
return qr, qr.err return rows, rows.err
} }
// Simple queries don't know the field descriptions of the result. // Simple queries don't know the field descriptions of the result.
@ -674,27 +674,27 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) {
for { for {
t, r, err := c.rxMsg() t, r, err := c.rxMsg()
if err != nil { if err != nil {
qr.Fatal(err) rows.Fatal(err)
return qr, qr.err return rows, rows.err
} }
switch t { switch t {
case rowDescription: case rowDescription:
qr.fields = qr.conn.rxRowDescription(r) rows.fields = rows.conn.rxRowDescription(r)
return qr, nil return rows, nil
default: default:
err = qr.conn.processContextFreeMsg(t, r) err = rows.conn.processContextFreeMsg(t, r)
if err != nil { if err != nil {
qr.Fatal(err) rows.Fatal(err)
return qr, qr.err return rows, rows.err
} }
} }
} }
} }
func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
qr, _ := c.Query(sql, args...) rows, _ := c.Query(sql, args...)
return (*Row)(qr) return (*Row)(rows)
} }
func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {

View File

@ -176,26 +176,26 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
return c.Exec(sql, arguments...) return c.Exec(sql, arguments...)
} }
func (p *ConnPool) Query(sql string, args ...interface{}) (*QueryResult, error) { func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
// Because checking for errors can be deferred to the *QueryResult, build one with the error // Because checking for errors can be deferred to the *Rows, build one with the error
return &QueryResult{closed: true, err: err}, err return &Rows{closed: true, err: err}, err
} }
qr, err := c.Query(sql, args...) rows, err := c.Query(sql, args...)
if err != nil { if err != nil {
p.Release(c) p.Release(c)
return qr, err return rows, err
} }
qr.pool = p rows.pool = p
return qr, nil return rows, nil
} }
func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
qr, _ := p.Query(sql, args...) rows, _ := p.Query(sql, args...)
return (*Row)(qr) return (*Row)(rows)
} }
// Transaction acquires a connection, delegates the call to that connection, // Transaction acquires a connection, delegates the call to that connection,

View File

@ -209,8 +209,8 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Unable to Acquire: %v", err) t.Fatalf("Unable to Acquire: %v", err)
} }
qr, _ := c.Query("select 1") rows, _ := c.Query("select 1")
qr.Close() rows.Close()
pool.Release(c) pool.Release(c)
} }
@ -273,9 +273,9 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) {
} }
// do something with the connection so it knows it's dead // do something with the connection so it knows it's dead
qr, _ := c1.Query("select 1") rows, _ := c1.Query("select 1")
qr.Close() rows.Close()
if qr.Err() == nil { if rows.Err() == nil {
t.Fatal("Expected error but none occurred") t.Fatal("Expected error but none occurred")
} }
@ -400,7 +400,7 @@ func TestConnPoolQuery(t *testing.T) {
var sum, rowCount int32 var sum, rowCount int32
qr, err := pool.Query("select generate_series(1,$1)", 10) rows, err := pool.Query("select generate_series(1,$1)", 10)
if err != nil { if err != nil {
t.Fatalf("pool.Query failed: %v", err) t.Fatalf("pool.Query failed: %v", err)
} }
@ -410,14 +410,14 @@ func TestConnPoolQuery(t *testing.T) {
t.Fatalf("Unexpected connection pool stats: %v", stats) t.Fatalf("Unexpected connection pool stats: %v", stats)
} }
for qr.NextRow() { for rows.NextRow() {
var n int32 var n int32
qr.Scan(&n) rows.Scan(&n)
sum += n sum += n
rowCount++ rowCount++
} }
if qr.Err() != nil { if rows.Err() != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }

View File

@ -295,10 +295,10 @@ func TestExecFailure(t *testing.T) {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
qr, _ := conn.Query("select 1") rows, _ := conn.Query("select 1")
qr.Close() rows.Close()
if qr.Err() != nil { if rows.Err() != nil {
t.Fatalf("Exec failure appears to have broken connection: %v", qr.Err()) t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err())
} }
} }
@ -339,20 +339,20 @@ func TestConnQuery(t *testing.T) {
func ensureConnValid(t *testing.T, conn *pgx.Conn) { func ensureConnValid(t *testing.T, conn *pgx.Conn) {
var sum, rowCount int32 var sum, rowCount int32
qr, err := conn.Query("select generate_series(1,$1)", 10) rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
defer qr.Close() defer rows.Close()
for qr.NextRow() { for rows.NextRow() {
var n int32 var n int32
qr.Scan(&n) rows.Scan(&n)
sum += n sum += n
rowCount++ rowCount++
} }
if qr.Err() != nil { if rows.Err() != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
@ -372,32 +372,32 @@ func TestConnQueryCloseEarly(t *testing.T) {
defer closeConn(t, conn) defer closeConn(t, conn)
// Immediately close query without reading any rows // Immediately close query without reading any rows
qr, err := conn.Query("select generate_series(1,$1)", 10) rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
qr.Close() rows.Close()
ensureConnValid(t, conn) ensureConnValid(t, conn)
// Read partial response then close // Read partial response then close
qr, err = conn.Query("select generate_series(1,$1)", 10) rows, err = conn.Query("select generate_series(1,$1)", 10)
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
ok := qr.NextRow() ok := rows.NextRow()
if !ok { if !ok {
t.Fatal("qr.NextRow terminated early") t.Fatal("rows.NextRow terminated early")
} }
var n int32 var n int32
qr.Scan(&n) rows.Scan(&n)
if n != 1 { if n != 1 {
t.Fatalf("Expected 1 from first row, but got %v", n) t.Fatalf("Expected 1 from first row, but got %v", n)
} }
qr.Close() rows.Close()
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
@ -410,16 +410,16 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
defer closeConn(t, conn) defer closeConn(t, conn)
// Read a single value incorrectly // Read a single value incorrectly
qr, err := conn.Query("select generate_series(1,$1)", 10) rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
rowsRead := 0 rowsRead := 0
for qr.NextRow() { for rows.NextRow() {
var t time.Time var t time.Time
qr.Scan(&t) rows.Scan(&t)
rowsRead++ rowsRead++
} }
@ -427,8 +427,8 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
} }
if qr.Err() == nil { if rows.Err() == nil {
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't") t.Fatal("Expected Rows to have an error after an improper read but it didn't")
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)
@ -442,16 +442,16 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
defer closeConn(t, conn) defer closeConn(t, conn)
// Read too many values // Read too many values
qr, err := conn.Query("select generate_series(1,$1)", 10) rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
rowsRead := 0 rowsRead := 0
for qr.NextRow() { for rows.NextRow() {
var n, m int32 var n, m int32
qr.Scan(&n, &m) rows.Scan(&n, &m)
rowsRead++ rowsRead++
} }
@ -459,8 +459,8 @@ func TestConnQueryReadTooManyValues(t *testing.T) {
t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
} }
if qr.Err() == nil { if rows.Err() == nil {
t.Fatal("Expected QueryResult to have an error after an improper read but it didn't") t.Fatal("Expected Rows to have an error after an improper read but it didn't")
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)
@ -472,22 +472,22 @@ func TestConnQueryUnpreparedScanner(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn) defer closeConn(t, conn)
qr, err := conn.Query("select null::int8, 1::int8") rows, err := conn.Query("select null::int8, 1::int8")
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
ok := qr.NextRow() ok := rows.NextRow()
if !ok { if !ok {
t.Fatal("qr.NextRow terminated early") t.Fatal("rows.NextRow terminated early")
} }
var n, m pgx.NullInt64 var n, m pgx.NullInt64
err = qr.Scan(&n, &m) err = rows.Scan(&n, &m)
if err != nil { if err != nil {
t.Fatalf("qr.Scan failed: ", err) t.Fatalf("rows.Scan failed: ", err)
} }
qr.Close() rows.Close()
if n.Valid { if n.Valid {
t.Error("Null should not be valid, but it was") t.Error("Null should not be valid, but it was")
@ -512,22 +512,22 @@ func TestConnQueryPreparedScanner(t *testing.T) {
mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8") mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8")
qr, err := conn.Query("scannerTest") rows, err := conn.Query("scannerTest")
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
ok := qr.NextRow() ok := rows.NextRow()
if !ok { if !ok {
t.Fatal("qr.NextRow terminated early") t.Fatal("rows.NextRow terminated early")
} }
var n, m pgx.NullInt64 var n, m pgx.NullInt64
err = qr.Scan(&n, &m) err = rows.Scan(&n, &m)
if err != nil { if err != nil {
t.Fatalf("qr.Scan failed: ", err) t.Fatalf("rows.Scan failed: ", err)
} }
qr.Close() rows.Close()
if n.Valid { if n.Valid {
t.Error("Null should not be valid, but it was") t.Error("Null should not be valid, but it was")
@ -552,22 +552,22 @@ func TestConnQueryUnpreparedEncoder(t *testing.T) {
n := pgx.NullInt64{Int64: 1, Valid: true} n := pgx.NullInt64{Int64: 1, Valid: true}
qr, err := conn.Query("select $1::int8", &n) rows, err := conn.Query("select $1::int8", &n)
if err != nil { if err != nil {
t.Fatalf("conn.Query failed: ", err) t.Fatalf("conn.Query failed: ", err)
} }
ok := qr.NextRow() ok := rows.NextRow()
if !ok { if !ok {
t.Fatal("qr.NextRow terminated early") t.Fatal("rows.NextRow terminated early")
} }
var m pgx.NullInt64 var m pgx.NullInt64
err = qr.Scan(&m) err = rows.Scan(&m)
if err != nil { if err != nil {
t.Fatalf("qr.Scan failed: ", err) t.Fatalf("rows.Scan failed: ", err)
} }
qr.Close() rows.Close()
if !m.Valid { if !m.Valid {
t.Error("m should be valid, but it wasn't") t.Error("m should be valid, but it wasn't")
@ -787,10 +787,10 @@ func TestListenNotify(t *testing.T) {
// when notification has already been read during previous query // when notification has already been read during previous query
mustExec(t, notifier, "notify chat") mustExec(t, notifier, "notify chat")
qr, _ := listener.Query("select 1") rows, _ := listener.Query("select 1")
qr.Close() rows.Close()
if qr.Err() != nil { if rows.Err() != nil {
t.Fatalf("Unexpected error on Query: %v", qr.Err()) t.Fatalf("Unexpected error on Query: %v", rows.Err())
} }
notification, err = listener.WaitForNotification(0) notification, err = listener.WaitForNotification(0)
if err != nil { if err != nil {

View File

@ -133,12 +133,12 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
args := valueToInterface(argsV) args := valueToInterface(argsV)
qr, err := c.conn.Query(query, args...) rows, err := c.conn.Query(query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Rows{qr: qr}, nil return &Rows{rows: rows}, nil
} }
type Stmt struct { type Stmt struct {
@ -164,11 +164,11 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
// TODO - rename to avoid alloc // TODO - rename to avoid alloc
type Rows struct { type Rows struct {
qr *pgx.QueryResult rows *pgx.Rows
} }
func (r *Rows) Columns() []string { func (r *Rows) Columns() []string {
fieldDescriptions := r.qr.FieldDescriptions() fieldDescriptions := r.rows.FieldDescriptions()
names := make([]string, 0, len(fieldDescriptions)) names := make([]string, 0, len(fieldDescriptions))
for _, fd := range fieldDescriptions { for _, fd := range fieldDescriptions {
names = append(names, fd.Name) names = append(names, fd.Name)
@ -177,22 +177,22 @@ func (r *Rows) Columns() []string {
} }
func (r *Rows) Close() error { func (r *Rows) Close() error {
r.qr.Close() r.rows.Close()
return nil return nil
} }
func (r *Rows) Next(dest []driver.Value) error { func (r *Rows) Next(dest []driver.Value) error {
more := r.qr.NextRow() more := r.rows.NextRow()
if !more { if !more {
if r.qr.Err() == nil { if r.rows.Err() == nil {
return io.EOF return io.EOF
} else { } else {
return r.qr.Err() return r.rows.Err()
} }
} }
for i, _ := range r.qr.FieldDescriptions() { for i, _ := range r.rows.FieldDescriptions() {
v, err := r.qr.ReadValue() v, err := r.rows.ReadValue()
if err != nil { if err != nil {
return err return err
} }

128
values.go
View File

@ -40,7 +40,7 @@ func (e SerializationError) Error() string {
type Scanner interface { type Scanner interface {
// Scan MUST check fd's DataType and FormatCode before decoding. It should // Scan MUST check fd's DataType and FormatCode before decoding. It should
// not assume that it was called on the type of value. // not assume that it was called on the type of value.
Scan(qr *QueryResult, fd *FieldDescription, size int32) error Scan(rows *Rows, fd *FieldDescription, size int32) error
} }
// TextEncoder is an interface used to encode values in text format for // TextEncoder is an interface used to encode values in text format for
@ -145,14 +145,14 @@ func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
return return
} }
func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error { func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
if size == -1 { if size == -1 {
n.Int64, n.Valid = 0, false n.Int64, n.Valid = 0, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Int64 = decodeInt8(qr, fd, size) n.Int64 = decodeInt8(rows, fd, size)
return qr.Err() return rows.Err()
} }
func (n *NullInt64) EncodeText() (string, error) { func (n *NullInt64) EncodeText() (string, error) {
@ -163,28 +163,28 @@ func (n *NullInt64) EncodeText() (string, error) {
} }
} }
func decodeBool(qr *QueryResult, fd *FieldDescription, size int32) bool { func decodeBool(rows *Rows, fd *FieldDescription, size int32) bool {
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
switch s { switch s {
case "t": case "t":
return true return true
case "f": case "f":
return false return false
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s)))
return false return false
} }
case BinaryFormatCode: case BinaryFormatCode:
if size != 1 { if size != 1 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size)))
return false return false
} }
b := qr.mr.ReadByte() b := rows.mr.ReadByte()
return b != 0 return b != 0
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return false return false
} }
} }
@ -207,29 +207,29 @@ func encodeBool(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeInt8(qr *QueryResult, fd *FieldDescription, size int32) int64 { func decodeInt8(rows *Rows, fd *FieldDescription, size int32) int64 {
if fd.DataType != Int8Oid { if fd.DataType != Int8Oid {
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType))) rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType)))
return 0 return 0
} }
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 64) n, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s)))
return 0 return 0
} }
return n return n
case BinaryFormatCode: case BinaryFormatCode:
if size != 8 { if size != 8 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size)))
return 0 return 0
} }
return qr.mr.ReadInt64() return rows.mr.ReadInt64()
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return 0 return 0
} }
} }
@ -268,29 +268,29 @@ func encodeInt8(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeInt2(qr *QueryResult, fd *FieldDescription, size int32) int16 { func decodeInt2(rows *Rows, fd *FieldDescription, size int32) int16 {
if fd.DataType != Int2Oid { if fd.DataType != Int2Oid {
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType))) rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType)))
return 0 return 0
} }
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 16) n, err := strconv.ParseInt(s, 10, 16)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s)))
return 0 return 0
} }
return int16(n) return int16(n)
case BinaryFormatCode: case BinaryFormatCode:
if size != 2 { if size != 2 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size)))
return 0 return 0
} }
return qr.mr.ReadInt16() return rows.mr.ReadInt16()
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return 0 return 0
} }
} }
@ -344,28 +344,28 @@ func encodeInt2(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeInt4(qr *QueryResult, fd *FieldDescription, size int32) int32 { func decodeInt4(rows *Rows, fd *FieldDescription, size int32) int32 {
if fd.DataType != Int4Oid { if fd.DataType != Int4Oid {
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType))) rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType)))
return 0 return 0
} }
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s)))
} }
return int32(n) return int32(n)
case BinaryFormatCode: case BinaryFormatCode:
if size != 4 { if size != 4 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size)))
return 0 return 0
} }
return qr.mr.ReadInt32() return rows.mr.ReadInt32()
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return 0 return 0
} }
} }
@ -413,27 +413,27 @@ func encodeInt4(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeFloat4(qr *QueryResult, fd *FieldDescription, size int32) float32 { func decodeFloat4(rows *Rows, fd *FieldDescription, size int32) float32 {
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
n, err := strconv.ParseFloat(s, 32) n, err := strconv.ParseFloat(s, 32)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s)))
return 0 return 0
} }
return float32(n) return float32(n)
case BinaryFormatCode: case BinaryFormatCode:
if size != 4 { if size != 4 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size)))
return 0 return 0
} }
i := qr.mr.ReadInt32() i := rows.mr.ReadInt32()
p := unsafe.Pointer(&i) p := unsafe.Pointer(&i)
return *(*float32)(p) return *(*float32)(p)
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return 0 return 0
} }
} }
@ -460,27 +460,27 @@ func encodeFloat4(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeFloat8(qr *QueryResult, fd *FieldDescription, size int32) float64 { func decodeFloat8(rows *Rows, fd *FieldDescription, size int32) float64 {
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
v, err := strconv.ParseFloat(s, 64) v, err := strconv.ParseFloat(s, 64)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s)))
return 0 return 0
} }
return v return v
case BinaryFormatCode: case BinaryFormatCode:
if size != 8 { if size != 8 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size)))
return 0 return 0
} }
i := qr.mr.ReadInt64() i := rows.mr.ReadInt64()
p := unsafe.Pointer(&i) p := unsafe.Pointer(&i)
return *(*float64)(p) return *(*float64)(p)
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return 0 return 0
} }
} }
@ -504,8 +504,8 @@ func encodeFloat8(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeText(qr *QueryResult, fd *FieldDescription, size int32) string { func decodeText(rows *Rows, fd *FieldDescription, size int32) string {
return qr.mr.ReadString(size) return rows.mr.ReadString(size)
} }
func encodeText(w *WriteBuf, value interface{}) error { func encodeText(w *WriteBuf, value interface{}) error {
@ -520,20 +520,20 @@ func encodeText(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeBytea(qr *QueryResult, fd *FieldDescription, size int32) []byte { func decodeBytea(rows *Rows, fd *FieldDescription, size int32) []byte {
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
b, err := hex.DecodeString(s[2:]) b, err := hex.DecodeString(s[2:])
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))) rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s)))
return nil return nil
} }
return b return b
case BinaryFormatCode: case BinaryFormatCode:
return qr.mr.ReadBytes(size) return rows.mr.ReadBytes(size)
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return nil return nil
} }
} }
@ -550,31 +550,31 @@ func encodeBytea(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeDate(qr *QueryResult, fd *FieldDescription, size int32) time.Time { func decodeDate(rows *Rows, fd *FieldDescription, size int32) time.Time {
var zeroTime time.Time var zeroTime time.Time
if fd.DataType != DateOid { if fd.DataType != DateOid {
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType))) rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType)))
return zeroTime return zeroTime
} }
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
t, err := time.ParseInLocation("2006-01-02", s, time.Local) t, err := time.ParseInLocation("2006-01-02", s, time.Local)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s))) rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s)))
return zeroTime return zeroTime
} }
return t return t
case BinaryFormatCode: case BinaryFormatCode:
if size != 4 { if size != 4 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", size)))
} }
dayOffset := qr.mr.ReadInt32() dayOffset := rows.mr.ReadInt32()
return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local)
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return zeroTime return zeroTime
} }
} }
@ -589,33 +589,33 @@ func encodeDate(w *WriteBuf, value interface{}) error {
return encodeText(w, s) return encodeText(w, s)
} }
func decodeTimestampTz(qr *QueryResult, fd *FieldDescription, size int32) time.Time { func decodeTimestampTz(rows *Rows, fd *FieldDescription, size int32) time.Time {
var zeroTime time.Time var zeroTime time.Time
if fd.DataType != TimestampTzOid { if fd.DataType != TimestampTzOid {
qr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType))) rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType)))
return zeroTime return zeroTime
} }
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
s := qr.mr.ReadString(size) s := rows.mr.ReadString(size)
t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) t, err := time.Parse("2006-01-02 15:04:05.999999-07", s)
if err != nil { if err != nil {
qr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))) rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s)))
return zeroTime return zeroTime
} }
return t return t
case BinaryFormatCode: case BinaryFormatCode:
if size != 8 { if size != 8 {
qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size))) rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size)))
} }
microsecFromUnixEpochToY2K := int64(946684800 * 1000000) microsecFromUnixEpochToY2K := int64(946684800 * 1000000)
microsecSinceY2K := qr.mr.ReadInt64() microsecSinceY2K := rows.mr.ReadInt64()
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
default: default:
qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
return zeroTime return zeroTime
} }
} }