mirror of https://github.com/jackc/pgx.git
parent
be45d46b37
commit
7323d3f5a7
|
@ -8,6 +8,7 @@
|
|||
|
||||
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
||||
* Decode inet/cidr to net.IP
|
||||
* Encode/decode [][]byte to/from bytea[]
|
||||
|
||||
## Performance
|
||||
|
||||
|
|
2
conn.go
2
conn.go
|
@ -857,7 +857,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid:
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
|
|
67
values.go
67
values.go
|
@ -32,6 +32,7 @@ const (
|
|||
Int2ArrayOid = 1005
|
||||
Int4ArrayOid = 1007
|
||||
TextArrayOid = 1009
|
||||
ByteaArrayOid = 1001
|
||||
VarcharArrayOid = 1015
|
||||
Int8ArrayOid = 1016
|
||||
Float4ArrayOid = 1021
|
||||
|
@ -67,6 +68,7 @@ var DefaultTypeFormats map[string]int16
|
|||
func init() {
|
||||
DefaultTypeFormats = map[string]int16{
|
||||
"_bool": BinaryFormatCode,
|
||||
"_bytea": BinaryFormatCode,
|
||||
"_cidr": BinaryFormatCode,
|
||||
"_float4": BinaryFormatCode,
|
||||
"_float8": BinaryFormatCode,
|
||||
|
@ -604,6 +606,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
|||
return encodeString(wbuf, oid, arg)
|
||||
case []byte:
|
||||
return encodeByteSlice(wbuf, oid, arg)
|
||||
case [][]byte:
|
||||
return encodeByteSliceSlice(wbuf, oid, arg)
|
||||
}
|
||||
|
||||
if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr {
|
||||
|
@ -801,6 +805,8 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||
*v = decodeTextArray(vr)
|
||||
case *[]time.Time:
|
||||
*v = decodeTimestampArray(vr)
|
||||
case *[][]byte:
|
||||
*v = decodeByteaArray(vr)
|
||||
case *time.Time:
|
||||
switch vr.Type().DataType {
|
||||
case DateOid:
|
||||
|
@ -1683,6 +1689,67 @@ func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func decodeByteaArray(vr *ValueReader) [][]byte {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != ByteaArrayOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
numElems, err := decode1dArrayHeader(vr)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
a := make([][]byte, int(numElems))
|
||||
for i := 0; i < len(a); i++ {
|
||||
elSize := vr.ReadInt32()
|
||||
switch elSize {
|
||||
case -1:
|
||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||
return nil
|
||||
default:
|
||||
a[i] = vr.ReadBytes(elSize)
|
||||
}
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error {
|
||||
if oid != ByteaArrayOid {
|
||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid)
|
||||
}
|
||||
|
||||
size := 20 // array header size
|
||||
for _, el := range value {
|
||||
size += 4 + len(el)
|
||||
}
|
||||
|
||||
w.WriteInt32(int32(size))
|
||||
|
||||
w.WriteInt32(1) // number of dimensions
|
||||
w.WriteInt32(0) // no nulls
|
||||
w.WriteInt32(int32(ByteaOid)) // type of elements
|
||||
w.WriteInt32(int32(len(value))) // number of elements
|
||||
w.WriteInt32(1) // index of first element
|
||||
|
||||
for _, el := range value {
|
||||
encodeByteSlice(w, ByteaOid, el)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeInt2Array(vr *ValueReader) []int16 {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -706,12 +707,30 @@ func TestArrayDecoding(t *testing.T) {
|
|||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{},
|
||||
func(t *testing.T, query, scan interface{}) {
|
||||
queryBytesSliceSlice := query.([][]byte)
|
||||
scanBytesSliceSlice := *(scan.(*[][]byte))
|
||||
if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) {
|
||||
t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice))
|
||||
}
|
||||
for i := range queryBytesSliceSlice {
|
||||
qb := queryBytesSliceSlice[i]
|
||||
sb := scanBytesSliceSlice[i]
|
||||
if bytes.Compare(qb, sb) != 0 {
|
||||
t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan)
|
||||
if err != nil {
|
||||
t.Errorf(`%d. error reading array: %v`, i, err)
|
||||
continue
|
||||
}
|
||||
tt.assert(t, tt.query, tt.scan)
|
||||
ensureConnValid(t, conn)
|
||||
|
|
Loading…
Reference in New Issue