From 8b296b9d5840d0624dd2f61c990b5304a5e6fcaa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 9 Sep 2015 18:49:20 -0500 Subject: [PATCH] Encode from net.IP to inet and cidr --- CHANGELOG.md | 1 + doc.go | 6 ++++++ values.go | 4 ++++ values_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53691d56..5fcea692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Master +* Encode from net.IP to inet and cidr * Generalize encoding pointer to string to any PostgreSQL type * Add UUID encoding from pointer to string (Joseph Glanville) * Add null mapping to pointer to pointer (Jonathan Rudenberg) diff --git a/doc.go b/doc.go index 93145107..edbfd79d 100644 --- a/doc.go +++ b/doc.go @@ -150,6 +150,12 @@ JSON and JSONB Mapping pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. +Inet and Cidr Mapping + +pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In +addition, as a convenience pgx will encode from a net.IP; it will assume a /32 +netmask for IPv4 and a /128 for IPv6. + Custom Type Support pgx includes support for the common data types like integers, floats, strings, diff --git a/values.go b/values.go index 2fdf32ae..2a1ca8b8 100644 --- a/values.go +++ b/values.go @@ -1196,6 +1196,10 @@ func encodeInet(w *WriteBuf, value interface{}) error { switch value := value.(type) { case net.IPNet: ipnet = value + case net.IP: + ipnet.IP = value + bitCount := len(value) * 8 + ipnet.Mask = net.CIDRMask(bitCount, bitCount) default: return fmt.Errorf("Expected net.IPNet, received %T %v", value, value) } diff --git a/values_test.go b/values_test.go index b93542d5..490a358e 100644 --- a/values_test.go +++ b/values_test.go @@ -274,6 +274,48 @@ func TestInetCidrTranscode(t *testing.T) { } } +func TestInetCidrTranscodeWithJustIP(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value string + }{ + {"select $1::inet", "0.0.0.0/32"}, + {"select $1::inet", "127.0.0.1/32"}, + {"select $1::inet", "12.34.56.0/32"}, + {"select $1::inet", "255.255.255.255/32"}, + {"select $1::inet", "::/128"}, + {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, + {"select $1::cidr", "0.0.0.0/32"}, + {"select $1::cidr", "127.0.0.1/32"}, + {"select $1::cidr", "12.34.56.0/32"}, + {"select $1::cidr", "255.255.255.255/32"}, + {"select $1::cidr", "::/128"}, + {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, + } + + for i, tt := range tests { + expected := mustParseCIDR(t, tt.value) + var actual net.IPNet + + err := conn.QueryRow(tt.sql, expected.IP).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if actual.String() != expected.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + func TestNullX(t *testing.T) { t.Parallel()