diff --git a/conn.go b/conn.go index 8b8ac7a8..ac1e56f1 100644 --- a/conn.go +++ b/conn.go @@ -57,6 +57,8 @@ type Conn struct { logger Logger mr msgReader fp *fastpath + pgsql_af_inet byte + pgsql_af_inet6 byte } type PreparedStatement struct { @@ -222,6 +224,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return err } + err = c.loadInetConstants() + if err != nil { + return err + } + return nil default: if err = c.processContextFreeMsg(t, r); err != nil { @@ -254,6 +261,23 @@ func (c *Conn) loadPgTypes() error { return rows.Err() } +// Family is needed for binary encoding of inet/cidr. The constant is based on +// the server's definition of AF_INET. In theory, this could differ between +// platforms, so request an IPv4 and an IPv6 inet and get the family from that. +func (c *Conn) loadInetConstants() error { + var ipv4, ipv6 []byte + + err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6) + if err != nil { + return err + } + + c.pgsql_af_inet = ipv4[0] + c.pgsql_af_inet6 = ipv6[0] + + return nil +} + // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { @@ -261,7 +285,7 @@ func (c *Conn) Close() (err error) { return nil } - wbuf := newWriteBuf(c.wbuf[0:0], 'X') + wbuf := newWriteBuf(c, 'X') wbuf.closeMsg() _, err = c.conn.Write(wbuf.buf) @@ -442,7 +466,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { }() // parse - wbuf := newWriteBuf(c.wbuf[0:0], 'P') + wbuf := newWriteBuf(c, 'P') wbuf.WriteCString(name) wbuf.WriteCString(sql) wbuf.WriteInt16(0) @@ -509,7 +533,7 @@ func (c *Conn) Deallocate(name string) (err error) { delete(c.preparedStatements, name) // close - wbuf := newWriteBuf(c.wbuf[0:0], 'C') + wbuf := newWriteBuf(c, 'C') wbuf.WriteByte('S') wbuf.WriteCString(name) @@ -667,7 +691,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { if len(args) == 0 { - wbuf := newWriteBuf(c.wbuf[0:0], 'Q') + wbuf := newWriteBuf(c, 'Q') wbuf.WriteCString(sql) wbuf.closeMsg() @@ -694,7 +718,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } // bind - wbuf := newWriteBuf(c.wbuf[0:0], 'B') + wbuf := newWriteBuf(c, 'B') wbuf.WriteByte(0) wbuf.WriteCString(ps.Name) @@ -707,7 +731,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -1050,7 +1074,7 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error { } func (c *Conn) txPasswordMessage(password string) (err error) { - wbuf := newWriteBuf(c.wbuf[0:0], 'p') + wbuf := newWriteBuf(c, 'p') wbuf.WriteCString(password) wbuf.closeMsg() diff --git a/fastpath.go b/fastpath.go index 5eee1ea1..8814e559 100644 --- a/fastpath.go +++ b/fastpath.go @@ -50,11 +50,11 @@ func fpInt64Arg(n int64) fpArg { } func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { - wbuf := newWriteBuf(f.cn.wbuf[:0], 'F') // function call - wbuf.WriteInt32(int32(oid)) // function object id - wbuf.WriteInt16(1) // # of argument format codes - wbuf.WriteInt16(1) // format code: binary - wbuf.WriteInt16(int16(len(args))) // # of arguments + wbuf := newWriteBuf(f.cn, 'F') // function call + wbuf.WriteInt32(int32(oid)) // function object id + wbuf.WriteInt16(1) // # of argument format codes + wbuf.WriteInt16(1) // format code: binary + wbuf.WriteInt16(int16(len(args))) // # of arguments for _, arg := range args { wbuf.WriteInt32(int32(len(arg))) // length of argument wbuf.WriteBytes(arg) // argument value diff --git a/messages.go b/messages.go index 16d7301e..7cc244c8 100644 --- a/messages.go +++ b/messages.go @@ -89,9 +89,9 @@ func (self PgError) Error() string { return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")" } -func newWriteBuf(buf []byte, t byte) *WriteBuf { - buf = append(buf, t, 0, 0, 0, 0) - return &WriteBuf{buf: buf, sizeIdx: 1} +func newWriteBuf(c *Conn, t byte) *WriteBuf { + buf := append(c.wbuf[0:0], t, 0, 0, 0, 0) + return &WriteBuf{buf: buf, sizeIdx: 1, conn: c} } // WrifeBuf is used build messages to send to the PostgreSQL server. It is used @@ -99,6 +99,7 @@ func newWriteBuf(buf []byte, t byte) *WriteBuf { type WriteBuf struct { buf []byte sizeIdx int + conn *Conn } func (wb *WriteBuf) startMsg(t byte) { diff --git a/query_test.go b/query_test.go index b9e80f8c..206b8a82 100644 --- a/query_test.go +++ b/query_test.go @@ -422,8 +422,8 @@ func TestQueryRowUnknownType(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - sql := "select $1::inet" - expected := "127.0.0.1" + sql := "select $1::point" + expected := "(1,0)" var actual string err := conn.QueryRow(sql, expected).Scan(&actual) diff --git a/values.go b/values.go index 22797d81..e1c5cd2c 100644 --- a/values.go +++ b/values.go @@ -54,21 +54,23 @@ var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = make(map[string]int16) + DefaultTypeFormats["_bool"] = BinaryFormatCode DefaultTypeFormats["_float4"] = BinaryFormatCode DefaultTypeFormats["_float8"] = BinaryFormatCode - DefaultTypeFormats["_bool"] = BinaryFormatCode DefaultTypeFormats["_int2"] = BinaryFormatCode DefaultTypeFormats["_int4"] = BinaryFormatCode DefaultTypeFormats["_int8"] = BinaryFormatCode DefaultTypeFormats["_text"] = BinaryFormatCode - DefaultTypeFormats["_varchar"] = BinaryFormatCode DefaultTypeFormats["_timestamp"] = BinaryFormatCode DefaultTypeFormats["_timestamptz"] = BinaryFormatCode + DefaultTypeFormats["_varchar"] = BinaryFormatCode DefaultTypeFormats["bool"] = BinaryFormatCode DefaultTypeFormats["bytea"] = BinaryFormatCode + DefaultTypeFormats["cidr"] = BinaryFormatCode DefaultTypeFormats["date"] = BinaryFormatCode DefaultTypeFormats["float4"] = BinaryFormatCode DefaultTypeFormats["float8"] = BinaryFormatCode + DefaultTypeFormats["inet"] = BinaryFormatCode DefaultTypeFormats["int2"] = BinaryFormatCode DefaultTypeFormats["int4"] = BinaryFormatCode DefaultTypeFormats["int8"] = BinaryFormatCode @@ -1112,41 +1114,32 @@ func decodeInet(vr *ValueReader) net.IPNet { return zero } + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return zero + } + pgType := vr.Type() + if vr.Len() != 8 && vr.Len() != 20 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) + return zero + } + if pgType.DataType != InetOid && pgType.DataType != CidrOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) return zero } - s := vr.ReadString(vr.Len()) - hasNetmask := strings.ContainsRune(s, '/') - if !hasNetmask { - isIpv6 := strings.ContainsRune(s, ':') - if isIpv6 { - s += "/128" - } else { - s += "/32" - } - } + vr.ReadByte() // ignore family + bits := vr.ReadByte() + vr.ReadByte() // ignore is_cidr + addressLength := vr.ReadByte() - _, ipnet, err := net.ParseCIDR(s) - if err != nil { - vr.Fatal(err) - return zero - } - - // if vr.Type().FormatCode != BinaryFormatCode { - // vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - // return zero - // } - - // if vr.Len() != 4 { - // vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) - // return zero - // } - - return *ipnet + var ipnet net.IPNet + ipnet.IP = vr.ReadBytes(int32(addressLength)) + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + return ipnet } func encodeInet(w *WriteBuf, value interface{}) error { @@ -1159,10 +1152,26 @@ func encodeInet(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected net.IPNet, received %T %v", value, value) } - s := ipnet.String() + var size int32 + var family byte + switch len(ipnet.IP) { + case net.IPv4len: + size = 8 + family = w.conn.pgsql_af_inet + case net.IPv6len: + size = 20 + family = w.conn.pgsql_af_inet6 + default: + return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP)) + } - w.WriteInt32(int32(len(s))) - w.WriteBytes([]byte(s)) + w.WriteInt32(size) + w.WriteByte(family) + ones, _ := ipnet.Mask.Size() + w.WriteByte(byte(ones)) + w.WriteByte(0) // is_cidr is ignored on server + w.WriteByte(byte(len(ipnet.IP))) + w.WriteBytes(ipnet.IP) return nil } diff --git a/values_test.go b/values_test.go index bf866fef..f191d479 100644 --- a/values_test.go +++ b/values_test.go @@ -137,6 +137,7 @@ func TestInetCidrTranscode(t *testing.T) { err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue } if actual.String() != tt.value.String() {