mirror of https://github.com/jackc/pgx.git
Add inet and cidr to pgtype
parent
2010bea555
commit
4cdea13f0f
4
conn.go
4
conn.go
|
@ -281,12 +281,16 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.oidPgtypeValues = map[OID]pgtype.Value{
|
||||
BoolArrayOID: &pgtype.BoolArray{},
|
||||
BoolOID: &pgtype.Bool{},
|
||||
CidrArrayOID: &pgtype.CidrArray{},
|
||||
CidrOID: &pgtype.Inet{},
|
||||
DateArrayOID: &pgtype.DateArray{},
|
||||
DateOID: &pgtype.Date{},
|
||||
Float4ArrayOID: &pgtype.Float4Array{},
|
||||
Float4OID: &pgtype.Float4{},
|
||||
Float8ArrayOID: &pgtype.Float8Array{},
|
||||
Float8OID: &pgtype.Float8{},
|
||||
InetArrayOID: &pgtype.InetArray{},
|
||||
InetOID: &pgtype.Inet{},
|
||||
Int2ArrayOID: &pgtype.Int2Array{},
|
||||
Int2OID: &pgtype.Int2{},
|
||||
Int4ArrayOID: &pgtype.Int4Array{},
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type CidrArray InetArray
|
||||
|
||||
func (dst *CidrArray) ConvertFrom(src interface{}) error {
|
||||
return (*InetArray)(dst).ConvertFrom(src)
|
||||
}
|
||||
|
||||
func (src *CidrArray) AssignTo(dst interface{}) error {
|
||||
return (*InetArray)(src).AssignTo(dst)
|
||||
}
|
||||
|
||||
func (dst *CidrArray) DecodeText(r io.Reader) error {
|
||||
return (*InetArray)(dst).DecodeText(r)
|
||||
}
|
||||
|
||||
func (dst *CidrArray) DecodeBinary(r io.Reader) error {
|
||||
return (*InetArray)(dst).DecodeBinary(r)
|
||||
}
|
||||
|
||||
func (src *CidrArray) EncodeText(w io.Writer) error {
|
||||
return (*InetArray)(src).EncodeText(w)
|
||||
}
|
||||
|
||||
func (src *CidrArray) EncodeBinary(w io.Writer) error {
|
||||
return (*InetArray)(src).encodeBinary(w, CidrOID)
|
||||
}
|
|
@ -85,6 +85,22 @@ func underlyingBoolType(val interface{}) (interface{}, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
// underlyingPtrType dereferences a pointer
|
||||
func underlyingPtrType(val interface{}) (interface{}, bool) {
|
||||
refVal := reflect.ValueOf(val)
|
||||
|
||||
switch refVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
if refVal.IsNil() {
|
||||
return nil, false
|
||||
}
|
||||
convVal := refVal.Elem().Interface()
|
||||
return convVal, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// underlyingTimeType gets the underlying type that can be converted to time.Time
|
||||
func underlyingTimeType(val interface{}) (interface{}, bool) {
|
||||
refVal := reflect.ValueOf(val)
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
// Network address family is dependent on server socket.h value for AF_INET.
|
||||
// In practice, all platforms appear to have the same value. See
|
||||
// src/include/utils/inet.h for more information.
|
||||
const (
|
||||
defaultAFInet = 2
|
||||
defaultAFInet6 = 3
|
||||
)
|
||||
|
||||
// Inet represents both inet and cidr PostgreSQL types.
|
||||
type Inet struct {
|
||||
IPNet *net.IPNet
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *Inet) ConvertFrom(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case Inet:
|
||||
*dst = value
|
||||
case net.IPNet:
|
||||
*dst = Inet{IPNet: &value, Status: Present}
|
||||
case *net.IPNet:
|
||||
*dst = Inet{IPNet: value, Status: Present}
|
||||
case net.IP:
|
||||
bitCount := len(value) * 8
|
||||
mask := net.CIDRMask(bitCount, bitCount)
|
||||
*dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present}
|
||||
case string:
|
||||
_, ipnet, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = Inet{IPNet: ipnet, Status: Present}
|
||||
default:
|
||||
if originalSrc, ok := underlyingPtrType(src); ok {
|
||||
return dst.ConvertFrom(originalSrc)
|
||||
}
|
||||
return fmt.Errorf("cannot convert %v to Inet", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *Inet) AssignTo(dst interface{}) error {
|
||||
switch v := dst.(type) {
|
||||
case *net.IPNet:
|
||||
if src.Status != Present {
|
||||
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
*v = *src.IPNet
|
||||
case *net.IP:
|
||||
if src.Status == Present {
|
||||
|
||||
if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount {
|
||||
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||
}
|
||||
*v = src.IPNet.IP
|
||||
} else {
|
||||
*v = nil
|
||||
}
|
||||
default:
|
||||
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
switch el.Kind() {
|
||||
// if dst is a pointer to pointer, strip the pointer and try again
|
||||
case reflect.Ptr:
|
||||
if src.Status == Null {
|
||||
el.Set(reflect.Zero(el.Type()))
|
||||
return nil
|
||||
}
|
||||
if el.IsNil() {
|
||||
// allocate destination
|
||||
el.Set(reflect.New(el.Type().Elem()))
|
||||
}
|
||||
return src.AssignTo(el.Interface())
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("cannot decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Inet) DecodeText(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
*dst = Inet{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := make([]byte, int(size))
|
||||
_, err = io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ipnet *net.IPNet
|
||||
|
||||
if ip := net.ParseIP(string(buf)); ip != nil {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 != nil {
|
||||
ip = ipv4
|
||||
}
|
||||
bitCount := len(ip) * 8
|
||||
mask := net.CIDRMask(bitCount, bitCount)
|
||||
ipnet = &net.IPNet{Mask: mask, IP: ip}
|
||||
} else {
|
||||
_, ipnet, err = net.ParseCIDR(string(buf))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*dst = Inet{IPNet: ipnet, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *Inet) DecodeBinary(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
*dst = Inet{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
if size != 8 && size != 20 {
|
||||
return fmt.Errorf("Received an invalid size for a inet: %d", size)
|
||||
}
|
||||
|
||||
// ignore family
|
||||
_, err = pgio.ReadByte(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bits, err := pgio.ReadByte(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ignore is_cidr
|
||||
_, err = pgio.ReadByte(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addressLength, err := pgio.ReadByte(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ipnet net.IPNet
|
||||
ipnet.IP = make(net.IP, int(addressLength))
|
||||
_, err = r.Read(ipnet.IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
||||
|
||||
*dst = Inet{IPNet: &ipnet, Status: Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src Inet) EncodeText(w io.Writer) error {
|
||||
if done, err := encodeNotPresent(w, src.Status); done {
|
||||
return err
|
||||
}
|
||||
|
||||
s := src.IPNet.String()
|
||||
_, err := pgio.WriteInt32(w, int32(len(s)))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, err = w.Write([]byte(s))
|
||||
return err
|
||||
}
|
||||
|
||||
// EncodeBinary encodes src into w.
|
||||
func (src Inet) EncodeBinary(w io.Writer) error {
|
||||
if done, err := encodeNotPresent(w, src.Status); done {
|
||||
return err
|
||||
}
|
||||
|
||||
var size int32
|
||||
var family byte
|
||||
switch len(src.IPNet.IP) {
|
||||
case net.IPv4len:
|
||||
size = 8
|
||||
family = defaultAFInet
|
||||
case net.IPv6len:
|
||||
size = 20
|
||||
family = defaultAFInet6
|
||||
default:
|
||||
return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP))
|
||||
}
|
||||
|
||||
if _, err := pgio.WriteInt32(w, size); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pgio.WriteByte(w, family); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ones, _ := src.IPNet.Mask.Size()
|
||||
if err := pgio.WriteByte(w, byte(ones)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// is_cidr is ignored on server
|
||||
if err := pgio.WriteByte(w, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := w.Write(src.IPNet.IP)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestInetTranscode(t *testing.T) {
|
||||
for _, pgTypeName := range []string{"inet", "cidr"} {
|
||||
testSuccessfulTranscode(t, pgTypeName, []interface{}{
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
|
||||
pgtype.Inet{Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInetConvertFrom(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.Inet
|
||||
}{
|
||||
{source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}},
|
||||
{source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
{source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
{source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.Inet
|
||||
err := r.ConvertFrom(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInetAssignTo(t *testing.T) {
|
||||
var ipnet net.IPNet
|
||||
var pipnet *net.IPNet
|
||||
var ip net.IP
|
||||
var pip *net.IP
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.Inet
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")},
|
||||
{src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP},
|
||||
{src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))},
|
||||
{src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
pointerAllocTests := []struct {
|
||||
src pgtype.Inet
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")},
|
||||
{src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP},
|
||||
}
|
||||
|
||||
for i, tt := range pointerAllocTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
src pgtype.Inet
|
||||
dst interface{}
|
||||
}{
|
||||
{src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip},
|
||||
{src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,320 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type InetArray struct {
|
||||
Elements []Inet
|
||||
Dimensions []ArrayDimension
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (dst *InetArray) ConvertFrom(src interface{}) error {
|
||||
switch value := src.(type) {
|
||||
case InetArray:
|
||||
*dst = value
|
||||
case CidrArray:
|
||||
*dst = InetArray(value)
|
||||
case []*net.IPNet:
|
||||
if value == nil {
|
||||
*dst = InetArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = InetArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]Inet, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].ConvertFrom(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = InetArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
case []net.IP:
|
||||
if value == nil {
|
||||
*dst = InetArray{Status: Null}
|
||||
} else if len(value) == 0 {
|
||||
*dst = InetArray{Status: Present}
|
||||
} else {
|
||||
elements := make([]Inet, len(value))
|
||||
for i := range value {
|
||||
if err := elements[i].ConvertFrom(value[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*dst = InetArray{
|
||||
Elements: elements,
|
||||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||
Status: Present,
|
||||
}
|
||||
}
|
||||
default:
|
||||
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||
return dst.ConvertFrom(originalSrc)
|
||||
}
|
||||
return fmt.Errorf("cannot convert %v to Inet", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *InetArray) AssignTo(dst interface{}) error {
|
||||
switch v := dst.(type) {
|
||||
|
||||
case *[]*net.IPNet:
|
||||
if src.Status == Present {
|
||||
*v = make([]*net.IPNet, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*v = nil
|
||||
}
|
||||
case *[]net.IP:
|
||||
if src.Status == Present {
|
||||
*v = make([]net.IP, len(src.Elements))
|
||||
for i := range src.Elements {
|
||||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*v = nil
|
||||
}
|
||||
|
||||
default:
|
||||
if originalDst, ok := underlyingPtrSliceType(dst); ok {
|
||||
return src.AssignTo(originalDst)
|
||||
}
|
||||
return fmt.Errorf("cannot put decode %v into %T", src, dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *InetArray) DecodeText(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
*dst = InetArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := make([]byte, int(size))
|
||||
_, err = io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uta, err := ParseUntypedTextArray(string(buf))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
textElementReader := NewTextElementReader(r)
|
||||
var elements []Inet
|
||||
|
||||
if len(uta.Elements) > 0 {
|
||||
elements = make([]Inet, len(uta.Elements))
|
||||
|
||||
for i, s := range uta.Elements {
|
||||
var elem Inet
|
||||
textElementReader.Reset(s)
|
||||
err = elem.DecodeText(textElementReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elements[i] = elem
|
||||
}
|
||||
}
|
||||
|
||||
*dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dst *InetArray) DecodeBinary(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
*dst = InetArray{Status: Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
var arrayHeader ArrayHeader
|
||||
err = arrayHeader.DecodeBinary(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(arrayHeader.Dimensions) == 0 {
|
||||
*dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
elementCount := arrayHeader.Dimensions[0].Length
|
||||
for _, d := range arrayHeader.Dimensions[1:] {
|
||||
elementCount *= d.Length
|
||||
}
|
||||
|
||||
elements := make([]Inet, elementCount)
|
||||
|
||||
for i := range elements {
|
||||
err = elements[i].DecodeBinary(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *InetArray) EncodeText(w io.Writer) error {
|
||||
if done, err := encodeNotPresent(w, src.Status); done {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(src.Dimensions) == 0 {
|
||||
_, err := pgio.WriteInt32(w, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = w.Write([]byte("{}"))
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
err := EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||
// or '}'.
|
||||
dimElemCounts := make([]int, len(src.Dimensions))
|
||||
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||
}
|
||||
|
||||
textElementWriter := NewTextElementWriter(buf)
|
||||
|
||||
for i, elem := range src.Elements {
|
||||
if i > 0 {
|
||||
err = pgio.WriteByte(buf, ',')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if i%dec == 0 {
|
||||
err = pgio.WriteByte(buf, '{')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textElementWriter.Reset()
|
||||
err = elem.EncodeText(textElementWriter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, dec := range dimElemCounts {
|
||||
if (i+1)%dec == 0 {
|
||||
err = pgio.WriteByte(buf, '}')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err = pgio.WriteInt32(w, int32(buf.Len()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
|
||||
func (src *InetArray) EncodeBinary(w io.Writer) error {
|
||||
return src.encodeBinary(w, InetOID)
|
||||
}
|
||||
|
||||
func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) error {
|
||||
if done, err := encodeNotPresent(w, src.Status); done {
|
||||
return err
|
||||
}
|
||||
|
||||
var arrayHeader ArrayHeader
|
||||
|
||||
// TODO - consider how to avoid having to buffer array before writing length -
|
||||
// or how not pay allocations for the byte order conversions.
|
||||
elemBuf := &bytes.Buffer{}
|
||||
|
||||
for i := range src.Elements {
|
||||
err := src.Elements[i].EncodeBinary(elemBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if src.Elements[i].Status == Null {
|
||||
arrayHeader.ContainsNull = true
|
||||
}
|
||||
}
|
||||
|
||||
arrayHeader.ElementOID = elementOID
|
||||
arrayHeader.Dimensions = src.Dimensions
|
||||
|
||||
// TODO - consider how to avoid having to buffer array before writing length -
|
||||
// or how not pay allocations for the byte order conversions.
|
||||
headerBuf := &bytes.Buffer{}
|
||||
err := arrayHeader.EncodeBinary(headerBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = headerBuf.WriteTo(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = elemBuf.WriteTo(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func TestInetArrayTranscode(t *testing.T) {
|
||||
testSuccessfulTranscode(t, "inet[]", []interface{}{
|
||||
&pgtype.InetArray{
|
||||
Elements: nil,
|
||||
Dimensions: nil,
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{Status: pgtype.Null},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.InetArray{Status: pgtype.Null},
|
||||
&pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
|
||||
pgtype.Inet{Status: pgtype.Null},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
&pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present},
|
||||
pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
|
||||
},
|
||||
Dimensions: []pgtype.ArrayDimension{
|
||||
{Length: 2, LowerBound: 4},
|
||||
{Length: 2, LowerBound: 2},
|
||||
},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestInetArrayConvertFrom(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
result pgtype.InetArray
|
||||
}{
|
||||
{
|
||||
source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")},
|
||||
result: pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([]*net.IPNet)(nil)),
|
||||
result: pgtype.InetArray{Status: pgtype.Null},
|
||||
},
|
||||
{
|
||||
source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP},
|
||||
result: pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: (([]net.IP)(nil)),
|
||||
result: pgtype.InetArray{Status: pgtype.Null},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
var r pgtype.InetArray
|
||||
err := r.ConvertFrom(tt.source)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, tt.result) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInetArrayAssignTo(t *testing.T) {
|
||||
var ipnetSlice []*net.IPNet
|
||||
var ipSlice []net.IP
|
||||
|
||||
simpleTests := []struct {
|
||||
src pgtype.InetArray
|
||||
dst interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
src: pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipnetSlice,
|
||||
expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")},
|
||||
},
|
||||
{
|
||||
src: pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{{Status: pgtype.Null}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipnetSlice,
|
||||
expected: []*net.IPNet{nil},
|
||||
},
|
||||
{
|
||||
src: pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipSlice,
|
||||
expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP},
|
||||
},
|
||||
{
|
||||
src: pgtype.InetArray{
|
||||
Elements: []pgtype.Inet{{Status: pgtype.Null}},
|
||||
Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
dst: &ipSlice,
|
||||
expected: []net.IP{nil},
|
||||
},
|
||||
{
|
||||
src: pgtype.InetArray{Status: pgtype.Null},
|
||||
dst: &ipnetSlice,
|
||||
expected: (([]*net.IPNet)(nil)),
|
||||
},
|
||||
{
|
||||
src: pgtype.InetArray{Status: pgtype.Null},
|
||||
dst: &ipSlice,
|
||||
expected: (([]net.IP)(nil)),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range simpleTests {
|
||||
err := tt.src.AssignTo(tt.dst)
|
||||
if err != nil {
|
||||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
|
||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package pgtype_test
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -44,6 +45,15 @@ func mustClose(t testing.TB, conn interface {
|
|||
}
|
||||
}
|
||||
|
||||
func mustParseCIDR(t testing.TB, s string) *net.IPNet {
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ipnet
|
||||
}
|
||||
|
||||
type forceTextEncoder struct {
|
||||
e pgtype.TextEncoder
|
||||
}
|
||||
|
|
|
@ -7,3 +7,4 @@ erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_
|
|||
erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go
|
||||
erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go
|
||||
erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go
|
||||
erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go
|
||||
|
|
28
values.go
28
values.go
|
@ -1088,14 +1088,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error {
|
|||
// The name data type goes over the wire using the same format as string,
|
||||
// so just cast to string and use encodeString
|
||||
return encodeString(wbuf, oid, string(arg))
|
||||
case net.IP:
|
||||
return encodeIP(wbuf, oid, arg)
|
||||
case []net.IP:
|
||||
return encodeIPSlice(wbuf, oid, arg)
|
||||
case net.IPNet:
|
||||
return encodeIPNet(wbuf, oid, arg)
|
||||
case []net.IPNet:
|
||||
return encodeIPNetSlice(wbuf, oid, arg)
|
||||
case OID:
|
||||
return encodeOID(wbuf, oid, arg)
|
||||
case Xid:
|
||||
|
@ -1195,26 +1187,6 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||
*v = decodeByteaArray(vr)
|
||||
case *[]interface{}:
|
||||
*v = decodeRecord(vr)
|
||||
case *net.IP:
|
||||
ipnet := decodeInet(vr)
|
||||
if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
|
||||
return fmt.Errorf("Cannot decode netmask into *net.IP")
|
||||
}
|
||||
*v = ipnet.IP
|
||||
case *[]net.IP:
|
||||
ipnets := decodeInetArray(vr)
|
||||
ips := make([]net.IP, len(ipnets))
|
||||
for i, ipnet := range ipnets {
|
||||
if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
|
||||
return fmt.Errorf("Cannot decode netmask into *net.IP")
|
||||
}
|
||||
ips[i] = ipnet.IP
|
||||
}
|
||||
*v = ips
|
||||
case *net.IPNet:
|
||||
*v = decodeInet(vr)
|
||||
case *[]net.IPNet:
|
||||
*v = decodeInetArray(vr)
|
||||
default:
|
||||
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
|
|
|
@ -232,13 +232,13 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16)
|
|||
}
|
||||
}
|
||||
|
||||
func mustParseCIDR(t *testing.T, s string) net.IPNet {
|
||||
func mustParseCIDR(t *testing.T, s string) *net.IPNet {
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return *ipnet
|
||||
return ipnet
|
||||
}
|
||||
|
||||
func TestStringToNotTextTypeTranscode(t *testing.T) {
|
||||
|
@ -275,7 +275,7 @@ func TestInetCidrTranscodeIPNet(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
sql string
|
||||
value net.IPNet
|
||||
value *net.IPNet
|
||||
}{
|
||||
{"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")},
|
||||
{"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")},
|
||||
|
@ -358,7 +358,7 @@ func TestInetCidrTranscodeIP(t *testing.T) {
|
|||
|
||||
failTests := []struct {
|
||||
sql string
|
||||
value net.IPNet
|
||||
value *net.IPNet
|
||||
}{
|
||||
{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
|
||||
{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
|
||||
|
@ -367,8 +367,8 @@ func TestInetCidrTranscodeIP(t *testing.T) {
|
|||
var actual net.IP
|
||||
|
||||
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||
if !strings.Contains(err.Error(), "Cannot decode netmask") {
|
||||
t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||
if err == nil {
|
||||
t.Errorf("%d. Expected failure but got none", i)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -384,11 +384,11 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
sql string
|
||||
value []net.IPNet
|
||||
value []*net.IPNet
|
||||
}{
|
||||
{
|
||||
"select $1::inet[]",
|
||||
[]net.IPNet{
|
||||
[]*net.IPNet{
|
||||
mustParseCIDR(t, "0.0.0.0/32"),
|
||||
mustParseCIDR(t, "127.0.0.1/32"),
|
||||
mustParseCIDR(t, "12.34.56.0/32"),
|
||||
|
@ -403,7 +403,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) {
|
|||
},
|
||||
{
|
||||
"select $1::cidr[]",
|
||||
[]net.IPNet{
|
||||
[]*net.IPNet{
|
||||
mustParseCIDR(t, "0.0.0.0/32"),
|
||||
mustParseCIDR(t, "127.0.0.1/32"),
|
||||
mustParseCIDR(t, "12.34.56.0/32"),
|
||||
|
@ -419,7 +419,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var actual []net.IPNet
|
||||
var actual []*net.IPNet
|
||||
|
||||
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||
if err != nil {
|
||||
|
@ -485,18 +485,18 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) {
|
|||
|
||||
failTests := []struct {
|
||||
sql string
|
||||
value []net.IPNet
|
||||
value []*net.IPNet
|
||||
}{
|
||||
{
|
||||
"select $1::inet[]",
|
||||
[]net.IPNet{
|
||||
[]*net.IPNet{
|
||||
mustParseCIDR(t, "12.34.56.0/32"),
|
||||
mustParseCIDR(t, "192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"select $1::cidr[]",
|
||||
[]net.IPNet{
|
||||
[]*net.IPNet{
|
||||
mustParseCIDR(t, "12.34.56.0/32"),
|
||||
mustParseCIDR(t, "192.168.1.0/24"),
|
||||
},
|
||||
|
@ -507,8 +507,8 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) {
|
|||
var actual []net.IP
|
||||
|
||||
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||
if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") {
|
||||
t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||
if err == nil {
|
||||
t.Errorf("%d. Expected failure but got none", i)
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue