diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a1c6e7c..bb575b78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ standard database/sql package such as * Add ConnPool.Reset method * Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces * Rows.Scan errors now include which argument caused error +* Add Encode() to allow custom Encoders to reuse internal encoding functionality +* Add Decode() to allow customer Decoders to reuse internal decoding functionality ## Performance diff --git a/conn.go b/conn.go index ef22daba..12290507 100644 --- a/conn.go +++ b/conn.go @@ -4,7 +4,6 @@ import ( "bufio" "crypto/md5" "crypto/tls" - "database/sql/driver" "encoding/binary" "encoding/hex" "errors" @@ -15,7 +14,6 @@ import ( "os" "os/user" "path/filepath" - "reflect" "regexp" "strconv" "strings" @@ -849,92 +847,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(arguments))) for i, oid := range ps.ParameterOids { - encode: - if arguments[i] == nil { - wbuf.WriteInt32(-1) - continue - } - - switch arg := arguments[i].(type) { - case Encoder: - err = arg.Encode(wbuf, oid) - case driver.Valuer: - arguments[i], err = arg.Value() - if err == nil { - goto encode - } - case string: - err = encodeText(wbuf, arguments[i]) - case []byte: - err = encodeBytea(wbuf, arguments[i]) - default: - if v := reflect.ValueOf(arguments[i]); v.Kind() == reflect.Ptr { - if v.IsNil() { - wbuf.WriteInt32(-1) - continue - } else { - arguments[i] = v.Elem().Interface() - goto encode - } - } - switch oid { - case BoolOid: - err = encodeBool(wbuf, arguments[i]) - case ByteaOid: - err = encodeBytea(wbuf, arguments[i]) - case Int2Oid: - err = encodeInt2(wbuf, arguments[i]) - case Int4Oid: - err = encodeInt4(wbuf, arguments[i]) - case Int8Oid: - err = encodeInt8(wbuf, arguments[i]) - case Float4Oid: - err = encodeFloat4(wbuf, arguments[i]) - case Float8Oid: - err = encodeFloat8(wbuf, arguments[i]) - case TextOid, VarcharOid: - err = encodeText(wbuf, arguments[i]) - case DateOid: - err = encodeDate(wbuf, arguments[i]) - case TimestampTzOid: - err = encodeTimestampTz(wbuf, arguments[i]) - case TimestampOid: - err = encodeTimestamp(wbuf, arguments[i]) - case InetOid, CidrOid: - err = encodeInet(wbuf, arguments[i]) - case InetArrayOid: - err = encodeInetArray(wbuf, arguments[i], InetOid) - case CidrArrayOid: - err = encodeInetArray(wbuf, arguments[i], CidrOid) - case BoolArrayOid: - err = encodeBoolArray(wbuf, arguments[i]) - case Int2ArrayOid: - err = encodeInt2Array(wbuf, arguments[i]) - case Int4ArrayOid: - err = encodeInt4Array(wbuf, arguments[i]) - case Int8ArrayOid: - err = encodeInt8Array(wbuf, arguments[i]) - case Float4ArrayOid: - err = encodeFloat4Array(wbuf, arguments[i]) - case Float8ArrayOid: - err = encodeFloat8Array(wbuf, arguments[i]) - case TextArrayOid: - err = encodeTextArray(wbuf, arguments[i], TextOid) - case VarcharArrayOid: - err = encodeTextArray(wbuf, arguments[i], VarcharOid) - case TimestampArrayOid: - err = encodeTimestampArray(wbuf, arguments[i], TimestampOid) - case TimestampTzArrayOid: - err = encodeTimestampArray(wbuf, arguments[i], TimestampTzOid) - case OidOid: - err = encodeOid(wbuf, arguments[i]) - case JsonOid, JsonbOid: - err = encodeJson(wbuf, arguments[i]) - default: - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) - } - } - if err != nil { + if err := Encode(wbuf, oid, arguments[i]); err != nil { return err } } diff --git a/query.go b/query.go index 49c1e311..a60043cd 100644 --- a/query.go +++ b/query.go @@ -4,8 +4,6 @@ import ( "database/sql" "errors" "fmt" - "net" - "reflect" "time" ) @@ -297,79 +295,9 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { decodeJson(vr, &d) } else { - decode: - switch v := d.(type) { - case *bool: - *v = decodeBool(vr) - case *int64: - *v = decodeInt8(vr) - case *int16: - *v = decodeInt2(vr) - case *int32: - *v = decodeInt4(vr) - case *Oid: - *v = decodeOid(vr) - case *string: - *v = decodeText(vr) - case *float32: - *v = decodeFloat4(vr) - case *float64: - *v = decodeFloat8(vr) - case *[]bool: - *v = decodeBoolArray(vr) - case *[]int16: - *v = decodeInt2Array(vr) - case *[]int32: - *v = decodeInt4Array(vr) - case *[]int64: - *v = decodeInt8Array(vr) - case *[]float32: - *v = decodeFloat4Array(vr) - case *[]float64: - *v = decodeFloat8Array(vr) - case *[]string: - *v = decodeTextArray(vr) - case *[]time.Time: - *v = decodeTimestampArray(vr) - case *time.Time: - switch vr.Type().DataType { - case DateOid: - *v = decodeDate(vr) - case TimestampTzOid: - *v = decodeTimestampTz(vr) - case TimestampOid: - *v = decodeTimestamp(vr) - default: - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)}) - } - case *net.IPNet: - *v = decodeInet(vr) - case *[]net.IPNet: - *v = decodeInetArray(vr) - default: - // if d is a pointer to pointer, strip the pointer and try again - if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { - if el := v.Elem(); el.Kind() == reflect.Ptr { - // -1 is a null value - if vr.Len() == -1 { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } - continue - } else { - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - d = el.Interface() - goto decode - } - } - } - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("Scan cannot decode into %T", d)}) + if err := Decode(vr, d); err != nil { + rows.Fatal(scanArgError{col: i, err: err}) } - } if vr.Err() != nil { rows.Fatal(scanArgError{col: i, err: vr.Err()}) diff --git a/query_test.go b/query_test.go index 0a36f161..6ba44f50 100644 --- a/query_test.go +++ b/query_test.go @@ -530,7 +530,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into int16"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600 - int must implement Encoder or be converted to a string"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode uint64 into oid 600"}, {"select 42::int4", []interface{}{}, []interface{}{&actual.i}, "Scan cannot decode into *int"}, } diff --git a/values.go b/values.go index 54326833..662499d0 100644 --- a/values.go +++ b/values.go @@ -2,10 +2,12 @@ package pgx import ( "bytes" + "database/sql/driver" "encoding/json" "fmt" "math" "net" + "reflect" "strconv" "strings" "time" @@ -159,7 +161,7 @@ func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeFloat4(w, n.Float32) + return encodeFloat32(w, oid, n.Float32) } // NullFloat64 represents an float8 that may be null. NullFloat64 implements the @@ -190,7 +192,7 @@ func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { if oid != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", oid)) + return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid)) } if !n.Valid { @@ -198,7 +200,7 @@ func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeFloat8(w, n.Float64) + return encodeFloat64(w, oid, n.Float64) } // NullString represents an string that may be null. NullString implements the @@ -232,7 +234,7 @@ func (s NullString) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeText(w, s.String) + return encodeString(w, oid, s.String) } // NullInt16 represents an smallint that may be null. NullInt16 implements the @@ -271,7 +273,7 @@ func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeInt2(w, n.Int16) + return encodeInt16(w, oid, n.Int16) } // NullInt32 represents an integer that may be null. NullInt32 implements the @@ -310,7 +312,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeInt4(w, n.Int32) + return encodeInt32(w, oid, n.Int32) } // NullInt64 represents an bigint that may be null. NullInt64 implements the @@ -349,7 +351,7 @@ func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeInt8(w, n.Int64) + return encodeInt64(w, oid, n.Int64) } // NullBool represents an bool that may be null. NullBool implements the Scanner @@ -388,7 +390,7 @@ func (n NullBool) Encode(w *WriteBuf, oid Oid) error { return nil } - return encodeBool(w, n.Bool) + return encodeBool(w, oid, n.Bool) } // NullTime represents an time.Time that may be null. NullTime implements the @@ -438,16 +440,7 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error { return nil } - switch oid { - case TimestampTzOid: - return encodeTimestampTz(w, n.Time) - case TimestampOid: - return encodeTimestamp(w, n.Time) - case DateOid: - return encodeDate(w, n.Time) - default: - panic("unreachable") - } + return encodeTime(w, oid, n.Time) } // Hstore represents an hstore column. It does not support a null column or null @@ -584,6 +577,181 @@ func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { return nil } +// Encode encodes arg into wbuf as the type oid. This allows implementations +// of the Encoder interface to delegate the actual work of encoding to the +// built-in functionality. +func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { + if arg == nil { + wbuf.WriteInt32(-1) + return nil + } + + switch arg := arg.(type) { + case Encoder: + return arg.Encode(wbuf, oid) + case driver.Valuer: + v, err := arg.Value() + if err != nil { + return err + } + return Encode(wbuf, oid, v) + case string: + return encodeString(wbuf, oid, arg) + case []byte: + return encodeByteSlice(wbuf, oid, arg) + } + + if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr { + if v.IsNil() { + wbuf.WriteInt32(-1) + return nil + } else { + arg = v.Elem().Interface() + return Encode(wbuf, oid, arg) + } + } + + if oid == JsonOid || oid == JsonbOid { + return encodeJson(wbuf, oid, arg) + } + + switch arg := arg.(type) { + case []string: + return encodeStringSlice(wbuf, oid, arg) + case bool: + return encodeBool(wbuf, oid, arg) + case []bool: + return encodeBoolSlice(wbuf, oid, arg) + case int8: + return encodeInt8(wbuf, oid, arg) + case uint8: + return encodeUInt8(wbuf, oid, arg) + case int16: + return encodeInt16(wbuf, oid, arg) + case []int16: + return encodeInt16Slice(wbuf, oid, arg) + case uint16: + return encodeUInt16(wbuf, oid, arg) + case int32: + return encodeInt32(wbuf, oid, arg) + case []int32: + return encodeInt32Slice(wbuf, oid, arg) + case uint32: + return encodeUInt32(wbuf, oid, arg) + case int64: + return encodeInt64(wbuf, oid, arg) + case []int64: + return encodeInt64Slice(wbuf, oid, arg) + case uint64: + return encodeUInt64(wbuf, oid, arg) + case int: + return encodeInt(wbuf, oid, arg) + case float32: + return encodeFloat32(wbuf, oid, arg) + case []float32: + return encodeFloat32Slice(wbuf, oid, arg) + case float64: + return encodeFloat64(wbuf, oid, arg) + case []float64: + return encodeFloat64Slice(wbuf, oid, arg) + case time.Time: + return encodeTime(wbuf, oid, arg) + case []time.Time: + return encodeTimeSlice(wbuf, oid, arg) + case net.IP: + return encodeIP(wbuf, oid, arg) + case []net.IP: + return encodeIPSlice(wbuf, oid, arg) + case net.IPNet: + return encodeIPNet(wbuf, oid, arg) + case []net.IPNet: + return encodeIPNetSlice(wbuf, oid, arg) + case Oid: + return encodeOid(wbuf, oid, arg) + default: + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + } +} + +// Decode decodes from vr into d. d must be a pointer. This allows +// implementations of the Decoder interface to delegate the actual work of +// decoding to the built-in functionality. +func Decode(vr *ValueReader, d interface{}) error { + switch v := d.(type) { + case *bool: + *v = decodeBool(vr) + case *int64: + *v = decodeInt8(vr) + case *int16: + *v = decodeInt2(vr) + case *int32: + *v = decodeInt4(vr) + case *Oid: + *v = decodeOid(vr) + case *string: + *v = decodeText(vr) + case *float32: + *v = decodeFloat4(vr) + case *float64: + *v = decodeFloat8(vr) + case *[]bool: + *v = decodeBoolArray(vr) + case *[]int16: + *v = decodeInt2Array(vr) + case *[]int32: + *v = decodeInt4Array(vr) + case *[]int64: + *v = decodeInt8Array(vr) + case *[]float32: + *v = decodeFloat4Array(vr) + case *[]float64: + *v = decodeFloat8Array(vr) + case *[]string: + *v = decodeTextArray(vr) + case *[]time.Time: + *v = decodeTimestampArray(vr) + case *time.Time: + switch vr.Type().DataType { + case DateOid: + *v = decodeDate(vr) + case TimestampTzOid: + *v = decodeTimestampTz(vr) + case TimestampOid: + *v = decodeTimestamp(vr) + default: + return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) + } + case *net.IPNet: + *v = decodeInet(vr) + case *[]net.IPNet: + *v = decodeInetArray(vr) + default: + // if d is a pointer to pointer, strip the pointer and try again + if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { + if el := v.Elem(); el.Kind() == reflect.Ptr { + // -1 is a null value + if vr.Len() == -1 { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } else { + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + d = el.Interface() + return Decode(vr, d) + } + } + } + return fmt.Errorf("Scan cannot decode into %T", d) + } + + return nil +} + func decodeBool(vr *ValueReader) bool { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into bool")) @@ -609,16 +777,15 @@ func decodeBool(vr *ValueReader) bool { return b != 0 } -func encodeBool(w *WriteBuf, value interface{}) error { - v, ok := value.(bool) - if !ok { - return fmt.Errorf("Expected bool, received %T", value) +func encodeBool(w *WriteBuf, oid Oid, value bool) error { + if oid != BoolOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } w.WriteInt32(1) var n byte - if v { + if value { n = 1 } @@ -651,40 +818,6 @@ func decodeInt8(vr *ValueReader) int64 { return vr.ReadInt64() } -func encodeInt8(w *WriteBuf, value interface{}) error { - var v int64 - switch value := value.(type) { - case int8: - v = int64(value) - case uint8: - v = int64(value) - case int16: - v = int64(value) - case uint16: - v = int64(value) - case int32: - v = int64(value) - case uint32: - v = int64(value) - case int64: - v = int64(value) - case uint64: - if value > math.MaxInt64 { - return fmt.Errorf("uint64 %d is larger than max int64 %d", value, int64(math.MaxInt64)) - } - v = int64(value) - case int: - v = int64(value) - default: - return fmt.Errorf("Expected integer representable in int64, received %T %v", value, value) - } - - w.WriteInt32(8) - w.WriteInt64(v) - - return nil -} - func decodeInt2(vr *ValueReader) int16 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int16")) @@ -709,51 +842,212 @@ func decodeInt2(vr *ValueReader) int16 { return vr.ReadInt16() } -func encodeInt2(w *WriteBuf, value interface{}) error { - var v int16 - switch value := value.(type) { - case int8: - v = int16(value) - case uint8: - v = int16(value) - case int16: - v = int16(value) - case uint16: - if value > math.MaxInt16 { - return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16) - } - v = int16(value) - case int32: - if value > math.MaxInt16 { - return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16) - } - v = int16(value) - case uint32: - if value > math.MaxInt16 { - return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16) - } - v = int16(value) - case int64: - if value > math.MaxInt16 { - return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16) - } - v = int16(value) - case uint64: - if value > math.MaxInt16 { - return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16) - } - v = int16(value) - case int: - if value > math.MaxInt16 { - return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16) - } - v = int16(value) +func encodeInt8(w *WriteBuf, oid Oid, value int8) error { + switch oid { + case Int2Oid: + w.WriteInt32(2) + w.WriteInt16(int16(value)) + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) default: - return fmt.Errorf("Expected integer representable in int16, received %T %v", value, value) + return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) } - w.WriteInt32(2) - w.WriteInt16(v) + return nil +} + +func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error { + switch oid { + case Int2Oid: + w.WriteInt32(2) + w.WriteInt16(int16(value)) + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) + } + + return nil +} + +func encodeInt16(w *WriteBuf, oid Oid, value int16) error { + switch oid { + case Int2Oid: + w.WriteInt32(2) + w.WriteInt16(value) + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) + } + + return nil +} + +func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(value)) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) + } + + return nil +} + +func encodeInt32(w *WriteBuf, oid Oid, value int32) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + w.WriteInt32(4) + w.WriteInt32(value) + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int32", oid) + } + + return nil +} + +func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(value)) + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid) + } + + return nil +} + +func encodeInt64(w *WriteBuf, oid Oid, value int64) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + w.WriteInt32(8) + w.WriteInt64(value) + default: + return fmt.Errorf("cannot encode %s into oid %v", "int64", oid) + } + + return nil +} + +func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + if value <= math.MaxInt64 { + w.WriteInt32(8) + w.WriteInt64(int64(value)) + } else { + return fmt.Errorf("%d is larger than max int64 %d", value, math.MaxInt64) + } + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) + } + + return nil +} + +func encodeInt(w *WriteBuf, oid Oid, value int) error { + switch oid { + case Int2Oid: + if value <= math.MaxInt16 { + w.WriteInt32(2) + w.WriteInt16(int16(value)) + } else { + return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) + } + case Int4Oid: + if value <= math.MaxInt32 { + w.WriteInt32(4) + w.WriteInt32(int32(value)) + } else { + return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) + } + case Int8Oid: + if value <= math.MaxInt64 { + w.WriteInt32(8) + w.WriteInt64(int64(value)) + } else { + return fmt.Errorf("%d is larger than max int64 %d", value, math.MaxInt64) + } + default: + return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) + } return nil } @@ -782,49 +1076,6 @@ func decodeInt4(vr *ValueReader) int32 { return vr.ReadInt32() } -func encodeInt4(w *WriteBuf, value interface{}) error { - var v int32 - switch value := value.(type) { - case int8: - v = int32(value) - case uint8: - v = int32(value) - case int16: - v = int32(value) - case uint16: - v = int32(value) - case int32: - v = int32(value) - case uint32: - if value > math.MaxInt32 { - return fmt.Errorf("%T %d is larger than max int32 %d", value, value, math.MaxInt32) - } - v = int32(value) - case int64: - if value > math.MaxInt32 { - return fmt.Errorf("%T %d is larger than max int32 %d", value, value, math.MaxInt32) - } - v = int32(value) - case uint64: - if value > math.MaxInt32 { - return fmt.Errorf("%T %d is larger than max int32 %d", value, value, math.MaxInt32) - } - v = int32(value) - case int: - if value > math.MaxInt32 { - return fmt.Errorf("%T %d is larger than max int32 %d", value, value, math.MaxInt32) - } - v = int32(value) - default: - return fmt.Errorf("Expected integer representable in int32, received %T %v", value, value) - } - - w.WriteInt32(4) - w.WriteInt32(v) - - return nil -} - func decodeOid(vr *ValueReader) Oid { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into Oid")) @@ -857,14 +1108,13 @@ func decodeOid(vr *ValueReader) Oid { } } -func encodeOid(w *WriteBuf, value interface{}) error { - v, ok := value.(Oid) - if !ok { - return fmt.Errorf("Expected Oid, received %T", value) +func encodeOid(w *WriteBuf, oid Oid, value Oid) error { + if oid != OidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid) } w.WriteInt32(4) - w.WriteInt32(int32(v)) + w.WriteInt32(int32(value)) return nil } @@ -894,16 +1144,18 @@ func decodeFloat4(vr *ValueReader) float32 { return math.Float32frombits(uint32(i)) } -func encodeFloat4(w *WriteBuf, value interface{}) error { - v, ok := value.(float32) - if !ok { - return fmt.Errorf("Expected float32, received %T", value) +func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { + switch oid { + case Float4Oid: + w.WriteInt32(4) + w.WriteInt32(int32(math.Float32bits(value))) + case Float8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(math.Float64bits(float64(value)))) + default: + return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) } - w.WriteInt32(4) - - w.WriteInt32(int32(math.Float32bits(v))) - return nil } @@ -932,21 +1184,15 @@ func decodeFloat8(vr *ValueReader) float64 { return math.Float64frombits(uint64(i)) } -func encodeFloat8(w *WriteBuf, value interface{}) error { - var v float64 - switch value := value.(type) { - case float32: - v = float64(value) - case float64: - v = float64(value) +func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { + switch oid { + case Float8Oid: + w.WriteInt32(8) + w.WriteInt64(int64(math.Float64bits(value))) default: - return fmt.Errorf("Expected float representable in float64, received %T %v", value, value) + return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) } - w.WriteInt32(8) - - w.WriteInt64(int64(math.Float64bits(v))) - return nil } @@ -959,18 +1205,9 @@ func decodeText(vr *ValueReader) string { return vr.ReadString(vr.Len()) } -func encodeText(w *WriteBuf, value interface{}) error { - switch t := value.(type) { - case string: - w.WriteInt32(int32(len(t))) - w.WriteBytes([]byte(t)) - case []byte: - w.WriteInt32(int32(len(t))) - w.WriteBytes(t) - default: - return fmt.Errorf("Expected string, received %T", value) - } - +func encodeString(w *WriteBuf, oid Oid, value string) error { + w.WriteInt32(int32(len(value))) + w.WriteBytes([]byte(value)) return nil } @@ -992,14 +1229,9 @@ func decodeBytea(vr *ValueReader) []byte { return vr.ReadBytes(vr.Len()) } -func encodeBytea(w *WriteBuf, value interface{}) error { - b, ok := value.([]byte) - if !ok { - return fmt.Errorf("Expected []byte, received %T", value) - } - - w.WriteInt32(int32(len(b))) - w.WriteBytes(b) +func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { + w.WriteInt32(int32(len(value))) + w.WriteBytes(value) return nil } @@ -1021,13 +1253,20 @@ func decodeJson(vr *ValueReader, d interface{}) error { return err } -func encodeJson(w *WriteBuf, value interface{}) error { +func encodeJson(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonOid && oid != JsonbOid { + return fmt.Errorf("cannot encode JSON into oid %v", oid) + } + s, err := json.Marshal(value) if err != nil { return fmt.Errorf("Failed to encode json from type: %T", value) } - return encodeText(w, s) + w.WriteInt32(int32(len(s))) + w.WriteBytes(s) + + return nil } func decodeDate(vr *ValueReader) time.Time { @@ -1055,22 +1294,30 @@ func decodeDate(vr *ValueReader) time.Time { return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) } -func encodeDate(w *WriteBuf, value interface{}) error { - t, ok := value.(time.Time) - if !ok { - return fmt.Errorf("Expected time.Time, received %T", value) +func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { + switch oid { + case DateOid: + tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch := secSinceDateEpoch / 86400 + + w.WriteInt32(4) + w.WriteInt32(int32(daysSinceDateEpoch)) + + return nil + case TimestampTzOid, TimestampOid: + microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + + w.WriteInt32(8) + w.WriteInt64(microsecSinceY2K) + + return nil + default: + return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) } - - tUnix := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC).Unix() - dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() - - secSinceDateEpoch := tUnix - dateEpoch - daysSinceDateEpoch := secSinceDateEpoch / 86400 - - w.WriteInt32(4) - w.WriteInt32(int32(daysSinceDateEpoch)) - - return nil } const microsecFromUnixEpochToY2K = 946684800 * 1000000 @@ -1103,21 +1350,6 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } -func encodeTimestampTz(w *WriteBuf, value interface{}) error { - t, ok := value.(time.Time) - if !ok { - return fmt.Errorf("Expected time.Time, received %T", value) - } - - microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 - microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - - w.WriteInt32(8) - w.WriteInt64(microsecSinceY2K) - - return nil -} - func decodeTimestamp(vr *ValueReader) time.Time { var zeroTime time.Time @@ -1146,10 +1378,6 @@ func decodeTimestamp(vr *ValueReader) time.Time { return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } -func encodeTimestamp(w *WriteBuf, value interface{}) error { - return encodeTimestampTz(w, value) -} - func decodeInet(vr *ValueReader) net.IPNet { var zero net.IPNet @@ -1186,23 +1414,14 @@ func decodeInet(vr *ValueReader) net.IPNet { return ipnet } -func encodeInet(w *WriteBuf, value interface{}) error { - var ipnet net.IPNet - - switch value := value.(type) { - case net.IPNet: - ipnet = value - case net.IP: - ipnet.IP = value - bitCount := len(value) * 8 - ipnet.Mask = net.CIDRMask(bitCount, bitCount) - default: - return fmt.Errorf("Expected net.IPNet, received %T %v", value, value) +func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { + if oid != InetOid && oid != CidrOid { + return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid) } var size int32 var family byte - switch len(ipnet.IP) { + switch len(value.IP) { case net.IPv4len: size = 8 family = w.conn.pgsql_af_inet @@ -1210,20 +1429,32 @@ func encodeInet(w *WriteBuf, value interface{}) error { size = 20 family = w.conn.pgsql_af_inet6 default: - return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP)) + return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) } w.WriteInt32(size) w.WriteByte(family) - ones, _ := ipnet.Mask.Size() + ones, _ := value.Mask.Size() w.WriteByte(byte(ones)) w.WriteByte(0) // is_cidr is ignored on server - w.WriteByte(byte(len(ipnet.IP))) - w.WriteBytes(ipnet.IP) + w.WriteByte(byte(len(value.IP))) + w.WriteBytes(value.IP) return nil } +func encodeIP(w *WriteBuf, oid Oid, value net.IP) error { + if oid != InetOid && oid != CidrOid { + return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid) + } + + var ipnet net.IPNet + ipnet.IP = value + bitCount := len(value) * 8 + ipnet.Mask = net.CIDRMask(bitCount, bitCount) + return encodeIPNet(w, oid, ipnet) +} + func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { numDims := vr.ReadInt32() if numDims > 1 { @@ -1288,10 +1519,9 @@ func decodeBoolArray(vr *ValueReader) []bool { return a } -func encodeBoolArray(w *WriteBuf, value interface{}) error { - slice, ok := value.([]bool) - if !ok { - return fmt.Errorf("Expected []bool, received %T", value) +func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error { + if oid != BoolArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid) } encodeArrayHeader(w, BoolOid, len(slice), 5) @@ -1346,10 +1576,9 @@ func decodeInt2Array(vr *ValueReader) []int16 { return a } -func encodeInt2Array(w *WriteBuf, value interface{}) error { - slice, ok := value.([]int16) - if !ok { - return fmt.Errorf("Expected []int16, received %T", value) +func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { + if oid != Int2ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) } encodeArrayHeader(w, Int2Oid, len(slice), 6) @@ -1400,10 +1629,9 @@ func decodeInt4Array(vr *ValueReader) []int32 { return a } -func encodeInt4Array(w *WriteBuf, value interface{}) error { - slice, ok := value.([]int32) - if !ok { - return fmt.Errorf("Expected []int32, received %T", value) +func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { + if oid != Int4ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) } encodeArrayHeader(w, Int4Oid, len(slice), 8) @@ -1454,10 +1682,9 @@ func decodeInt8Array(vr *ValueReader) []int64 { return a } -func encodeInt8Array(w *WriteBuf, value interface{}) error { - slice, ok := value.([]int64) - if !ok { - return fmt.Errorf("Expected []int64, received %T", value) +func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { + if oid != Int8ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) } encodeArrayHeader(w, Int8Oid, len(slice), 12) @@ -1509,11 +1736,11 @@ func decodeFloat4Array(vr *ValueReader) []float32 { return a } -func encodeFloat4Array(w *WriteBuf, value interface{}) error { - slice, ok := value.([]float32) - if !ok { - return fmt.Errorf("Expected []float32, received %T", value) +func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error { + if oid != Float4ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid) } + encodeArrayHeader(w, Float4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) @@ -1563,10 +1790,9 @@ func decodeFloat8Array(vr *ValueReader) []float64 { return a } -func encodeFloat8Array(w *WriteBuf, value interface{}) error { - slice, ok := value.([]float64) - if !ok { - return fmt.Errorf("Expected []float64, received %T", value) +func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error { + if oid != Float8ArrayOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid) } encodeArrayHeader(w, Float8Oid, len(slice), 12) @@ -1613,10 +1839,15 @@ func decodeTextArray(vr *ValueReader) []string { return a } -func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { - slice, ok := value.([]string) - if !ok { - return fmt.Errorf("Expected []string, received %T", value) +func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { + var elOid Oid + switch oid { + case VarcharArrayOid: + elOid = VarcharOid + case TextArrayOid: + elOid = TextOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) } var totalStringSize int @@ -1682,10 +1913,15 @@ func decodeTimestampArray(vr *ValueReader) []time.Time { return a } -func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { - slice, ok := value.([]time.Time) - if !ok { - return fmt.Errorf("Expected []time.Time, received %T", value) +func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error { + var elOid Oid + switch oid { + case TimestampArrayOid: + elOid = TimestampOid + case TimestampTzArrayOid: + elOid = TimestampTzOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid) } encodeArrayHeader(w, int(elOid), len(slice), 12) @@ -1743,10 +1979,15 @@ func decodeInetArray(vr *ValueReader) []net.IPNet { return a } -func encodeInetArray(w *WriteBuf, value interface{}, elOid Oid) error { - slice, ok := value.([]net.IPNet) - if !ok { - return fmt.Errorf("Expected []net.IPNet, received %T", value) +func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error { + var elOid Oid + switch oid { + case InetArrayOid: + elOid = InetOid + case CidrArrayOid: + elOid = CidrOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) } size := int32(20) // array header size @@ -1762,7 +2003,37 @@ func encodeInetArray(w *WriteBuf, value interface{}, elOid Oid) error { w.WriteInt32(1) // index of first element for _, ipnet := range slice { - encodeInet(w, ipnet) + encodeIPNet(w, elOid, ipnet) + } + + return nil +} + +func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error { + var elOid Oid + switch oid { + case InetArrayOid: + elOid = InetOid + case CidrArrayOid: + elOid = CidrOid + default: + return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) + } + + size := int32(20) // array header size + for _, ip := range slice { + size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes + } + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, ip := range slice { + encodeIP(w, elOid, ip) } return nil