Add Rows.Values

scan-io
Jack Christensen 2014-07-11 16:55:45 -05:00
parent 6c1c819a5e
commit 19537badff
3 changed files with 105 additions and 40 deletions

46
conn.go
View File

@ -609,45 +609,57 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
return nil return nil
} }
func (rows *Rows) ReadValue() (v interface{}, err error) { // Values returns an array of the row values
fd, size, _ := rows.nextColumn() func (rows *Rows) Values() ([]interface{}, error) {
if rows.closed {
return nil, errors.New("rows is closed")
}
values := make([]interface{}, 0, len(rows.fields))
for _, _ = range rows.fields {
if rows.Err() != nil { if rows.Err() != nil {
return nil, rows.Err() return nil, rows.Err()
} }
fd, size, _ := rows.nextColumn()
switch fd.DataType { switch fd.DataType {
case BoolOid: case BoolOid:
return decodeBool(rows, fd, size), rows.Err() values = append(values, decodeBool(rows, fd, size))
case ByteaOid: case ByteaOid:
return decodeBytea(rows, fd, size), rows.Err() values = append(values, decodeBytea(rows, fd, size))
case Int8Oid: case Int8Oid:
return decodeInt8(rows, fd, size), rows.Err() values = append(values, decodeInt8(rows, fd, size))
case Int2Oid: case Int2Oid:
return decodeInt2(rows, fd, size), rows.Err() values = append(values, decodeInt2(rows, fd, size))
case Int4Oid: case Int4Oid:
return decodeInt4(rows, fd, size), rows.Err() values = append(values, decodeInt4(rows, fd, size))
case VarcharOid, TextOid: case VarcharOid, TextOid:
return decodeText(rows, fd, size), rows.Err() values = append(values, decodeText(rows, fd, size))
case Float4Oid: case Float4Oid:
return decodeFloat4(rows, fd, size), rows.Err() values = append(values, decodeFloat4(rows, fd, size))
case Float8Oid: case Float8Oid:
return decodeFloat8(rows, fd, size), rows.Err() values = append(values, decodeFloat8(rows, fd, size))
case DateOid: case DateOid:
return decodeDate(rows, fd, size), rows.Err() values = append(values, decodeDate(rows, fd, size))
case TimestampTzOid: case TimestampTzOid:
return decodeTimestampTz(rows, fd, size), rows.Err() values = append(values, decodeTimestampTz(rows, fd, size))
} default:
// if it is not an intrinsic type then return the text // if it is not an intrinsic type then return the text
switch fd.FormatCode { switch fd.FormatCode {
case TextFormatCode: case TextFormatCode:
return rows.MsgReader().ReadString(size), rows.Err() values = append(values, rows.MsgReader().ReadString(size))
// TODO case BinaryFormatCode:
//case BinaryFormatCode: return nil, errors.New("Values cannot handle binary format non-intrinsic types")
default: default:
return nil, errors.New("Unknown format code") return nil, errors.New("Unknown format code")
} }
} }
}
return values, rows.Err()
}
// TODO - document // TODO - document
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {

View File

@ -302,7 +302,7 @@ func TestExecFailure(t *testing.T) {
} }
} }
func TestConnQuery(t *testing.T) { func TestConnQueryScan(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -335,6 +335,51 @@ func TestConnQuery(t *testing.T) {
} }
} }
func TestConnQueryValues(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var rowCount int32
rows, err := conn.Query("select 'foo', n from generate_series(1,$1) n", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
defer rows.Close()
for rows.Next() {
rowCount++
values, err := rows.Values()
if err != nil {
t.Fatalf("rows.Values failed: %v", err)
}
if len(values) != 2 {
t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values))
}
if values[0] != "foo" {
t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
}
if values[0] != "foo" {
t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
}
if values[1] != rowCount {
t.Errorf(`Expected values[1] to be %d, but it was %d`, rowCount, values[1])
}
}
if rows.Err() != nil {
t.Fatalf("conn.Query failed: ", err)
}
if rowCount != 10 {
t.Error("Select called onDataRow wrong number of times")
}
}
// Do a simple query to ensure the connection is still usable // Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, conn *pgx.Conn) { func ensureConnValid(t *testing.T, conn *pgx.Conn) {
var sum, rowCount int32 var sum, rowCount int32

View File

@ -3,6 +3,7 @@ package stdlib
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"io" "io"
@ -191,11 +192,18 @@ func (r *Rows) Next(dest []driver.Value) error {
} }
} }
for i, _ := range r.rows.FieldDescriptions() { values, err := r.rows.Values()
v, err := r.rows.ReadValue()
if err != nil { if err != nil {
return err return err
} }
if len(dest) < len(values) {
fmt.Printf("%d: %#v\n", len(dest), dest)
fmt.Printf("%d: %#v\n", len(values), values)
return errors.New("expected more values than were received")
}
for i, v := range values {
dest[i] = driver.Value(v) dest[i] = driver.Value(v)
} }