Change Scan / decodeX to use ValueReader

Also improve Rows.Scan and Rows.Values error handling.
scan-io
Jack Christensen 2014-07-12 19:39:58 -05:00
parent 70c32fadc6
commit f215c8bf5f
5 changed files with 320 additions and 184 deletions

101
query.go
View File

@ -33,6 +33,7 @@ type Rows struct {
conn *Conn conn *Conn
mr *MsgReader mr *MsgReader
fields []FieldDescription fields []FieldDescription
vr ValueReader
rowCount int rowCount int
columnIdx int columnIdx int
err error err error
@ -153,19 +154,20 @@ func (rows *Rows) Next() bool {
} }
} }
func (rows *Rows) nextColumn() (*FieldDescription, int32, bool) { func (rows *Rows) nextColumn() (*ValueReader, bool) {
if rows.closed { if rows.closed {
return nil, 0, false return nil, false
} }
if len(rows.fields) <= rows.columnIdx { if len(rows.fields) <= rows.columnIdx {
rows.Fatal(ProtocolError("No next column available")) rows.Fatal(ProtocolError("No next column available"))
return nil, 0, false return nil, false
} }
fd := &rows.fields[rows.columnIdx] fd := &rows.fields[rows.columnIdx]
rows.columnIdx++ rows.columnIdx++
size := rows.mr.ReadInt32() size := rows.mr.ReadInt32()
return fd, size, true rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size}
return &rows.vr, true
} }
func (rows *Rows) Scan(dest ...interface{}) (err error) { func (rows *Rows) Scan(dest ...interface{}) (err error) {
@ -175,46 +177,53 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
return err return err
} }
// TODO - decodeX should return err and Scan should Fatal the rows
for _, d := range dest { for _, d := range dest {
fd, size, _ := rows.nextColumn() vr, _ := rows.nextColumn()
switch d := d.(type) { switch d := d.(type) {
case *bool: case *bool:
*d = decodeBool(rows, fd, size) *d = decodeBool(vr)
case *[]byte: case *[]byte:
*d = decodeBytea(rows, fd, size) *d = decodeBytea(vr)
case *int64: case *int64:
*d = decodeInt8(rows, fd, size) *d = decodeInt8(vr)
case *int16: case *int16:
*d = decodeInt2(rows, fd, size) *d = decodeInt2(vr)
case *int32: case *int32:
*d = decodeInt4(rows, fd, size) *d = decodeInt4(vr)
case *string: case *string:
*d = decodeText(rows, fd, size) *d = decodeText(vr)
case *float32: case *float32:
*d = decodeFloat4(rows, fd, size) *d = decodeFloat4(vr)
case *float64: case *float64:
*d = decodeFloat8(rows, fd, size) *d = decodeFloat8(vr)
case *time.Time: case *time.Time:
switch fd.DataType { switch vr.Type().DataType {
case DateOid: case DateOid:
*d = decodeDate(rows, fd, size) *d = decodeDate(vr)
case TimestampTzOid: case TimestampTzOid:
*d = decodeTimestampTz(rows, fd, size) *d = decodeTimestampTz(vr)
case TimestampOid: case TimestampOid:
*d = decodeTimestamp(rows, fd, size) *d = decodeTimestamp(vr)
default: default:
err = fmt.Errorf("Can't convert OID %v to time.Time", fd.DataType) rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
rows.Fatal(err)
return err
} }
case Scanner: case Scanner:
err = d.Scan(rows, fd, size) err = d.Scan(vr)
if err != nil { if err != nil {
return err rows.Fatal(err)
} }
default: default:
return errors.New("Unknown type") rows.Fatal(errors.New("Unknown type"))
}
if vr.Err() != nil {
rows.Fatal(vr.Err())
}
if rows.Err() != nil {
return rows.Err()
} }
} }
@ -230,46 +239,50 @@ func (rows *Rows) Values() ([]interface{}, error) {
values := make([]interface{}, 0, len(rows.fields)) values := make([]interface{}, 0, len(rows.fields))
for _, _ = range rows.fields { for _, _ = range rows.fields {
if rows.Err() != nil { vr, _ := rows.nextColumn()
return nil, rows.Err()
}
fd, size, _ := rows.nextColumn() switch vr.Type().DataType {
switch fd.DataType {
case BoolOid: case BoolOid:
values = append(values, decodeBool(rows, fd, size)) values = append(values, decodeBool(vr))
case ByteaOid: case ByteaOid:
values = append(values, decodeBytea(rows, fd, size)) values = append(values, decodeBytea(vr))
case Int8Oid: case Int8Oid:
values = append(values, decodeInt8(rows, fd, size)) values = append(values, decodeInt8(vr))
case Int2Oid: case Int2Oid:
values = append(values, decodeInt2(rows, fd, size)) values = append(values, decodeInt2(vr))
case Int4Oid: case Int4Oid:
values = append(values, decodeInt4(rows, fd, size)) values = append(values, decodeInt4(vr))
case VarcharOid, TextOid: case VarcharOid, TextOid:
values = append(values, decodeText(rows, fd, size)) values = append(values, decodeText(vr))
case Float4Oid: case Float4Oid:
values = append(values, decodeFloat4(rows, fd, size)) values = append(values, decodeFloat4(vr))
case Float8Oid: case Float8Oid:
values = append(values, decodeFloat8(rows, fd, size)) values = append(values, decodeFloat8(vr))
case DateOid: case DateOid:
values = append(values, decodeDate(rows, fd, size)) values = append(values, decodeDate(vr))
case TimestampTzOid: case TimestampTzOid:
values = append(values, decodeTimestampTz(rows, fd, size)) values = append(values, decodeTimestampTz(vr))
case TimestampOid: case TimestampOid:
values = append(values, decodeTimestamp(rows, fd, size)) values = append(values, decodeTimestamp(vr))
default: default:
// if it is not an intrinsic type then return the text // if it is not an intrinsic type then return the text
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
values = append(values, rows.mr.ReadString(size)) values = append(values, vr.ReadString(vr.Len()))
case BinaryFormatCode: case BinaryFormatCode:
return nil, errors.New("Values cannot handle binary format non-intrinsic types") rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
default: default:
return nil, errors.New("Unknown format code") rows.Fatal(errors.New("Unknown format code"))
} }
} }
if vr.Err() != nil {
rows.Fatal(vr.Err())
}
if rows.Err() != nil {
return nil, rows.Err()
}
} }
return values, rows.Err() return values, rows.Err()

View File

@ -472,7 +472,7 @@ func TestQueryRowUnpreparedErrors(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
} }
if !strings.Contains(err.Error(), tt.err) { if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs)
} }
@ -511,7 +511,7 @@ func TestQueryRowPreparedErrors(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
} }
if !strings.Contains(err.Error(), tt.err) { if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs)
} }

123
value_reader.go Normal file
View File

@ -0,0 +1,123 @@
package pgx
import (
"errors"
)
// ValueReader the mechanism for implementing the BinaryDecoder interface.
type ValueReader struct {
mr *MsgReader
fd *FieldDescription
valueBytesRemaining int32
err error
}
// Err returns any error that the ValueReader has experienced
func (r *ValueReader) Err() error {
return r.err
}
// Fatal tells r that a Fatal error has occurred
func (r *ValueReader) Fatal(err error) {
r.err = err
}
// Len returns the number of unread bytes
func (r *ValueReader) Len() int32 {
return r.valueBytesRemaining
}
// Type returns the *FieldDescription of the value
func (r *ValueReader) Type() *FieldDescription {
return r.fd
}
func (r *ValueReader) ReadByte() byte {
if r.err != nil {
return 0
}
r.valueBytesRemaining -= 1
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return 0
}
return r.mr.ReadByte()
}
func (r *ValueReader) ReadInt16() int16 {
if r.err != nil {
return 0
}
r.valueBytesRemaining -= 2
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return 0
}
return r.mr.ReadInt16()
}
func (r *ValueReader) ReadInt32() int32 {
if r.err != nil {
return 0
}
r.valueBytesRemaining -= 4
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return 0
}
return r.mr.ReadInt32()
}
func (r *ValueReader) ReadInt64() int64 {
if r.err != nil {
return 0
}
r.valueBytesRemaining -= 8
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return 0
}
return r.mr.ReadInt64()
}
func (r *ValueReader) ReadOid() Oid {
return Oid(r.ReadInt32())
}
// ReadString reads count bytes and returns as string
func (r *ValueReader) ReadString(count int32) string {
if r.err != nil {
return ""
}
r.valueBytesRemaining -= count
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return ""
}
return r.mr.ReadString(count)
}
// ReadBytes reads count bytes and returns as []byte
func (r *ValueReader) ReadBytes(count int32) []byte {
if r.err != nil {
return nil
}
r.valueBytesRemaining -= count
if r.valueBytesRemaining < 0 {
r.Fatal(errors.New("read past end of value"))
return nil
}
return r.mr.ReadBytes(count)
}

274
values.go
View File

@ -47,7 +47,7 @@ func (e SerializationError) Error() string {
type Scanner interface { type Scanner interface {
// Scan MUST check fd's DataType and FormatCode before decoding. It should // Scan MUST check fd's DataType and FormatCode before decoding. It should
// not assume that it was called on the type of value. // not assume that it was called on the type of value.
Scan(rows *Rows, fd *FieldDescription, size int32) error Scan(r *ValueReader) error
} }
// TextEncoder is an interface used to encode values in text format for // TextEncoder is an interface used to encode values in text format for
@ -85,18 +85,18 @@ type NullFloat32 struct {
Valid bool // Valid is true if Float32 is not NULL Valid bool // Valid is true if Float32 is not NULL
} }
func (n *NullFloat32) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullFloat32) Scan(vr *ValueReader) error {
if fd.DataType != Float4Oid { if vr.Type().DataType != Float4Oid {
return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot decode OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Float32, n.Valid = 0, false n.Float32, n.Valid = 0, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Float32 = decodeFloat4(rows, fd, size) n.Float32 = decodeFloat4(vr)
return rows.Err() return vr.Err()
} }
func (n NullFloat32) EncodeText() (string, byte, error) { func (n NullFloat32) EncodeText() (string, byte, error) {
@ -131,18 +131,18 @@ type NullFloat64 struct {
Valid bool // Valid is true if Float64 is not NULL Valid bool // Valid is true if Float64 is not NULL
} }
func (n *NullFloat64) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullFloat64) Scan(vr *ValueReader) error {
if fd.DataType != Float8Oid { if vr.Type().DataType != Float8Oid {
return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Float64, n.Valid = 0, false n.Float64, n.Valid = 0, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Float64 = decodeFloat8(rows, fd, size) n.Float64 = decodeFloat8(vr)
return rows.Err() return vr.Err()
} }
func (n NullFloat64) EncodeText() (string, byte, error) { func (n NullFloat64) EncodeText() (string, byte, error) {
@ -177,17 +177,17 @@ type NullString struct {
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL
} }
func (s *NullString) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (s *NullString) Scan(vr *ValueReader) error {
// Not checking oid as so we can scan anything into into a NullString - may revisit this decision later // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later
if size == -1 { if vr.Len() == -1 {
s.String, s.Valid = "", false s.String, s.Valid = "", false
return nil return nil
} }
s.Valid = true s.Valid = true
s.String = decodeText(rows, fd, size) s.String = decodeText(vr)
return rows.Err() return vr.Err()
} }
func (s NullString) EncodeText() (string, byte, error) { func (s NullString) EncodeText() (string, byte, error) {
@ -209,18 +209,18 @@ type NullInt16 struct {
Valid bool // Valid is true if Int16 is not NULL Valid bool // Valid is true if Int16 is not NULL
} }
func (n *NullInt16) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullInt16) Scan(vr *ValueReader) error {
if fd.DataType != Int2Oid { if vr.Type().DataType != Int2Oid {
return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Int16, n.Valid = 0, false n.Int16, n.Valid = 0, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Int16 = decodeInt2(rows, fd, size) n.Int16 = decodeInt2(vr)
return rows.Err() return vr.Err()
} }
func (n NullInt16) EncodeText() (string, byte, error) { func (n NullInt16) EncodeText() (string, byte, error) {
@ -255,18 +255,18 @@ type NullInt32 struct {
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL
} }
func (n *NullInt32) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullInt32) Scan(vr *ValueReader) error {
if fd.DataType != Int4Oid { if vr.Type().DataType != Int4Oid {
return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Int32, n.Valid = 0, false n.Int32, n.Valid = 0, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Int32 = decodeInt4(rows, fd, size) n.Int32 = decodeInt4(vr)
return rows.Err() return vr.Err()
} }
func (n NullInt32) EncodeText() (string, byte, error) { func (n NullInt32) EncodeText() (string, byte, error) {
@ -301,18 +301,18 @@ type NullInt64 struct {
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL
} }
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullInt64) Scan(vr *ValueReader) error {
if fd.DataType != Int8Oid { if vr.Type().DataType != Int8Oid {
return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Int64, n.Valid = 0, false n.Int64, n.Valid = 0, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Int64 = decodeInt8(rows, fd, size) n.Int64 = decodeInt8(vr)
return rows.Err() return vr.Err()
} }
func (n NullInt64) EncodeText() (string, byte, error) { func (n NullInt64) EncodeText() (string, byte, error) {
@ -347,18 +347,18 @@ type NullBool struct {
Valid bool // Valid is true if Bool is not NULL Valid bool // Valid is true if Bool is not NULL
} }
func (n *NullBool) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullBool) Scan(vr *ValueReader) error {
if fd.DataType != BoolOid { if vr.Type().DataType != BoolOid {
return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Bool, n.Valid = false, false n.Bool, n.Valid = false, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Bool = decodeBool(rows, fd, size) n.Bool = decodeBool(vr)
return rows.Err() return vr.Err()
} }
func (n NullBool) EncodeText() (string, byte, error) { func (n NullBool) EncodeText() (string, byte, error) {
@ -393,20 +393,20 @@ type NullTime struct {
Valid bool // Valid is true if Time is not NULL Valid bool // Valid is true if Time is not NULL
} }
func (n *NullTime) Scan(rows *Rows, fd *FieldDescription, size int32) error { func (n *NullTime) Scan(vr *ValueReader) error {
if fd.DataType != TimestampTzOid { if vr.Type().DataType != TimestampTzOid {
return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", fd.DataType)) return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType))
} }
if size == -1 { if vr.Len() == -1 {
n.Time, n.Valid = time.Time{}, false n.Time, n.Valid = time.Time{}, false
return nil return nil
} }
n.Valid = true n.Valid = true
n.Time = decodeTimestampTz(rows, fd, size) n.Time = decodeTimestampTz(vr)
return rows.Err() return vr.Err()
} }
func (n NullTime) EncodeText() (string, byte, error) { func (n NullTime) EncodeText() (string, byte, error) {
@ -523,28 +523,28 @@ func sanitizeArg(arg interface{}) (string, error) {
} }
} }
func decodeBool(rows *Rows, fd *FieldDescription, size int32) bool { func decodeBool(vr *ValueReader) bool {
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
switch s { switch s {
case "t": case "t":
return true return true
case "f": case "f":
return false return false
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s)))
return false return false
} }
case BinaryFormatCode: case BinaryFormatCode:
if size != 1 { if vr.Len() != 1 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len())))
return false return false
} }
b := rows.mr.ReadByte() b := vr.ReadByte()
return b != 0 return b != 0
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return false return false
} }
} }
@ -567,29 +567,29 @@ func encodeBool(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeInt8(rows *Rows, fd *FieldDescription, size int32) int64 { func decodeInt8(vr *ValueReader) int64 {
if fd.DataType != Int8Oid { if vr.Type().DataType != Int8Oid {
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType)))
return 0 return 0
} }
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
n, err := strconv.ParseInt(s, 10, 64) n, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s)))
return 0 return 0
} }
return n return n
case BinaryFormatCode: case BinaryFormatCode:
if size != 8 { if vr.Len() != 8 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len())))
return 0 return 0
} }
return rows.mr.ReadInt64() return vr.ReadInt64()
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return 0 return 0
} }
} }
@ -628,29 +628,29 @@ func encodeInt8(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeInt2(rows *Rows, fd *FieldDescription, size int32) int16 { func decodeInt2(vr *ValueReader) int16 {
if fd.DataType != Int2Oid { if vr.Type().DataType != Int2Oid {
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType)))
return 0 return 0
} }
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
n, err := strconv.ParseInt(s, 10, 16) n, err := strconv.ParseInt(s, 10, 16)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s)))
return 0 return 0
} }
return int16(n) return int16(n)
case BinaryFormatCode: case BinaryFormatCode:
if size != 2 { if vr.Len() != 2 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len())))
return 0 return 0
} }
return rows.mr.ReadInt16() return vr.ReadInt16()
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return 0 return 0
} }
} }
@ -704,28 +704,28 @@ func encodeInt2(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeInt4(rows *Rows, fd *FieldDescription, size int32) int32 { func decodeInt4(vr *ValueReader) int32 {
if fd.DataType != Int4Oid { if vr.Type().DataType != Int4Oid {
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType)))
return 0 return 0
} }
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s)))
} }
return int32(n) return int32(n)
case BinaryFormatCode: case BinaryFormatCode:
if size != 4 { if vr.Len() != 4 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len())))
return 0 return 0
} }
return rows.mr.ReadInt32() return vr.ReadInt32()
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return 0 return 0
} }
} }
@ -773,27 +773,27 @@ func encodeInt4(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeFloat4(rows *Rows, fd *FieldDescription, size int32) float32 { func decodeFloat4(vr *ValueReader) float32 {
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
n, err := strconv.ParseFloat(s, 32) n, err := strconv.ParseFloat(s, 32)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s)))
return 0 return 0
} }
return float32(n) return float32(n)
case BinaryFormatCode: case BinaryFormatCode:
if size != 4 { if vr.Len() != 4 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
return 0 return 0
} }
i := rows.mr.ReadInt32() i := vr.ReadInt32()
p := unsafe.Pointer(&i) p := unsafe.Pointer(&i)
return *(*float32)(p) return *(*float32)(p)
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return 0 return 0
} }
} }
@ -820,27 +820,27 @@ func encodeFloat4(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeFloat8(rows *Rows, fd *FieldDescription, size int32) float64 { func decodeFloat8(vr *ValueReader) float64 {
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
v, err := strconv.ParseFloat(s, 64) v, err := strconv.ParseFloat(s, 64)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s)))
return 0 return 0
} }
return v return v
case BinaryFormatCode: case BinaryFormatCode:
if size != 8 { if vr.Len() != 8 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len())))
return 0 return 0
} }
i := rows.mr.ReadInt64() i := vr.ReadInt64()
p := unsafe.Pointer(&i) p := unsafe.Pointer(&i)
return *(*float64)(p) return *(*float64)(p)
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return 0 return 0
} }
} }
@ -864,8 +864,8 @@ func encodeFloat8(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeText(rows *Rows, fd *FieldDescription, size int32) string { func decodeText(vr *ValueReader) string {
return rows.mr.ReadString(size) return vr.ReadString(vr.Len())
} }
func encodeText(w *WriteBuf, value interface{}) error { func encodeText(w *WriteBuf, value interface{}) error {
@ -880,20 +880,20 @@ func encodeText(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeBytea(rows *Rows, fd *FieldDescription, size int32) []byte { func decodeBytea(vr *ValueReader) []byte {
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
b, err := hex.DecodeString(s[2:]) b, err := hex.DecodeString(s[2:])
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))) vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s)))
return nil return nil
} }
return b return b
case BinaryFormatCode: case BinaryFormatCode:
return rows.mr.ReadBytes(size) return vr.ReadBytes(vr.Len())
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil return nil
} }
} }
@ -910,31 +910,31 @@ func encodeBytea(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeDate(rows *Rows, fd *FieldDescription, size int32) time.Time { func decodeDate(vr *ValueReader) time.Time {
var zeroTime time.Time var zeroTime time.Time
if fd.DataType != DateOid { if vr.Type().DataType != DateOid {
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType)))
return zeroTime return zeroTime
} }
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
t, err := time.ParseInLocation("2006-01-02", s, time.Local) t, err := time.ParseInLocation("2006-01-02", s, time.Local)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s))) vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s)))
return zeroTime return zeroTime
} }
return t return t
case BinaryFormatCode: case BinaryFormatCode:
if size != 4 { if vr.Len() != 4 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len())))
} }
dayOffset := rows.mr.ReadInt32() dayOffset := vr.ReadInt32()
return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local)
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return zeroTime return zeroTime
} }
} }
@ -951,32 +951,32 @@ func encodeDate(w *WriteBuf, value interface{}) error {
const microsecFromUnixEpochToY2K = 946684800 * 1000000 const microsecFromUnixEpochToY2K = 946684800 * 1000000
func decodeTimestampTz(rows *Rows, fd *FieldDescription, size int32) time.Time { func decodeTimestampTz(vr *ValueReader) time.Time {
var zeroTime time.Time var zeroTime time.Time
if fd.DataType != TimestampTzOid { if vr.Type().DataType != TimestampTzOid {
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType)))
return zeroTime return zeroTime
} }
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) t, err := time.Parse("2006-01-02 15:04:05.999999-07", s)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))) vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s)))
return zeroTime return zeroTime
} }
return t return t
case BinaryFormatCode: case BinaryFormatCode:
if size != 8 { if vr.Len() != 8 {
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size))) vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len())))
} }
microsecSinceY2K := rows.mr.ReadInt64() microsecSinceY2K := vr.ReadInt64()
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return zeroTime return zeroTime
} }
} }
@ -996,28 +996,28 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeTimestamp(rows *Rows, fd *FieldDescription, size int32) time.Time { func decodeTimestamp(vr *ValueReader) time.Time {
var zeroTime time.Time var zeroTime time.Time
if fd.DataType != TimestampOid { if vr.Type().DataType != TimestampOid {
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, fd.DataType))) vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType)))
return zeroTime return zeroTime
} }
switch fd.FormatCode { switch vr.Type().FormatCode {
case TextFormatCode: case TextFormatCode:
s := rows.mr.ReadString(size) s := vr.ReadString(vr.Len())
t, err := time.ParseInLocation("2006-01-02 15:04:05.999999", s, time.Local) t, err := time.ParseInLocation("2006-01-02 15:04:05.999999", s, time.Local)
if err != nil { if err != nil {
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamp: %v - %v", err, s))) vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamp: %v - %v", err, s)))
return zeroTime return zeroTime
} }
return t return t
case BinaryFormatCode: case BinaryFormatCode:
rows.Fatal(ProtocolError("Can't decode binary timestamp")) vr.Fatal(ProtocolError("Can't decode binary timestamp"))
return zeroTime return zeroTime
default: default:
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return zeroTime return zeroTime
} }
} }

View File

@ -199,7 +199,7 @@ func TestNullX(t *testing.T) {
{"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}},
{"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}},
{"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}},
{"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.b}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
{"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}},
} }