mirror of https://github.com/jackc/pgx.git
parent
9d284da48e
commit
30cb421551
|
@ -1,15 +1,13 @@
|
|||
# Unreleased
|
||||
|
||||
## Features
|
||||
|
||||
* Add PrepareEx
|
||||
|
||||
## Fixes
|
||||
|
||||
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
|
||||
|
||||
## Features
|
||||
|
||||
* Add PrepareEx
|
||||
* Add basic record to []interface{} decoding
|
||||
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
||||
* Decode inet/cidr to net.IP
|
||||
* Encode/decode [][]byte to/from bytea[]
|
||||
|
|
11
conn.go
11
conn.go
|
@ -332,7 +332,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
}
|
||||
|
||||
func (c *Conn) loadPgTypes() error {
|
||||
rows, err := c.Query("select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where t.typtype='b' and (base_type.oid is null or base_type.typtype='b');")
|
||||
rows, err := c.Query(`select t.oid, t.typname
|
||||
from pg_type t
|
||||
left join pg_type base_type on t.typelem=base_type.oid
|
||||
where (
|
||||
t.typtype='b'
|
||||
and (base_type.oid is null or base_type.typtype='b')
|
||||
)
|
||||
or t.typname in('record');`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -910,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid:
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
|
|
78
values.go
78
values.go
|
@ -27,6 +27,7 @@ const (
|
|||
CidrArrayOid = 651
|
||||
Float4Oid = 700
|
||||
Float8Oid = 701
|
||||
UnknownOid = 705
|
||||
InetOid = 869
|
||||
BoolArrayOid = 1000
|
||||
Int2ArrayOid = 1005
|
||||
|
@ -44,6 +45,7 @@ const (
|
|||
TimestampArrayOid = 1115
|
||||
TimestampTzOid = 1184
|
||||
TimestampTzArrayOid = 1185
|
||||
RecordOid = 2249
|
||||
UuidOid = 2950
|
||||
JsonbOid = 3802
|
||||
)
|
||||
|
@ -91,8 +93,11 @@ func init() {
|
|||
"int4": BinaryFormatCode,
|
||||
"int8": BinaryFormatCode,
|
||||
"oid": BinaryFormatCode,
|
||||
"record": BinaryFormatCode,
|
||||
"text": BinaryFormatCode,
|
||||
"timestamp": BinaryFormatCode,
|
||||
"timestamptz": BinaryFormatCode,
|
||||
"varchar": BinaryFormatCode,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -807,6 +812,8 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||
*v = decodeTimestampArray(vr)
|
||||
case *[][]byte:
|
||||
*v = decodeByteaArray(vr)
|
||||
case *[]interface{}:
|
||||
*v = decodeRecord(vr)
|
||||
case *time.Time:
|
||||
switch vr.Type().DataType {
|
||||
case DateOid:
|
||||
|
@ -1613,6 +1620,77 @@ func encodeIP(w *WriteBuf, oid Oid, value net.IP) error {
|
|||
return encodeIPNet(w, oid, ipnet)
|
||||
}
|
||||
|
||||
func decodeRecord(vr *ValueReader) []interface{} {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != RecordOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
valueCount := vr.ReadInt32()
|
||||
record := make([]interface{}, 0, int(valueCount))
|
||||
|
||||
for i := int32(0); i < valueCount; i++ {
|
||||
fd := FieldDescription{FormatCode: BinaryFormatCode}
|
||||
fieldVR := ValueReader{mr: vr.mr, fd: &fd}
|
||||
fd.DataType = vr.ReadOid()
|
||||
fieldVR.valueBytesRemaining = vr.ReadInt32()
|
||||
vr.valueBytesRemaining -= fieldVR.valueBytesRemaining
|
||||
|
||||
switch fd.DataType {
|
||||
case BoolOid:
|
||||
record = append(record, decodeBool(&fieldVR))
|
||||
case ByteaOid:
|
||||
record = append(record, decodeBytea(&fieldVR))
|
||||
case Int8Oid:
|
||||
record = append(record, decodeInt8(&fieldVR))
|
||||
case Int2Oid:
|
||||
record = append(record, decodeInt2(&fieldVR))
|
||||
case Int4Oid:
|
||||
record = append(record, decodeInt4(&fieldVR))
|
||||
case OidOid:
|
||||
record = append(record, decodeOid(&fieldVR))
|
||||
case Float4Oid:
|
||||
record = append(record, decodeFloat4(&fieldVR))
|
||||
case Float8Oid:
|
||||
record = append(record, decodeFloat8(&fieldVR))
|
||||
case DateOid:
|
||||
record = append(record, decodeDate(&fieldVR))
|
||||
case TimestampTzOid:
|
||||
record = append(record, decodeTimestampTz(&fieldVR))
|
||||
case TimestampOid:
|
||||
record = append(record, decodeTimestamp(&fieldVR))
|
||||
case InetOid, CidrOid:
|
||||
record = append(record, decodeInet(&fieldVR))
|
||||
case TextOid, VarcharOid, UnknownOid:
|
||||
record = append(record, decodeText(&fieldVR))
|
||||
default:
|
||||
vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Consume any remaining data
|
||||
if fieldVR.Len() > 0 {
|
||||
fieldVR.ReadBytes(fieldVR.Len())
|
||||
}
|
||||
|
||||
if fieldVR.Err() != nil {
|
||||
vr.Fatal(fieldVR.Err())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
||||
numDims := vr.ReadInt32()
|
||||
if numDims > 1 {
|
||||
|
|
|
@ -959,3 +959,40 @@ func TestPointerPointerNonZero(t *testing.T) {
|
|||
t.Errorf("Expected dest to be nil, got %#v", dest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []interface{}
|
||||
}{
|
||||
{
|
||||
"select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)",
|
||||
[]interface{}{
|
||||
int32(1),
|
||||
"cat",
|
||||
time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var actual []interface{}
|
||||
|
||||
err := conn.QueryRow(tt.sql).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue