mirror of https://github.com/jackc/pgx.git
Restructure connect process
- Moved lots of connection logic to pgconn from pgx - Extracted pgpassfile packagepull/483/head
parent
9990e4894d
commit
b3c8a73dc7
522
conn.go
522
conn.go
|
@ -4,17 +4,11 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -57,26 +51,14 @@ func init() {
|
||||||
// aware that this is distinct from LISTEN/NOTIFY notification.
|
// aware that this is distinct from LISTEN/NOTIFY notification.
|
||||||
type NoticeHandler func(*Conn, *Notice)
|
type NoticeHandler func(*Conn, *Notice)
|
||||||
|
|
||||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
|
||||||
type DialFunc func(network, addr string) (net.Conn, error)
|
|
||||||
|
|
||||||
// ConnConfig contains all the options used to establish a connection.
|
// ConnConfig contains all the options used to establish a connection.
|
||||||
type ConnConfig struct {
|
type ConnConfig struct {
|
||||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
pgconn.Config
|
||||||
Port uint16 // default: 5432
|
Logger Logger
|
||||||
Database string
|
LogLevel int
|
||||||
User string // default: OS user name
|
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
||||||
Password string
|
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
||||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
CustomCancel func(*Conn) error // Callback function used to override cancellation behavior
|
||||||
UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
|
|
||||||
FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
|
|
||||||
Logger Logger
|
|
||||||
LogLevel int
|
|
||||||
Dial DialFunc
|
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
|
||||||
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
|
||||||
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
|
||||||
CustomCancel func(*Conn) error // Callback function used to override cancellation behavior
|
|
||||||
|
|
||||||
// PreferSimpleProtocol disables implicit prepared statement usage. By default
|
// PreferSimpleProtocol disables implicit prepared statement usage. By default
|
||||||
// pgx automatically uses the unnamed prepared statement for Query and
|
// pgx automatically uses the unnamed prepared statement for Query and
|
||||||
|
@ -96,7 +78,7 @@ type ConnConfig struct {
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
pgConn *pgconn.PgConn
|
pgConn *pgconn.PgConn
|
||||||
wbuf []byte
|
wbuf []byte
|
||||||
config ConnConfig // config used when establishing this connection
|
config *ConnConfig // config used when establishing this connection
|
||||||
preparedStatements map[string]*PreparedStatement
|
preparedStatements map[string]*PreparedStatement
|
||||||
channels map[string]struct{}
|
channels map[string]struct{}
|
||||||
notifications []*Notification
|
notifications []*Notification
|
||||||
|
@ -195,18 +177,30 @@ func (e ProtocolError) Error() string {
|
||||||
return string(e)
|
return string(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes a connection with a PostgreSQL server using config.
|
// Connect establishes a connection with a PostgreSQL server with a connection string. See
|
||||||
// config.Host must be specified. config.User will default to the OS user name.
|
// pgconn.Connect for details.
|
||||||
// Other config fields are optional.
|
func Connect(ctx context.Context, connString string) (*Conn, error) {
|
||||||
func Connect(config ConnConfig) (c *Conn, err error) {
|
config, err := pgconn.ParseConfig(connString)
|
||||||
return connect(config, minimalConnInfo)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connConfig := &ConnConfig{
|
||||||
|
Config: *config,
|
||||||
|
}
|
||||||
|
|
||||||
|
return connect(ctx, connConfig, minimalConnInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a connection with a PostgreSQL server with a configuration struct.
|
||||||
|
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
|
||||||
|
return connect(ctx, connConfig, minimalConnInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultDialer() *net.Dialer {
|
func defaultDialer() *net.Dialer {
|
||||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
|
func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
|
||||||
c = new(Conn)
|
c = new(Conn)
|
||||||
|
|
||||||
c.config = config
|
c.config = config
|
||||||
|
@ -223,14 +217,11 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||||
c.onNotice = config.OnNotice
|
c.onNotice = config.OnNotice
|
||||||
|
|
||||||
if c.shouldLog(LogLevelInfo) {
|
if c.shouldLog(LogLevelInfo) {
|
||||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Host})
|
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
|
||||||
}
|
}
|
||||||
err = c.connect(config, config.TLSConfig)
|
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
|
||||||
if err != nil && config.UseFallbackTLS {
|
if err != nil {
|
||||||
if c.shouldLog(LogLevelInfo) {
|
return nil, err
|
||||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
|
|
||||||
}
|
|
||||||
err = c.connect(config, config.FallbackTLSConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -240,34 +231,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) {
|
|
||||||
cc := pgconn.ConnConfig{
|
|
||||||
Host: config.Host,
|
|
||||||
Port: config.Port,
|
|
||||||
Database: config.Database,
|
|
||||||
User: config.User,
|
|
||||||
Password: config.Password,
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
Dial: pgconn.DialFunc(config.Dial),
|
|
||||||
RuntimeParams: config.RuntimeParams,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.pgConn, err = pgconn.Connect(cc)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if c != nil && err != nil {
|
|
||||||
c.pgConn.NetConn.Close()
|
|
||||||
c.mux.Lock()
|
|
||||||
c.status = connStatusClosed
|
|
||||||
c.mux.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||||
c.channels = make(map[string]struct{})
|
c.channels = make(map[string]struct{})
|
||||||
c.cancelQueryCompleted = make(chan struct{})
|
c.cancelQueryCompleted = make(chan struct{})
|
||||||
|
@ -275,25 +238,23 @@ func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) {
|
||||||
c.doneChan = make(chan struct{})
|
c.doneChan = make(chan struct{})
|
||||||
c.closedChan = make(chan error)
|
c.closedChan = make(chan error)
|
||||||
c.wbuf = make([]byte, 0, 1024)
|
c.wbuf = make([]byte, 0, 1024)
|
||||||
|
|
||||||
c.mux.Lock()
|
|
||||||
c.status = connStatusIdle
|
c.status = connStatusIdle
|
||||||
c.mux.Unlock()
|
|
||||||
|
|
||||||
// Replication connections can't execute the queries to
|
// Replication connections can't execute the queries to
|
||||||
// populate the c.PgTypes and c.pgsqlAfInet
|
// populate the c.PgTypes and c.pgsqlAfInet
|
||||||
if _, ok := config.RuntimeParams["replication"]; ok {
|
if _, ok := c.pgConn.Config.RuntimeParams["replication"]; ok {
|
||||||
return nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.ConnInfo == minimalConnInfo {
|
if c.ConnInfo == minimalConnInfo {
|
||||||
err = c.initConnInfo()
|
err = c.initConnInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
c.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
|
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
|
||||||
|
@ -524,7 +485,7 @@ func (c *Conn) LocalAddr() (net.Addr, error) {
|
||||||
|
|
||||||
// Close closes a connection. It is safe to call Close on a already closed
|
// Close closes a connection. It is safe to call Close on a already closed
|
||||||
// connection.
|
// connection.
|
||||||
func (c *Conn) Close() (err error) {
|
func (c *Conn) Close() error {
|
||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
defer c.mux.Unlock()
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
@ -533,404 +494,12 @@ func (c *Conn) Close() (err error) {
|
||||||
}
|
}
|
||||||
c.status = connStatusClosed
|
c.status = connStatusClosed
|
||||||
|
|
||||||
defer func() {
|
err := c.pgConn.Close()
|
||||||
c.pgConn.NetConn.Close()
|
c.causeOfDeath = errors.New("Closed")
|
||||||
c.causeOfDeath = errors.New("Closed")
|
if c.shouldLog(LogLevelInfo) {
|
||||||
if c.shouldLog(LogLevelInfo) {
|
c.log(LogLevelInfo, "closed connection", nil)
|
||||||
c.log(LogLevelInfo, "closed connection", nil)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = c.pgConn.NetConn.SetDeadline(time.Time{})
|
|
||||||
if err != nil && c.shouldLog(LogLevelWarn) {
|
|
||||||
c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err})
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
_, err = c.pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4})
|
|
||||||
if err != nil && c.shouldLog(LogLevelWarn) {
|
|
||||||
c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.pgConn.NetConn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
||||||
if err != nil && c.shouldLog(LogLevelWarn) {
|
|
||||||
c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.pgConn.NetConn.Read(make([]byte, 1))
|
|
||||||
if err != io.EOF {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Merge returns a new ConnConfig with the attributes of old and other
|
|
||||||
// combined. When an attribute is set on both, other takes precedence.
|
|
||||||
//
|
|
||||||
// As a security precaution, if the other TLSConfig is nil, all old TLS
|
|
||||||
// attributes will be preserved.
|
|
||||||
func (old ConnConfig) Merge(other ConnConfig) ConnConfig {
|
|
||||||
cc := old
|
|
||||||
|
|
||||||
if other.Host != "" {
|
|
||||||
cc.Host = other.Host
|
|
||||||
}
|
|
||||||
if other.Port != 0 {
|
|
||||||
cc.Port = other.Port
|
|
||||||
}
|
|
||||||
if other.Database != "" {
|
|
||||||
cc.Database = other.Database
|
|
||||||
}
|
|
||||||
if other.User != "" {
|
|
||||||
cc.User = other.User
|
|
||||||
}
|
|
||||||
if other.Password != "" {
|
|
||||||
cc.Password = other.Password
|
|
||||||
}
|
|
||||||
|
|
||||||
if other.TLSConfig != nil {
|
|
||||||
cc.TLSConfig = other.TLSConfig
|
|
||||||
cc.UseFallbackTLS = other.UseFallbackTLS
|
|
||||||
cc.FallbackTLSConfig = other.FallbackTLSConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
if other.Logger != nil {
|
|
||||||
cc.Logger = other.Logger
|
|
||||||
}
|
|
||||||
if other.LogLevel != 0 {
|
|
||||||
cc.LogLevel = other.LogLevel
|
|
||||||
}
|
|
||||||
|
|
||||||
if other.Dial != nil {
|
|
||||||
cc.Dial = other.Dial
|
|
||||||
}
|
|
||||||
|
|
||||||
cc.PreferSimpleProtocol = other.PreferSimpleProtocol
|
|
||||||
|
|
||||||
cc.RuntimeParams = make(map[string]string)
|
|
||||||
for k, v := range old.RuntimeParams {
|
|
||||||
cc.RuntimeParams[k] = v
|
|
||||||
}
|
|
||||||
for k, v := range other.RuntimeParams {
|
|
||||||
cc.RuntimeParams[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return cc
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseURI parses a database URI into ConnConfig
|
|
||||||
//
|
|
||||||
// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams.
|
|
||||||
func ParseURI(uri string) (ConnConfig, error) {
|
|
||||||
var cp ConnConfig
|
|
||||||
|
|
||||||
url, err := url.Parse(uri)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if url.User != nil {
|
|
||||||
cp.User = url.User.Username()
|
|
||||||
cp.Password, _ = url.User.Password()
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.SplitN(url.Host, ":", 2)
|
|
||||||
cp.Host = parts[0]
|
|
||||||
if len(parts) == 2 {
|
|
||||||
p, err := strconv.ParseUint(parts[1], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
cp.Port = uint16(p)
|
|
||||||
}
|
|
||||||
cp.Database = strings.TrimLeft(url.Path, "/")
|
|
||||||
|
|
||||||
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
|
|
||||||
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
d := defaultDialer()
|
|
||||||
d.Timeout = time.Duration(timeout) * time.Second
|
|
||||||
cp.Dial = d.Dial
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsArgs := configTLSArgs{
|
|
||||||
sslCert: url.Query().Get("sslcert"),
|
|
||||||
sslKey: url.Query().Get("sslkey"),
|
|
||||||
sslMode: url.Query().Get("sslmode"),
|
|
||||||
sslRootCert: url.Query().Get("sslrootcert"),
|
|
||||||
}
|
|
||||||
err = configTLS(tlsArgs, &cp)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ignoreKeys := map[string]struct{}{
|
|
||||||
"connect_timeout": {},
|
|
||||||
"sslcert": {},
|
|
||||||
"sslkey": {},
|
|
||||||
"sslmode": {},
|
|
||||||
"sslrootcert": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
cp.RuntimeParams = make(map[string]string)
|
|
||||||
|
|
||||||
for k, v := range url.Query() {
|
|
||||||
if _, ok := ignoreKeys[k]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if k == "host" {
|
|
||||||
cp.Host = v[0]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
cp.RuntimeParams[k] = v[0]
|
|
||||||
}
|
|
||||||
if cp.Password == "" {
|
|
||||||
pgpass(&cp)
|
|
||||||
}
|
|
||||||
return cp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
|
||||||
|
|
||||||
// ParseDSN parses a database DSN (data source name) into a ConnConfig
|
|
||||||
//
|
|
||||||
// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable")
|
|
||||||
//
|
|
||||||
// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams.
|
|
||||||
//
|
|
||||||
// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb")
|
|
||||||
//
|
|
||||||
// ParseDSN tries to match libpq behavior with regard to sslmode. See comments
|
|
||||||
// for ParseEnvLibpq for more information on the security implications of
|
|
||||||
// sslmode options.
|
|
||||||
func ParseDSN(s string) (ConnConfig, error) {
|
|
||||||
var cp ConnConfig
|
|
||||||
|
|
||||||
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
|
||||||
|
|
||||||
tlsArgs := configTLSArgs{}
|
|
||||||
|
|
||||||
cp.RuntimeParams = make(map[string]string)
|
|
||||||
|
|
||||||
for _, b := range m {
|
|
||||||
switch b[1] {
|
|
||||||
case "user":
|
|
||||||
cp.User = b[2]
|
|
||||||
case "password":
|
|
||||||
cp.Password = b[2]
|
|
||||||
case "host":
|
|
||||||
cp.Host = b[2]
|
|
||||||
case "port":
|
|
||||||
p, err := strconv.ParseUint(b[2], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
cp.Port = uint16(p)
|
|
||||||
case "dbname":
|
|
||||||
cp.Database = b[2]
|
|
||||||
case "sslmode":
|
|
||||||
tlsArgs.sslMode = b[2]
|
|
||||||
case "sslrootcert":
|
|
||||||
tlsArgs.sslRootCert = b[2]
|
|
||||||
case "sslcert":
|
|
||||||
tlsArgs.sslCert = b[2]
|
|
||||||
case "sslkey":
|
|
||||||
tlsArgs.sslKey = b[2]
|
|
||||||
case "connect_timeout":
|
|
||||||
timeout, err := strconv.ParseInt(b[2], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
d := defaultDialer()
|
|
||||||
d.Timeout = time.Duration(timeout) * time.Second
|
|
||||||
cp.Dial = d.Dial
|
|
||||||
default:
|
|
||||||
cp.RuntimeParams[b[1]] = b[2]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := configTLS(tlsArgs, &cp)
|
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
if cp.Password == "" {
|
|
||||||
pgpass(&cp)
|
|
||||||
}
|
|
||||||
return cp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseConnectionString parses either a URI or a DSN connection string.
|
|
||||||
// see ParseURI and ParseDSN for details.
|
|
||||||
func ParseConnectionString(s string) (ConnConfig, error) {
|
|
||||||
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
|
|
||||||
return ParseURI(s)
|
|
||||||
}
|
|
||||||
return ParseDSN(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseEnvLibpq parses the environment like libpq does into a ConnConfig
|
|
||||||
//
|
|
||||||
// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details
|
|
||||||
// on the meaning of environment variables.
|
|
||||||
//
|
|
||||||
// ParseEnvLibpq currently recognizes the following environment variables:
|
|
||||||
// PGHOST
|
|
||||||
// PGPORT
|
|
||||||
// PGDATABASE
|
|
||||||
// PGUSER
|
|
||||||
// PGPASSWORD
|
|
||||||
// PGSSLMODE
|
|
||||||
// PGSSLCERT
|
|
||||||
// PGSSLKEY
|
|
||||||
// PGSSLROOTCERT
|
|
||||||
// PGAPPNAME
|
|
||||||
// PGCONNECT_TIMEOUT
|
|
||||||
//
|
|
||||||
// Important TLS Security Notes:
|
|
||||||
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
|
|
||||||
// includes defaulting to "prefer" behavior if no environment variable is set.
|
|
||||||
//
|
|
||||||
// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION
|
|
||||||
// for details on what level of security each sslmode provides.
|
|
||||||
//
|
|
||||||
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
|
|
||||||
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
|
||||||
// may be possible to match libpq in the future. If you need full security use
|
|
||||||
// "verify-full".
|
|
||||||
//
|
|
||||||
// Several of the PGSSLMODE options (including the default behavior of "prefer")
|
|
||||||
// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or
|
|
||||||
// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is
|
|
||||||
// later set from a different source that UseFallbackTLS MUST be set false to
|
|
||||||
// avoid the possibility of falling back to weaker or disabled security.
|
|
||||||
func ParseEnvLibpq() (ConnConfig, error) {
|
|
||||||
var cc ConnConfig
|
|
||||||
|
|
||||||
cc.Host = os.Getenv("PGHOST")
|
|
||||||
|
|
||||||
if pgport := os.Getenv("PGPORT"); pgport != "" {
|
|
||||||
if port, err := strconv.ParseUint(pgport, 10, 16); err == nil {
|
|
||||||
cc.Port = uint16(port)
|
|
||||||
} else {
|
|
||||||
return cc, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cc.Database = os.Getenv("PGDATABASE")
|
|
||||||
cc.User = os.Getenv("PGUSER")
|
|
||||||
cc.Password = os.Getenv("PGPASSWORD")
|
|
||||||
|
|
||||||
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
|
|
||||||
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
|
|
||||||
d := defaultDialer()
|
|
||||||
d.Timeout = time.Duration(timeout) * time.Second
|
|
||||||
cc.Dial = d.Dial
|
|
||||||
} else {
|
|
||||||
return cc, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsArgs := configTLSArgs{
|
|
||||||
sslMode: os.Getenv("PGSSLMODE"),
|
|
||||||
sslKey: os.Getenv("PGSSLKEY"),
|
|
||||||
sslCert: os.Getenv("PGSSLCERT"),
|
|
||||||
sslRootCert: os.Getenv("PGSSLROOTCERT"),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := configTLS(tlsArgs, &cc)
|
|
||||||
if err != nil {
|
|
||||||
return cc, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cc.RuntimeParams = make(map[string]string)
|
|
||||||
if appname := os.Getenv("PGAPPNAME"); appname != "" {
|
|
||||||
cc.RuntimeParams["application_name"] = appname
|
|
||||||
}
|
|
||||||
if cc.Password == "" {
|
|
||||||
pgpass(&cc)
|
|
||||||
}
|
|
||||||
return cc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type configTLSArgs struct {
|
|
||||||
sslMode string
|
|
||||||
sslRootCert string
|
|
||||||
sslCert string
|
|
||||||
sslKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config.
|
|
||||||
// Inputs are parsed out and provided by ParseDSN() or ParseURI().
|
|
||||||
func configTLS(args configTLSArgs, cc *ConnConfig) error {
|
|
||||||
// Match libpq default behavior
|
|
||||||
if args.sslMode == "" {
|
|
||||||
args.sslMode = "prefer"
|
|
||||||
}
|
|
||||||
|
|
||||||
switch args.sslMode {
|
|
||||||
case "disable":
|
|
||||||
cc.UseFallbackTLS = false
|
|
||||||
cc.TLSConfig = nil
|
|
||||||
cc.FallbackTLSConfig = nil
|
|
||||||
return nil
|
|
||||||
case "allow":
|
|
||||||
cc.UseFallbackTLS = true
|
|
||||||
cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
case "prefer":
|
|
||||||
cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
cc.UseFallbackTLS = true
|
|
||||||
cc.FallbackTLSConfig = nil
|
|
||||||
case "require":
|
|
||||||
cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
case "verify-ca", "verify-full":
|
|
||||||
cc.TLSConfig = &tls.Config{
|
|
||||||
ServerName: cc.Host,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errors.New("sslmode is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.sslRootCert != "" {
|
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
|
|
||||||
caPath := args.sslRootCert
|
|
||||||
caCert, err := ioutil.ReadFile(caPath)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "unable to read CA file %q", caPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
|
||||||
return errors.Wrap(err, "unable to add CA to cert pool")
|
|
||||||
}
|
|
||||||
|
|
||||||
cc.TLSConfig.RootCAs = caCertPool
|
|
||||||
cc.TLSConfig.ClientCAs = caCertPool
|
|
||||||
}
|
|
||||||
|
|
||||||
sslcert := args.sslCert
|
|
||||||
sslkey := args.sslKey
|
|
||||||
|
|
||||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
|
||||||
return fmt.Errorf(`both "sslcert" and "sslkey" are required`)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sslcert != "" && sslkey != "" {
|
|
||||||
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "unable to read cert")
|
|
||||||
}
|
|
||||||
|
|
||||||
cc.TLSConfig.Certificates = []tls.Certificate{cert}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
||||||
|
@ -1336,9 +905,9 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case pgproto3.AuthTypeOk:
|
case pgproto3.AuthTypeOk:
|
||||||
case pgproto3.AuthTypeCleartextPassword:
|
case pgproto3.AuthTypeCleartextPassword:
|
||||||
err = c.txPasswordMessage(c.config.Password)
|
err = c.txPasswordMessage(c.pgConn.Config.Password)
|
||||||
case pgproto3.AuthTypeMD5Password:
|
case pgproto3.AuthTypeMD5Password:
|
||||||
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
|
digestedPassword := "md5" + hexMD5(hexMD5(c.pgConn.Config.Password+c.pgConn.Config.User)+string(msg.Salt[:]))
|
||||||
err = c.txPasswordMessage(digestedPassword)
|
err = c.txPasswordMessage(digestedPassword)
|
||||||
default:
|
default:
|
||||||
err = errors.New("Received unknown authentication message")
|
err = errors.New("Received unknown authentication message")
|
||||||
|
@ -1554,8 +1123,11 @@ func quoteIdentifier(s string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func doCancel(c *Conn) error {
|
func doCancel(c *Conn) error {
|
||||||
network, address := c.pgConn.Config.NetworkAddress()
|
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
|
||||||
cancelConn, err := c.pgConn.Config.Dial(network, address)
|
// the connection config. This is important in high availability configurations where fallback connections may be
|
||||||
|
// specified or DNS may be used to load balance.
|
||||||
|
serverAddr := c.pgConn.NetConn.RemoteAddr()
|
||||||
|
cancelConn, err := c.pgConn.Config.DialFunc(context.TODO(), serverAddr.Network(), serverAddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -297,7 +297,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) createConnection() (*Conn, error) {
|
func (p *ConnPool) createConnection() (*Conn, error) {
|
||||||
c, err := connect(p.config, p.connInfo)
|
c, err := connect(context.TODO(), &p.config, p.connInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -319,7 +319,8 @@ func (p *ConnPool) createConnection() (*Conn, error) {
|
||||||
func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
||||||
p.inProgressConnects++
|
p.inProgressConnects++
|
||||||
p.cond.L.Unlock()
|
p.cond.L.Unlock()
|
||||||
c, err := Connect(p.config)
|
// c, err := Connect(p.config)
|
||||||
|
c, err := Connect(context.TODO(), "TODO")
|
||||||
p.cond.L.Lock()
|
p.cond.L.Lock()
|
||||||
p.inProgressConnects--
|
p.inProgressConnects--
|
||||||
|
|
||||||
|
|
|
@ -162,7 +162,7 @@ func TestPoolNonBlockingConnections(t *testing.T) {
|
||||||
var dialCountLock sync.Mutex
|
var dialCountLock sync.Mutex
|
||||||
dialCount := 0
|
dialCount := 0
|
||||||
openTimeout := 1 * time.Second
|
openTimeout := 1 * time.Second
|
||||||
testDialer := func(network, address string) (net.Conn, error) {
|
testDialer := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
var firstDial bool
|
var firstDial bool
|
||||||
dialCountLock.Lock()
|
dialCountLock.Lock()
|
||||||
dialCount++
|
dialCount++
|
||||||
|
@ -182,7 +182,7 @@ func TestPoolNonBlockingConnections(t *testing.T) {
|
||||||
ConnConfig: *defaultConnConfig,
|
ConnConfig: *defaultConnConfig,
|
||||||
MaxConnections: maxConnections,
|
MaxConnections: maxConnections,
|
||||||
}
|
}
|
||||||
config.ConnConfig.Dial = testDialer
|
config.ConnConfig.Config.DialFunc = testDialer
|
||||||
|
|
||||||
pool, err := pgx.NewConnPool(config)
|
pool, err := pgx.NewConnPool(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
834
conn_test.go
834
conn_test.go
|
@ -2,11 +2,8 @@ package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -24,7 +21,7 @@ func TestCrateDBConnect(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined cratedbConnConfig")
|
t.Skip("Skipping due to undefined cratedbConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(*cratedbConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), cratedbConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -47,7 +44,7 @@ func TestCrateDBConnect(t *testing.T) {
|
||||||
func TestConnect(t *testing.T) {
|
func TestConnect(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn, err := pgx.Connect(*defaultConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -65,8 +62,8 @@ func TestConnect(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||||
}
|
}
|
||||||
if currentDB != defaultConnConfig.Database {
|
if currentDB != defaultConnConfig.Config.Database {
|
||||||
t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database)
|
t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Config.Database)
|
||||||
}
|
}
|
||||||
|
|
||||||
var user string
|
var user string
|
||||||
|
@ -74,8 +71,8 @@ func TestConnect(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||||
}
|
}
|
||||||
if user != defaultConnConfig.User {
|
if user != defaultConnConfig.Config.User {
|
||||||
t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User)
|
t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.Config.User)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
|
@ -92,27 +89,7 @@ func TestConnectWithUnixSocketDirectory(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined unixSocketConnConfig")
|
t.Skip("Skipping due to undefined unixSocketConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(*unixSocketConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), unixSocketConnConfig)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Unable to close connection")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectWithUnixSocketFile(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if unixSocketConnConfig == nil {
|
|
||||||
t.Skip("Skipping due to undefined unixSocketConnConfig")
|
|
||||||
}
|
|
||||||
|
|
||||||
connParams := *unixSocketConnConfig
|
|
||||||
connParams.Host = connParams.Host + "/.s.PGSQL.5432"
|
|
||||||
conn, err := pgx.Connect(connParams)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -130,7 +107,7 @@ func TestConnectWithTcp(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined tcpConnConfig")
|
t.Skip("Skipping due to undefined tcpConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(*tcpConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), tcpConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to establish connection: " + err.Error())
|
t.Fatal("Unable to establish connection: " + err.Error())
|
||||||
}
|
}
|
||||||
|
@ -148,7 +125,7 @@ func TestConnectWithTLS(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined tlsConnConfig")
|
t.Skip("Skipping due to undefined tlsConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(*tlsConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), tlsConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to establish connection: " + err.Error())
|
t.Fatal("Unable to establish connection: " + err.Error())
|
||||||
}
|
}
|
||||||
|
@ -166,7 +143,7 @@ func TestConnectWithInvalidUser(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined invalidUserConnConfig")
|
t.Skip("Skipping due to undefined invalidUserConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := pgx.Connect(*invalidUserConnConfig)
|
_, err := pgx.ConnectConfig(context.Background(), invalidUserConnConfig)
|
||||||
pgErr, ok := err.(pgx.PgError)
|
pgErr, ok := err.(pgx.PgError)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err)
|
t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err)
|
||||||
|
@ -183,7 +160,7 @@ func TestConnectWithPlainTextPassword(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined plainPasswordConnConfig")
|
t.Skip("Skipping due to undefined plainPasswordConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(*plainPasswordConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), plainPasswordConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to establish connection: " + err.Error())
|
t.Fatal("Unable to establish connection: " + err.Error())
|
||||||
}
|
}
|
||||||
|
@ -201,38 +178,7 @@ func TestConnectWithMD5Password(t *testing.T) {
|
||||||
t.Skip("Skipping due to undefined md5ConnConfig")
|
t.Skip("Skipping due to undefined md5ConnConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(*md5ConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), md5ConnConfig)
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Unable to establish connection: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Unable to close connection")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectWithTLSFallback(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if tlsConnConfig == nil {
|
|
||||||
t.Skip("Skipping due to undefined tlsConnConfig")
|
|
||||||
}
|
|
||||||
|
|
||||||
connConfig := *tlsConnConfig
|
|
||||||
connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure
|
|
||||||
|
|
||||||
conn, err := pgx.Connect(connConfig)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected failed connection, but succeeded")
|
|
||||||
}
|
|
||||||
|
|
||||||
connConfig = *tlsConnConfig
|
|
||||||
connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"}
|
|
||||||
connConfig.UseFallbackTLS = true
|
|
||||||
connConfig.FallbackTLSConfig = &tls.Config{ServerName: "bogus.local", InsecureSkipVerify: true}
|
|
||||||
|
|
||||||
conn, err = pgx.Connect(connConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to establish connection: " + err.Error())
|
t.Fatal("Unable to establish connection: " + err.Error())
|
||||||
}
|
}
|
||||||
|
@ -251,7 +197,7 @@ func TestConnectWithConnectionRefused(t *testing.T) {
|
||||||
bad.Host = "127.0.0.1"
|
bad.Host = "127.0.0.1"
|
||||||
bad.Port = 1
|
bad.Port = 1
|
||||||
|
|
||||||
_, err := pgx.Connect(bad)
|
_, err := pgx.ConnectConfig(context.Background(), &bad)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error establishing connection to bad port")
|
t.Fatal("Expected error establishing connection to bad port")
|
||||||
}
|
}
|
||||||
|
@ -291,12 +237,12 @@ func TestConnectCustomDialer(t *testing.T) {
|
||||||
|
|
||||||
dialled := false
|
dialled := false
|
||||||
conf := *customDialerConnConfig
|
conf := *customDialerConnConfig
|
||||||
conf.Dial = func(network, address string) (net.Conn, error) {
|
conf.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
dialled = true
|
dialled = true
|
||||||
return net.Dial(network, address)
|
return net.Dial(network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(conf)
|
conn, err := pgx.ConnectConfig(context.Background(), &conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %s", err)
|
t.Fatalf("Unable to establish connection: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -319,7 +265,7 @@ func TestConnectWithRuntimeParams(t *testing.T) {
|
||||||
"search_path": "myschema",
|
"search_path": "myschema",
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgx.Connect(connConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), &connConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -343,750 +289,6 @@ func TestConnectWithRuntimeParams(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseURI(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
url string
|
|
||||||
connParams pgx.ConnConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack:secret@localhost:5432/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgresql://jack:secret@localhost:5432/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost:5432/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{
|
|
||||||
"application_name": "pgxtest",
|
|
||||||
"search_path": "myschema",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres:///foo?host=/tmp",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
Host: "/tmp",
|
|
||||||
Database: "foo",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
connParams, err := pgx.ParseURI(tt.url)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%d. Unexpected error from pgx.ParseURL(%q) => %v", i, tt.url, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
|
||||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseDSN(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
url string
|
|
||||||
connParams pgx.ConnConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost port=5432 dbname=mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost dbname=mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{
|
|
||||||
"application_name": "pgxtest",
|
|
||||||
"search_path": "myschema",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost dbname=mydb connect_timeout=10",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
actual, err := pgx.ParseDSN(tt.url)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseConnectionString(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
url string
|
|
||||||
connParams pgx.ConnConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: nil,
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack:secret@localhost:5432/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgresql://jack:secret@localhost:5432/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost:5432/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost/mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{
|
|
||||||
"application_name": "pgxtest",
|
|
||||||
"search_path": "myschema",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "postgres://jack@localhost/mydb?connect_timeout=10",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Password: "secret",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost port=5432 dbname=mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Port: 5432,
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost dbname=mydb",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema",
|
|
||||||
connParams: pgx.ConnConfig{
|
|
||||||
User: "jack",
|
|
||||||
Host: "localhost",
|
|
||||||
Database: "mydb",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{
|
|
||||||
"application_name": "pgxtest",
|
|
||||||
"search_path": "myschema",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
actual, err := pgx.ParseConnectionString(tt.url)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) {
|
|
||||||
if actual.Host != expected.Host {
|
|
||||||
t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host)
|
|
||||||
}
|
|
||||||
if actual.Database != expected.Database {
|
|
||||||
t.Errorf("%s: expected Database to be %v got %v", testName, expected.Database, actual.Database)
|
|
||||||
}
|
|
||||||
if actual.Port != expected.Port {
|
|
||||||
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
|
|
||||||
}
|
|
||||||
if actual.Port != expected.Port {
|
|
||||||
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
|
|
||||||
}
|
|
||||||
if actual.User != expected.User {
|
|
||||||
t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User)
|
|
||||||
}
|
|
||||||
if actual.Password != expected.Password {
|
|
||||||
t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password)
|
|
||||||
}
|
|
||||||
// Cannot test value of underlying Dialer stuct but can at least test if Dial func is set.
|
|
||||||
if (actual.Dial != nil) != (expected.Dial != nil) {
|
|
||||||
t.Errorf("%s: expected Dial mismatch", testName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) {
|
|
||||||
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsTests := []struct {
|
|
||||||
name string
|
|
||||||
expected *tls.Config
|
|
||||||
actual *tls.Config
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "TLSConfig",
|
|
||||||
expected: expected.TLSConfig,
|
|
||||||
actual: actual.TLSConfig,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FallbackTLSConfig",
|
|
||||||
expected: expected.FallbackTLSConfig,
|
|
||||||
actual: actual.FallbackTLSConfig,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tlsTest := range tlsTests {
|
|
||||||
name := tlsTest.name
|
|
||||||
expected := tlsTest.expected
|
|
||||||
actual := tlsTest.actual
|
|
||||||
|
|
||||||
if expected == nil && actual != nil {
|
|
||||||
t.Errorf("%s / %s: expected nil, but it was set", testName, name)
|
|
||||||
} else if expected != nil && actual == nil {
|
|
||||||
t.Errorf("%s / %s: expected to be set, but got nil", testName, name)
|
|
||||||
} else if expected != nil && actual != nil {
|
|
||||||
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
|
|
||||||
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
|
|
||||||
}
|
|
||||||
|
|
||||||
if actual.ServerName != expected.ServerName {
|
|
||||||
t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if actual.UseFallbackTLS != expected.UseFallbackTLS {
|
|
||||||
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseEnvLibpq(t *testing.T) {
|
|
||||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
|
||||||
|
|
||||||
savedEnv := make(map[string]string)
|
|
||||||
for _, n := range pgEnvvars {
|
|
||||||
savedEnv[n] = os.Getenv(n)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
for k, v := range savedEnv {
|
|
||||||
err := os.Setenv(k, v)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to restore environment: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
envvars map[string]string
|
|
||||||
config pgx.ConnConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "No environment",
|
|
||||||
envvars: map[string]string{},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Normal PG vars",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGHOST": "123.123.123.123",
|
|
||||||
"PGPORT": "7777",
|
|
||||||
"PGDATABASE": "foo",
|
|
||||||
"PGUSER": "bar",
|
|
||||||
"PGPASSWORD": "baz",
|
|
||||||
"PGCONNECT_TIMEOUT": "10",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
Host: "123.123.123.123",
|
|
||||||
Port: 7777,
|
|
||||||
Database: "foo",
|
|
||||||
User: "bar",
|
|
||||||
Password: "baz",
|
|
||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "application_name",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGAPPNAME": "pgxtest",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=disable",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGSSLMODE": "disable",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: nil,
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=allow",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGSSLMODE": "allow",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: nil,
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=prefer",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGSSLMODE": "prefer",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
UseFallbackTLS: true,
|
|
||||||
FallbackTLSConfig: nil,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=require",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGSSLMODE": "require",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=verify-ca",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGSSLMODE": "verify-ca",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: &tls.Config{},
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=verify-full",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGSSLMODE": "verify-full",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
TLSConfig: &tls.Config{},
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sslmode=verify-full with host",
|
|
||||||
envvars: map[string]string{
|
|
||||||
"PGHOST": "pgx.example",
|
|
||||||
"PGSSLMODE": "verify-full",
|
|
||||||
},
|
|
||||||
config: pgx.ConnConfig{
|
|
||||||
Host: "pgx.example",
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
ServerName: "pgx.example",
|
|
||||||
},
|
|
||||||
UseFallbackTLS: false,
|
|
||||||
RuntimeParams: map[string]string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
for _, n := range pgEnvvars {
|
|
||||||
err := os.Unsetenv(n)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%s: Unable to clear environment: %v", tt.name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range tt.envvars {
|
|
||||||
err := os.Setenv(k, v)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%s: Unable to set environment: %v", tt.name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
actual, err := pgx.ParseEnvLibpq()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
testConnConfigEquals(t, tt.config, actual, tt.name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExec(t *testing.T) {
|
func TestExec(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -1863,7 +1065,7 @@ func TestFatalRxError(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
otherConn, err := pgx.Connect(*defaultConnConfig)
|
otherConn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1889,7 +1091,7 @@ func TestFatalTxError(t *testing.T) {
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
defer closeConn(t, conn)
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
otherConn, err := pgx.Connect(*defaultConnConfig)
|
otherConn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -72,7 +73,7 @@ func (src *Point) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_CustomType() {
|
func Example_CustomType() {
|
||||||
conn, err := pgx.Connect(*defaultConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Unable to establish connection: %v", err)
|
fmt.Printf("Unable to establish connection: %v", err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Example_JSON() {
|
func Example_JSON() {
|
||||||
conn, err := pgx.Connect(*defaultConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Unable to establish connection: %v", err)
|
fmt.Printf("Unable to establish connection: %v", err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
||||||
conn, err := pgx.Connect(config)
|
conn, err := pgx.ConnectConfig(context.Background(), &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -15,7 +16,7 @@ func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn {
|
func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn {
|
||||||
conn, err := pgx.ReplicationConnect(config)
|
conn, err := pgx.ReplicationConnect(&config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to establish connection: %v", err)
|
t.Fatalf("Unable to establish connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -10,7 +11,7 @@ import (
|
||||||
func TestLargeObjects(t *testing.T) {
|
func TestLargeObjects(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn, err := pgx.Connect(*defaultConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -123,7 +124,7 @@ func TestLargeObjects(t *testing.T) {
|
||||||
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
func TestLargeObjectsMultipleTransactions(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn, err := pgx.Connect(*defaultConnConfig)
|
conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,421 @@
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgpassfile"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||||
|
type Config struct {
|
||||||
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
Database string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
|
Fallbacks []*FallbackConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||||
|
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||||
|
type FallbackConfig struct {
|
||||||
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||||
|
// net.Dial.
|
||||||
|
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||||
|
if strings.HasPrefix(host, "/") {
|
||||||
|
network = "unix"
|
||||||
|
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||||
|
} else {
|
||||||
|
network = "tcp"
|
||||||
|
address = fmt.Sprintf("%s:%d", host, port)
|
||||||
|
}
|
||||||
|
return network, address
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq.
|
||||||
|
// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment
|
||||||
|
// variables. connString may be a URL or a DSN. It also may be empty to only read from the
|
||||||
|
// environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||||
|
//
|
||||||
|
// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca"
|
||||||
|
//
|
||||||
|
// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca"
|
||||||
|
//
|
||||||
|
// Multiple configs may be returned due to sslmode settings with fallback options (e.g.
|
||||||
|
// sslmode=prefer). Future implementations may also support multiple hosts
|
||||||
|
// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS).
|
||||||
|
//
|
||||||
|
// ParseConfig currently recognizes the following environment variable and their parameter key word
|
||||||
|
// equivalents passed via database URL or DSN:
|
||||||
|
//
|
||||||
|
// PGHOST
|
||||||
|
// PGPORT
|
||||||
|
// PGDATABASE
|
||||||
|
// PGUSER
|
||||||
|
// PGPASSWORD
|
||||||
|
// PGPASSFILE
|
||||||
|
// PGSSLMODE
|
||||||
|
// PGSSLCERT
|
||||||
|
// PGSSLKEY
|
||||||
|
// PGSSLROOTCERT
|
||||||
|
// PGAPPNAME
|
||||||
|
// PGCONNECT_TIMEOUT
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of
|
||||||
|
// environment variables.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key
|
||||||
|
// word names. They are usually but not always the environment variable name downcased and without
|
||||||
|
// the "PG" prefix.
|
||||||
|
//
|
||||||
|
// Important TLS Security Notes:
|
||||||
|
//
|
||||||
|
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to
|
||||||
|
// "prefer" behavior if not set.
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on
|
||||||
|
// what level of security each sslmode provides.
|
||||||
|
//
|
||||||
|
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
|
||||||
|
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
||||||
|
// may be possible to match libpq in the future. If you need full security use
|
||||||
|
// "verify-full".
|
||||||
|
func ParseConfig(connString string) (*Config, error) {
|
||||||
|
settings := defaultSettings()
|
||||||
|
addEnvSettings(settings)
|
||||||
|
|
||||||
|
if connString != "" {
|
||||||
|
// connString may be a database URL or a DSN
|
||||||
|
if strings.HasPrefix(connString, "postgres://") {
|
||||||
|
url, err := url.Parse(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addURLSettings(settings, url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := addDSNSettings(settings, connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Host: settings["host"],
|
||||||
|
Database: settings["database"],
|
||||||
|
User: settings["user"],
|
||||||
|
Password: settings["password"],
|
||||||
|
RuntimeParams: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
if port, err := parsePort(settings["port"]); err == nil {
|
||||||
|
config.Port = port
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("invalid port: %v", settings["port"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if connectTimeout, present := settings["connect_timeout"]; present {
|
||||||
|
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
config.DialFunc = dialFunc
|
||||||
|
} else {
|
||||||
|
defaultDialer := makeDefaultDialer()
|
||||||
|
config.DialFunc = defaultDialer.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
notRuntimeParams := map[string]struct{}{
|
||||||
|
"host": struct{}{},
|
||||||
|
"port": struct{}{},
|
||||||
|
"database": struct{}{},
|
||||||
|
"user": struct{}{},
|
||||||
|
"password": struct{}{},
|
||||||
|
"passfile": struct{}{},
|
||||||
|
"connect_timeout": struct{}{},
|
||||||
|
"sslmode": struct{}{},
|
||||||
|
"sslkey": struct{}{},
|
||||||
|
"sslcert": struct{}{},
|
||||||
|
"sslrootcert": struct{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range settings {
|
||||||
|
if _, present := notRuntimeParams[k]; present {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
config.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfigs []*tls.Config
|
||||||
|
|
||||||
|
// Ignore TLS settings if Unix domain socket like libpq
|
||||||
|
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||||
|
tlsConfigs = append(tlsConfigs, nil)
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
tlsConfigs, err = configTLS(settings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.TLSConfig = tlsConfigs[0]
|
||||||
|
|
||||||
|
for _, tlsConfig := range tlsConfigs[1:] {
|
||||||
|
config.Fallbacks = append(config.Fallbacks, &FallbackConfig{
|
||||||
|
Host: config.Host,
|
||||||
|
Port: config.Port,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||||
|
if err == nil {
|
||||||
|
if config.Password == "" {
|
||||||
|
host := config.Host
|
||||||
|
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||||
|
host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
settings["user"] = user.Username
|
||||||
|
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
candidatePaths := []string{
|
||||||
|
"/var/run/postgresql", // Debian
|
||||||
|
"/private/tmp", // OSX - homebrew
|
||||||
|
"/tmp", // standard PostgreSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range candidatePaths {
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
func addEnvSettings(settings map[string]string) {
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"PGHOST": "host",
|
||||||
|
"PGPORT": "port",
|
||||||
|
"PGDATABASE": "database",
|
||||||
|
"PGUSER": "user",
|
||||||
|
"PGPASSWORD": "password",
|
||||||
|
"PGPASSFILE": "passfile",
|
||||||
|
"PGAPPNAME": "application_name",
|
||||||
|
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||||
|
"PGSSLMODE": "sslmode",
|
||||||
|
"PGSSLKEY": "sslkey",
|
||||||
|
"PGSSLCERT": "sslcert",
|
||||||
|
"PGSSLROOTCERT": "sslrootcert",
|
||||||
|
}
|
||||||
|
|
||||||
|
for envname, realname := range nameMap {
|
||||||
|
value := os.Getenv(envname)
|
||||||
|
if value != "" {
|
||||||
|
settings[realname] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addURLSettings(settings map[string]string, url *url.URL) error {
|
||||||
|
if url.User != nil {
|
||||||
|
settings["user"] = url.User.Username()
|
||||||
|
if password, present := url.User.Password(); present {
|
||||||
|
settings["password"] = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(url.Host, ":", 2)
|
||||||
|
if parts[0] != "" {
|
||||||
|
settings["host"] = parts[0]
|
||||||
|
}
|
||||||
|
if len(parts) == 2 {
|
||||||
|
settings["port"] = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
database := strings.TrimLeft(url.Path, "/")
|
||||||
|
if database != "" {
|
||||||
|
settings["database"] = database
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range url.Query() {
|
||||||
|
settings[k] = v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||||
|
|
||||||
|
func addDSNSettings(settings map[string]string, s string) error {
|
||||||
|
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
||||||
|
|
||||||
|
for _, b := range m {
|
||||||
|
settings[b[1]] = b[2]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type pgTLSArgs struct {
|
||||||
|
sslMode string
|
||||||
|
sslRootCert string
|
||||||
|
sslCert string
|
||||||
|
sslKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||||
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||||
|
// "prefer" allow fallback.
|
||||||
|
func configTLS(settings map[string]string) ([]*tls.Config, error) {
|
||||||
|
host := settings["host"]
|
||||||
|
sslmode := settings["sslmode"]
|
||||||
|
sslrootcert := settings["sslrootcert"]
|
||||||
|
sslcert := settings["sslcert"]
|
||||||
|
sslkey := settings["sslkey"]
|
||||||
|
|
||||||
|
// Match libpq default behavior
|
||||||
|
if sslmode == "" {
|
||||||
|
sslmode = "prefer"
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "disable":
|
||||||
|
return []*tls.Config{nil}, nil
|
||||||
|
case "allow", "prefer":
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
case "require":
|
||||||
|
tlsConfig.InsecureSkipVerify = sslrootcert == ""
|
||||||
|
case "verify-ca", "verify-full":
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
default:
|
||||||
|
return nil, errors.New("sslmode is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslrootcert != "" {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
|
||||||
|
caPath := sslrootcert
|
||||||
|
caCert, err := ioutil.ReadFile(caPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "unable to read CA file %q", caPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return nil, errors.Wrap(err, "unable to add CA to cert pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
tlsConfig.ClientCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||||
|
return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslcert != "" && sslkey != "" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "unable to read cert")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "allow":
|
||||||
|
return []*tls.Config{nil, tlsConfig}, nil
|
||||||
|
case "prefer":
|
||||||
|
return []*tls.Config{tlsConfig, nil}, nil
|
||||||
|
case "require", "verify-ca", "verify-full":
|
||||||
|
return []*tls.Config{tlsConfig}, nil
|
||||||
|
default:
|
||||||
|
panic("BUG: bad sslmode should already have been caught")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(s string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(s, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > math.MaxUint16 {
|
||||||
|
return 0, errors.New("outside range")
|
||||||
|
}
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultDialer() *net.Dialer {
|
||||||
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
||||||
|
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if timeout < 0 {
|
||||||
|
return nil, errors.New("negative timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
d := makeDefaultDialer()
|
||||||
|
d.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
return d.DialContext, nil
|
||||||
|
}
|
|
@ -0,0 +1,392 @@
|
||||||
|
package pgconn_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgconn"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var osUserName string
|
||||||
|
osUser, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
osUserName = osUser.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connString string
|
||||||
|
config *pgconn.Config
|
||||||
|
}{
|
||||||
|
// Test all sslmodes
|
||||||
|
{
|
||||||
|
name: "sslmode not set (prefer)",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode disable",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode allow",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode prefer",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode require",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode verify-ca",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sslmode verify-full",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url everything",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{
|
||||||
|
"application_name": "pgxtest",
|
||||||
|
"search_path": "myschema",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url missing password",
|
||||||
|
connString: "postgres://jack@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url missing user and password",
|
||||||
|
connString: "postgres://localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url missing port",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "database url unix domain socket host",
|
||||||
|
connString: "postgres:///foo?host=/tmp",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "/tmp",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "foo",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DSN everything",
|
||||||
|
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{
|
||||||
|
"application_name": "pgxtest",
|
||||||
|
"search_path": "myschema",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
config, err := pgconn.ParseConfig(tt.connString)
|
||||||
|
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||||
|
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||||
|
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||||
|
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||||
|
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||||
|
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||||
|
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||||
|
|
||||||
|
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||||
|
if expected.TLSConfig != nil {
|
||||||
|
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
||||||
|
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) {
|
||||||
|
for i := range expected.Fallbacks {
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
||||||
|
|
||||||
|
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) {
|
||||||
|
if expected.Fallbacks[i].TLSConfig != nil {
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
||||||
|
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigEnvLibpq(t *testing.T) {
|
||||||
|
var osUserName string
|
||||||
|
osUser, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
osUserName = osUser.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
||||||
|
|
||||||
|
savedEnv := make(map[string]string)
|
||||||
|
for _, n := range pgEnvvars {
|
||||||
|
savedEnv[n] = os.Getenv(n)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
for k, v := range savedEnv {
|
||||||
|
err := os.Setenv(k, v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to restore environment: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envvars map[string]string
|
||||||
|
config *pgconn.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// not testing no environment at all as that would use default host and that can vary.
|
||||||
|
name: "PGHOST only",
|
||||||
|
envvars: map[string]string{"PGHOST": "123.123.123.123"},
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: osUserName,
|
||||||
|
Host: "123.123.123.123",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
Fallbacks: []*pgconn.FallbackConfig{
|
||||||
|
&pgconn.FallbackConfig{
|
||||||
|
Host: "123.123.123.123",
|
||||||
|
Port: 5432,
|
||||||
|
TLSConfig: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All non-TLS environment",
|
||||||
|
envvars: map[string]string{
|
||||||
|
"PGHOST": "123.123.123.123",
|
||||||
|
"PGPORT": "7777",
|
||||||
|
"PGDATABASE": "foo",
|
||||||
|
"PGUSER": "bar",
|
||||||
|
"PGPASSWORD": "baz",
|
||||||
|
"PGCONNECT_TIMEOUT": "10",
|
||||||
|
"PGSSLMODE": "disable",
|
||||||
|
"PGAPPNAME": "pgxtest",
|
||||||
|
},
|
||||||
|
config: &pgconn.Config{
|
||||||
|
Host: "123.123.123.123",
|
||||||
|
Port: 7777,
|
||||||
|
Database: "foo",
|
||||||
|
User: "bar",
|
||||||
|
Password: "baz",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
for _, n := range pgEnvvars {
|
||||||
|
err := os.Unsetenv(n)
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range tt.envvars {
|
||||||
|
err := os.Setenv(k, v)
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig("")
|
||||||
|
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||||
|
tf, err := ioutil.TempFile("", "")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
defer tf.Close()
|
||||||
|
defer os.Remove(tf.Name())
|
||||||
|
|
||||||
|
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
|
||||||
|
expected := &pgconn.Config{
|
||||||
|
User: "curly",
|
||||||
|
Password: "nyuknyuknyuk",
|
||||||
|
Host: "test1",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "curlydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, err := pgconn.ParseConfig(connString)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assertConfigsEqual(t, expected, actual, "passfile")
|
||||||
|
}
|
130
pgconn/pgconn.go
130
pgconn/pgconn.go
|
@ -1,20 +1,16 @@
|
||||||
package pgconn
|
package pgconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/user"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgx/pgio"
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgx/pgproto3"
|
||||||
|
@ -23,7 +19,7 @@ import (
|
||||||
const batchBufferSize = 4096
|
const batchBufferSize = 4096
|
||||||
|
|
||||||
// PgError represents an error reported by the PostgreSQL server. See
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||||
// detailed field description.
|
// detailed field description.
|
||||||
type PgError struct {
|
type PgError struct {
|
||||||
Severity string
|
Severity string
|
||||||
|
@ -50,60 +46,12 @@ func (pe PgError) Error() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||||
type DialFunc func(network, addr string) (net.Conn, error)
|
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||||
// PostgreSQL server refuses to use TLS
|
// PostgreSQL server refuses to use TLS
|
||||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||||
|
|
||||||
type ConnConfig struct {
|
|
||||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
|
||||||
Port uint16 // default: 5432
|
|
||||||
Database string
|
|
||||||
User string // default: OS user name
|
|
||||||
Password string
|
|
||||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
|
||||||
Dial DialFunc
|
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cc *ConnConfig) NetworkAddress() (network, address string) {
|
|
||||||
// If host is a valid path, then address is unix socket
|
|
||||||
if _, err := os.Stat(cc.Host); err == nil {
|
|
||||||
network = "unix"
|
|
||||||
address = cc.Host
|
|
||||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
|
||||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
network = "tcp"
|
|
||||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
return network, address
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cc *ConnConfig) assignDefaults() error {
|
|
||||||
if cc.User == "" {
|
|
||||||
user, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cc.User = user.Username
|
|
||||||
}
|
|
||||||
|
|
||||||
if cc.Port == 0 {
|
|
||||||
cc.Port = 5432
|
|
||||||
}
|
|
||||||
|
|
||||||
if cc.Dial == nil {
|
|
||||||
defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute}
|
|
||||||
cc.Dial = defaultDialer.Dial
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
type PgConn struct {
|
type PgConn struct {
|
||||||
NetConn net.Conn // the underlying TCP or unix domain socket connection
|
NetConn net.Conn // the underlying TCP or unix domain socket connection
|
||||||
|
@ -113,7 +61,7 @@ type PgConn struct {
|
||||||
TxStatus byte
|
TxStatus byte
|
||||||
Frontend *pgproto3.Frontend
|
Frontend *pgproto3.Frontend
|
||||||
|
|
||||||
Config ConnConfig
|
Config *Config
|
||||||
|
|
||||||
batchBuf []byte
|
batchBuf []byte
|
||||||
batchCount int32
|
batchCount int32
|
||||||
|
@ -123,24 +71,72 @@ type PgConn struct {
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func Connect(cc ConnConfig) (*PgConn, error) {
|
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
|
||||||
err := cc.assignDefaults()
|
// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt.
|
||||||
|
func Connect(ctx context.Context, connString string) (*PgConn, error) {
|
||||||
|
config, err := ParseConfig(connString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn := new(PgConn)
|
return ConnectConfig(ctx, config)
|
||||||
pgConn.Config = cc
|
}
|
||||||
|
|
||||||
pgConn.NetConn, err = cc.Dial(cc.NetworkAddress())
|
// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt.
|
||||||
|
//
|
||||||
|
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
|
||||||
|
// authentication error will terminate the chain of attempts (like libpq:
|
||||||
|
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
|
||||||
|
// if all attempts fail the last error is returned.
|
||||||
|
func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) {
|
||||||
|
// For convenience set a few defaults if not already set. This makes it simpler to directly construct a config.
|
||||||
|
if config.Port == 0 {
|
||||||
|
config.Port = 5432
|
||||||
|
}
|
||||||
|
if config.DialFunc == nil {
|
||||||
|
config.DialFunc = makeDefaultDialer().DialContext
|
||||||
|
}
|
||||||
|
if config.RuntimeParams == nil {
|
||||||
|
config.RuntimeParams = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simplify usage by treating primary config and fallbacks the same.
|
||||||
|
fallbackConfigs := []*FallbackConfig{
|
||||||
|
{
|
||||||
|
Host: config.Host,
|
||||||
|
Port: config.Port,
|
||||||
|
TLSConfig: config.TLSConfig,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
|
||||||
|
|
||||||
|
for _, fc := range fallbackConfigs {
|
||||||
|
pgConn, err = connect(ctx, config, fc)
|
||||||
|
if err == nil {
|
||||||
|
return pgConn, nil
|
||||||
|
} else if err, ok := err.(PgError); ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||||
|
pgConn := new(PgConn)
|
||||||
|
pgConn.Config = config
|
||||||
|
|
||||||
|
var err error
|
||||||
|
network, address := NetworkAddress(config.Host, config.Port)
|
||||||
|
pgConn.NetConn, err = config.DialFunc(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.parameterStatuses = make(map[string]string)
|
pgConn.parameterStatuses = make(map[string]string)
|
||||||
|
|
||||||
if cc.TLSConfig != nil {
|
if config.TLSConfig != nil {
|
||||||
if err := pgConn.startTLS(cc.TLSConfig); err != nil {
|
if err := pgConn.startTLS(config.TLSConfig); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,13 +152,13 @@ func Connect(cc ConnConfig) (*PgConn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy default run-time params
|
// Copy default run-time params
|
||||||
for k, v := range cc.RuntimeParams {
|
for k, v := range config.RuntimeParams {
|
||||||
startupMsg.Parameters[k] = v
|
startupMsg.Parameters[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
startupMsg.Parameters["user"] = cc.User
|
startupMsg.Parameters["user"] = config.User
|
||||||
if cc.Database != "" {
|
if config.Database != "" {
|
||||||
startupMsg.Parameters["database"] = cc.Database
|
startupMsg.Parameters["database"] = config.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
package pgconn_test
|
package pgconn_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/jackc/pgx/pgconn"
|
"context"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgconn"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSimple(t *testing.T) {
|
func TestSimple(t *testing.T) {
|
||||||
pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"})
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
pgConn.SendExec("select current_database()")
|
pgConn.SendExec("select current_database()")
|
||||||
|
|
106
pgpass.go
106
pgpass.go
|
@ -1,106 +0,0 @@
|
||||||
package pgx
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/user"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *string {
|
|
||||||
const (
|
|
||||||
backslash = "\r"
|
|
||||||
colon = "\n"
|
|
||||||
)
|
|
||||||
const (
|
|
||||||
host int = iota
|
|
||||||
port
|
|
||||||
database
|
|
||||||
username
|
|
||||||
pw
|
|
||||||
)
|
|
||||||
if strings.HasPrefix(line, "#") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
line = strings.Replace(line, `\:`, colon, -1)
|
|
||||||
line = strings.Replace(line, `\\`, backslash, -1)
|
|
||||||
parts := strings.Split(line, `:`)
|
|
||||||
if len(parts) != 5 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
for i := range parts {
|
|
||||||
if parts[i] == `*` {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1)
|
|
||||||
switch i {
|
|
||||||
case host:
|
|
||||||
if parts[i] != cfgHost {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case port:
|
|
||||||
if parts[i] != cfgPort {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case database:
|
|
||||||
if parts[i] != cfgDatabase {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case username:
|
|
||||||
if parts[i] != cfgUsername {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &parts[4]
|
|
||||||
}
|
|
||||||
|
|
||||||
func pgpass(cfg *ConnConfig) (found bool) {
|
|
||||||
passfile := os.Getenv("PGPASSFILE")
|
|
||||||
if passfile == "" {
|
|
||||||
u, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
passfile = filepath.Join(u.HomeDir, ".pgpass")
|
|
||||||
}
|
|
||||||
f, err := os.Open(passfile)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
host := cfg.Host
|
|
||||||
if _, err := os.Stat(host); err == nil {
|
|
||||||
host = "localhost"
|
|
||||||
}
|
|
||||||
port := fmt.Sprintf(`%v`, cfg.Port)
|
|
||||||
if port == "0" {
|
|
||||||
port = "5432"
|
|
||||||
}
|
|
||||||
username := cfg.User
|
|
||||||
if username == "" {
|
|
||||||
user, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
username = user.Username
|
|
||||||
}
|
|
||||||
database := cfg.Database
|
|
||||||
if database == "" {
|
|
||||||
database = username
|
|
||||||
}
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
var pw *string
|
|
||||||
for scanner.Scan() {
|
|
||||||
pw = parsepgpass(scanner.Text(), host, port, database, username)
|
|
||||||
if pw != nil {
|
|
||||||
cfg.Password = *pw
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,90 +0,0 @@
|
||||||
package pgx
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"os/user"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func unescape(s string) string {
|
|
||||||
s = strings.Replace(s, `\:`, `:`, -1)
|
|
||||||
s = strings.Replace(s, `\\`, `\`, -1)
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
var passfile = [][]string{
|
|
||||||
{"test1", "5432", "larrydb", "larry", "whatstheidea"},
|
|
||||||
{"test1", "5432", "moedb", "moe", "imbecile"},
|
|
||||||
{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"},
|
|
||||||
{"test2", "5432", "*", "shemp", "heymoe"},
|
|
||||||
{"test2", "5432", "*", "*", `test\\ing\:`},
|
|
||||||
{"localhost", "*", "*", "*", "sesam"},
|
|
||||||
{"test3", "*", "", "", "swordfish"}, // user will be filled later
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPGPass(t *testing.T) {
|
|
||||||
tf, err := ioutil.TempFile("", "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
user, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
passfile[len(passfile)-1][2] = user.Username
|
|
||||||
passfile[len(passfile)-1][3] = user.Username
|
|
||||||
|
|
||||||
defer tf.Close()
|
|
||||||
defer os.Remove(tf.Name())
|
|
||||||
os.Setenv("PGPASSFILE", tf.Name())
|
|
||||||
_, err = fmt.Fprintln(tf, "#some comment\n\n#more comment")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
for _, l := range passfile {
|
|
||||||
_, err := fmt.Fprintln(tf, strings.Join(l, `:`))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = tf.Close(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
for i, l := range passfile {
|
|
||||||
cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]}
|
|
||||||
found := pgpass(&cfg)
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("Entry %v not found", i)
|
|
||||||
}
|
|
||||||
if cfg.Password != unescape(l[4]) {
|
|
||||||
t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password)
|
|
||||||
}
|
|
||||||
if l[0] == "localhost" {
|
|
||||||
// using some existing path as socket
|
|
||||||
cfg := ConnConfig{Host: tf.Name(), Database: l[2], User: l[3]}
|
|
||||||
found := pgpass(&cfg)
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("Entry %v not found", i)
|
|
||||||
}
|
|
||||||
if cfg.Password != unescape(l[4]) {
|
|
||||||
t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cfg := ConnConfig{Host: "test3"}
|
|
||||||
found := pgpass(&cfg)
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("Entry for default user name")
|
|
||||||
}
|
|
||||||
if cfg.Password != "swordfish" {
|
|
||||||
t.Fatalf(`Password mismatch for default user entry, want %s got %s`, "swordfish", cfg.Password)
|
|
||||||
}
|
|
||||||
cfg = ConnConfig{Host: "derp", Database: "herp", User: "joe"}
|
|
||||||
found = pgpass(&cfg)
|
|
||||||
if found {
|
|
||||||
t.Fatal("bad found")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
package pgpassfile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Entry represents a line in a PG passfile.
|
||||||
|
type Entry struct {
|
||||||
|
Hostname string
|
||||||
|
Port string
|
||||||
|
Database string
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Passfile is the in memory data structure representing a PG passfile.
|
||||||
|
type Passfile struct {
|
||||||
|
Entries []*Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPassfile reads the file at path and parses it into a Passfile.
|
||||||
|
func ReadPassfile(path string) (*Passfile, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return ParsePassfile(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePassfile reads r and parses it into a Passfile.
|
||||||
|
func ParsePassfile(r io.Reader) (*Passfile, error) {
|
||||||
|
passfile := &Passfile{}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
for scanner.Scan() {
|
||||||
|
entry := parseLine(scanner.Text())
|
||||||
|
if entry != nil {
|
||||||
|
passfile.Entries = append(passfile.Entries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return passfile, scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped
|
||||||
|
// colon.
|
||||||
|
var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+")
|
||||||
|
|
||||||
|
// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)")
|
||||||
|
|
||||||
|
// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable
|
||||||
|
// line.
|
||||||
|
func parseLine(line string) *Entry {
|
||||||
|
const (
|
||||||
|
tmpBackslash = "\r"
|
||||||
|
tmpColon = "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "#") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line = strings.Replace(line, `\\`, tmpBackslash, -1)
|
||||||
|
line = strings.Replace(line, `\:`, tmpColon, -1)
|
||||||
|
|
||||||
|
parts := strings.Split(line, ":")
|
||||||
|
if len(parts) != 5 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unescape escaped colons and backslashes
|
||||||
|
for i := range parts {
|
||||||
|
parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1)
|
||||||
|
parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Entry{
|
||||||
|
Hostname: parts[0],
|
||||||
|
Port: parts[1],
|
||||||
|
Database: parts[2],
|
||||||
|
Username: parts[3],
|
||||||
|
Password: parts[4],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindPassword finds the password for the provided hostname, port, database, and username. For a
|
||||||
|
// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no
|
||||||
|
// match is found.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information.
|
||||||
|
func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) {
|
||||||
|
for _, e := range pf.Entries {
|
||||||
|
if (e.Hostname == "*" || e.Hostname == hostname) &&
|
||||||
|
(e.Port == "*" || e.Port == port) &&
|
||||||
|
(e.Database == "*" || e.Database == database) &&
|
||||||
|
(e.Username == "*" || e.Username == username) {
|
||||||
|
return e.Password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package pgpassfile
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func unescape(s string) string {
|
||||||
|
s = strings.Replace(s, `\:`, `:`, -1)
|
||||||
|
s = strings.Replace(s, `\\`, `\`, -1)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
var passfile = [][]string{
|
||||||
|
{"test1", "5432", "larrydb", "larry", "whatstheidea"},
|
||||||
|
{"test1", "5432", "moedb", "moe", "imbecile"},
|
||||||
|
{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"},
|
||||||
|
{"test2", "5432", "*", "shemp", "heymoe"},
|
||||||
|
{"test2", "5432", "*", "*", `test\\ing\:`},
|
||||||
|
{"localhost", "*", "*", "*", "sesam"},
|
||||||
|
{"test3", "*", "", "", "swordfish"}, // user will be filled later
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePassFile(t *testing.T) {
|
||||||
|
buf := bytes.NewBufferString(`# A comment
|
||||||
|
test1:5432:larrydb:larry:whatstheidea
|
||||||
|
test1:5432:moedb:moe:imbecile
|
||||||
|
test1:5432:curlydb:curly:nyuknyuknyuk
|
||||||
|
test2:5432:*:shemp:heymoe
|
||||||
|
test2:5432:*:*:test\\ing\:
|
||||||
|
localhost:*:*:*:sesam
|
||||||
|
`)
|
||||||
|
|
||||||
|
passfile, err := ParsePassfile(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, passfile.Entries, 6)
|
||||||
|
|
||||||
|
assert.Equal(t, "whatstheidea", passfile.FindPassword("test1", "5432", "larrydb", "larry"))
|
||||||
|
assert.Equal(t, "imbecile", passfile.FindPassword("test1", "5432", "moedb", "moe"))
|
||||||
|
assert.Equal(t, `test\ing:`, passfile.FindPassword("test2", "5432", "something", "else"))
|
||||||
|
assert.Equal(t, "sesam", passfile.FindPassword("localhost", "9999", "foo", "bare"))
|
||||||
|
|
||||||
|
assert.Equal(t, "", passfile.FindPassword("wrong", "5432", "larrydb", "larry"))
|
||||||
|
assert.Equal(t, "", passfile.FindPassword("test1", "wrong", "larrydb", "larry"))
|
||||||
|
assert.Equal(t, "", passfile.FindPassword("test1", "5432", "wrong", "larry"))
|
||||||
|
assert.Equal(t, "", passfile.FindPassword("test1", "5432", "larrydb", "wrong"))
|
||||||
|
}
|
|
@ -158,13 +158,13 @@ func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) {
|
func ReplicationConnect(config *ConnConfig) (r *ReplicationConn, err error) {
|
||||||
if config.RuntimeParams == nil {
|
if config.Config.RuntimeParams == nil {
|
||||||
config.RuntimeParams = make(map[string]string)
|
config.Config.RuntimeParams = make(map[string]string)
|
||||||
}
|
}
|
||||||
config.RuntimeParams["replication"] = "database"
|
config.Config.RuntimeParams["replication"] = "database"
|
||||||
|
|
||||||
c, err := Connect(config)
|
c, err := ConnectConfig(context.TODO(), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
|
"github.com/jackc/pgx/pgconn"
|
||||||
"github.com/jackc/pgx/pgmock"
|
"github.com/jackc/pgx/pgmock"
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgx/pgproto3"
|
||||||
)
|
)
|
||||||
|
@ -254,12 +255,12 @@ func TestConnBeginExContextCancel(t *testing.T) {
|
||||||
errChan <- server.ServeOne()
|
errChan <- server.ServeOne()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := mustConnect(t, mockConfig)
|
conn := mustConnect(t, pgx.ConnConfig{Config: *pc})
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -303,12 +304,12 @@ func TestTxCommitExCancel(t *testing.T) {
|
||||||
errChan <- server.ServeOne()
|
errChan <- server.ServeOne()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := mustConnect(t, mockConfig)
|
conn := mustConnect(t, pgx.ConnConfig{Config: *pc})
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
tx, err := conn.Begin()
|
tx, err := conn.Begin()
|
||||||
|
|
2
v4.md
2
v4.md
|
@ -38,4 +38,6 @@ Minor Potential Changes:
|
||||||
|
|
||||||
### Incompatible Changes
|
### Incompatible Changes
|
||||||
|
|
||||||
|
* Connect method now takes context and connection string.
|
||||||
|
* ConnectConfig takes context and config object.
|
||||||
* `RuntimeParams` `pgx.Conn`. Server reported status can now be queried with the `ParameterStatus` method. The rename aligns with the PostgreSQL protocol and standard libpq naming. Access via a method instead of direct access to the map protects against outside modification.
|
* `RuntimeParams` `pgx.Conn`. Server reported status can now be queried with the `ParameterStatus` method. The rename aligns with the PostgreSQL protocol and standard libpq naming. Access via a method instead of direct access to the map protects against outside modification.
|
||||||
|
|
Loading…
Reference in New Issue