diff --git a/README.md b/README.md index 9adf50c4..d998d01e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Pgx supports many additional features beyond what is available through database/ * PostgreSQL array to Go slice mapping for integers, floats, and strings * Hstore support * JSON and JSONB support -* Maps inet and cidr PostgreSQL types to net.IPNet +* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP * Large object support * Null mapping to Null* struct or pointer to pointer. * Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types diff --git a/values.go b/values.go index 032d774a..06f49a9c 100644 --- a/values.go +++ b/values.go @@ -764,6 +764,22 @@ func Decode(vr *ValueReader, d interface{}) error { default: return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) } + case *net.IP: + ipnet := decodeInet(vr) + if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("Cannot decode netmask into *net.IP") + } + *v = ipnet.IP + case *[]net.IP: + ipnets := decodeInetArray(vr) + ips := make([]net.IP, len(ipnets)) + for i, ipnet := range ipnets { + if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("Cannot decode netmask into *net.IP") + } + ips[i] = ipnet.IP + } + *v = ips case *net.IPNet: *v = decodeInet(vr) case *[]net.IPNet: @@ -1436,15 +1452,14 @@ func decodeInet(vr *ValueReader) net.IPNet { } pgType := vr.Type() - if vr.Len() != 8 && vr.Len() != 20 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) - return zero - } - if pgType.DataType != InetOid && pgType.DataType != CidrOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) return zero } + if vr.Len() != 8 && vr.Len() != 20 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) + return zero + } vr.ReadByte() // ignore family bits := vr.ReadByte() diff --git a/values_test.go b/values_test.go index f6ddc623..a729dac3 100644 --- a/values_test.go +++ b/values_test.go @@ -1,12 +1,13 @@ package pgx_test import ( - "github.com/jackc/pgx" "net" "reflect" "strings" "testing" "time" + + "github.com/jackc/pgx" ) func TestDateTranscode(t *testing.T) { @@ -258,7 +259,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) { } } -func TestInetCidrTranscode(t *testing.T) { +func TestInetCidrTranscodeIPNet(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -307,7 +308,67 @@ func TestInetCidrTranscode(t *testing.T) { } } -func TestInetCidrArrayTranscode(t *testing.T) { +func TestInetCidrTranscodeIP(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value net.IP + }{ + {"select $1::inet", net.ParseIP("0.0.0.0")}, + {"select $1::inet", net.ParseIP("127.0.0.1")}, + {"select $1::inet", net.ParseIP("12.34.56.0")}, + {"select $1::inet", net.ParseIP("255.255.255.255")}, + {"select $1::inet", net.ParseIP("::1")}, + {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, + {"select $1::cidr", net.ParseIP("0.0.0.0")}, + {"select $1::cidr", net.ParseIP("127.0.0.1")}, + {"select $1::cidr", net.ParseIP("12.34.56.0")}, + {"select $1::cidr", net.ParseIP("255.255.255.255")}, + {"select $1::cidr", net.ParseIP("::1")}, + {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, + } + + for i, tt := range tests { + var actual net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !actual.Equal(tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + + failTests := []struct { + sql string + value net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + } + for i, tt := range failTests { + var actual net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if !strings.Contains(err.Error(), "Cannot decode netmask") { + t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + ensureConnValid(t, conn) + } +} + +func TestInetCidrArrayTranscodeIPNet(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -366,6 +427,87 @@ func TestInetCidrArrayTranscode(t *testing.T) { } } +func TestInetCidrArrayTranscodeIP(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value []net.IP + }{ + { + "select $1::inet[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, + }, + { + "select $1::cidr[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, + }, + } + + for i, tt := range tests { + var actual []net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !reflect.DeepEqual(actual, tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + + failTests := []struct { + sql string + value []net.IPNet + }{ + { + "select $1::inet[]", + []net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, + }, + { + "select $1::cidr[]", + []net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, + }, + } + + for i, tt := range failTests { + var actual []net.IP + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") { + t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + ensureConnValid(t, conn) + } +} + func TestInetCidrTranscodeWithJustIP(t *testing.T) { t.Parallel()