mirror of https://github.com/jackc/pgx.git
parent
9af068add0
commit
fd39261551
38
conn.go
38
conn.go
|
@ -57,6 +57,8 @@ type Conn struct {
|
|||
logger Logger
|
||||
mr msgReader
|
||||
fp *fastpath
|
||||
pgsql_af_inet byte
|
||||
pgsql_af_inet6 byte
|
||||
}
|
||||
|
||||
type PreparedStatement struct {
|
||||
|
@ -222,6 +224,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
return err
|
||||
}
|
||||
|
||||
err = c.loadInetConstants()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
|
@ -254,6 +261,23 @@ func (c *Conn) loadPgTypes() error {
|
|||
return rows.Err()
|
||||
}
|
||||
|
||||
// Family is needed for binary encoding of inet/cidr. The constant is based on
|
||||
// the server's definition of AF_INET. In theory, this could differ between
|
||||
// platforms, so request an IPv4 and an IPv6 inet and get the family from that.
|
||||
func (c *Conn) loadInetConstants() error {
|
||||
var ipv4, ipv6 []byte
|
||||
|
||||
err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.pgsql_af_inet = ipv4[0]
|
||||
c.pgsql_af_inet6 = ipv6[0]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes a connection. It is safe to call Close on a already closed
|
||||
// connection.
|
||||
func (c *Conn) Close() (err error) {
|
||||
|
@ -261,7 +285,7 @@ func (c *Conn) Close() (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
wbuf := newWriteBuf(c.wbuf[0:0], 'X')
|
||||
wbuf := newWriteBuf(c, 'X')
|
||||
wbuf.closeMsg()
|
||||
|
||||
_, err = c.conn.Write(wbuf.buf)
|
||||
|
@ -442,7 +466,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
}()
|
||||
|
||||
// parse
|
||||
wbuf := newWriteBuf(c.wbuf[0:0], 'P')
|
||||
wbuf := newWriteBuf(c, 'P')
|
||||
wbuf.WriteCString(name)
|
||||
wbuf.WriteCString(sql)
|
||||
wbuf.WriteInt16(0)
|
||||
|
@ -509,7 +533,7 @@ func (c *Conn) Deallocate(name string) (err error) {
|
|||
delete(c.preparedStatements, name)
|
||||
|
||||
// close
|
||||
wbuf := newWriteBuf(c.wbuf[0:0], 'C')
|
||||
wbuf := newWriteBuf(c, 'C')
|
||||
wbuf.WriteByte('S')
|
||||
wbuf.WriteCString(name)
|
||||
|
||||
|
@ -667,7 +691,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
|||
|
||||
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
||||
if len(args) == 0 {
|
||||
wbuf := newWriteBuf(c.wbuf[0:0], 'Q')
|
||||
wbuf := newWriteBuf(c, 'Q')
|
||||
wbuf.WriteCString(sql)
|
||||
wbuf.closeMsg()
|
||||
|
||||
|
@ -694,7 +718,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
}
|
||||
|
||||
// bind
|
||||
wbuf := newWriteBuf(c.wbuf[0:0], 'B')
|
||||
wbuf := newWriteBuf(c, 'B')
|
||||
wbuf.WriteByte(0)
|
||||
wbuf.WriteCString(ps.Name)
|
||||
|
||||
|
@ -707,7 +731,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
|
@ -1050,7 +1074,7 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error {
|
|||
}
|
||||
|
||||
func (c *Conn) txPasswordMessage(password string) (err error) {
|
||||
wbuf := newWriteBuf(c.wbuf[0:0], 'p')
|
||||
wbuf := newWriteBuf(c, 'p')
|
||||
wbuf.WriteCString(password)
|
||||
wbuf.closeMsg()
|
||||
|
||||
|
|
10
fastpath.go
10
fastpath.go
|
@ -50,11 +50,11 @@ func fpInt64Arg(n int64) fpArg {
|
|||
}
|
||||
|
||||
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
|
||||
wbuf := newWriteBuf(f.cn.wbuf[:0], 'F') // function call
|
||||
wbuf.WriteInt32(int32(oid)) // function object id
|
||||
wbuf.WriteInt16(1) // # of argument format codes
|
||||
wbuf.WriteInt16(1) // format code: binary
|
||||
wbuf.WriteInt16(int16(len(args))) // # of arguments
|
||||
wbuf := newWriteBuf(f.cn, 'F') // function call
|
||||
wbuf.WriteInt32(int32(oid)) // function object id
|
||||
wbuf.WriteInt16(1) // # of argument format codes
|
||||
wbuf.WriteInt16(1) // format code: binary
|
||||
wbuf.WriteInt16(int16(len(args))) // # of arguments
|
||||
for _, arg := range args {
|
||||
wbuf.WriteInt32(int32(len(arg))) // length of argument
|
||||
wbuf.WriteBytes(arg) // argument value
|
||||
|
|
|
@ -89,9 +89,9 @@ func (self PgError) Error() string {
|
|||
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
|
||||
}
|
||||
|
||||
func newWriteBuf(buf []byte, t byte) *WriteBuf {
|
||||
buf = append(buf, t, 0, 0, 0, 0)
|
||||
return &WriteBuf{buf: buf, sizeIdx: 1}
|
||||
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
|
||||
return &WriteBuf{buf: buf, sizeIdx: 1, conn: c}
|
||||
}
|
||||
|
||||
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
|
||||
|
@ -99,6 +99,7 @@ func newWriteBuf(buf []byte, t byte) *WriteBuf {
|
|||
type WriteBuf struct {
|
||||
buf []byte
|
||||
sizeIdx int
|
||||
conn *Conn
|
||||
}
|
||||
|
||||
func (wb *WriteBuf) startMsg(t byte) {
|
||||
|
|
|
@ -422,8 +422,8 @@ func TestQueryRowUnknownType(t *testing.T) {
|
|||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := "select $1::inet"
|
||||
expected := "127.0.0.1"
|
||||
sql := "select $1::point"
|
||||
expected := "(1,0)"
|
||||
var actual string
|
||||
|
||||
err := conn.QueryRow(sql, expected).Scan(&actual)
|
||||
|
|
75
values.go
75
values.go
|
@ -54,21 +54,23 @@ var DefaultTypeFormats map[string]int16
|
|||
|
||||
func init() {
|
||||
DefaultTypeFormats = make(map[string]int16)
|
||||
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_float4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_float8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int2"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_text"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
||||
DefaultTypeFormats["bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["bytea"] = BinaryFormatCode
|
||||
DefaultTypeFormats["cidr"] = BinaryFormatCode
|
||||
DefaultTypeFormats["date"] = BinaryFormatCode
|
||||
DefaultTypeFormats["float4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["float8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["inet"] = BinaryFormatCode
|
||||
DefaultTypeFormats["int2"] = BinaryFormatCode
|
||||
DefaultTypeFormats["int4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["int8"] = BinaryFormatCode
|
||||
|
@ -1112,41 +1114,32 @@ func decodeInet(vr *ValueReader) net.IPNet {
|
|||
return zero
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return zero
|
||||
}
|
||||
|
||||
pgType := vr.Type()
|
||||
if vr.Len() != 8 && vr.Len() != 20 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
|
||||
return zero
|
||||
}
|
||||
|
||||
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name)))
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
|
||||
return zero
|
||||
}
|
||||
|
||||
s := vr.ReadString(vr.Len())
|
||||
hasNetmask := strings.ContainsRune(s, '/')
|
||||
if !hasNetmask {
|
||||
isIpv6 := strings.ContainsRune(s, ':')
|
||||
if isIpv6 {
|
||||
s += "/128"
|
||||
} else {
|
||||
s += "/32"
|
||||
}
|
||||
}
|
||||
vr.ReadByte() // ignore family
|
||||
bits := vr.ReadByte()
|
||||
vr.ReadByte() // ignore is_cidr
|
||||
addressLength := vr.ReadByte()
|
||||
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return zero
|
||||
}
|
||||
|
||||
// if vr.Type().FormatCode != BinaryFormatCode {
|
||||
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
// return zero
|
||||
// }
|
||||
|
||||
// if vr.Len() != 4 {
|
||||
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
|
||||
// return zero
|
||||
// }
|
||||
|
||||
return *ipnet
|
||||
var ipnet net.IPNet
|
||||
ipnet.IP = vr.ReadBytes(int32(addressLength))
|
||||
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
||||
|
||||
return ipnet
|
||||
}
|
||||
|
||||
func encodeInet(w *WriteBuf, value interface{}) error {
|
||||
|
@ -1159,10 +1152,26 @@ func encodeInet(w *WriteBuf, value interface{}) error {
|
|||
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
|
||||
}
|
||||
|
||||
s := ipnet.String()
|
||||
var size int32
|
||||
var family byte
|
||||
switch len(ipnet.IP) {
|
||||
case net.IPv4len:
|
||||
size = 8
|
||||
family = w.conn.pgsql_af_inet
|
||||
case net.IPv6len:
|
||||
size = 20
|
||||
family = w.conn.pgsql_af_inet6
|
||||
default:
|
||||
return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP))
|
||||
}
|
||||
|
||||
w.WriteInt32(int32(len(s)))
|
||||
w.WriteBytes([]byte(s))
|
||||
w.WriteInt32(size)
|
||||
w.WriteByte(family)
|
||||
ones, _ := ipnet.Mask.Size()
|
||||
w.WriteByte(byte(ones))
|
||||
w.WriteByte(0) // is_cidr is ignored on server
|
||||
w.WriteByte(byte(len(ipnet.IP)))
|
||||
w.WriteBytes(ipnet.IP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -137,6 +137,7 @@ func TestInetCidrTranscode(t *testing.T) {
|
|||
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||
continue
|
||||
}
|
||||
|
||||
if actual.String() != tt.value.String() {
|
||||
|
|
Loading…
Reference in New Issue