mirror of https://github.com/jackc/pgx.git
Expose encoding and decoding functions
parent
30feade829
commit
c6b6d7bad7
|
@ -24,6 +24,8 @@ standard database/sql package such as
|
||||||
* Add ConnPool.Reset method
|
* Add ConnPool.Reset method
|
||||||
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
|
* Add support for database/sql.Scanner and database/sql/driver.Valuer interfaces
|
||||||
* Rows.Scan errors now include which argument caused error
|
* Rows.Scan errors now include which argument caused error
|
||||||
|
* Add Encode() to allow custom Encoders to reuse internal encoding functionality
|
||||||
|
* Add Decode() to allow customer Decoders to reuse internal decoding functionality
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
|
|
89
conn.go
89
conn.go
|
@ -4,7 +4,6 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -15,7 +14,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -849,92 +847,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(arguments)))
|
wbuf.WriteInt16(int16(len(arguments)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
encode:
|
if err := Encode(wbuf, oid, arguments[i]); err != nil {
|
||||||
if arguments[i] == nil {
|
|
||||||
wbuf.WriteInt32(-1)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch arg := arguments[i].(type) {
|
|
||||||
case Encoder:
|
|
||||||
err = arg.Encode(wbuf, oid)
|
|
||||||
case driver.Valuer:
|
|
||||||
arguments[i], err = arg.Value()
|
|
||||||
if err == nil {
|
|
||||||
goto encode
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
err = encodeText(wbuf, arguments[i])
|
|
||||||
case []byte:
|
|
||||||
err = encodeBytea(wbuf, arguments[i])
|
|
||||||
default:
|
|
||||||
if v := reflect.ValueOf(arguments[i]); v.Kind() == reflect.Ptr {
|
|
||||||
if v.IsNil() {
|
|
||||||
wbuf.WriteInt32(-1)
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
arguments[i] = v.Elem().Interface()
|
|
||||||
goto encode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
switch oid {
|
|
||||||
case BoolOid:
|
|
||||||
err = encodeBool(wbuf, arguments[i])
|
|
||||||
case ByteaOid:
|
|
||||||
err = encodeBytea(wbuf, arguments[i])
|
|
||||||
case Int2Oid:
|
|
||||||
err = encodeInt2(wbuf, arguments[i])
|
|
||||||
case Int4Oid:
|
|
||||||
err = encodeInt4(wbuf, arguments[i])
|
|
||||||
case Int8Oid:
|
|
||||||
err = encodeInt8(wbuf, arguments[i])
|
|
||||||
case Float4Oid:
|
|
||||||
err = encodeFloat4(wbuf, arguments[i])
|
|
||||||
case Float8Oid:
|
|
||||||
err = encodeFloat8(wbuf, arguments[i])
|
|
||||||
case TextOid, VarcharOid:
|
|
||||||
err = encodeText(wbuf, arguments[i])
|
|
||||||
case DateOid:
|
|
||||||
err = encodeDate(wbuf, arguments[i])
|
|
||||||
case TimestampTzOid:
|
|
||||||
err = encodeTimestampTz(wbuf, arguments[i])
|
|
||||||
case TimestampOid:
|
|
||||||
err = encodeTimestamp(wbuf, arguments[i])
|
|
||||||
case InetOid, CidrOid:
|
|
||||||
err = encodeInet(wbuf, arguments[i])
|
|
||||||
case InetArrayOid:
|
|
||||||
err = encodeInetArray(wbuf, arguments[i], InetOid)
|
|
||||||
case CidrArrayOid:
|
|
||||||
err = encodeInetArray(wbuf, arguments[i], CidrOid)
|
|
||||||
case BoolArrayOid:
|
|
||||||
err = encodeBoolArray(wbuf, arguments[i])
|
|
||||||
case Int2ArrayOid:
|
|
||||||
err = encodeInt2Array(wbuf, arguments[i])
|
|
||||||
case Int4ArrayOid:
|
|
||||||
err = encodeInt4Array(wbuf, arguments[i])
|
|
||||||
case Int8ArrayOid:
|
|
||||||
err = encodeInt8Array(wbuf, arguments[i])
|
|
||||||
case Float4ArrayOid:
|
|
||||||
err = encodeFloat4Array(wbuf, arguments[i])
|
|
||||||
case Float8ArrayOid:
|
|
||||||
err = encodeFloat8Array(wbuf, arguments[i])
|
|
||||||
case TextArrayOid:
|
|
||||||
err = encodeTextArray(wbuf, arguments[i], TextOid)
|
|
||||||
case VarcharArrayOid:
|
|
||||||
err = encodeTextArray(wbuf, arguments[i], VarcharOid)
|
|
||||||
case TimestampArrayOid:
|
|
||||||
err = encodeTimestampArray(wbuf, arguments[i], TimestampOid)
|
|
||||||
case TimestampTzArrayOid:
|
|
||||||
err = encodeTimestampArray(wbuf, arguments[i], TimestampTzOid)
|
|
||||||
case OidOid:
|
|
||||||
err = encodeOid(wbuf, arguments[i])
|
|
||||||
case JsonOid, JsonbOid:
|
|
||||||
err = encodeJson(wbuf, arguments[i])
|
|
||||||
default:
|
|
||||||
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
76
query.go
76
query.go
|
@ -4,8 +4,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -297,79 +295,9 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
||||||
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
||||||
decodeJson(vr, &d)
|
decodeJson(vr, &d)
|
||||||
} else {
|
} else {
|
||||||
decode:
|
if err := Decode(vr, d); err != nil {
|
||||||
switch v := d.(type) {
|
rows.Fatal(scanArgError{col: i, err: err})
|
||||||
case *bool:
|
|
||||||
*v = decodeBool(vr)
|
|
||||||
case *int64:
|
|
||||||
*v = decodeInt8(vr)
|
|
||||||
case *int16:
|
|
||||||
*v = decodeInt2(vr)
|
|
||||||
case *int32:
|
|
||||||
*v = decodeInt4(vr)
|
|
||||||
case *Oid:
|
|
||||||
*v = decodeOid(vr)
|
|
||||||
case *string:
|
|
||||||
*v = decodeText(vr)
|
|
||||||
case *float32:
|
|
||||||
*v = decodeFloat4(vr)
|
|
||||||
case *float64:
|
|
||||||
*v = decodeFloat8(vr)
|
|
||||||
case *[]bool:
|
|
||||||
*v = decodeBoolArray(vr)
|
|
||||||
case *[]int16:
|
|
||||||
*v = decodeInt2Array(vr)
|
|
||||||
case *[]int32:
|
|
||||||
*v = decodeInt4Array(vr)
|
|
||||||
case *[]int64:
|
|
||||||
*v = decodeInt8Array(vr)
|
|
||||||
case *[]float32:
|
|
||||||
*v = decodeFloat4Array(vr)
|
|
||||||
case *[]float64:
|
|
||||||
*v = decodeFloat8Array(vr)
|
|
||||||
case *[]string:
|
|
||||||
*v = decodeTextArray(vr)
|
|
||||||
case *[]time.Time:
|
|
||||||
*v = decodeTimestampArray(vr)
|
|
||||||
case *time.Time:
|
|
||||||
switch vr.Type().DataType {
|
|
||||||
case DateOid:
|
|
||||||
*v = decodeDate(vr)
|
|
||||||
case TimestampTzOid:
|
|
||||||
*v = decodeTimestampTz(vr)
|
|
||||||
case TimestampOid:
|
|
||||||
*v = decodeTimestamp(vr)
|
|
||||||
default:
|
|
||||||
rows.Fatal(scanArgError{col: i, err: fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)})
|
|
||||||
}
|
}
|
||||||
case *net.IPNet:
|
|
||||||
*v = decodeInet(vr)
|
|
||||||
case *[]net.IPNet:
|
|
||||||
*v = decodeInetArray(vr)
|
|
||||||
default:
|
|
||||||
// if d is a pointer to pointer, strip the pointer and try again
|
|
||||||
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
|
||||||
if el := v.Elem(); el.Kind() == reflect.Ptr {
|
|
||||||
// -1 is a null value
|
|
||||||
if vr.Len() == -1 {
|
|
||||||
if !el.IsNil() {
|
|
||||||
// if the destination pointer is not nil, nil it out
|
|
||||||
el.Set(reflect.Zero(el.Type()))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
if el.IsNil() {
|
|
||||||
// allocate destination
|
|
||||||
el.Set(reflect.New(el.Type().Elem()))
|
|
||||||
}
|
|
||||||
d = el.Interface()
|
|
||||||
goto decode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rows.Fatal(scanArgError{col: i, err: fmt.Errorf("Scan cannot decode into %T", d)})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
if vr.Err() != nil {
|
if vr.Err() != nil {
|
||||||
rows.Fatal(scanArgError{col: i, err: vr.Err()})
|
rows.Fatal(scanArgError{col: i, err: vr.Err()})
|
||||||
|
|
|
@ -530,7 +530,7 @@ func TestQueryRowErrors(t *testing.T) {
|
||||||
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
{"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
|
||||||
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
{"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
|
||||||
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into int16"},
|
{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into int16"},
|
||||||
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600 - int must implement Encoder or be converted to a string"},
|
{"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode uint64 into oid 600"},
|
||||||
{"select 42::int4", []interface{}{}, []interface{}{&actual.i}, "Scan cannot decode into *int"},
|
{"select 42::int4", []interface{}{}, []interface{}{&actual.i}, "Scan cannot decode into *int"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue