From 4cdea13f0f6a1e0239b2db00b68fde85eddd265b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 17:33:41 -0600 Subject: [PATCH] Add inet and cidr to pgtype --- conn.go | 4 + pgtype/cidrarray.go | 31 ++++ pgtype/convert.go | 16 ++ pgtype/inet.go | 240 ++++++++++++++++++++++++++++ pgtype/inet_test.go | 115 ++++++++++++++ pgtype/inetarray.go | 320 ++++++++++++++++++++++++++++++++++++++ pgtype/inetarray_test.go | 164 +++++++++++++++++++ pgtype/pgtype_test.go | 10 ++ pgtype/typed_array_gen.sh | 1 + values.go | 28 ---- values_test.go | 30 ++-- 11 files changed, 916 insertions(+), 43 deletions(-) create mode 100644 pgtype/cidrarray.go create mode 100644 pgtype/inet.go create mode 100644 pgtype/inet_test.go create mode 100644 pgtype/inetarray.go create mode 100644 pgtype/inetarray_test.go diff --git a/conn.go b/conn.go index 1e277a0e..b6670735 100644 --- a/conn.go +++ b/conn.go @@ -281,12 +281,16 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + CidrArrayOID: &pgtype.CidrArray{}, + CidrOID: &pgtype.Inet{}, DateArrayOID: &pgtype.DateArray{}, DateOID: &pgtype.Date{}, Float4ArrayOID: &pgtype.Float4Array{}, Float4OID: &pgtype.Float4{}, Float8ArrayOID: &pgtype.Float8Array{}, Float8OID: &pgtype.Float8{}, + InetArrayOID: &pgtype.InetArray{}, + InetOID: &pgtype.Inet{}, Int2ArrayOID: &pgtype.Int2Array{}, Int2OID: &pgtype.Int2{}, Int4ArrayOID: &pgtype.Int4Array{}, diff --git a/pgtype/cidrarray.go b/pgtype/cidrarray.go new file mode 100644 index 00000000..66dd20d0 --- /dev/null +++ b/pgtype/cidrarray.go @@ -0,0 +1,31 @@ +package pgtype + +import ( + "io" +) + +type CidrArray InetArray + +func (dst *CidrArray) ConvertFrom(src interface{}) error { + return (*InetArray)(dst).ConvertFrom(src) +} + +func (src *CidrArray) AssignTo(dst interface{}) error { + return (*InetArray)(src).AssignTo(dst) +} + +func (dst *CidrArray) DecodeText(r io.Reader) error { + return (*InetArray)(dst).DecodeText(r) +} + +func (dst *CidrArray) DecodeBinary(r io.Reader) error { + return (*InetArray)(dst).DecodeBinary(r) +} + +func (src *CidrArray) EncodeText(w io.Writer) error { + return (*InetArray)(src).EncodeText(w) +} + +func (src *CidrArray) EncodeBinary(w io.Writer) error { + return (*InetArray)(src).encodeBinary(w, CidrOID) +} diff --git a/pgtype/convert.go b/pgtype/convert.go index c4b52322..7111f8bc 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -85,6 +85,22 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + // underlyingTimeType gets the underlying type that can be converted to time.Time func underlyingTimeType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype/inet.go b/pgtype/inet.go new file mode 100644 index 00000000..e47c64b0 --- /dev/null +++ b/pgtype/inet.go @@ -0,0 +1,240 @@ +package pgtype + +import ( + "fmt" + "io" + "net" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +// Inet represents both inet and cidr PostgreSQL types. +type Inet struct { + IPNet *net.IPNet + Status Status +} + +func (dst *Inet) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Inet: + *dst = value + case net.IPNet: + *dst = Inet{IPNet: &value, Status: Present} + case *net.IPNet: + *dst = Inet{IPNet: value, Status: Present} + case net.IP: + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + case string: + _, ipnet, err := net.ParseCIDR(value) + if err != nil { + return err + } + *dst = Inet{IPNet: ipnet, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (src *Inet) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *net.IPNet: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = *src.IPNet + case *net.IP: + if src.Status == Present { + + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.IPNet.IP + } else { + *v = nil + } + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Inet) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Inet{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + var ipnet *net.IPNet + + if ip := net.ParseIP(string(buf)); ip != nil { + ipv4 := ip.To4() + if ipv4 != nil { + ip = ipv4 + } + bitCount := len(ip) * 8 + mask := net.CIDRMask(bitCount, bitCount) + ipnet = &net.IPNet{Mask: mask, IP: ip} + } else { + _, ipnet, err = net.ParseCIDR(string(buf)) + if err != nil { + return err + } + } + + *dst = Inet{IPNet: ipnet, Status: Present} + return nil +} + +func (dst *Inet) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Inet{Status: Null} + return nil + } + + if size != 8 && size != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", size) + } + + // ignore family + _, err = pgio.ReadByte(r) + if err != nil { + return err + } + + bits, err := pgio.ReadByte(r) + if err != nil { + return err + } + + // ignore is_cidr + _, err = pgio.ReadByte(r) + if err != nil { + return err + } + + addressLength, err := pgio.ReadByte(r) + if err != nil { + return err + } + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + _, err = r.Read(ipnet.IP) + if err != nil { + return err + } + + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + *dst = Inet{IPNet: &ipnet, Status: Present} + + return nil +} + +func (src Inet) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := src.IPNet.String() + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +// EncodeBinary encodes src into w. +func (src Inet) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var size int32 + var family byte + switch len(src.IPNet.IP) { + case net.IPv4len: + size = 8 + family = defaultAFInet + case net.IPv6len: + size = 20 + family = defaultAFInet6 + default: + return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + } + + if _, err := pgio.WriteInt32(w, size); err != nil { + return err + } + + if err := pgio.WriteByte(w, family); err != nil { + return err + } + + ones, _ := src.IPNet.Mask.Size() + if err := pgio.WriteByte(w, byte(ones)); err != nil { + return err + } + + // is_cidr is ignored on server + if err := pgio.WriteByte(w, 0); err != nil { + return err + } + + if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { + return err + } + + _, err := w.Write(src.IPNet.IP) + return err +} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go new file mode 100644 index 00000000..5e86376b --- /dev/null +++ b/pgtype/inet_test.go @@ -0,0 +1,115 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInetTranscode(t *testing.T) { + for _, pgTypeName := range []string{"inet", "cidr"} { + testSuccessfulTranscode(t, pgTypeName, []interface{}{ + pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }) + } +} + +func TestInetConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Inet + }{ + {source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}}, + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Inet + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetAssignTo(t *testing.T) { + var ipnet net.IPNet + var pipnet *net.IPNet + var ip net.IP + var pip *net.IP + + simpleTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Inet + dst interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go new file mode 100644 index 00000000..eb5a4c88 --- /dev/null +++ b/pgtype/inetarray.go @@ -0,0 +1,320 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "net" + + "github.com/jackc/pgx/pgio" +) + +type InetArray struct { + Elements []Inet + Dimensions []ArrayDimension + Status Status +} + +func (dst *InetArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case InetArray: + *dst = value + case CidrArray: + *dst = InetArray(value) + case []*net.IPNet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (src *InetArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *InetArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = InetArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Inet + + if len(uta.Elements) > 0 { + elements = make([]Inet, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Inet + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *InetArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = InetArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Inet, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *InetArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *InetArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, InetOID) +} + +func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = elementOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/inetarray_test.go b/pgtype/inetarray_test.go new file mode 100644 index 00000000..8cab5355 --- /dev/null +++ b/pgtype/inetarray_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInetArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "inet[]", []interface{}{ + &pgtype.InetArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{Status: pgtype.Null}, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInetArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.InetArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.InetArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.InetArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index a1dcd11b..7d34ae34 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -3,6 +3,7 @@ package pgtype_test import ( "fmt" "io" + "net" "os" "reflect" "testing" @@ -44,6 +45,15 @@ func mustClose(t testing.TB, conn interface { } } +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + type forceTextEncoder struct { e pgtype.TextEncoder } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 4ce6c3b5..47afdf1d 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -7,3 +7,4 @@ erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go diff --git a/values.go b/values.go index d2ec9fc2..8a2da367 100644 --- a/values.go +++ b/values.go @@ -1088,14 +1088,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case net.IP: - return encodeIP(wbuf, oid, arg) - case []net.IP: - return encodeIPSlice(wbuf, oid, arg) - case net.IPNet: - return encodeIPNet(wbuf, oid, arg) - case []net.IPNet: - return encodeIPNetSlice(wbuf, oid, arg) case OID: return encodeOID(wbuf, oid, arg) case Xid: @@ -1195,26 +1187,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeByteaArray(vr) case *[]interface{}: *v = decodeRecord(vr) - case *net.IP: - ipnet := decodeInet(vr) - if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("Cannot decode netmask into *net.IP") - } - *v = ipnet.IP - case *[]net.IP: - ipnets := decodeInetArray(vr) - ips := make([]net.IP, len(ipnets)) - for i, ipnet := range ipnets { - if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("Cannot decode netmask into *net.IP") - } - ips[i] = ipnet.IP - } - *v = ips - case *net.IPNet: - *v = decodeInet(vr) - case *[]net.IPNet: - *v = decodeInetArray(vr) default: if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { el := v.Elem() diff --git a/values_test.go b/values_test.go index 28f7371f..d6ce705a 100644 --- a/values_test.go +++ b/values_test.go @@ -232,13 +232,13 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) } } -func mustParseCIDR(t *testing.T, s string) net.IPNet { +func mustParseCIDR(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) } - return *ipnet + return ipnet } func TestStringToNotTextTypeTranscode(t *testing.T) { @@ -275,7 +275,7 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { tests := []struct { sql string - value net.IPNet + value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, @@ -358,7 +358,7 @@ func TestInetCidrTranscodeIP(t *testing.T) { failTests := []struct { sql string - value net.IPNet + value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, @@ -367,8 +367,8 @@ func TestInetCidrTranscodeIP(t *testing.T) { var actual net.IP err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if !strings.Contains(err.Error(), "Cannot decode netmask") { - t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) continue } @@ -384,11 +384,11 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { tests := []struct { sql string - value []net.IPNet + value []*net.IPNet }{ { "select $1::inet[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), @@ -403,7 +403,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { }, { "select $1::cidr[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), @@ -419,7 +419,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { } for i, tt := range tests { - var actual []net.IPNet + var actual []*net.IPNet err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) if err != nil { @@ -485,18 +485,18 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { failTests := []struct { sql string - value []net.IPNet + value []*net.IPNet }{ { "select $1::inet[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, }, { "select $1::cidr[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, @@ -507,8 +507,8 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { var actual []net.IP err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") { - t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) continue }