mirror of https://github.com/jackc/pgx.git
parent
9af068add0
commit
fd39261551
38
conn.go
38
conn.go
|
@ -57,6 +57,8 @@ type Conn struct {
|
||||||
logger Logger
|
logger Logger
|
||||||
mr msgReader
|
mr msgReader
|
||||||
fp *fastpath
|
fp *fastpath
|
||||||
|
pgsql_af_inet byte
|
||||||
|
pgsql_af_inet6 byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStatement struct {
|
type PreparedStatement struct {
|
||||||
|
@ -222,6 +224,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = c.loadInetConstants()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
|
@ -254,6 +261,23 @@ func (c *Conn) loadPgTypes() error {
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Family is needed for binary encoding of inet/cidr. The constant is based on
|
||||||
|
// the server's definition of AF_INET. In theory, this could differ between
|
||||||
|
// platforms, so request an IPv4 and an IPv6 inet and get the family from that.
|
||||||
|
func (c *Conn) loadInetConstants() error {
|
||||||
|
var ipv4, ipv6 []byte
|
||||||
|
|
||||||
|
err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.pgsql_af_inet = ipv4[0]
|
||||||
|
c.pgsql_af_inet6 = ipv6[0]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes a connection. It is safe to call Close on a already closed
|
// Close closes a connection. It is safe to call Close on a already closed
|
||||||
// connection.
|
// connection.
|
||||||
func (c *Conn) Close() (err error) {
|
func (c *Conn) Close() (err error) {
|
||||||
|
@ -261,7 +285,7 @@ func (c *Conn) Close() (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'X')
|
wbuf := newWriteBuf(c, 'X')
|
||||||
wbuf.closeMsg()
|
wbuf.closeMsg()
|
||||||
|
|
||||||
_, err = c.conn.Write(wbuf.buf)
|
_, err = c.conn.Write(wbuf.buf)
|
||||||
|
@ -442,7 +466,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// parse
|
// parse
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'P')
|
wbuf := newWriteBuf(c, 'P')
|
||||||
wbuf.WriteCString(name)
|
wbuf.WriteCString(name)
|
||||||
wbuf.WriteCString(sql)
|
wbuf.WriteCString(sql)
|
||||||
wbuf.WriteInt16(0)
|
wbuf.WriteInt16(0)
|
||||||
|
@ -509,7 +533,7 @@ func (c *Conn) Deallocate(name string) (err error) {
|
||||||
delete(c.preparedStatements, name)
|
delete(c.preparedStatements, name)
|
||||||
|
|
||||||
// close
|
// close
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'C')
|
wbuf := newWriteBuf(c, 'C')
|
||||||
wbuf.WriteByte('S')
|
wbuf.WriteByte('S')
|
||||||
wbuf.WriteCString(name)
|
wbuf.WriteCString(name)
|
||||||
|
|
||||||
|
@ -667,7 +691,7 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
||||||
|
|
||||||
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'Q')
|
wbuf := newWriteBuf(c, 'Q')
|
||||||
wbuf.WriteCString(sql)
|
wbuf.WriteCString(sql)
|
||||||
wbuf.closeMsg()
|
wbuf.closeMsg()
|
||||||
|
|
||||||
|
@ -694,7 +718,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// bind
|
// bind
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'B')
|
wbuf := newWriteBuf(c, 'B')
|
||||||
wbuf.WriteByte(0)
|
wbuf.WriteByte(0)
|
||||||
wbuf.WriteCString(ps.Name)
|
wbuf.WriteCString(ps.Name)
|
||||||
|
|
||||||
|
@ -707,7 +731,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
default:
|
default:
|
||||||
switch oid {
|
switch oid {
|
||||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
|
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid:
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
wbuf.WriteInt16(BinaryFormatCode)
|
||||||
default:
|
default:
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
wbuf.WriteInt16(TextFormatCode)
|
||||||
|
@ -1050,7 +1074,7 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) txPasswordMessage(password string) (err error) {
|
func (c *Conn) txPasswordMessage(password string) (err error) {
|
||||||
wbuf := newWriteBuf(c.wbuf[0:0], 'p')
|
wbuf := newWriteBuf(c, 'p')
|
||||||
wbuf.WriteCString(password)
|
wbuf.WriteCString(password)
|
||||||
wbuf.closeMsg()
|
wbuf.closeMsg()
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ func fpInt64Arg(n int64) fpArg {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
|
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
|
||||||
wbuf := newWriteBuf(f.cn.wbuf[:0], 'F') // function call
|
wbuf := newWriteBuf(f.cn, 'F') // function call
|
||||||
wbuf.WriteInt32(int32(oid)) // function object id
|
wbuf.WriteInt32(int32(oid)) // function object id
|
||||||
wbuf.WriteInt16(1) // # of argument format codes
|
wbuf.WriteInt16(1) // # of argument format codes
|
||||||
wbuf.WriteInt16(1) // format code: binary
|
wbuf.WriteInt16(1) // format code: binary
|
||||||
|
|
|
@ -89,9 +89,9 @@ func (self PgError) Error() string {
|
||||||
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
|
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWriteBuf(buf []byte, t byte) *WriteBuf {
|
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||||
buf = append(buf, t, 0, 0, 0, 0)
|
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
|
||||||
return &WriteBuf{buf: buf, sizeIdx: 1}
|
return &WriteBuf{buf: buf, sizeIdx: 1, conn: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
|
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
|
||||||
|
@ -99,6 +99,7 @@ func newWriteBuf(buf []byte, t byte) *WriteBuf {
|
||||||
type WriteBuf struct {
|
type WriteBuf struct {
|
||||||
buf []byte
|
buf []byte
|
||||||
sizeIdx int
|
sizeIdx int
|
||||||
|
conn *Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wb *WriteBuf) startMsg(t byte) {
|
func (wb *WriteBuf) startMsg(t byte) {
|
||||||
|
|
|
@ -422,8 +422,8 @@ func TestQueryRowUnknownType(t *testing.T) {
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
sql := "select $1::inet"
|
sql := "select $1::point"
|
||||||
expected := "127.0.0.1"
|
expected := "(1,0)"
|
||||||
var actual string
|
var actual string
|
||||||
|
|
||||||
err := conn.QueryRow(sql, expected).Scan(&actual)
|
err := conn.QueryRow(sql, expected).Scan(&actual)
|
||||||
|
|
75
values.go
75
values.go
|
@ -54,21 +54,23 @@ var DefaultTypeFormats map[string]int16
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultTypeFormats = make(map[string]int16)
|
DefaultTypeFormats = make(map[string]int16)
|
||||||
|
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_float4"] = BinaryFormatCode
|
DefaultTypeFormats["_float4"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_float8"] = BinaryFormatCode
|
DefaultTypeFormats["_float8"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
|
||||||
DefaultTypeFormats["_int2"] = BinaryFormatCode
|
DefaultTypeFormats["_int2"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_int4"] = BinaryFormatCode
|
DefaultTypeFormats["_int4"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_int8"] = BinaryFormatCode
|
DefaultTypeFormats["_int8"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_text"] = BinaryFormatCode
|
DefaultTypeFormats["_text"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
|
||||||
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
|
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
|
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
|
||||||
|
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["bool"] = BinaryFormatCode
|
DefaultTypeFormats["bool"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["bytea"] = BinaryFormatCode
|
DefaultTypeFormats["bytea"] = BinaryFormatCode
|
||||||
|
DefaultTypeFormats["cidr"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["date"] = BinaryFormatCode
|
DefaultTypeFormats["date"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["float4"] = BinaryFormatCode
|
DefaultTypeFormats["float4"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["float8"] = BinaryFormatCode
|
DefaultTypeFormats["float8"] = BinaryFormatCode
|
||||||
|
DefaultTypeFormats["inet"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["int2"] = BinaryFormatCode
|
DefaultTypeFormats["int2"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["int4"] = BinaryFormatCode
|
DefaultTypeFormats["int4"] = BinaryFormatCode
|
||||||
DefaultTypeFormats["int8"] = BinaryFormatCode
|
DefaultTypeFormats["int8"] = BinaryFormatCode
|
||||||
|
@ -1112,41 +1114,32 @@ func decodeInet(vr *ValueReader) net.IPNet {
|
||||||
return zero
|
return zero
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
pgType := vr.Type()
|
pgType := vr.Type()
|
||||||
|
if vr.Len() != 8 && vr.Len() != 20 {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
|
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
|
||||||
return zero
|
return zero
|
||||||
}
|
}
|
||||||
|
|
||||||
s := vr.ReadString(vr.Len())
|
vr.ReadByte() // ignore family
|
||||||
hasNetmask := strings.ContainsRune(s, '/')
|
bits := vr.ReadByte()
|
||||||
if !hasNetmask {
|
vr.ReadByte() // ignore is_cidr
|
||||||
isIpv6 := strings.ContainsRune(s, ':')
|
addressLength := vr.ReadByte()
|
||||||
if isIpv6 {
|
|
||||||
s += "/128"
|
|
||||||
} else {
|
|
||||||
s += "/32"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ipnet, err := net.ParseCIDR(s)
|
var ipnet net.IPNet
|
||||||
if err != nil {
|
ipnet.IP = vr.ReadBytes(int32(addressLength))
|
||||||
vr.Fatal(err)
|
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
||||||
return zero
|
|
||||||
}
|
|
||||||
|
|
||||||
// if vr.Type().FormatCode != BinaryFormatCode {
|
|
||||||
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
|
||||||
// return zero
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if vr.Len() != 4 {
|
|
||||||
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
|
|
||||||
// return zero
|
|
||||||
// }
|
|
||||||
|
|
||||||
return *ipnet
|
|
||||||
|
|
||||||
|
return ipnet
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeInet(w *WriteBuf, value interface{}) error {
|
func encodeInet(w *WriteBuf, value interface{}) error {
|
||||||
|
@ -1159,10 +1152,26 @@ func encodeInet(w *WriteBuf, value interface{}) error {
|
||||||
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
|
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := ipnet.String()
|
var size int32
|
||||||
|
var family byte
|
||||||
|
switch len(ipnet.IP) {
|
||||||
|
case net.IPv4len:
|
||||||
|
size = 8
|
||||||
|
family = w.conn.pgsql_af_inet
|
||||||
|
case net.IPv6len:
|
||||||
|
size = 20
|
||||||
|
family = w.conn.pgsql_af_inet6
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP))
|
||||||
|
}
|
||||||
|
|
||||||
w.WriteInt32(int32(len(s)))
|
w.WriteInt32(size)
|
||||||
w.WriteBytes([]byte(s))
|
w.WriteByte(family)
|
||||||
|
ones, _ := ipnet.Mask.Size()
|
||||||
|
w.WriteByte(byte(ones))
|
||||||
|
w.WriteByte(0) // is_cidr is ignored on server
|
||||||
|
w.WriteByte(byte(len(ipnet.IP)))
|
||||||
|
w.WriteBytes(ipnet.IP)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,6 +137,7 @@ func TestInetCidrTranscode(t *testing.T) {
|
||||||
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if actual.String() != tt.value.String() {
|
if actual.String() != tt.value.String() {
|
||||||
|
|
Loading…
Reference in New Issue