From 3eceab0f382295901243f9b43973108c36ee4d1a Mon Sep 17 00:00:00 2001
From: Cameron Daniel <cam.daniel@gmail.com>
Date: Wed, 30 Jun 2021 14:22:26 +0200
Subject: [PATCH] Maintain host bits for inet types

---
 inet.go        | 15 +++++++++++----
 inet_test.go   | 45 +++++++++++++++++++++++++++++----------------
 pgtype_test.go | 14 ++++++++++++++
 3 files changed, 54 insertions(+), 20 deletions(-)

diff --git a/inet.go b/inet.go
index 43f7252a..1645334e 100644
--- a/inet.go
+++ b/inet.go
@@ -132,18 +132,22 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
 	var err error
 
 	if ip := net.ParseIP(string(src)); ip != nil {
-		ipv4 := ip.To4()
-		if ipv4 != nil {
+		if ipv4 := ip.To4(); ipv4 != nil {
 			ip = ipv4
 		}
 		bitCount := len(ip) * 8
 		mask := net.CIDRMask(bitCount, bitCount)
 		ipnet = &net.IPNet{Mask: mask, IP: ip}
 	} else {
-		_, ipnet, err = net.ParseCIDR(string(src))
+		ip, ipnet, err = net.ParseCIDR(string(src))
 		if err != nil {
 			return err
 		}
+		if ipv4 := ip.To4(); ipv4 != nil {
+			ip = ipv4
+		}
+		ones, _ := ipnet.Mask.Size()
+		*ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)}
 	}
 
 	*dst = Inet{IPNet: ipnet, Status: Present}
@@ -168,7 +172,10 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error {
 	var ipnet net.IPNet
 	ipnet.IP = make(net.IP, int(addressLength))
 	copy(ipnet.IP, src[4:])
-	ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
+	if ipv4 := ipnet.IP.To4(); ipv4 != nil {
+		ipnet.IP = ipv4
+	}
+	ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8)
 
 	*dst = Inet{IPNet: &ipnet, Status: Present}
 
diff --git a/inet_test.go b/inet_test.go
index 08d73e4e..66fe777f 100644
--- a/inet_test.go
+++ b/inet_test.go
@@ -11,22 +11,35 @@ import (
 )
 
 func TestInetTranscode(t *testing.T) {
-	for _, pgTypeName := range []string{"inet", "cidr"} {
-		testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.50/24"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present},
-			&pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
-			&pgtype.Inet{Status: pgtype.Null},
-		})
-	}
+	testutil.TestSuccessfulTranscode(t, "inet", []interface{}{
+		&pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "127.0.0.1/8"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "12.34.56.65/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e/64"), Status: pgtype.Present},
+		&pgtype.Inet{Status: pgtype.Null},
+	})
+}
+
+func TestCidrTranscode(t *testing.T) {
+	testutil.TestSuccessfulTranscode(t, "cidr", []interface{}{
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present},
+		&pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present},
+		&pgtype.Inet{Status: pgtype.Null},
+	})
 }
 
 func TestInetSet(t *testing.T) {
diff --git a/pgtype_test.go b/pgtype_test.go
index f46ec12a..75e1909f 100644
--- a/pgtype_test.go
+++ b/pgtype_test.go
@@ -35,6 +35,20 @@ func mustParseCIDR(t testing.TB, s string) *net.IPNet {
 	return ipnet
 }
 
+func mustParseInet(t testing.TB, s string) *net.IPNet {
+	ip, ipnet, err := net.ParseCIDR(s)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if ipv4 := ip.To4(); ipv4 != nil {
+		ip = ipv4
+	}
+
+	ipnet.IP = ip
+
+	return ipnet
+}
+
 func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr {
 	addr, err := net.ParseMAC(s)
 	if err != nil {