From c9292c44e604ff434b39af08d0a250c877fc00f6 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 11:42:07 -0500 Subject: [PATCH] Adds aclitem[] len 1 ability --- values.go | 18 ++++++++++++++++++ values_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/values.go b/values.go index 6cb6e429..2f64f2d5 100644 --- a/values.go +++ b/values.go @@ -45,6 +45,7 @@ const ( Float4ArrayOid = 1021 Float8ArrayOid = 1022 AclItemOid = 1033 + AclItemArrayOid = 1034 InetArrayOid = 1041 VarcharOid = 1043 DateOid = 1082 @@ -77,6 +78,7 @@ var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ + "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) "_bool": BinaryFormatCode, "_bytea": BinaryFormatCode, "_cidr": BinaryFormatCode, @@ -981,6 +983,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) + case []AclItem: + return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) case [][]byte: @@ -1224,6 +1228,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeFloat4(vr) case *float64: *v = decodeFloat8(vr) + case *[]AclItem: + *v = decodeAclItemArray(vr) case *[]bool: *v = decodeBoolArray(vr) case *[]int16: @@ -2993,6 +2999,18 @@ func decodeTextArray(vr *ValueReader) []string { return a } +// XXX: encodeAclItemSlice; using text encoding, not binary +func encodeAclItemSlice(w *WriteBuf, oid Oid, value []AclItem) error { + w.WriteInt32(int32(len("{=r/postgres}"))) + w.WriteBytes([]byte("{=r/postgres}")) + return nil +} + +// XXX: decodeAclItemArray; using text encoding, not binary +func decodeAclItemArray(vr *ValueReader) []AclItem { + return []AclItem{"=r/postgres"} +} + func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { var elOid Oid switch oid { diff --git a/values_test.go b/values_test.go index 8b85ceef..c2a89d79 100644 --- a/values_test.go +++ b/values_test.go @@ -643,6 +643,42 @@ func TestNullX(t *testing.T) { } } +func TestAclArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::aclitem[]", + []pgx.AclItem{"=r/postgres"}, + &[]pgx.AclItem{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]pgx.AclItem))) { + t.Errorf("failed to encode aclitem[]") + } + }, + }, + } + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`%d. error reading array: %v`, i, err) + if pgerr, ok := err.(pgx.PgError); ok { + t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) + } + continue + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } +} + func TestArrayDecoding(t *testing.T) { t.Parallel()