Add Rows.Values

scan-io
Jack Christensen 2014-07-11 16:55:45 -05:00
parent 6c1c819a5e
commit 19537badff
3 changed files with 105 additions and 40 deletions

80
conn.go
View File

@ -609,44 +609,56 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
return nil
}
func (rows *Rows) ReadValue() (v interface{}, err error) {
fd, size, _ := rows.nextColumn()
if rows.Err() != nil {
return nil, rows.Err()
// Values returns an array of the row values
func (rows *Rows) Values() ([]interface{}, error) {
if rows.closed {
return nil, errors.New("rows is closed")
}
switch fd.DataType {
case BoolOid:
return decodeBool(rows, fd, size), rows.Err()
case ByteaOid:
return decodeBytea(rows, fd, size), rows.Err()
case Int8Oid:
return decodeInt8(rows, fd, size), rows.Err()
case Int2Oid:
return decodeInt2(rows, fd, size), rows.Err()
case Int4Oid:
return decodeInt4(rows, fd, size), rows.Err()
case VarcharOid, TextOid:
return decodeText(rows, fd, size), rows.Err()
case Float4Oid:
return decodeFloat4(rows, fd, size), rows.Err()
case Float8Oid:
return decodeFloat8(rows, fd, size), rows.Err()
case DateOid:
return decodeDate(rows, fd, size), rows.Err()
case TimestampTzOid:
return decodeTimestampTz(rows, fd, size), rows.Err()
values := make([]interface{}, 0, len(rows.fields))
for _, _ = range rows.fields {
if rows.Err() != nil {
return nil, rows.Err()
}
fd, size, _ := rows.nextColumn()
switch fd.DataType {
case BoolOid:
values = append(values, decodeBool(rows, fd, size))
case ByteaOid:
values = append(values, decodeBytea(rows, fd, size))
case Int8Oid:
values = append(values, decodeInt8(rows, fd, size))
case Int2Oid:
values = append(values, decodeInt2(rows, fd, size))
case Int4Oid:
values = append(values, decodeInt4(rows, fd, size))
case VarcharOid, TextOid:
values = append(values, decodeText(rows, fd, size))
case Float4Oid:
values = append(values, decodeFloat4(rows, fd, size))
case Float8Oid:
values = append(values, decodeFloat8(rows, fd, size))
case DateOid:
values = append(values, decodeDate(rows, fd, size))
case TimestampTzOid:
values = append(values, decodeTimestampTz(rows, fd, size))
default:
// if it is not an intrinsic type then return the text
switch fd.FormatCode {
case TextFormatCode:
values = append(values, rows.MsgReader().ReadString(size))
case BinaryFormatCode:
return nil, errors.New("Values cannot handle binary format non-intrinsic types")
default:
return nil, errors.New("Unknown format code")
}
}
}
// if it is not an intrinsic type then return the text
switch fd.FormatCode {
case TextFormatCode:
return rows.MsgReader().ReadString(size), rows.Err()
// TODO
//case BinaryFormatCode:
default:
return nil, errors.New("Unknown format code")
}
return values, rows.Err()
}
// TODO - document

View File

@ -302,7 +302,7 @@ func TestExecFailure(t *testing.T) {
}
}
func TestConnQuery(t *testing.T) {
func TestConnQueryScan(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -335,6 +335,51 @@ func TestConnQuery(t *testing.T) {
}
}
func TestConnQueryValues(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var rowCount int32
rows, err := conn.Query("select 'foo', n from generate_series(1,$1) n", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
defer rows.Close()
for rows.Next() {
rowCount++
values, err := rows.Values()
if err != nil {
t.Fatalf("rows.Values failed: %v", err)
}
if len(values) != 2 {
t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values))
}
if values[0] != "foo" {
t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
}
if values[0] != "foo" {
t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
}
if values[1] != rowCount {
t.Errorf(`Expected values[1] to be %d, but it was %d`, rowCount, values[1])
}
}
if rows.Err() != nil {
t.Fatalf("conn.Query failed: ", err)
}
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
}
}
// Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
var sum, rowCount int32

View File

@ -3,6 +3,7 @@ package stdlib
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/jackc/pgx"
"io"
@ -191,11 +192,18 @@ func (r *Rows) Next(dest []driver.Value) error {
}
}
for i, _ := range r.rows.FieldDescriptions() {
v, err := r.rows.ReadValue()
if err != nil {
return err
}
values, err := r.rows.Values()
if err != nil {
return err
}
if len(dest) < len(values) {
fmt.Printf("%d: %#v\n", len(dest), dest)
fmt.Printf("%d: %#v\n", len(values), values)
return errors.New("expected more values than were received")
}
for i, v := range values {
dest[i] = driver.Value(v)
}