diff --git a/conn.go b/conn.go index 85705d7d..96add0fe 100644 --- a/conn.go +++ b/conn.go @@ -63,8 +63,8 @@ type Conn struct { logLevel int mr msgReader fp *fastpath - pgsql_af_inet byte - pgsql_af_inet6 byte + pgsql_af_inet *byte + pgsql_af_inet6 *byte busy bool poolResetCount int preallocatedRows []Rows @@ -137,9 +137,16 @@ func (e ProtocolError) Error() string { // config.Host must be specified. config.User will default to the OS user name. // Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { + return connect(config, nil, nil, nil) +} + +func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) { c = new(Conn) c.config = config + c.PgTypes = pgTypes + c.pgsql_af_inet = pgsql_af_inet + c.pgsql_af_inet6 = pgsql_af_inet6 if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel @@ -283,14 +290,18 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.log(LogLevelInfo, "Connection established") } - err = c.loadPgTypes() - if err != nil { - return err + if c.PgTypes == nil { + err = c.loadPgTypes() + if err != nil { + return err + } } - err = c.loadInetConstants() - if err != nil { - return err + if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil { + err = c.loadInetConstants() + if err != nil { + return err + } } return nil @@ -336,8 +347,8 @@ func (c *Conn) loadInetConstants() error { return err } - c.pgsql_af_inet = ipv4[0] - c.pgsql_af_inet6 = ipv6[0] + c.pgsql_af_inet = &ipv4[0] + c.pgsql_af_inet6 = &ipv6[0] return nil } diff --git a/conn_pool.go b/conn_pool.go index 680be677..6695e0e3 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -26,6 +26,9 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration + pgTypes map[Oid]PgType + pgsql_af_inet *byte + pgsql_af_inet6 *byte } type ConnPoolStat struct { @@ -244,11 +247,15 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := Connect(p.config) + c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6) if err != nil { return nil, err } + p.pgTypes = c.PgTypes + p.pgsql_af_inet = c.pgsql_af_inet + p.pgsql_af_inet6 = c.pgsql_af_inet6 + if p.afterConnect != nil { err = p.afterConnect(c) if err != nil { diff --git a/values.go b/values.go index f46ce1df..3894bace 100644 --- a/values.go +++ b/values.go @@ -1576,10 +1576,10 @@ func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { switch len(value.IP) { case net.IPv4len: size = 8 - family = w.conn.pgsql_af_inet + family = *w.conn.pgsql_af_inet case net.IPv6len: size = 20 - family = w.conn.pgsql_af_inet6 + family = *w.conn.pgsql_af_inet6 default: return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) }