mirror of
https://github.com/jackc/pgx.git
synced 2025-05-25 17:00:21 +00:00
Remove AF_INET fetching system
Also remove old encode/decode inet/cidr code. This removed some functionality from Rows.Values, but that entire system will soon change anyway.
This commit is contained in:
parent
005916166a
commit
b1fc8109db
39
conn.go
39
conn.go
@ -87,8 +87,6 @@ type Conn struct {
|
|||||||
logLevel int
|
logLevel int
|
||||||
mr msgReader
|
mr msgReader
|
||||||
fp *fastpath
|
fp *fastpath
|
||||||
pgsqlAfInet *byte
|
|
||||||
pgsqlAfInet6 *byte
|
|
||||||
poolResetCount int
|
poolResetCount int
|
||||||
preallocatedRows []Rows
|
preallocatedRows []Rows
|
||||||
|
|
||||||
@ -179,10 +177,10 @@ func (e ProtocolError) Error() string {
|
|||||||
// config.Host must be specified. config.User will default to the OS user name.
|
// config.Host must be specified. config.User will default to the OS user name.
|
||||||
// Other config fields are optional.
|
// Other config fields are optional.
|
||||||
func Connect(config ConnConfig) (c *Conn, err error) {
|
func Connect(config ConnConfig) (c *Conn, err error) {
|
||||||
return connect(config, nil, nil, nil)
|
return connect(config, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) {
|
func connect(config ConnConfig, pgTypes map[OID]PgType) (c *Conn, err error) {
|
||||||
c = new(Conn)
|
c = new(Conn)
|
||||||
|
|
||||||
c.config = config
|
c.config = config
|
||||||
@ -194,15 +192,6 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsql
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if pgsqlAfInet != nil {
|
|
||||||
c.pgsqlAfInet = new(byte)
|
|
||||||
*c.pgsqlAfInet = *pgsqlAfInet
|
|
||||||
}
|
|
||||||
if pgsqlAfInet6 != nil {
|
|
||||||
c.pgsqlAfInet6 = new(byte)
|
|
||||||
*c.pgsqlAfInet6 = *pgsqlAfInet6
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.LogLevel != 0 {
|
if c.config.LogLevel != 0 {
|
||||||
c.logLevel = c.config.LogLevel
|
c.logLevel = c.config.LogLevel
|
||||||
} else {
|
} else {
|
||||||
@ -372,13 +361,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil {
|
|
||||||
err = c.loadInetConstants()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
@ -418,23 +400,6 @@ where (
|
|||||||
return rows.Err()
|
return rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Family is needed for binary encoding of inet/cidr. The constant is based on
|
|
||||||
// the server's definition of AF_INET. In theory, this could differ between
|
|
||||||
// platforms, so request an IPv4 and an IPv6 inet and get the family from that.
|
|
||||||
func (c *Conn) loadInetConstants() error {
|
|
||||||
var ipv4, ipv6 []byte
|
|
||||||
|
|
||||||
err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.pgsqlAfInet = &ipv4[0]
|
|
||||||
c.pgsqlAfInet6 = &ipv6[0]
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PID returns the backend PID for this connection.
|
// PID returns the backend PID for this connection.
|
||||||
func (c *Conn) PID() int32 {
|
func (c *Conn) PID() int32 {
|
||||||
return c.pid
|
return c.pid
|
||||||
|
@ -29,8 +29,6 @@ type ConnPool struct {
|
|||||||
preparedStatements map[string]*PreparedStatement
|
preparedStatements map[string]*PreparedStatement
|
||||||
acquireTimeout time.Duration
|
acquireTimeout time.Duration
|
||||||
pgTypes map[OID]PgType
|
pgTypes map[OID]PgType
|
||||||
pgsqlAfInet *byte
|
|
||||||
pgsqlAfInet6 *byte
|
|
||||||
txAfterClose func(tx *Tx)
|
txAfterClose func(tx *Tx)
|
||||||
rowsAfterClose func(rows *Rows)
|
rowsAfterClose func(rows *Rows)
|
||||||
}
|
}
|
||||||
@ -294,7 +292,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) createConnection() (*Conn, error) {
|
func (p *ConnPool) createConnection() (*Conn, error) {
|
||||||
c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
|
c, err := connect(p.config, p.pgTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -330,8 +328,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
|||||||
// all the known statements for the new connection.
|
// all the known statements for the new connection.
|
||||||
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||||
p.pgTypes = c.PgTypes
|
p.pgTypes = c.PgTypes
|
||||||
p.pgsqlAfInet = c.pgsqlAfInet
|
|
||||||
p.pgsqlAfInet6 = c.pgsqlAfInet6
|
|
||||||
|
|
||||||
if p.afterConnect != nil {
|
if p.afterConnect != nil {
|
||||||
err := p.afterConnect(c)
|
err := p.afterConnect(c)
|
||||||
|
4
query.go
4
query.go
@ -410,8 +410,6 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||||||
values = append(values, decodeTimestampTz(vr))
|
values = append(values, decodeTimestampTz(vr))
|
||||||
case TimestampOID:
|
case TimestampOID:
|
||||||
values = append(values, decodeTimestamp(vr))
|
values = append(values, decodeTimestamp(vr))
|
||||||
case InetOID, CidrOID:
|
|
||||||
values = append(values, decodeInet(vr))
|
|
||||||
case JSONOID:
|
case JSONOID:
|
||||||
var d interface{}
|
var d interface{}
|
||||||
decodeJSON(vr, &d)
|
decodeJSON(vr, &d)
|
||||||
@ -503,8 +501,6 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) {
|
|||||||
values = append(values, decodeTimestampTz(vr))
|
values = append(values, decodeTimestampTz(vr))
|
||||||
case TimestampOID:
|
case TimestampOID:
|
||||||
values = append(values, decodeTimestamp(vr))
|
values = append(values, decodeTimestamp(vr))
|
||||||
case InetOID, CidrOID:
|
|
||||||
values = append(values, decodeInet(vr))
|
|
||||||
case JSONOID:
|
case JSONOID:
|
||||||
var d interface{}
|
var d interface{}
|
||||||
decodeJSON(vr, &d)
|
decodeJSON(vr, &d)
|
||||||
|
183
values.go
183
values.go
@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -1934,82 +1933,6 @@ func decodeTimestamp(vr *ValueReader) time.Time {
|
|||||||
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
|
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeInet(vr *ValueReader) net.IPNet {
|
|
||||||
var zero net.IPNet
|
|
||||||
|
|
||||||
if vr.Len() == -1 {
|
|
||||||
vr.Fatal(ProtocolError("Cannot decode null into net.IPNet"))
|
|
||||||
return zero
|
|
||||||
}
|
|
||||||
|
|
||||||
if vr.Type().FormatCode != BinaryFormatCode {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
|
||||||
return zero
|
|
||||||
}
|
|
||||||
|
|
||||||
pgType := vr.Type()
|
|
||||||
if pgType.DataType != InetOID && pgType.DataType != CidrOID {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
|
|
||||||
return zero
|
|
||||||
}
|
|
||||||
if vr.Len() != 8 && vr.Len() != 20 {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
|
|
||||||
return zero
|
|
||||||
}
|
|
||||||
|
|
||||||
vr.ReadByte() // ignore family
|
|
||||||
bits := vr.ReadByte()
|
|
||||||
vr.ReadByte() // ignore is_cidr
|
|
||||||
addressLength := vr.ReadByte()
|
|
||||||
|
|
||||||
var ipnet net.IPNet
|
|
||||||
ipnet.IP = vr.ReadBytes(int32(addressLength))
|
|
||||||
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
|
||||||
|
|
||||||
return ipnet
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeIPNet(w *WriteBuf, oid OID, value net.IPNet) error {
|
|
||||||
if oid != InetOID && oid != CidrOID {
|
|
||||||
return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
var size int32
|
|
||||||
var family byte
|
|
||||||
switch len(value.IP) {
|
|
||||||
case net.IPv4len:
|
|
||||||
size = 8
|
|
||||||
family = *w.conn.pgsqlAfInet
|
|
||||||
case net.IPv6len:
|
|
||||||
size = 20
|
|
||||||
family = *w.conn.pgsqlAfInet6
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("Unexpected IP length: %v", len(value.IP))
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteInt32(size)
|
|
||||||
w.WriteByte(family)
|
|
||||||
ones, _ := value.Mask.Size()
|
|
||||||
w.WriteByte(byte(ones))
|
|
||||||
w.WriteByte(0) // is_cidr is ignored on server
|
|
||||||
w.WriteByte(byte(len(value.IP)))
|
|
||||||
w.WriteBytes(value.IP)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeIP(w *WriteBuf, oid OID, value net.IP) error {
|
|
||||||
if oid != InetOID && oid != CidrOID {
|
|
||||||
return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ipnet net.IPNet
|
|
||||||
ipnet.IP = value
|
|
||||||
bitCount := len(value) * 8
|
|
||||||
ipnet.Mask = net.CIDRMask(bitCount, bitCount)
|
|
||||||
return encodeIPNet(w, oid, ipnet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeRecord(vr *ValueReader) []interface{} {
|
func decodeRecord(vr *ValueReader) []interface{} {
|
||||||
if vr.Len() == -1 {
|
if vr.Len() == -1 {
|
||||||
return nil
|
return nil
|
||||||
@ -2058,8 +1981,6 @@ func decodeRecord(vr *ValueReader) []interface{} {
|
|||||||
record = append(record, decodeTimestampTz(&fieldVR))
|
record = append(record, decodeTimestampTz(&fieldVR))
|
||||||
case TimestampOID:
|
case TimestampOID:
|
||||||
record = append(record, decodeTimestamp(&fieldVR))
|
record = append(record, decodeTimestamp(&fieldVR))
|
||||||
case InetOID, CidrOID:
|
|
||||||
record = append(record, decodeInet(&fieldVR))
|
|
||||||
case TextOID, VarcharOID, UnknownOID:
|
case TextOID, VarcharOID, UnknownOID:
|
||||||
record = append(record, decodeTextAllowBinary(&fieldVR))
|
record = append(record, decodeTextAllowBinary(&fieldVR))
|
||||||
default:
|
default:
|
||||||
@ -2983,110 +2904,6 @@ func encodeTimeSlice(w *WriteBuf, oid OID, slice []time.Time) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeInetArray(vr *ValueReader) []net.IPNet {
|
|
||||||
if vr.Len() == -1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if vr.Type().DataType != InetArrayOID && vr.Type().DataType != CidrArrayOID {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if vr.Type().FormatCode != BinaryFormatCode {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
numElems, err := decode1dArrayHeader(vr)
|
|
||||||
if err != nil {
|
|
||||||
vr.Fatal(err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
a := make([]net.IPNet, int(numElems))
|
|
||||||
for i := 0; i < len(a); i++ {
|
|
||||||
elSize := vr.ReadInt32()
|
|
||||||
if elSize == -1 {
|
|
||||||
vr.Fatal(ProtocolError("Cannot decode null element"))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
vr.ReadByte() // ignore family
|
|
||||||
bits := vr.ReadByte()
|
|
||||||
vr.ReadByte() // ignore is_cidr
|
|
||||||
addressLength := vr.ReadByte()
|
|
||||||
|
|
||||||
var ipnet net.IPNet
|
|
||||||
ipnet.IP = vr.ReadBytes(int32(addressLength))
|
|
||||||
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
|
||||||
|
|
||||||
a[i] = ipnet
|
|
||||||
}
|
|
||||||
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeIPNetSlice(w *WriteBuf, oid OID, slice []net.IPNet) error {
|
|
||||||
var elOID OID
|
|
||||||
switch oid {
|
|
||||||
case InetArrayOID:
|
|
||||||
elOID = InetOID
|
|
||||||
case CidrArrayOID:
|
|
||||||
elOID = CidrOID
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
size := int32(20) // array header size
|
|
||||||
for _, ipnet := range slice {
|
|
||||||
size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes
|
|
||||||
}
|
|
||||||
w.WriteInt32(int32(size))
|
|
||||||
|
|
||||||
w.WriteInt32(1) // number of dimensions
|
|
||||||
w.WriteInt32(0) // no nulls
|
|
||||||
w.WriteInt32(int32(elOID)) // type of elements
|
|
||||||
w.WriteInt32(int32(len(slice))) // number of elements
|
|
||||||
w.WriteInt32(1) // index of first element
|
|
||||||
|
|
||||||
for _, ipnet := range slice {
|
|
||||||
encodeIPNet(w, elOID, ipnet)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeIPSlice(w *WriteBuf, oid OID, slice []net.IP) error {
|
|
||||||
var elOID OID
|
|
||||||
switch oid {
|
|
||||||
case InetArrayOID:
|
|
||||||
elOID = InetOID
|
|
||||||
case CidrArrayOID:
|
|
||||||
elOID = CidrOID
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid)
|
|
||||||
}
|
|
||||||
|
|
||||||
size := int32(20) // array header size
|
|
||||||
for _, ip := range slice {
|
|
||||||
size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes
|
|
||||||
}
|
|
||||||
w.WriteInt32(int32(size))
|
|
||||||
|
|
||||||
w.WriteInt32(1) // number of dimensions
|
|
||||||
w.WriteInt32(0) // no nulls
|
|
||||||
w.WriteInt32(int32(elOID)) // type of elements
|
|
||||||
w.WriteInt32(int32(len(slice))) // number of elements
|
|
||||||
w.WriteInt32(1) // index of first element
|
|
||||||
|
|
||||||
for _, ip := range slice {
|
|
||||||
encodeIP(w, elOID, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
|
func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) {
|
||||||
w.WriteInt32(int32(20 + length*sizePerItem))
|
w.WriteInt32(int32(20 + length*sizePerItem))
|
||||||
w.WriteInt32(1) // number of dimensions
|
w.WriteInt32(1) // number of dimensions
|
||||||
|
Loading…
x
Reference in New Issue
Block a user