diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index b653260e..df955f18 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -3,6 +3,7 @@ package pgtype import ( "fmt" "math" + "net" "strconv" "time" ) @@ -340,3 +341,43 @@ func (w *timeWrapper) ScanDate(v Date) error { func (w timeWrapper) DateValue() (Date, error) { return Date{Time: time.Time(w), Valid: true}, nil } + +type netIPNetWrapper net.IPNet + +func (w *netIPNetWrapper) ScanInet(v Inet) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *net.IPNet") + } + + *w = (netIPNetWrapper)(*v.IPNet) + return nil +} + +func (w netIPNetWrapper) InetValue() (Inet, error) { + return Inet{IPNet: (*net.IPNet)(&w), Valid: true}, nil +} + +type netIPWrapper net.IP + +func (w *netIPWrapper) ScanInet(v Inet) error { + if !v.Valid { + *w = nil + return nil + } + + if oneCount, bitCount := v.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot scan %v to *net.IP", v) + } + *w = netIPWrapper(v.IPNet.IP) + return nil +} + +func (w netIPWrapper) InetValue() (Inet, error) { + if w == nil { + return Inet{}, nil + } + + bitCount := len(w) * 8 + mask := net.CIDRMask(bitCount, bitCount) + return Inet{IPNet: &net.IPNet{Mask: mask, IP: net.IP(w)}, Valid: true}, nil +} diff --git a/pgtype/cidr.go b/pgtype/cidr.go deleted file mode 100644 index 2241ca1c..00000000 --- a/pgtype/cidr.go +++ /dev/null @@ -1,31 +0,0 @@ -package pgtype - -type CIDR Inet - -func (dst *CIDR) Set(src interface{}) error { - return (*Inet)(dst).Set(src) -} - -func (dst CIDR) Get() interface{} { - return (Inet)(dst).Get() -} - -func (src *CIDR) AssignTo(dst interface{}) error { - return (*Inet)(src).AssignTo(dst) -} - -func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { - return (*Inet)(dst).DecodeText(ci, src) -} - -func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { - return (*Inet)(dst).DecodeBinary(ci, src) -} - -func (src CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Inet)(src).EncodeText(ci, buf) -} - -func (src CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - return (Inet)(src).EncodeBinary(ci, buf) -} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go deleted file mode 100644 index 48a6a4c1..00000000 --- a/pgtype/cidr_array.go +++ /dev/null @@ -1,533 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "net" - "reflect" - - "github.com/jackc/pgio" -) - -type CIDRArray struct { - Elements []CIDR - Dimensions []ArrayDimension - Valid bool -} - -func (dst *CIDRArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = CIDRArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []*net.IPNet: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []net.IP: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*net.IP: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - elements := make([]CIDR, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = CIDRArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []CIDR: - if value == nil { - *dst = CIDRArray{} - } else if len(value) == 0 { - *dst = CIDRArray{Valid: true} - } else { - *dst = CIDRArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = CIDRArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src) - } - if elementsLength == 0 { - *dst = CIDRArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to CIDRArray", src) - } - - *dst = CIDRArray{ - Elements: make([]CIDR, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]CIDR, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to CIDRArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in CIDRArray", err) - } - index++ - - return index, nil -} - -func (dst CIDRArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *CIDRArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from CIDRArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from CIDRArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = CIDRArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []CIDR - - if len(uta.Elements) > 0 { - elements = make([]CIDR, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem CIDR - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = CIDRArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]CIDR, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("cidr"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *CIDRArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src CIDRArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go deleted file mode 100644 index 550bf9d1..00000000 --- a/pgtype/cidr_array_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestCIDRArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ - &pgtype.CIDRArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.CIDRArray{}, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - {}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestCIDRArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.CIDRArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.CIDRArray{}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.CIDRArray{}, - }, - { - source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.CIDRArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestCIDRArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - var ipSliceDim2 [][]net.IP - var ipnetSliceDim4 [][][][]*net.IPNet - var ipArrayDim2 [2][1]net.IP - var ipnetArrayDim4 [2][1][1][3]*net.IPNet - - simpleTests := []struct { - src pgtype.CIDRArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.CIDRArray{}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.CIDRArray{Valid: true}, - dst: &ipnetSlice, - expected: []*net.IPNet{}, - }, - { - src: pgtype.CIDRArray{}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - { - src: pgtype.CIDRArray{Valid: true}, - dst: &ipSlice, - expected: []net.IP{}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipSliceDim2, - expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetSliceDim4, - expected: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipArrayDim2, - expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.CIDRArray{ - Elements: []pgtype.CIDR{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetArrayDim4, - expected: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/inet.go b/pgtype/inet.go index 4b3217a9..f88d1712 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -14,119 +14,208 @@ const ( defaultAFInet6 = 3 ) +type InetScanner interface { + ScanInet(v Inet) error +} + +type InetValuer interface { + InetValue() (Inet, error) +} + // Inet represents both inet and cidr PostgreSQL types. type Inet struct { IPNet *net.IPNet Valid bool } -func (dst *Inet) Set(src interface{}) error { +func (inet *Inet) ScanInet(v Inet) error { + *inet = v + return nil +} + +func (inet Inet) InetValue() (Inet, error) { + return inet, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { if src == nil { *dst = Inet{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextAnyToInetScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) } - switch value := src.(type) { - case net.IPNet: - *dst = Inet{IPNet: &value, Valid: true} - case net.IP: - if len(value) == 0 { - *dst = Inet{} - } else { - bitCount := len(value) * 8 - mask := net.CIDRMask(bitCount, bitCount) - *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Valid: true} - } - case string: - ip, ipnet, err := net.ParseCIDR(value) - if err != nil { - ip = net.ParseIP(value) - if ip == nil { - return fmt.Errorf("unable to parse inet address: %s", value) - } - ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - ipnet.Mask = net.CIDRMask(32, 32) - } - } - ipnet.IP = ip - *dst = Inet{IPNet: ipnet, Valid: true} - case *net.IPNet: - if value == nil { - *dst = Inet{} - } else { - return dst.Set(*value) - } - case *net.IP: - if value == nil { - *dst = Inet{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Inet{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Inet", value) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := InetCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type InetCodec struct{} + +func (InetCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (InetCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (InetCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(InetValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanInetCodecBinary{} + case TextFormatCode: + return encodePlanInetCodecText{} } return nil } -func (dst Inet) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.IPNet -} +type encodePlanInetCodecBinary struct{} -func (src *Inet) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +func (encodePlanInetCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + inet, err := value.(InetValuer).InetValue() + if err != nil { + return nil, err } - switch v := dst.(type) { - case *net.IPNet: - *v = net.IPNet{ - IP: make(net.IP, len(src.IPNet.IP)), - Mask: make(net.IPMask, len(src.IPNet.Mask)), - } - copy(v.IP, src.IPNet.IP) - copy(v.Mask, src.IPNet.Mask) - return nil - case *net.IP: - if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = make(net.IP, len(src.IPNet.IP)) - copy(*v, src.IPNet.IP) - return nil + if !inet.Valid { + return nil, nil + } + + var family byte + switch len(inet.IPNet.IP) { + case net.IPv4len: + family = defaultAFInet + case net.IPv6len: + family = defaultAFInet6 default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) + return nil, fmt.Errorf("Unexpected IP length: %v", len(inet.IPNet.IP)) } + + buf = append(buf, family) + + ones, _ := inet.IPNet.Mask.Size() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + buf = append(buf, byte(len(inet.IPNet.IP))) + + return append(buf, inet.IPNet.IP...), nil } -func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { +type encodePlanInetCodecText struct{} + +func (encodePlanInetCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + inet, err := value.(InetValuer).InetValue() + if err != nil { + return nil, err + } + + if !inet.Valid { + return nil, nil + } + + return append(buf, inet.IPNet.String()...), nil +} + +func (InetCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case InetScanner: + return scanPlanBinaryInetToInetScanner{} + } + case TextFormatCode: + switch target.(type) { + case InetScanner: + return scanPlanTextAnyToInetScanner{} + } + } + + return nil +} + +func (c InetCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c InetCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Inet{} - return nil + return nil, nil + } + + var inet Inet + err := codecScan(c, ci, oid, format, src, &inet) + if err != nil { + return nil, err + } + + if !inet.Valid { + return nil, nil + } + + return inet.IPNet, nil +} + +type scanPlanBinaryInetToInetScanner struct{} + +func (scanPlanBinaryInetToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(InetScanner) + + if src == nil { + return scanner.ScanInet(Inet{}) + } + + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + } + + // ignore family + bits := src[1] + // ignore is_cidr + addressLength := src[3] + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + copy(ipnet.IP, src[4:]) + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) + + return scanner.ScanInet(Inet{IPNet: &ipnet, Valid: true}) +} + +type scanPlanTextAnyToInetScanner struct{} + +func (scanPlanTextAnyToInetScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(InetScanner) + + if src == nil { + return scanner.ScanInet(Inet{}) } var ipnet *net.IPNet @@ -151,95 +240,5 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} } - *dst = Inet{IPNet: ipnet, Valid: true} - return nil -} - -func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Inet{} - return nil - } - - if len(src) != 8 && len(src) != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) - } - - // ignore family - bits := src[1] - // ignore is_cidr - addressLength := src[3] - - var ipnet net.IPNet - ipnet.IP = make(net.IP, int(addressLength)) - copy(ipnet.IP, src[4:]) - if ipv4 := ipnet.IP.To4(); ipv4 != nil { - ipnet.IP = ipv4 - } - ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) - - *dst = Inet{IPNet: &ipnet, Valid: true} - - return nil -} - -func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - return append(buf, src.IPNet.String()...), nil -} - -// EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - var family byte - switch len(src.IPNet.IP) { - case net.IPv4len: - family = defaultAFInet - case net.IPv6len: - family = defaultAFInet6 - default: - return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) - } - - buf = append(buf, family) - - ones, _ := src.IPNet.Mask.Size() - buf = append(buf, byte(ones)) - - // is_cidr is ignored on server - buf = append(buf, 0) - - buf = append(buf, byte(len(src.IPNet.IP))) - - return append(buf, src.IPNet.IP...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Inet) Scan(src interface{}) error { - if src == nil { - *dst = Inet{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Inet) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanInet(Inet{IPNet: ipnet, Valid: true}) } diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go deleted file mode 100644 index 7f41c4e5..00000000 --- a/pgtype/inet_array.go +++ /dev/null @@ -1,533 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "net" - "reflect" - - "github.com/jackc/pgio" -) - -type InetArray struct { - Elements []Inet - Dimensions []ArrayDimension - Valid bool -} - -func (dst *InetArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = InetArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []*net.IPNet: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []net.IP: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*net.IP: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - elements := make([]Inet, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = InetArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Inet: - if value == nil { - *dst = InetArray{} - } else if len(value) == 0 { - *dst = InetArray{Valid: true} - } else { - *dst = InetArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = InetArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for InetArray", src) - } - if elementsLength == 0 { - *dst = InetArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to InetArray", src) - } - - *dst = InetArray{ - Elements: make([]Inet, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Inet, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to InetArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in InetArray", err) - } - index++ - - return index, nil -} - -func (dst InetArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *InetArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]*net.IPNet: - *v = make([]*net.IPNet, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]net.IP: - *v = make([]net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*net.IP: - *v = make([]*net.IP, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from InetArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from InetArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = InetArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Inet - - if len(uta.Elements) > 0 { - elements = make([]Inet, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Inet - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = InetArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = InetArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Inet, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("inet"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "inet") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *InetArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src InetArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go deleted file mode 100644 index da7ee975..00000000 --- a/pgtype/inet_array_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgtype_test - -import ( - "net" - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestInetArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ - &pgtype.InetArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.InetArray{}, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - {}, - {IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestInetArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.InetArray - }{ - { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]*net.IPNet)(nil)), - result: pgtype.InetArray{}, - }, - { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]net.IP)(nil)), - result: pgtype.InetArray{}, - }, - { - source: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - result: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.InetArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestInetArrayAssignTo(t *testing.T) { - var ipnetSlice []*net.IPNet - var ipSlice []net.IP - var ipSliceDim2 [][]net.IP - var ipnetSliceDim4 [][][][]*net.IPNet - var ipArrayDim2 [2][1]net.IP - var ipnetArrayDim4 [2][1][1][3]*net.IPNet - - simpleTests := []struct { - src pgtype.InetArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipnetSlice, - expected: []*net.IPNet{nil}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &ipSlice, - expected: []net.IP{nil}, - }, - { - src: pgtype.InetArray{}, - dst: &ipnetSlice, - expected: (([]*net.IPNet)(nil)), - }, - { - src: pgtype.InetArray{Valid: true}, - dst: &ipnetSlice, - expected: []*net.IPNet{}, - }, - { - src: pgtype.InetArray{}, - dst: &ipSlice, - expected: (([]net.IP)(nil)), - }, - { - src: pgtype.InetArray{Valid: true}, - dst: &ipSlice, - expected: []net.IP{}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipSliceDim2, - expected: [][]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetSliceDim4, - expected: [][][][]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/32"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &ipArrayDim2, - expected: [2][1]net.IP{{mustParseCIDR(t, "127.0.0.1/32").IP}, {mustParseCIDR(t, "10.0.0.1/32").IP}}, - }, - { - src: pgtype.InetArray{ - Elements: []pgtype.Inet{ - {IPNet: mustParseCIDR(t, "127.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "10.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "172.16.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "192.168.0.1/16"), Valid: true}, - {IPNet: mustParseCIDR(t, "224.0.0.1/24"), Valid: true}, - {IPNet: mustParseCIDR(t, "169.168.0.1/16"), Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - dst: &ipnetArrayDim4, - expected: [2][1][1][3]*net.IPNet{ - {{{ - mustParseCIDR(t, "127.0.0.1/24"), - mustParseCIDR(t, "10.0.0.1/24"), - mustParseCIDR(t, "172.16.0.1/16")}}}, - {{{ - mustParseCIDR(t, "192.168.0.1/16"), - mustParseCIDR(t, "224.0.0.1/24"), - mustParseCIDR(t, "169.168.0.1/16")}}}}, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index d4716479..4ead4672 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -2,138 +2,48 @@ package pgtype_test import ( "net" - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" ) +func isExpectedEqIPNet(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + ap := a.(*net.IPNet) + vp := v.(net.IPNet) + + return ap.IP.Equal(vp.IP) && ap.Mask.String() == vp.Mask.String() + } +} + func TestInetTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ - &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Valid: true}, - &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Valid: true}, - &pgtype.Inet{}, + testPgxCodec(t, "inet", []PgxTranscodeTestCase{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/8"))}, + {mustParseInet(t, "12.34.56.65/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.65/32"))}, + {mustParseInet(t, "192.168.1.16/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.16/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "255.255.255.255/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.255.255.255/32"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e"))}, + {mustParseInet(t, "::1/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/64"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/64"))}, + {nil, new(pgtype.Inet), isExpectedEq(pgtype.Inet{})}, }) } func TestCidrTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Valid: true}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Valid: true}, - &pgtype.Inet{}, + testPgxCodec(t, "cidr", []PgxTranscodeTestCase{ + {mustParseInet(t, "0.0.0.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "0.0.0.0/32"))}, + {mustParseInet(t, "127.0.0.1/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "127.0.0.1/32"))}, + {mustParseInet(t, "12.34.56.0/32"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "12.34.56.0/32"))}, + {mustParseInet(t, "192.168.1.0/24"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "192.168.1.0/24"))}, + {mustParseInet(t, "255.0.0.0/8"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "255.0.0.0/8"))}, + {mustParseInet(t, "::/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/128"))}, + {mustParseInet(t, "::/0"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::/0"))}, + {mustParseInet(t, "::1/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "::1/128"))}, + {mustParseInet(t, "2607:f8b0:4009:80b::200e/128"), new(net.IPNet), isExpectedEqIPNet(mustParseInet(t, "2607:f8b0:4009:80b::200e/128"))}, + {nil, new(pgtype.Inet), isExpectedEq(pgtype.Inet{})}, }) } - -func TestInetSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Inet - }{ - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Valid: true}}, - {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Valid: true}}, - {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}}, - {source: net.ParseIP(""), result: pgtype.Inet{}}, - } - - for i, tt := range successfulTests { - var r pgtype.Inet - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - continue - } - - assert.Equalf(t, tt.result.Valid, r.Valid, "%d: Status", i) - if tt.result.Valid { - assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: IP", i) - assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: Mask", i) - } - } -} - -func TestInetAssignTo(t *testing.T) { - var ipnet net.IPNet - var pipnet *net.IPNet - var ip net.IP - var pip *net.IP - - simpleTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - {src: pgtype.Inet{}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, - {src: pgtype.Inet{}, dst: &pip, expected: ((*net.IP)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Inet - dst interface{} - expected interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Inet - dst interface{} - }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Valid: true}, dst: &ip}, - {src: pgtype.Inet{}, dst: &ipnet}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e48bec51..89c7b348 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -262,11 +262,11 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) ci.RegisterDataType(DataType{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementCodec: TextCodec{}, ElementOID: BPCharOID}}) ci.RegisterDataType(DataType{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementCodec: ByteaCodec{}, ElementOID: ByteaOID}}) - ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: CIDROID}}) ci.RegisterDataType(DataType{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementCodec: DateCodec{}, ElementOID: DateOID}}) ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) - ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementCodec: InetCodec{}, ElementOID: InetOID}}) ci.RegisterDataType(DataType{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementCodec: Int2Codec{}, ElementOID: Int2OID}}) ci.RegisterDataType(DataType{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementCodec: Int4Codec{}, ElementOID: Int4OID}}) ci.RegisterDataType(DataType{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementCodec: Int8Codec{}, ElementOID: Int8OID}}) @@ -293,13 +293,13 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) ci.RegisterDataType(DataType{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Name: "date", OID: DateOID, Codec: DateCodec{}}) // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) - ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Name: "inet", OID: InetOID, Codec: InetCodec{}}) ci.RegisterDataType(DataType{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) ci.RegisterDataType(DataType{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) @@ -336,15 +336,26 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + // T ci.RegisterDefaultPgType(value, name) - valueType := reflect.TypeOf(value) + // *T + valueType := reflect.TypeOf(value) ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + // []T sliceType := reflect.SliceOf(valueType) ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + // *[]T ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + + // []*T + sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface())) + ci.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName) + + // *[]*T + ci.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName) } // Integer types that directly map to a PostgreSQL type @@ -368,8 +379,7 @@ func NewConnInfo() *ConnInfo { registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) - ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr") - ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr") + registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{}) return ci } @@ -816,6 +826,10 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter switch dst := dst.(type) { case *time.Time: return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true + case *net.IPNet: + return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true + case *net.IP: + return &wrapNetIPScanPlan{}, (*netIPWrapper)(dst), true } return nil, nil, false @@ -831,6 +845,26 @@ func (plan *wrapTimeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return plan.next.Scan(ci, oid, formatCode, src, (*timeWrapper)(dst.(*time.Time))) } +type wrapNetIPNetScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPNetScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPNetScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*netIPNetWrapper)(dst.(*net.IPNet))) +} + +type wrapNetIPScanPlan struct { + next ScanPlan +} + +func (plan *wrapNetIPScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapNetIPScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*netIPWrapper)(dst.(*net.IP))) +} + type pointerEmptyInterfaceScanPlan struct { codec Codec } @@ -901,6 +935,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if oid == 0 { if dataType, ok := ci.DataTypeForValue(dst); ok { dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. } } else { if dataType, ok := ci.DataTypeForOID(oid); ok { @@ -1031,6 +1066,7 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco if oid == 0 { if dataType, ok := ci.DataTypeForValue(value); ok { dt = dataType + oid = dt.OID // Preserve assumed OID in case we are recursively called below. } } else { if dataType, ok := ci.DataTypeForOID(oid); ok { @@ -1166,6 +1202,10 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapStringEncodePlan{}, stringWrapper(value), true case time.Time: return &wrapTimeEncodePlan{}, timeWrapper(value), true + case net.IPNet: + return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true + case net.IP: + return &wrapNetIPEncodePlan{}, netIPWrapper(value), true } return nil, nil, false @@ -1311,6 +1351,26 @@ func (plan *wrapTimeEncodePlan) Encode(value interface{}, buf []byte) (newBuf [] return plan.next.Encode(timeWrapper(value.(time.Time)), buf) } +type wrapNetIPNetEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPNetEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPNetEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPNetWrapper(value.(net.IPNet)), buf) +} + +type wrapNetIPEncodePlan struct { + next EncodePlan +} + +func (plan *wrapNetIPEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapNetIPEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(netIPWrapper(value.(net.IP)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written.