diff --git a/conn.go b/conn.go index b204697b..447e5171 100644 --- a/conn.go +++ b/conn.go @@ -783,7 +783,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -839,6 +839,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestamp(wbuf, arguments[i]) case InetOid, CidrOid: err = encodeInet(wbuf, arguments[i]) + case InetArrayOid: + err = encodeInetArray(wbuf, arguments[i], InetOid) + case CidrArrayOid: + err = encodeInetArray(wbuf, arguments[i], CidrOid) case BoolArrayOid: err = encodeBoolArray(wbuf, arguments[i]) case Int2ArrayOid: diff --git a/query.go b/query.go index 1b29c425..c9a53e82 100644 --- a/query.go +++ b/query.go @@ -295,6 +295,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } case *net.IPNet: *v = decodeInet(vr) + case *[]net.IPNet: + *v = decodeInetArray(vr) default: // if d is a pointer to pointer, strip the pointer and try again if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { diff --git a/values.go b/values.go index 2a1ca8b8..e894618b 100644 --- a/values.go +++ b/values.go @@ -22,6 +22,7 @@ const ( OidOid = 26 JsonOid = 114 CidrOid = 650 + CidrArrayOid = 651 Float4Oid = 700 Float8Oid = 701 InetOid = 869 @@ -33,6 +34,7 @@ const ( Int8ArrayOid = 1016 Float4ArrayOid = 1021 Float8ArrayOid = 1022 + InetArrayOid = 1041 VarcharOid = 1043 DateOid = 1082 TimestampOid = 1114 @@ -59,8 +61,10 @@ var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ "_bool": BinaryFormatCode, + "_cidr": BinaryFormatCode, "_float4": BinaryFormatCode, "_float8": BinaryFormatCode, + "_inet": BinaryFormatCode, "_int2": BinaryFormatCode, "_int4": BinaryFormatCode, "_int8": BinaryFormatCode, @@ -1703,6 +1707,75 @@ func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { return nil } +func decodeInetArray(vr *ValueReader) []net.IPNet { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType))) + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + numElems, err := decode1dArrayHeader(vr) + if err != nil { + vr.Fatal(err) + return nil + } + + a := make([]net.IPNet, int(numElems)) + for i := 0; i < len(a); i++ { + elSize := vr.ReadInt32() + if elSize == -1 { + vr.Fatal(ProtocolError("Cannot decode null element")) + return nil + } + + vr.ReadByte() // ignore family + bits := vr.ReadByte() + vr.ReadByte() // ignore is_cidr + addressLength := vr.ReadByte() + + var ipnet net.IPNet + ipnet.IP = vr.ReadBytes(int32(addressLength)) + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + a[i] = ipnet + } + + return a +} + +func encodeInetArray(w *WriteBuf, value interface{}, elOid Oid) error { + slice, ok := value.([]net.IPNet) + if !ok { + return fmt.Errorf("Expected []net.IPNet, received %T", value) + } + + size := int32(20) // array header size + for _, ipnet := range slice { + size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes + } + w.WriteInt32(int32(size)) + + w.WriteInt32(1) // number of dimensions + w.WriteInt32(0) // no nulls + w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(len(slice))) // number of elements + w.WriteInt32(1) // index of first element + + for _, ipnet := range slice { + encodeInet(w, ipnet) + } + + return nil +} + func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { w.WriteInt32(int32(20 + length*sizePerItem)) w.WriteInt32(1) // number of dimensions diff --git a/values_test.go b/values_test.go index e7d52e6d..a63a8cc9 100644 --- a/values_test.go +++ b/values_test.go @@ -308,6 +308,65 @@ func TestInetCidrTranscode(t *testing.T) { } } +func TestInetCidrArrayTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value []net.IPNet + }{ + { + "select $1::inet[]", + []net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, + }, + { + "select $1::cidr[]", + []net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, + }, + } + + for i, tt := range tests { + var actual []net.IPNet + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !reflect.DeepEqual(actual, tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + func TestInetCidrTranscodeWithJustIP(t *testing.T) { t.Parallel()