mirror of https://github.com/jackc/pgx.git
Add inet[] and cidr[] support
parent
36fb7a3aec
commit
c726a51450
6
conn.go
6
conn.go
|
@ -783,7 +783,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
default:
|
default:
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid:
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
default:
|
default:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
|
@ -839,6 +839,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
err = encodeTimestamp(wbuf, arguments[i])
|
err = encodeTimestamp(wbuf, arguments[i])
|
||||||
case InetOid, CidrOid:
|
case InetOid, CidrOid:
|
||||||
err = encodeInet(wbuf, arguments[i])
|
err = encodeInet(wbuf, arguments[i])
|
||||||
|
case InetArrayOid:
|
||||||
|
err = encodeInetArray(wbuf, arguments[i], InetOid)
|
||||||
|
case CidrArrayOid:
|
||||||
|
err = encodeInetArray(wbuf, arguments[i], CidrOid)
|
||||||
case BoolArrayOid:
|
case BoolArrayOid:
|
||||||
err = encodeBoolArray(wbuf, arguments[i])
|
err = encodeBoolArray(wbuf, arguments[i])
|
||||||
case Int2ArrayOid:
|
case Int2ArrayOid:
|
||||||
|
|
2
query.go
2
query.go
|
@ -295,6 +295,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
||||||
}
|
}
|
||||||
case *net.IPNet:
|
case *net.IPNet:
|
||||||
*v = decodeInet(vr)
|
*v = decodeInet(vr)
|
||||||
|
case *[]net.IPNet:
|
||||||
|
*v = decodeInetArray(vr)
|
||||||
default:
|
default:
|
||||||
// if d is a pointer to pointer, strip the pointer and try again
|
// if d is a pointer to pointer, strip the pointer and try again
|
||||||
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
||||||
|
|
73
values.go
73
values.go
|
@ -22,6 +22,7 @@ const (
|
||||||
OidOid = 26
|
OidOid = 26
|
||||||
JsonOid = 114
|
JsonOid = 114
|
||||||
CidrOid = 650
|
CidrOid = 650
|
||||||
|
CidrArrayOid = 651
|
||||||
Float4Oid = 700
|
Float4Oid = 700
|
||||||
Float8Oid = 701
|
Float8Oid = 701
|
||||||
InetOid = 869
|
InetOid = 869
|
||||||
|
@ -33,6 +34,7 @@ const (
|
||||||
Int8ArrayOid = 1016
|
Int8ArrayOid = 1016
|
||||||
Float4ArrayOid = 1021
|
Float4ArrayOid = 1021
|
||||||
Float8ArrayOid = 1022
|
Float8ArrayOid = 1022
|
||||||
|
InetArrayOid = 1041
|
||||||
VarcharOid = 1043
|
VarcharOid = 1043
|
||||||
DateOid = 1082
|
DateOid = 1082
|
||||||
TimestampOid = 1114
|
TimestampOid = 1114
|
||||||
|
@ -59,8 +61,10 @@ var DefaultTypeFormats map[string]int16
|
||||||
func init() {
|
func init() {
|
||||||
DefaultTypeFormats = map[string]int16{
|
DefaultTypeFormats = map[string]int16{
|
||||||
"_bool": BinaryFormatCode,
|
"_bool": BinaryFormatCode,
|
||||||
|
"_cidr": BinaryFormatCode,
|
||||||
"_float4": BinaryFormatCode,
|
"_float4": BinaryFormatCode,
|
||||||
"_float8": BinaryFormatCode,
|
"_float8": BinaryFormatCode,
|
||||||
|
"_inet": BinaryFormatCode,
|
||||||
"_int2": BinaryFormatCode,
|
"_int2": BinaryFormatCode,
|
||||||
"_int4": BinaryFormatCode,
|
"_int4": BinaryFormatCode,
|
||||||
"_int8": BinaryFormatCode,
|
"_int8": BinaryFormatCode,
|
||||||
|
@ -1703,6 +1707,75 @@ func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeInetArray(vr *ValueReader) []net.IPNet {
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numElems, err := decode1dArrayHeader(vr)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a := make([]net.IPNet, int(numElems))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
elSize := vr.ReadInt32()
|
||||||
|
if elSize == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null element"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vr.ReadByte() // ignore family
|
||||||
|
bits := vr.ReadByte()
|
||||||
|
vr.ReadByte() // ignore is_cidr
|
||||||
|
addressLength := vr.ReadByte()
|
||||||
|
|
||||||
|
var ipnet net.IPNet
|
||||||
|
ipnet.IP = vr.ReadBytes(int32(addressLength))
|
||||||
|
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
||||||
|
|
||||||
|
a[i] = ipnet
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInetArray(w *WriteBuf, value interface{}, elOid Oid) error {
|
||||||
|
slice, ok := value.([]net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("Expected []net.IPNet, received %T", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := int32(20) // array header size
|
||||||
|
for _, ipnet := range slice {
|
||||||
|
size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes
|
||||||
|
}
|
||||||
|
w.WriteInt32(int32(size))
|
||||||
|
|
||||||
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
w.WriteInt32(0) // no nulls
|
||||||
|
w.WriteInt32(int32(elOid)) // type of elements
|
||||||
|
w.WriteInt32(int32(len(slice))) // number of elements
|
||||||
|
w.WriteInt32(1) // index of first element
|
||||||
|
|
||||||
|
for _, ipnet := range slice {
|
||||||
|
encodeInet(w, ipnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
|
func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
|
||||||
w.WriteInt32(int32(20 + length*sizePerItem))
|
w.WriteInt32(int32(20 + length*sizePerItem))
|
||||||
w.WriteInt32(1) // number of dimensions
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
|
|
@ -308,6 +308,65 @@ func TestInetCidrTranscode(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInetCidrArrayTranscode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
value []net.IPNet
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"select $1::inet[]",
|
||||||
|
[]net.IPNet{
|
||||||
|
mustParseCIDR(t, "0.0.0.0/32"),
|
||||||
|
mustParseCIDR(t, "127.0.0.1/32"),
|
||||||
|
mustParseCIDR(t, "12.34.56.0/32"),
|
||||||
|
mustParseCIDR(t, "192.168.1.0/24"),
|
||||||
|
mustParseCIDR(t, "255.0.0.0/8"),
|
||||||
|
mustParseCIDR(t, "255.255.255.255/32"),
|
||||||
|
mustParseCIDR(t, "::/128"),
|
||||||
|
mustParseCIDR(t, "::/0"),
|
||||||
|
mustParseCIDR(t, "::1/128"),
|
||||||
|
mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"select $1::cidr[]",
|
||||||
|
[]net.IPNet{
|
||||||
|
mustParseCIDR(t, "0.0.0.0/32"),
|
||||||
|
mustParseCIDR(t, "127.0.0.1/32"),
|
||||||
|
mustParseCIDR(t, "12.34.56.0/32"),
|
||||||
|
mustParseCIDR(t, "192.168.1.0/24"),
|
||||||
|
mustParseCIDR(t, "255.0.0.0/8"),
|
||||||
|
mustParseCIDR(t, "255.255.255.255/32"),
|
||||||
|
mustParseCIDR(t, "::/128"),
|
||||||
|
mustParseCIDR(t, "::/0"),
|
||||||
|
mustParseCIDR(t, "::1/128"),
|
||||||
|
mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
var actual []net.IPNet
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, tt.value) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInetCidrTranscodeWithJustIP(t *testing.T) {
|
func TestInetCidrTranscodeWithJustIP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue