From 19537badff3588edd7d1d803300516cb47271f43 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 11 Jul 2014 16:55:45 -0500 Subject: [PATCH] Add Rows.Values --- conn.go | 80 +++++++++++++++++++++++++++++---------------------- conn_test.go | 47 +++++++++++++++++++++++++++++- stdlib/sql.go | 18 ++++++++---- 3 files changed, 105 insertions(+), 40 deletions(-) diff --git a/conn.go b/conn.go index f6a1e902..7b70e045 100644 --- a/conn.go +++ b/conn.go @@ -609,44 +609,56 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { return nil } -func (rows *Rows) ReadValue() (v interface{}, err error) { - fd, size, _ := rows.nextColumn() - if rows.Err() != nil { - return nil, rows.Err() +// Values returns an array of the row values +func (rows *Rows) Values() ([]interface{}, error) { + if rows.closed { + return nil, errors.New("rows is closed") } - switch fd.DataType { - case BoolOid: - return decodeBool(rows, fd, size), rows.Err() - case ByteaOid: - return decodeBytea(rows, fd, size), rows.Err() - case Int8Oid: - return decodeInt8(rows, fd, size), rows.Err() - case Int2Oid: - return decodeInt2(rows, fd, size), rows.Err() - case Int4Oid: - return decodeInt4(rows, fd, size), rows.Err() - case VarcharOid, TextOid: - return decodeText(rows, fd, size), rows.Err() - case Float4Oid: - return decodeFloat4(rows, fd, size), rows.Err() - case Float8Oid: - return decodeFloat8(rows, fd, size), rows.Err() - case DateOid: - return decodeDate(rows, fd, size), rows.Err() - case TimestampTzOid: - return decodeTimestampTz(rows, fd, size), rows.Err() + values := make([]interface{}, 0, len(rows.fields)) + + for _, _ = range rows.fields { + if rows.Err() != nil { + return nil, rows.Err() + } + + fd, size, _ := rows.nextColumn() + + switch fd.DataType { + case BoolOid: + values = append(values, decodeBool(rows, fd, size)) + case ByteaOid: + values = append(values, decodeBytea(rows, fd, size)) + case Int8Oid: + values = append(values, decodeInt8(rows, fd, size)) + case Int2Oid: + values = append(values, decodeInt2(rows, fd, size)) + case Int4Oid: + values = append(values, decodeInt4(rows, fd, size)) + case VarcharOid, TextOid: + values = append(values, decodeText(rows, fd, size)) + case Float4Oid: + values = append(values, decodeFloat4(rows, fd, size)) + case Float8Oid: + values = append(values, decodeFloat8(rows, fd, size)) + case DateOid: + values = append(values, decodeDate(rows, fd, size)) + case TimestampTzOid: + values = append(values, decodeTimestampTz(rows, fd, size)) + default: + // if it is not an intrinsic type then return the text + switch fd.FormatCode { + case TextFormatCode: + values = append(values, rows.MsgReader().ReadString(size)) + case BinaryFormatCode: + return nil, errors.New("Values cannot handle binary format non-intrinsic types") + default: + return nil, errors.New("Unknown format code") + } + } } - // if it is not an intrinsic type then return the text - switch fd.FormatCode { - case TextFormatCode: - return rows.MsgReader().ReadString(size), rows.Err() - // TODO - //case BinaryFormatCode: - default: - return nil, errors.New("Unknown format code") - } + return values, rows.Err() } // TODO - document diff --git a/conn_test.go b/conn_test.go index af01e53d..e68f91c9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -302,7 +302,7 @@ func TestExecFailure(t *testing.T) { } } -func TestConnQuery(t *testing.T) { +func TestConnQueryScan(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -335,6 +335,51 @@ func TestConnQuery(t *testing.T) { } } +func TestConnQueryValues(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var rowCount int32 + + rows, err := conn.Query("select 'foo', n from generate_series(1,$1) n", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows.Close() + + for rows.Next() { + rowCount++ + + values, err := rows.Values() + if err != nil { + t.Fatalf("rows.Values failed: %v", err) + } + if len(values) != 2 { + t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) + } + if values[0] != "foo" { + t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) + } + if values[0] != "foo" { + t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) + } + + if values[1] != rowCount { + t.Errorf(`Expected values[1] to be %d, but it was %d`, rowCount, values[1]) + } + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } +} + // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, conn *pgx.Conn) { var sum, rowCount int32 diff --git a/stdlib/sql.go b/stdlib/sql.go index 1993db06..5a8b63de 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -3,6 +3,7 @@ package stdlib import ( "database/sql" "database/sql/driver" + "errors" "fmt" "github.com/jackc/pgx" "io" @@ -191,11 +192,18 @@ func (r *Rows) Next(dest []driver.Value) error { } } - for i, _ := range r.rows.FieldDescriptions() { - v, err := r.rows.ReadValue() - if err != nil { - return err - } + values, err := r.rows.Values() + if err != nil { + return err + } + + if len(dest) < len(values) { + fmt.Printf("%d: %#v\n", len(dest), dest) + fmt.Printf("%d: %#v\n", len(values), values) + return errors.New("expected more values than were received") + } + + for i, v := range values { dest[i] = driver.Value(v) }