diff --git a/inet.go b/inet.go index 25e56170..a343f5e2 100644 --- a/inet.go +++ b/inet.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "fmt" "net" + "strings" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -52,17 +53,17 @@ func (dst *Inet) Set(src interface{}) error { return fmt.Errorf("unable to parse inet address: %s", value) } - if ipv4 := ip.To4(); ipv4 != nil { + if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil { ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} } else { ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} } } else { ipnet.IP = ip - if ipv4 := ipnet.IP.To4(); ipv4 != nil { + if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil { ipnet.IP = ipv4 if len(ipnet.Mask) == 16 { - ipnet.Mask = ipnet.Mask[12:] // Needed if input is IPv4-mapped IPv6. + ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed. } } } @@ -96,6 +97,25 @@ func (dst *Inet) Set(src interface{}) error { return nil } +// Convert the net.IP to IPv4, if appropriate. +// +// When parsing a string to a net.IP using net.ParseIP() and the like, we get a +// 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function +// calls To4() to convert them to a 4 byte slice. This is useful as it allows +// users of the net.IP check for IPv4 addresses based on the length and makes +// it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6 +// addresses. +func maybeGetIPv4(input string, ip net.IP) net.IP { + // Do not do this if the provided input looks like IPv6. This is because + // To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave + // different in some cases. + if strings.Contains(input, ":") { + return nil + } + + return ip.To4() +} + func (dst Inet) Get() interface{} { switch dst.Status { case Present: diff --git a/inet_test.go b/inet_test.go index badbf82e..52759371 100644 --- a/inet_test.go +++ b/inet_test.go @@ -57,7 +57,7 @@ func TestInetSet(t *testing.T) { {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, {source: "0.0.0.0/8", result: pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/8"), Status: pgtype.Present}}, - {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("0.0.0.0").To4(), Mask: net.CIDRMask(8, 32)}, Status: pgtype.Present}}, + {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("::ffff:0.0.0.0"), Mask: net.CIDRMask(104, 128)}, Status: pgtype.Present}}, } for i, tt := range successfulTests {