diff --git a/query.go b/query.go index 12e3c136..dfdb3b2d 100644 --- a/query.go +++ b/query.go @@ -33,6 +33,7 @@ type Rows struct { conn *Conn mr *MsgReader fields []FieldDescription + vr ValueReader rowCount int columnIdx int err error @@ -153,19 +154,20 @@ func (rows *Rows) Next() bool { } } -func (rows *Rows) nextColumn() (*FieldDescription, int32, bool) { +func (rows *Rows) nextColumn() (*ValueReader, bool) { if rows.closed { - return nil, 0, false + return nil, false } if len(rows.fields) <= rows.columnIdx { rows.Fatal(ProtocolError("No next column available")) - return nil, 0, false + return nil, false } fd := &rows.fields[rows.columnIdx] rows.columnIdx++ size := rows.mr.ReadInt32() - return fd, size, true + rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size} + return &rows.vr, true } func (rows *Rows) Scan(dest ...interface{}) (err error) { @@ -175,46 +177,53 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { return err } + // TODO - decodeX should return err and Scan should Fatal the rows for _, d := range dest { - fd, size, _ := rows.nextColumn() + vr, _ := rows.nextColumn() switch d := d.(type) { case *bool: - *d = decodeBool(rows, fd, size) + *d = decodeBool(vr) case *[]byte: - *d = decodeBytea(rows, fd, size) + *d = decodeBytea(vr) case *int64: - *d = decodeInt8(rows, fd, size) + *d = decodeInt8(vr) case *int16: - *d = decodeInt2(rows, fd, size) + *d = decodeInt2(vr) case *int32: - *d = decodeInt4(rows, fd, size) + *d = decodeInt4(vr) case *string: - *d = decodeText(rows, fd, size) + *d = decodeText(vr) case *float32: - *d = decodeFloat4(rows, fd, size) + *d = decodeFloat4(vr) case *float64: - *d = decodeFloat8(rows, fd, size) + *d = decodeFloat8(vr) case *time.Time: - switch fd.DataType { + switch vr.Type().DataType { case DateOid: - *d = decodeDate(rows, fd, size) + *d = decodeDate(vr) case TimestampTzOid: - *d = decodeTimestampTz(rows, fd, size) + *d = decodeTimestampTz(vr) case TimestampOid: - *d = decodeTimestamp(rows, fd, size) + *d = decodeTimestamp(vr) default: - err = fmt.Errorf("Can't convert OID %v to time.Time", fd.DataType) - rows.Fatal(err) - return err + rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) } case Scanner: - err = d.Scan(rows, fd, size) + err = d.Scan(vr) if err != nil { - return err + rows.Fatal(err) } default: - return errors.New("Unknown type") + rows.Fatal(errors.New("Unknown type")) + } + + if vr.Err() != nil { + rows.Fatal(vr.Err()) + } + + if rows.Err() != nil { + return rows.Err() } } @@ -230,46 +239,50 @@ func (rows *Rows) Values() ([]interface{}, error) { values := make([]interface{}, 0, len(rows.fields)) for _, _ = range rows.fields { - if rows.Err() != nil { - return nil, rows.Err() - } + vr, _ := rows.nextColumn() - fd, size, _ := rows.nextColumn() - - switch fd.DataType { + switch vr.Type().DataType { case BoolOid: - values = append(values, decodeBool(rows, fd, size)) + values = append(values, decodeBool(vr)) case ByteaOid: - values = append(values, decodeBytea(rows, fd, size)) + values = append(values, decodeBytea(vr)) case Int8Oid: - values = append(values, decodeInt8(rows, fd, size)) + values = append(values, decodeInt8(vr)) case Int2Oid: - values = append(values, decodeInt2(rows, fd, size)) + values = append(values, decodeInt2(vr)) case Int4Oid: - values = append(values, decodeInt4(rows, fd, size)) + values = append(values, decodeInt4(vr)) case VarcharOid, TextOid: - values = append(values, decodeText(rows, fd, size)) + values = append(values, decodeText(vr)) case Float4Oid: - values = append(values, decodeFloat4(rows, fd, size)) + values = append(values, decodeFloat4(vr)) case Float8Oid: - values = append(values, decodeFloat8(rows, fd, size)) + values = append(values, decodeFloat8(vr)) case DateOid: - values = append(values, decodeDate(rows, fd, size)) + values = append(values, decodeDate(vr)) case TimestampTzOid: - values = append(values, decodeTimestampTz(rows, fd, size)) + values = append(values, decodeTimestampTz(vr)) case TimestampOid: - values = append(values, decodeTimestamp(rows, fd, size)) + values = append(values, decodeTimestamp(vr)) default: // if it is not an intrinsic type then return the text - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - values = append(values, rows.mr.ReadString(size)) + values = append(values, vr.ReadString(vr.Len())) case BinaryFormatCode: - return nil, errors.New("Values cannot handle binary format non-intrinsic types") + rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) default: - return nil, errors.New("Unknown format code") + rows.Fatal(errors.New("Unknown format code")) } } + + if vr.Err() != nil { + rows.Fatal(vr.Err()) + } + + if rows.Err() != nil { + return nil, rows.Err() + } } return values, rows.Err() diff --git a/query_test.go b/query_test.go index 040b5c06..8af481f2 100644 --- a/query_test.go +++ b/query_test.go @@ -472,7 +472,7 @@ func TestQueryRowUnpreparedErrors(t *testing.T) { if err == nil { t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) } - if !strings.Contains(err.Error(), tt.err) { + if err != nil && !strings.Contains(err.Error(), tt.err) { t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) } @@ -511,7 +511,7 @@ func TestQueryRowPreparedErrors(t *testing.T) { if err == nil { t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) } - if !strings.Contains(err.Error(), tt.err) { + if err != nil && !strings.Contains(err.Error(), tt.err) { t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) } diff --git a/value_reader.go b/value_reader.go new file mode 100644 index 00000000..6e8f65e0 --- /dev/null +++ b/value_reader.go @@ -0,0 +1,123 @@ +package pgx + +import ( + "errors" +) + +// ValueReader the mechanism for implementing the BinaryDecoder interface. +type ValueReader struct { + mr *MsgReader + fd *FieldDescription + valueBytesRemaining int32 + err error +} + +// Err returns any error that the ValueReader has experienced +func (r *ValueReader) Err() error { + return r.err +} + +// Fatal tells r that a Fatal error has occurred +func (r *ValueReader) Fatal(err error) { + r.err = err +} + +// Len returns the number of unread bytes +func (r *ValueReader) Len() int32 { + return r.valueBytesRemaining +} + +// Type returns the *FieldDescription of the value +func (r *ValueReader) Type() *FieldDescription { + return r.fd +} + +func (r *ValueReader) ReadByte() byte { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 1 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.ReadByte() +} + +func (r *ValueReader) ReadInt16() int16 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 2 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.ReadInt16() +} + +func (r *ValueReader) ReadInt32() int32 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 4 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.ReadInt32() +} + +func (r *ValueReader) ReadInt64() int64 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 8 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.ReadInt64() +} + +func (r *ValueReader) ReadOid() Oid { + return Oid(r.ReadInt32()) +} + +// ReadString reads count bytes and returns as string +func (r *ValueReader) ReadString(count int32) string { + if r.err != nil { + return "" + } + + r.valueBytesRemaining -= count + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return "" + } + + return r.mr.ReadString(count) +} + +// ReadBytes reads count bytes and returns as []byte +func (r *ValueReader) ReadBytes(count int32) []byte { + if r.err != nil { + return nil + } + + r.valueBytesRemaining -= count + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return nil + } + + return r.mr.ReadBytes(count) +} diff --git a/values.go b/values.go index 5ca69b36..ec924ea7 100644 --- a/values.go +++ b/values.go @@ -47,7 +47,7 @@ func (e SerializationError) Error() string { type Scanner interface { // Scan MUST check fd's DataType and FormatCode before decoding. It should // not assume that it was called on the type of value. - Scan(rows *Rows, fd *FieldDescription, size int32) error + Scan(r *ValueReader) error } // TextEncoder is an interface used to encode values in text format for @@ -85,18 +85,18 @@ type NullFloat32 struct { Valid bool // Valid is true if Float32 is not NULL } -func (n *NullFloat32) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot decode OID %d", fd.DataType)) +func (n *NullFloat32) Scan(vr *ValueReader) error { + if vr.Type().DataType != Float4Oid { + return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Float32, n.Valid = 0, false return nil } n.Valid = true - n.Float32 = decodeFloat4(rows, fd, size) - return rows.Err() + n.Float32 = decodeFloat4(vr) + return vr.Err() } func (n NullFloat32) EncodeText() (string, byte, error) { @@ -131,18 +131,18 @@ type NullFloat64 struct { Valid bool // Valid is true if Float64 is not NULL } -func (n *NullFloat64) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", fd.DataType)) +func (n *NullFloat64) Scan(vr *ValueReader) error { + if vr.Type().DataType != Float8Oid { + return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Float64, n.Valid = 0, false return nil } n.Valid = true - n.Float64 = decodeFloat8(rows, fd, size) - return rows.Err() + n.Float64 = decodeFloat8(vr) + return vr.Err() } func (n NullFloat64) EncodeText() (string, byte, error) { @@ -177,17 +177,17 @@ type NullString struct { Valid bool // Valid is true if Int64 is not NULL } -func (s *NullString) Scan(rows *Rows, fd *FieldDescription, size int32) error { +func (s *NullString) Scan(vr *ValueReader) error { // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later - if size == -1 { + if vr.Len() == -1 { s.String, s.Valid = "", false return nil } s.Valid = true - s.String = decodeText(rows, fd, size) - return rows.Err() + s.String = decodeText(vr) + return vr.Err() } func (s NullString) EncodeText() (string, byte, error) { @@ -209,18 +209,18 @@ type NullInt16 struct { Valid bool // Valid is true if Int16 is not NULL } -func (n *NullInt16) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", fd.DataType)) +func (n *NullInt16) Scan(vr *ValueReader) error { + if vr.Type().DataType != Int2Oid { + return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Int16, n.Valid = 0, false return nil } n.Valid = true - n.Int16 = decodeInt2(rows, fd, size) - return rows.Err() + n.Int16 = decodeInt2(vr) + return vr.Err() } func (n NullInt16) EncodeText() (string, byte, error) { @@ -255,18 +255,18 @@ type NullInt32 struct { Valid bool // Valid is true if Int64 is not NULL } -func (n *NullInt32) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", fd.DataType)) +func (n *NullInt32) Scan(vr *ValueReader) error { + if vr.Type().DataType != Int4Oid { + return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Int32, n.Valid = 0, false return nil } n.Valid = true - n.Int32 = decodeInt4(rows, fd, size) - return rows.Err() + n.Int32 = decodeInt4(vr) + return vr.Err() } func (n NullInt32) EncodeText() (string, byte, error) { @@ -301,18 +301,18 @@ type NullInt64 struct { Valid bool // Valid is true if Int64 is not NULL } -func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", fd.DataType)) +func (n *NullInt64) Scan(vr *ValueReader) error { + if vr.Type().DataType != Int8Oid { + return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Int64, n.Valid = 0, false return nil } n.Valid = true - n.Int64 = decodeInt8(rows, fd, size) - return rows.Err() + n.Int64 = decodeInt8(vr) + return vr.Err() } func (n NullInt64) EncodeText() (string, byte, error) { @@ -347,18 +347,18 @@ type NullBool struct { Valid bool // Valid is true if Bool is not NULL } -func (n *NullBool) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", fd.DataType)) +func (n *NullBool) Scan(vr *ValueReader) error { + if vr.Type().DataType != BoolOid { + return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Bool, n.Valid = false, false return nil } n.Valid = true - n.Bool = decodeBool(rows, fd, size) - return rows.Err() + n.Bool = decodeBool(vr) + return vr.Err() } func (n NullBool) EncodeText() (string, byte, error) { @@ -393,20 +393,20 @@ type NullTime struct { Valid bool // Valid is true if Time is not NULL } -func (n *NullTime) Scan(rows *Rows, fd *FieldDescription, size int32) error { - if fd.DataType != TimestampTzOid { - return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", fd.DataType)) +func (n *NullTime) Scan(vr *ValueReader) error { + if vr.Type().DataType != TimestampTzOid { + return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) } - if size == -1 { + if vr.Len() == -1 { n.Time, n.Valid = time.Time{}, false return nil } n.Valid = true - n.Time = decodeTimestampTz(rows, fd, size) + n.Time = decodeTimestampTz(vr) - return rows.Err() + return vr.Err() } func (n NullTime) EncodeText() (string, byte, error) { @@ -523,28 +523,28 @@ func sanitizeArg(arg interface{}) (string, error) { } } -func decodeBool(rows *Rows, fd *FieldDescription, size int32) bool { - switch fd.FormatCode { +func decodeBool(vr *ValueReader) bool { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) switch s { case "t": return true case "f": return false default: - rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))) return false } case BinaryFormatCode: - if size != 1 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size))) + if vr.Len() != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) return false } - b := rows.mr.ReadByte() + b := vr.ReadByte() return b != 0 default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false } } @@ -567,29 +567,29 @@ func encodeBool(w *WriteBuf, value interface{}) error { return nil } -func decodeInt8(rows *Rows, fd *FieldDescription, size int32) int64 { - if fd.DataType != Int8Oid { - rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType))) +func decodeInt8(vr *ValueReader) int64 { + if vr.Type().DataType != Int8Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType))) return 0 } - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) n, err := strconv.ParseInt(s, 10, 64) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))) return 0 } return n case BinaryFormatCode: - if size != 8 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))) + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) return 0 } - return rows.mr.ReadInt64() + return vr.ReadInt64() default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } } @@ -628,29 +628,29 @@ func encodeInt8(w *WriteBuf, value interface{}) error { return nil } -func decodeInt2(rows *Rows, fd *FieldDescription, size int32) int16 { - if fd.DataType != Int2Oid { - rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType))) +func decodeInt2(vr *ValueReader) int16 { + if vr.Type().DataType != Int2Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType))) return 0 } - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) n, err := strconv.ParseInt(s, 10, 16) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))) return 0 } return int16(n) case BinaryFormatCode: - if size != 2 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size))) + if vr.Len() != 2 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) return 0 } - return rows.mr.ReadInt16() + return vr.ReadInt16() default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } } @@ -704,28 +704,28 @@ func encodeInt2(w *WriteBuf, value interface{}) error { return nil } -func decodeInt4(rows *Rows, fd *FieldDescription, size int32) int32 { - if fd.DataType != Int4Oid { - rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType))) +func decodeInt4(vr *ValueReader) int32 { + if vr.Type().DataType != Int4Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType))) return 0 } - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) n, err := strconv.ParseInt(s, 10, 32) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))) } return int32(n) case BinaryFormatCode: - if size != 4 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size))) + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) return 0 } - return rows.mr.ReadInt32() + return vr.ReadInt32() default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } } @@ -773,27 +773,27 @@ func encodeInt4(w *WriteBuf, value interface{}) error { return nil } -func decodeFloat4(rows *Rows, fd *FieldDescription, size int32) float32 { - switch fd.FormatCode { +func decodeFloat4(vr *ValueReader) float32 { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) n, err := strconv.ParseFloat(s, 32) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))) return 0 } return float32(n) case BinaryFormatCode: - if size != 4 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size))) + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) return 0 } - i := rows.mr.ReadInt32() + i := vr.ReadInt32() p := unsafe.Pointer(&i) return *(*float32)(p) default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } } @@ -820,27 +820,27 @@ func encodeFloat4(w *WriteBuf, value interface{}) error { return nil } -func decodeFloat8(rows *Rows, fd *FieldDescription, size int32) float64 { - switch fd.FormatCode { +func decodeFloat8(vr *ValueReader) float64 { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) v, err := strconv.ParseFloat(s, 64) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))) return 0 } return v case BinaryFormatCode: - if size != 8 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size))) + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) return 0 } - i := rows.mr.ReadInt64() + i := vr.ReadInt64() p := unsafe.Pointer(&i) return *(*float64)(p) default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } } @@ -864,8 +864,8 @@ func encodeFloat8(w *WriteBuf, value interface{}) error { return nil } -func decodeText(rows *Rows, fd *FieldDescription, size int32) string { - return rows.mr.ReadString(size) +func decodeText(vr *ValueReader) string { + return vr.ReadString(vr.Len()) } func encodeText(w *WriteBuf, value interface{}) error { @@ -880,20 +880,20 @@ func encodeText(w *WriteBuf, value interface{}) error { return nil } -func decodeBytea(rows *Rows, fd *FieldDescription, size int32) []byte { - switch fd.FormatCode { +func decodeBytea(vr *ValueReader) []byte { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) b, err := hex.DecodeString(s[2:]) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))) return nil } return b case BinaryFormatCode: - return rows.mr.ReadBytes(size) + return vr.ReadBytes(vr.Len()) default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } } @@ -910,31 +910,31 @@ func encodeBytea(w *WriteBuf, value interface{}) error { return nil } -func decodeDate(rows *Rows, fd *FieldDescription, size int32) time.Time { +func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time - if fd.DataType != DateOid { - rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType))) + if vr.Type().DataType != DateOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType))) return zeroTime } - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) t, err := time.ParseInLocation("2006-01-02", s, time.Local) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s))) return zeroTime } return t case BinaryFormatCode: - if size != 4 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", size))) + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) } - dayOffset := rows.mr.ReadInt32() + dayOffset := vr.ReadInt32() return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } } @@ -951,32 +951,32 @@ func encodeDate(w *WriteBuf, value interface{}) error { const microsecFromUnixEpochToY2K = 946684800 * 1000000 -func decodeTimestampTz(rows *Rows, fd *FieldDescription, size int32) time.Time { +func decodeTimestampTz(vr *ValueReader) time.Time { var zeroTime time.Time - if fd.DataType != TimestampTzOid { - rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType))) + if vr.Type().DataType != TimestampTzOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType))) return zeroTime } - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))) return zeroTime } return t case BinaryFormatCode: - if size != 8 { - rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size))) + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) } - microsecSinceY2K := rows.mr.ReadInt64() + microsecSinceY2K := vr.ReadInt64() microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } } @@ -996,28 +996,28 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error { return nil } -func decodeTimestamp(rows *Rows, fd *FieldDescription, size int32) time.Time { +func decodeTimestamp(vr *ValueReader) time.Time { var zeroTime time.Time - if fd.DataType != TimestampOid { - rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, fd.DataType))) + if vr.Type().DataType != TimestampOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType))) return zeroTime } - switch fd.FormatCode { + switch vr.Type().FormatCode { case TextFormatCode: - s := rows.mr.ReadString(size) + s := vr.ReadString(vr.Len()) t, err := time.ParseInLocation("2006-01-02 15:04:05.999999", s, time.Local) if err != nil { - rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamp: %v - %v", err, s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamp: %v - %v", err, s))) return zeroTime } return t case BinaryFormatCode: - rows.Fatal(ProtocolError("Can't decode binary timestamp")) + vr.Fatal(ProtocolError("Can't decode binary timestamp")) return zeroTime default: - rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } } diff --git a/values_test.go b/values_test.go index 6bc9c1a3..01e2f2cb 100644 --- a/values_test.go +++ b/values_test.go @@ -199,7 +199,7 @@ func TestNullX(t *testing.T) { {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.b}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, }