mirror of https://github.com/jackc/pgx.git
Add Rows.Values
parent
6c1c819a5e
commit
19537badff
80
conn.go
80
conn.go
|
@ -609,44 +609,56 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (rows *Rows) ReadValue() (v interface{}, err error) {
|
||||
fd, size, _ := rows.nextColumn()
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
// Values returns an array of the row values
|
||||
func (rows *Rows) Values() ([]interface{}, error) {
|
||||
if rows.closed {
|
||||
return nil, errors.New("rows is closed")
|
||||
}
|
||||
|
||||
switch fd.DataType {
|
||||
case BoolOid:
|
||||
return decodeBool(rows, fd, size), rows.Err()
|
||||
case ByteaOid:
|
||||
return decodeBytea(rows, fd, size), rows.Err()
|
||||
case Int8Oid:
|
||||
return decodeInt8(rows, fd, size), rows.Err()
|
||||
case Int2Oid:
|
||||
return decodeInt2(rows, fd, size), rows.Err()
|
||||
case Int4Oid:
|
||||
return decodeInt4(rows, fd, size), rows.Err()
|
||||
case VarcharOid, TextOid:
|
||||
return decodeText(rows, fd, size), rows.Err()
|
||||
case Float4Oid:
|
||||
return decodeFloat4(rows, fd, size), rows.Err()
|
||||
case Float8Oid:
|
||||
return decodeFloat8(rows, fd, size), rows.Err()
|
||||
case DateOid:
|
||||
return decodeDate(rows, fd, size), rows.Err()
|
||||
case TimestampTzOid:
|
||||
return decodeTimestampTz(rows, fd, size), rows.Err()
|
||||
values := make([]interface{}, 0, len(rows.fields))
|
||||
|
||||
for _, _ = range rows.fields {
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
|
||||
fd, size, _ := rows.nextColumn()
|
||||
|
||||
switch fd.DataType {
|
||||
case BoolOid:
|
||||
values = append(values, decodeBool(rows, fd, size))
|
||||
case ByteaOid:
|
||||
values = append(values, decodeBytea(rows, fd, size))
|
||||
case Int8Oid:
|
||||
values = append(values, decodeInt8(rows, fd, size))
|
||||
case Int2Oid:
|
||||
values = append(values, decodeInt2(rows, fd, size))
|
||||
case Int4Oid:
|
||||
values = append(values, decodeInt4(rows, fd, size))
|
||||
case VarcharOid, TextOid:
|
||||
values = append(values, decodeText(rows, fd, size))
|
||||
case Float4Oid:
|
||||
values = append(values, decodeFloat4(rows, fd, size))
|
||||
case Float8Oid:
|
||||
values = append(values, decodeFloat8(rows, fd, size))
|
||||
case DateOid:
|
||||
values = append(values, decodeDate(rows, fd, size))
|
||||
case TimestampTzOid:
|
||||
values = append(values, decodeTimestampTz(rows, fd, size))
|
||||
default:
|
||||
// if it is not an intrinsic type then return the text
|
||||
switch fd.FormatCode {
|
||||
case TextFormatCode:
|
||||
values = append(values, rows.MsgReader().ReadString(size))
|
||||
case BinaryFormatCode:
|
||||
return nil, errors.New("Values cannot handle binary format non-intrinsic types")
|
||||
default:
|
||||
return nil, errors.New("Unknown format code")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if it is not an intrinsic type then return the text
|
||||
switch fd.FormatCode {
|
||||
case TextFormatCode:
|
||||
return rows.MsgReader().ReadString(size), rows.Err()
|
||||
// TODO
|
||||
//case BinaryFormatCode:
|
||||
default:
|
||||
return nil, errors.New("Unknown format code")
|
||||
}
|
||||
return values, rows.Err()
|
||||
}
|
||||
|
||||
// TODO - document
|
||||
|
|
47
conn_test.go
47
conn_test.go
|
@ -302,7 +302,7 @@ func TestExecFailure(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnQuery(t *testing.T) {
|
||||
func TestConnQueryScan(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -335,6 +335,51 @@ func TestConnQuery(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnQueryValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var rowCount int32
|
||||
|
||||
rows, err := conn.Query("select 'foo', n from generate_series(1,$1) n", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
|
||||
values, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Values failed: %v", err)
|
||||
}
|
||||
if len(values) != 2 {
|
||||
t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values))
|
||||
}
|
||||
if values[0] != "foo" {
|
||||
t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
|
||||
}
|
||||
if values[0] != "foo" {
|
||||
t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0])
|
||||
}
|
||||
|
||||
if values[1] != rowCount {
|
||||
t.Errorf(`Expected values[1] to be %d, but it was %d`, rowCount, values[1])
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Error("Select called onDataRow wrong number of times")
|
||||
}
|
||||
}
|
||||
|
||||
// Do a simple query to ensure the connection is still usable
|
||||
func ensureConnValid(t *testing.T, conn *pgx.Conn) {
|
||||
var sum, rowCount int32
|
||||
|
|
|
@ -3,6 +3,7 @@ package stdlib
|
|||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"io"
|
||||
|
@ -191,11 +192,18 @@ func (r *Rows) Next(dest []driver.Value) error {
|
|||
}
|
||||
}
|
||||
|
||||
for i, _ := range r.rows.FieldDescriptions() {
|
||||
v, err := r.rows.ReadValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values, err := r.rows.Values()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(dest) < len(values) {
|
||||
fmt.Printf("%d: %#v\n", len(dest), dest)
|
||||
fmt.Printf("%d: %#v\n", len(values), values)
|
||||
return errors.New("expected more values than were received")
|
||||
}
|
||||
|
||||
for i, v := range values {
|
||||
dest[i] = driver.Value(v)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue