Use binary transcoding for inet/cidr

fixes #87
pull/78/merge
Jack Christensen 2015-09-03 11:39:32 -05:00
parent 9af068add0
commit fd39261551
6 changed files with 85 additions and 50 deletions

38
conn.go
View File

@ -57,6 +57,8 @@ type Conn struct {
logger Logger
mr msgReader
fp *fastpath
pgsql_af_inet byte
pgsql_af_inet6 byte
}
type PreparedStatement struct {
@ -222,6 +224,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
return err
}
err = c.loadInetConstants()
if err != nil {
return err
}
return nil
default:
if err = c.processContextFreeMsg(t, r); err != nil {
@ -254,6 +261,23 @@ func (c *Conn) loadPgTypes() error {
return rows.Err()
}
// Family is needed for binary encoding of inet/cidr. The constant is based on
// the server's definition of AF_INET. In theory, this could differ between
// platforms, so request an IPv4 and an IPv6 inet and get the family from that.
func (c *Conn) loadInetConstants() error {
var ipv4, ipv6 []byte
err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6)
if err != nil {
return err
}
c.pgsql_af_inet = ipv4[0]
c.pgsql_af_inet6 = ipv6[0]
return nil
}
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close() (err error) {
@ -261,7 +285,7 @@ func (c *Conn) Close() (err error) {
return nil
}
wbuf := newWriteBuf(c.wbuf[0:0], 'X')
wbuf := newWriteBuf(c, 'X')
wbuf.closeMsg()
_, err = c.conn.Write(wbuf.buf)
@ -442,7 +466,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
}()
// parse
wbuf := newWriteBuf(c.wbuf[0:0], 'P')
wbuf := newWriteBuf(c, 'P')
wbuf.WriteCString(name)
wbuf.WriteCString(sql)
wbuf.WriteInt16(0)
@ -509,7 +533,7 @@ func (c *Conn) Deallocate(name string) (err error) {
delete(c.preparedStatements, name)
// close
wbuf := newWriteBuf(c.wbuf[0:0], 'C')
wbuf := newWriteBuf(c, 'C')
wbuf.WriteByte('S')
wbuf.WriteCString(name)
@ -667,7 +691,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
if len(args) == 0 {
wbuf := newWriteBuf(c.wbuf[0:0], 'Q')
wbuf := newWriteBuf(c, 'Q')
wbuf.WriteCString(sql)
wbuf.closeMsg()
@ -694,7 +718,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
}
// bind
wbuf := newWriteBuf(c.wbuf[0:0], 'B')
wbuf := newWriteBuf(c, 'B')
wbuf.WriteByte(0)
wbuf.WriteCString(ps.Name)
@ -707,7 +731,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(TextFormatCode)
default:
switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid:
wbuf.WriteInt16(BinaryFormatCode)
default:
wbuf.WriteInt16(TextFormatCode)
@ -1050,7 +1074,7 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error {
}
func (c *Conn) txPasswordMessage(password string) (err error) {
wbuf := newWriteBuf(c.wbuf[0:0], 'p')
wbuf := newWriteBuf(c, 'p')
wbuf.WriteCString(password)
wbuf.closeMsg()

View File

@ -50,11 +50,11 @@ func fpInt64Arg(n int64) fpArg {
}
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
wbuf := newWriteBuf(f.cn.wbuf[:0], 'F') // function call
wbuf.WriteInt32(int32(oid)) // function object id
wbuf.WriteInt16(1) // # of argument format codes
wbuf.WriteInt16(1) // format code: binary
wbuf.WriteInt16(int16(len(args))) // # of arguments
wbuf := newWriteBuf(f.cn, 'F') // function call
wbuf.WriteInt32(int32(oid)) // function object id
wbuf.WriteInt16(1) // # of argument format codes
wbuf.WriteInt16(1) // format code: binary
wbuf.WriteInt16(int16(len(args))) // # of arguments
for _, arg := range args {
wbuf.WriteInt32(int32(len(arg))) // length of argument
wbuf.WriteBytes(arg) // argument value

View File

@ -89,9 +89,9 @@ func (self PgError) Error() string {
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
}
func newWriteBuf(buf []byte, t byte) *WriteBuf {
buf = append(buf, t, 0, 0, 0, 0)
return &WriteBuf{buf: buf, sizeIdx: 1}
func newWriteBuf(c *Conn, t byte) *WriteBuf {
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
return &WriteBuf{buf: buf, sizeIdx: 1, conn: c}
}
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
@ -99,6 +99,7 @@ func newWriteBuf(buf []byte, t byte) *WriteBuf {
type WriteBuf struct {
buf []byte
sizeIdx int
conn *Conn
}
func (wb *WriteBuf) startMsg(t byte) {

View File

@ -422,8 +422,8 @@ func TestQueryRowUnknownType(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := "select $1::inet"
expected := "127.0.0.1"
sql := "select $1::point"
expected := "(1,0)"
var actual string
err := conn.QueryRow(sql, expected).Scan(&actual)

View File

@ -54,21 +54,23 @@ var DefaultTypeFormats map[string]int16
func init() {
DefaultTypeFormats = make(map[string]int16)
DefaultTypeFormats["_bool"] = BinaryFormatCode
DefaultTypeFormats["_float4"] = BinaryFormatCode
DefaultTypeFormats["_float8"] = BinaryFormatCode
DefaultTypeFormats["_bool"] = BinaryFormatCode
DefaultTypeFormats["_int2"] = BinaryFormatCode
DefaultTypeFormats["_int4"] = BinaryFormatCode
DefaultTypeFormats["_int8"] = BinaryFormatCode
DefaultTypeFormats["_text"] = BinaryFormatCode
DefaultTypeFormats["_varchar"] = BinaryFormatCode
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
DefaultTypeFormats["_varchar"] = BinaryFormatCode
DefaultTypeFormats["bool"] = BinaryFormatCode
DefaultTypeFormats["bytea"] = BinaryFormatCode
DefaultTypeFormats["cidr"] = BinaryFormatCode
DefaultTypeFormats["date"] = BinaryFormatCode
DefaultTypeFormats["float4"] = BinaryFormatCode
DefaultTypeFormats["float8"] = BinaryFormatCode
DefaultTypeFormats["inet"] = BinaryFormatCode
DefaultTypeFormats["int2"] = BinaryFormatCode
DefaultTypeFormats["int4"] = BinaryFormatCode
DefaultTypeFormats["int8"] = BinaryFormatCode
@ -1112,41 +1114,32 @@ func decodeInet(vr *ValueReader) net.IPNet {
return zero
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return zero
}
pgType := vr.Type()
if vr.Len() != 8 && vr.Len() != 20 {
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
return zero
}
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name)))
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
return zero
}
s := vr.ReadString(vr.Len())
hasNetmask := strings.ContainsRune(s, '/')
if !hasNetmask {
isIpv6 := strings.ContainsRune(s, ':')
if isIpv6 {
s += "/128"
} else {
s += "/32"
}
}
vr.ReadByte() // ignore family
bits := vr.ReadByte()
vr.ReadByte() // ignore is_cidr
addressLength := vr.ReadByte()
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
vr.Fatal(err)
return zero
}
// if vr.Type().FormatCode != BinaryFormatCode {
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
// return zero
// }
// if vr.Len() != 4 {
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
// return zero
// }
return *ipnet
var ipnet net.IPNet
ipnet.IP = vr.ReadBytes(int32(addressLength))
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
return ipnet
}
func encodeInet(w *WriteBuf, value interface{}) error {
@ -1159,10 +1152,26 @@ func encodeInet(w *WriteBuf, value interface{}) error {
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
}
s := ipnet.String()
var size int32
var family byte
switch len(ipnet.IP) {
case net.IPv4len:
size = 8
family = w.conn.pgsql_af_inet
case net.IPv6len:
size = 20
family = w.conn.pgsql_af_inet6
default:
return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP))
}
w.WriteInt32(int32(len(s)))
w.WriteBytes([]byte(s))
w.WriteInt32(size)
w.WriteByte(family)
ones, _ := ipnet.Mask.Size()
w.WriteByte(byte(ones))
w.WriteByte(0) // is_cidr is ignored on server
w.WriteByte(byte(len(ipnet.IP)))
w.WriteBytes(ipnet.IP)
return nil
}

View File

@ -137,6 +137,7 @@ func TestInetCidrTranscode(t *testing.T) {
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
continue
}
if actual.String() != tt.value.String() {