diff --git a/inet.go b/inet.go index 43f7252a..1645334e 100644 --- a/inet.go +++ b/inet.go @@ -132,18 +132,22 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { var err error if ip := net.ParseIP(string(src)); ip != nil { - ipv4 := ip.To4() - if ipv4 != nil { + if ipv4 := ip.To4(); ipv4 != nil { ip = ipv4 } bitCount := len(ip) * 8 mask := net.CIDRMask(bitCount, bitCount) ipnet = &net.IPNet{Mask: mask, IP: ip} } else { - _, ipnet, err = net.ParseCIDR(string(src)) + ip, ipnet, err = net.ParseCIDR(string(src)) if err != nil { return err } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + ones, _ := ipnet.Mask.Size() + *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} } *dst = Inet{IPNet: ipnet, Status: Present} @@ -168,7 +172,10 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { var ipnet net.IPNet ipnet.IP = make(net.IP, int(addressLength)) copy(ipnet.IP, src[4:]) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) *dst = Inet{IPNet: &ipnet, Status: Present} diff --git a/inet_test.go b/inet_test.go index 08d73e4e..66fe777f 100644 --- a/inet_test.go +++ b/inet_test.go @@ -11,22 +11,35 @@ import ( ) func TestInetTranscode(t *testing.T) { - for _, pgTypeName := range []string{"inet", "cidr"} { - testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.50/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - &pgtype.Inet{Status: pgtype.Null}, - }) - } + testutil.TestSuccessfulTranscode(t, "inet", []interface{}{ + &pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) +} + +func TestCidrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{ + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) } func TestInetSet(t *testing.T) { diff --git a/pgtype_test.go b/pgtype_test.go index f46ec12a..75e1909f 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -35,6 +35,20 @@ func mustParseCIDR(t testing.TB, s string) *net.IPNet { return ipnet } +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + + ipnet.IP = ip + + return ipnet +} + func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { addr, err := net.ParseMAC(s) if err != nil {