diff --git a/conn.go b/conn.go index f1bbf3bf..210f57e3 100644 --- a/conn.go +++ b/conn.go @@ -4,17 +4,11 @@ import ( "context" "crypto/md5" "crypto/tls" - "crypto/x509" "encoding/binary" "encoding/hex" - "fmt" "io" - "io/ioutil" "net" - "net/url" - "os" "reflect" - "regexp" "strconv" "strings" "sync" @@ -57,26 +51,14 @@ func init() { // aware that this is distinct from LISTEN/NOTIFY notification. type NoticeHandler func(*Conn, *Notice) -// DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(network, addr string) (net.Conn, error) - // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa - FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS - Logger Logger - LogLevel int - Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) - OnNotice NoticeHandler // Callback function called when a notice response is received. - CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. - CustomCancel func(*Conn) error // Callback function used to override cancellation behavior + pgconn.Config + Logger Logger + LogLevel int + OnNotice NoticeHandler // Callback function called when a notice response is received. + CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. + CustomCancel func(*Conn) error // Callback function used to override cancellation behavior // PreferSimpleProtocol disables implicit prepared statement usage. By default // pgx automatically uses the unnamed prepared statement for Query and @@ -96,7 +78,7 @@ type ConnConfig struct { type Conn struct { pgConn *pgconn.PgConn wbuf []byte - config ConnConfig // config used when establishing this connection + config *ConnConfig // config used when establishing this connection preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification @@ -195,18 +177,30 @@ func (e ProtocolError) Error() string { return string(e) } -// Connect establishes a connection with a PostgreSQL server using config. -// config.Host must be specified. config.User will default to the OS user name. -// Other config fields are optional. -func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, minimalConnInfo) +// Connect establishes a connection with a PostgreSQL server with a connection string. See +// pgconn.Connect for details. +func Connect(ctx context.Context, connString string) (*Conn, error) { + config, err := pgconn.ParseConfig(connString) + if err != nil { + return nil, err + } + connConfig := &ConnConfig{ + Config: *config, + } + + return connect(ctx, connConfig, minimalConnInfo) +} + +// Connect establishes a connection with a PostgreSQL server with a configuration struct. +func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + return connect(ctx, connConfig, minimalConnInfo) } func defaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } -func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { +func connect(ctx context.Context, config *ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { c = new(Conn) c.config = config @@ -223,14 +217,11 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.onNotice = config.OnNotice if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Host}) + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) } - err = c.connect(config, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) - } - err = c.connect(config, config.FallbackTLSConfig) + c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) + if err != nil { + return nil, err } if err != nil { @@ -240,34 +231,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, err } - return c, nil -} - -func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) { - cc := pgconn.ConnConfig{ - Host: config.Host, - Port: config.Port, - Database: config.Database, - User: config.User, - Password: config.Password, - TLSConfig: tlsConfig, - Dial: pgconn.DialFunc(config.Dial), - RuntimeParams: config.RuntimeParams, - } - - c.pgConn, err = pgconn.Connect(cc) - if err != nil { - return err - } - defer func() { - if c != nil && err != nil { - c.pgConn.NetConn.Close() - c.mux.Lock() - c.status = connStatusClosed - c.mux.Unlock() - } - }() - c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) c.cancelQueryCompleted = make(chan struct{}) @@ -275,25 +238,23 @@ func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) { c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) - - c.mux.Lock() c.status = connStatusIdle - c.mux.Unlock() // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := config.RuntimeParams["replication"]; ok { - return nil + if _, ok := c.pgConn.Config.RuntimeParams["replication"]; ok { + return c, nil } if c.ConnInfo == minimalConnInfo { err = c.initConnInfo() if err != nil { - return err + c.Close() + return nil, err } } - return nil + return c, nil } func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { @@ -524,7 +485,7 @@ func (c *Conn) LocalAddr() (net.Addr, error) { // Close closes a connection. It is safe to call Close on a already closed // connection. -func (c *Conn) Close() (err error) { +func (c *Conn) Close() error { c.mux.Lock() defer c.mux.Unlock() @@ -533,404 +494,12 @@ func (c *Conn) Close() (err error) { } c.status = connStatusClosed - defer func() { - c.pgConn.NetConn.Close() - c.causeOfDeath = errors.New("Closed") - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "closed connection", nil) - } - }() - - err = c.pgConn.NetConn.SetDeadline(time.Time{}) - if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) - return err + err := c.pgConn.Close() + c.causeOfDeath = errors.New("Closed") + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "closed connection", nil) } - - _, err = c.pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) - if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) - return err - } - - err = c.pgConn.NetConn.SetReadDeadline(time.Now().Add(5 * time.Second)) - if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) - return err - } - - _, err = c.pgConn.NetConn.Read(make([]byte, 1)) - if err != io.EOF { - return err - } - - return nil -} - -// Merge returns a new ConnConfig with the attributes of old and other -// combined. When an attribute is set on both, other takes precedence. -// -// As a security precaution, if the other TLSConfig is nil, all old TLS -// attributes will be preserved. -func (old ConnConfig) Merge(other ConnConfig) ConnConfig { - cc := old - - if other.Host != "" { - cc.Host = other.Host - } - if other.Port != 0 { - cc.Port = other.Port - } - if other.Database != "" { - cc.Database = other.Database - } - if other.User != "" { - cc.User = other.User - } - if other.Password != "" { - cc.Password = other.Password - } - - if other.TLSConfig != nil { - cc.TLSConfig = other.TLSConfig - cc.UseFallbackTLS = other.UseFallbackTLS - cc.FallbackTLSConfig = other.FallbackTLSConfig - } - - if other.Logger != nil { - cc.Logger = other.Logger - } - if other.LogLevel != 0 { - cc.LogLevel = other.LogLevel - } - - if other.Dial != nil { - cc.Dial = other.Dial - } - - cc.PreferSimpleProtocol = other.PreferSimpleProtocol - - cc.RuntimeParams = make(map[string]string) - for k, v := range old.RuntimeParams { - cc.RuntimeParams[k] = v - } - for k, v := range other.RuntimeParams { - cc.RuntimeParams[k] = v - } - - return cc -} - -// ParseURI parses a database URI into ConnConfig -// -// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams. -func ParseURI(uri string) (ConnConfig, error) { - var cp ConnConfig - - url, err := url.Parse(uri) - if err != nil { - return cp, err - } - - if url.User != nil { - cp.User = url.User.Username() - cp.Password, _ = url.User.Password() - } - - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) - } - cp.Database = strings.TrimLeft(url.Path, "/") - - if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { - timeout, err := strconv.ParseInt(pgtimeout, 10, 64) - if err != nil { - return cp, err - } - d := defaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - cp.Dial = d.Dial - } - - tlsArgs := configTLSArgs{ - sslCert: url.Query().Get("sslcert"), - sslKey: url.Query().Get("sslkey"), - sslMode: url.Query().Get("sslmode"), - sslRootCert: url.Query().Get("sslrootcert"), - } - err = configTLS(tlsArgs, &cp) - if err != nil { - return cp, err - } - - ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, - } - - cp.RuntimeParams = make(map[string]string) - - for k, v := range url.Query() { - if _, ok := ignoreKeys[k]; ok { - continue - } - - if k == "host" { - cp.Host = v[0] - continue - } - - cp.RuntimeParams[k] = v[0] - } - if cp.Password == "" { - pgpass(&cp) - } - return cp, nil -} - -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) - -// ParseDSN parses a database DSN (data source name) into a ConnConfig -// -// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable") -// -// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams. -// -// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb") -// -// ParseDSN tries to match libpq behavior with regard to sslmode. See comments -// for ParseEnvLibpq for more information on the security implications of -// sslmode options. -func ParseDSN(s string) (ConnConfig, error) { - var cp ConnConfig - - m := dsnRegexp.FindAllStringSubmatch(s, -1) - - tlsArgs := configTLSArgs{} - - cp.RuntimeParams = make(map[string]string) - - for _, b := range m { - switch b[1] { - case "user": - cp.User = b[2] - case "password": - cp.Password = b[2] - case "host": - cp.Host = b[2] - case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) - case "dbname": - cp.Database = b[2] - case "sslmode": - tlsArgs.sslMode = b[2] - case "sslrootcert": - tlsArgs.sslRootCert = b[2] - case "sslcert": - tlsArgs.sslCert = b[2] - case "sslkey": - tlsArgs.sslKey = b[2] - case "connect_timeout": - timeout, err := strconv.ParseInt(b[2], 10, 64) - if err != nil { - return cp, err - } - d := defaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - cp.Dial = d.Dial - default: - cp.RuntimeParams[b[1]] = b[2] - } - } - - err := configTLS(tlsArgs, &cp) - if err != nil { - return cp, err - } - if cp.Password == "" { - pgpass(&cp) - } - return cp, nil -} - -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. -func ParseConnectionString(s string) (ConnConfig, error) { - if u, err := url.Parse(s); err == nil && u.Scheme != "" { - return ParseURI(s) - } - return ParseDSN(s) -} - -// ParseEnvLibpq parses the environment like libpq does into a ConnConfig -// -// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details -// on the meaning of environment variables. -// -// ParseEnvLibpq currently recognizes the following environment variables: -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT -// PGAPPNAME -// PGCONNECT_TIMEOUT -// -// Important TLS Security Notes: -// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This -// includes defaulting to "prefer" behavior if no environment variable is set. -// -// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION -// for details on what level of security each sslmode provides. -// -// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger -// security guarantees than it would with libpq. Do not rely on this behavior as it -// may be possible to match libpq in the future. If you need full security use -// "verify-full". -// -// Several of the PGSSLMODE options (including the default behavior of "prefer") -// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or -// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is -// later set from a different source that UseFallbackTLS MUST be set false to -// avoid the possibility of falling back to weaker or disabled security. -func ParseEnvLibpq() (ConnConfig, error) { - var cc ConnConfig - - cc.Host = os.Getenv("PGHOST") - - if pgport := os.Getenv("PGPORT"); pgport != "" { - if port, err := strconv.ParseUint(pgport, 10, 16); err == nil { - cc.Port = uint16(port) - } else { - return cc, err - } - } - - cc.Database = os.Getenv("PGDATABASE") - cc.User = os.Getenv("PGUSER") - cc.Password = os.Getenv("PGPASSWORD") - - if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" { - if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil { - d := defaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - cc.Dial = d.Dial - } else { - return cc, err - } - } - - tlsArgs := configTLSArgs{ - sslMode: os.Getenv("PGSSLMODE"), - sslKey: os.Getenv("PGSSLKEY"), - sslCert: os.Getenv("PGSSLCERT"), - sslRootCert: os.Getenv("PGSSLROOTCERT"), - } - - err := configTLS(tlsArgs, &cc) - if err != nil { - return cc, err - } - - cc.RuntimeParams = make(map[string]string) - if appname := os.Getenv("PGAPPNAME"); appname != "" { - cc.RuntimeParams["application_name"] = appname - } - if cc.Password == "" { - pgpass(&cc) - } - return cc, nil -} - -type configTLSArgs struct { - sslMode string - sslRootCert string - sslCert string - sslKey string -} - -// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config. -// Inputs are parsed out and provided by ParseDSN() or ParseURI(). -func configTLS(args configTLSArgs, cc *ConnConfig) error { - // Match libpq default behavior - if args.sslMode == "" { - args.sslMode = "prefer" - } - - switch args.sslMode { - case "disable": - cc.UseFallbackTLS = false - cc.TLSConfig = nil - cc.FallbackTLSConfig = nil - return nil - case "allow": - cc.UseFallbackTLS = true - cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} - case "prefer": - cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} - cc.UseFallbackTLS = true - cc.FallbackTLSConfig = nil - case "require": - cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} - case "verify-ca", "verify-full": - cc.TLSConfig = &tls.Config{ - ServerName: cc.Host, - } - default: - return errors.New("sslmode is invalid") - } - - if args.sslRootCert != "" { - caCertPool := x509.NewCertPool() - - caPath := args.sslRootCert - caCert, err := ioutil.ReadFile(caPath) - if err != nil { - return errors.Wrapf(err, "unable to read CA file %q", caPath) - } - - if !caCertPool.AppendCertsFromPEM(caCert) { - return errors.Wrap(err, "unable to add CA to cert pool") - } - - cc.TLSConfig.RootCAs = caCertPool - cc.TLSConfig.ClientCAs = caCertPool - } - - sslcert := args.sslCert - sslkey := args.sslKey - - if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { - return fmt.Errorf(`both "sslcert" and "sslkey" are required`) - } - - if sslcert != "" && sslkey != "" { - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - return errors.Wrap(err, "unable to read cert") - } - - cc.TLSConfig.Certificates = []tls.Certificate{cert} - } - - return nil + return err } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -1336,9 +905,9 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.config.Password) + err = c.txPasswordMessage(c.pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) + digestedPassword := "md5" + hexMD5(hexMD5(c.pgConn.Config.Password+c.pgConn.Config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) default: err = errors.New("Received unknown authentication message") @@ -1554,8 +1123,11 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.pgConn.Config.NetworkAddress() - cancelConn, err := c.pgConn.Config.Dial(network, address) + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := c.pgConn.NetConn.RemoteAddr() + cancelConn, err := c.pgConn.Config.DialFunc(context.TODO(), serverAddr.Network(), serverAddr.String()) if err != nil { return err } diff --git a/conn_pool.go b/conn_pool.go index b9ae1d07..f857176d 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -297,7 +297,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.connInfo) + c, err := connect(context.TODO(), &p.config, p.connInfo) if err != nil { return nil, err } @@ -319,7 +319,8 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := Connect(p.config) + // c, err := Connect(p.config) + c, err := Connect(context.TODO(), "TODO") p.cond.L.Lock() p.inProgressConnects-- diff --git a/conn_pool_test.go b/conn_pool_test.go index 84a74aed..a7eebdd4 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -162,7 +162,7 @@ func TestPoolNonBlockingConnections(t *testing.T) { var dialCountLock sync.Mutex dialCount := 0 openTimeout := 1 * time.Second - testDialer := func(network, address string) (net.Conn, error) { + testDialer := func(ctx context.Context, network, address string) (net.Conn, error) { var firstDial bool dialCountLock.Lock() dialCount++ @@ -182,7 +182,7 @@ func TestPoolNonBlockingConnections(t *testing.T) { ConnConfig: *defaultConnConfig, MaxConnections: maxConnections, } - config.ConnConfig.Dial = testDialer + config.ConnConfig.Config.DialFunc = testDialer pool, err := pgx.NewConnPool(config) if err != nil { diff --git a/conn_test.go b/conn_test.go index db6cbc10..90da4a7d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,11 +2,8 @@ package pgx_test import ( "context" - "crypto/tls" "fmt" "net" - "os" - "reflect" "strconv" "strings" "sync" @@ -24,7 +21,7 @@ func TestCrateDBConnect(t *testing.T) { t.Skip("Skipping due to undefined cratedbConnConfig") } - conn, err := pgx.Connect(*cratedbConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), cratedbConnConfig) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } @@ -47,7 +44,7 @@ func TestCrateDBConnect(t *testing.T) { func TestConnect(t *testing.T) { t.Parallel() - conn, err := pgx.Connect(*defaultConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } @@ -65,8 +62,8 @@ func TestConnect(t *testing.T) { if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } - if currentDB != defaultConnConfig.Database { - t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + if currentDB != defaultConnConfig.Config.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Config.Database) } var user string @@ -74,8 +71,8 @@ func TestConnect(t *testing.T) { if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } - if user != defaultConnConfig.User { - t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + if user != defaultConnConfig.Config.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.Config.User) } err = conn.Close() @@ -92,27 +89,7 @@ func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Skip("Skipping due to undefined unixSocketConnConfig") } - conn, err := pgx.Connect(*unixSocketConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithUnixSocketFile(t *testing.T) { - t.Parallel() - - if unixSocketConnConfig == nil { - t.Skip("Skipping due to undefined unixSocketConnConfig") - } - - connParams := *unixSocketConnConfig - connParams.Host = connParams.Host + "/.s.PGSQL.5432" - conn, err := pgx.Connect(connParams) + conn, err := pgx.ConnectConfig(context.Background(), unixSocketConnConfig) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } @@ -130,7 +107,7 @@ func TestConnectWithTcp(t *testing.T) { t.Skip("Skipping due to undefined tcpConnConfig") } - conn, err := pgx.Connect(*tcpConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), tcpConnConfig) if err != nil { t.Fatal("Unable to establish connection: " + err.Error()) } @@ -148,7 +125,7 @@ func TestConnectWithTLS(t *testing.T) { t.Skip("Skipping due to undefined tlsConnConfig") } - conn, err := pgx.Connect(*tlsConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), tlsConnConfig) if err != nil { t.Fatal("Unable to establish connection: " + err.Error()) } @@ -166,7 +143,7 @@ func TestConnectWithInvalidUser(t *testing.T) { t.Skip("Skipping due to undefined invalidUserConnConfig") } - _, err := pgx.Connect(*invalidUserConnConfig) + _, err := pgx.ConnectConfig(context.Background(), invalidUserConnConfig) pgErr, ok := err.(pgx.PgError) if !ok { t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err) @@ -183,7 +160,7 @@ func TestConnectWithPlainTextPassword(t *testing.T) { t.Skip("Skipping due to undefined plainPasswordConnConfig") } - conn, err := pgx.Connect(*plainPasswordConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), plainPasswordConnConfig) if err != nil { t.Fatal("Unable to establish connection: " + err.Error()) } @@ -201,38 +178,7 @@ func TestConnectWithMD5Password(t *testing.T) { t.Skip("Skipping due to undefined md5ConnConfig") } - conn, err := pgx.Connect(*md5ConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithTLSFallback(t *testing.T) { - t.Parallel() - - if tlsConnConfig == nil { - t.Skip("Skipping due to undefined tlsConnConfig") - } - - connConfig := *tlsConnConfig - connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} // bogus ServerName should ensure certificate validation failure - - conn, err := pgx.Connect(connConfig) - if err == nil { - t.Fatal("Expected failed connection, but succeeded") - } - - connConfig = *tlsConnConfig - connConfig.TLSConfig = &tls.Config{ServerName: "bogus.local"} - connConfig.UseFallbackTLS = true - connConfig.FallbackTLSConfig = &tls.Config{ServerName: "bogus.local", InsecureSkipVerify: true} - - conn, err = pgx.Connect(connConfig) + conn, err := pgx.ConnectConfig(context.Background(), md5ConnConfig) if err != nil { t.Fatal("Unable to establish connection: " + err.Error()) } @@ -251,7 +197,7 @@ func TestConnectWithConnectionRefused(t *testing.T) { bad.Host = "127.0.0.1" bad.Port = 1 - _, err := pgx.Connect(bad) + _, err := pgx.ConnectConfig(context.Background(), &bad) if err == nil { t.Fatal("Expected error establishing connection to bad port") } @@ -291,12 +237,12 @@ func TestConnectCustomDialer(t *testing.T) { dialled := false conf := *customDialerConnConfig - conf.Dial = func(network, address string) (net.Conn, error) { + conf.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { dialled = true return net.Dial(network, address) } - conn, err := pgx.Connect(conf) + conn, err := pgx.ConnectConfig(context.Background(), &conf) if err != nil { t.Fatalf("Unable to establish connection: %s", err) } @@ -319,7 +265,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { "search_path": "myschema", } - conn, err := pgx.Connect(connConfig) + conn, err := pgx.ConnectConfig(context.Background(), &connConfig) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } @@ -343,750 +289,6 @@ func TestConnectWithRuntimeParams(t *testing.T) { } } -func TestParseURI(t *testing.T) { - t.Parallel() - - tests := []struct { - url string - connParams pgx.ConnConfig - }{ - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - UseFallbackTLS: false, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgresql://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - url: "postgres:///foo?host=/tmp", - connParams: pgx.ConnConfig{ - Host: "/tmp", - Database: "foo", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - } - - for i, tt := range tests { - connParams, err := pgx.ParseURI(tt.url) - if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseURL(%q) => %v", i, tt.url, err) - continue - } - - if !reflect.DeepEqual(connParams, tt.connParams) { - t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) - } - } -} - -func TestParseDSN(t *testing.T) { - t.Parallel() - - tests := []struct { - url string - connParams pgx.ConnConfig - }{ - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - url: "user=jack host=localhost dbname=mydb connect_timeout=10", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - } - - for i, tt := range tests { - actual, err := pgx.ParseDSN(tt.url) - if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) - continue - } - - testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i)) - } -} - -func TestParseConnectionString(t *testing.T) { - t.Parallel() - - tests := []struct { - url string - connParams pgx.ConnConfig - }{ - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - UseFallbackTLS: false, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgresql://jack:secret@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost:5432/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - { - url: "postgres://jack@localhost/mydb?connect_timeout=10", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost port=5432 dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", - connParams: pgx.ConnConfig{ - User: "jack", - Host: "localhost", - Database: "mydb", - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - }, - }, - }, - } - - for i, tt := range tests { - actual, err := pgx.ParseConnectionString(tt.url) - if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) - continue - } - - testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i)) - } -} - -func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) { - if actual.Host != expected.Host { - t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host) - } - if actual.Database != expected.Database { - t.Errorf("%s: expected Database to be %v got %v", testName, expected.Database, actual.Database) - } - if actual.Port != expected.Port { - t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port) - } - if actual.Port != expected.Port { - t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port) - } - if actual.User != expected.User { - t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User) - } - if actual.Password != expected.Password { - t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password) - } - // Cannot test value of underlying Dialer stuct but can at least test if Dial func is set. - if (actual.Dial != nil) != (expected.Dial != nil) { - t.Errorf("%s: expected Dial mismatch", testName) - } - - if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) { - t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams) - } - - tlsTests := []struct { - name string - expected *tls.Config - actual *tls.Config - }{ - { - name: "TLSConfig", - expected: expected.TLSConfig, - actual: actual.TLSConfig, - }, - { - name: "FallbackTLSConfig", - expected: expected.FallbackTLSConfig, - actual: actual.FallbackTLSConfig, - }, - } - for _, tlsTest := range tlsTests { - name := tlsTest.name - expected := tlsTest.expected - actual := tlsTest.actual - - if expected == nil && actual != nil { - t.Errorf("%s / %s: expected nil, but it was set", testName, name) - } else if expected != nil && actual == nil { - t.Errorf("%s / %s: expected to be set, but got nil", testName, name) - } else if expected != nil && actual != nil { - if actual.InsecureSkipVerify != expected.InsecureSkipVerify { - t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify) - } - - if actual.ServerName != expected.ServerName { - t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName) - } - } - } - - if actual.UseFallbackTLS != expected.UseFallbackTLS { - t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS) - } -} - -func TestParseEnvLibpq(t *testing.T) { - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} - - savedEnv := make(map[string]string) - for _, n := range pgEnvvars { - savedEnv[n] = os.Getenv(n) - } - defer func() { - for k, v := range savedEnv { - err := os.Setenv(k, v) - if err != nil { - t.Fatalf("Unable to restore environment: %v", err) - } - } - }() - - tests := []struct { - name string - envvars map[string]string - config pgx.ConnConfig - }{ - { - name: "No environment", - envvars: map[string]string{}, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "Normal PG vars", - envvars: map[string]string{ - "PGHOST": "123.123.123.123", - "PGPORT": "7777", - "PGDATABASE": "foo", - "PGUSER": "bar", - "PGPASSWORD": "baz", - "PGCONNECT_TIMEOUT": "10", - }, - config: pgx.ConnConfig{ - Host: "123.123.123.123", - Port: 7777, - Database: "foo", - User: "bar", - Password: "baz", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "application_name", - envvars: map[string]string{ - "PGAPPNAME": "pgxtest", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{"application_name": "pgxtest"}, - }, - }, - { - name: "sslmode=disable", - envvars: map[string]string{ - "PGSSLMODE": "disable", - }, - config: pgx.ConnConfig{ - TLSConfig: nil, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=allow", - envvars: map[string]string{ - "PGSSLMODE": "allow", - }, - config: pgx.ConnConfig{ - TLSConfig: nil, - UseFallbackTLS: true, - FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true}, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=prefer", - envvars: map[string]string{ - "PGSSLMODE": "prefer", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=require", - envvars: map[string]string{ - "PGSSLMODE": "require", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=verify-ca", - envvars: map[string]string{ - "PGSSLMODE": "verify-ca", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{}, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=verify-full", - envvars: map[string]string{ - "PGSSLMODE": "verify-full", - }, - config: pgx.ConnConfig{ - TLSConfig: &tls.Config{}, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - { - name: "sslmode=verify-full with host", - envvars: map[string]string{ - "PGHOST": "pgx.example", - "PGSSLMODE": "verify-full", - }, - config: pgx.ConnConfig{ - Host: "pgx.example", - TLSConfig: &tls.Config{ - ServerName: "pgx.example", - }, - UseFallbackTLS: false, - RuntimeParams: map[string]string{}, - }, - }, - } - - for _, tt := range tests { - for _, n := range pgEnvvars { - err := os.Unsetenv(n) - if err != nil { - t.Fatalf("%s: Unable to clear environment: %v", tt.name, err) - } - } - - for k, v := range tt.envvars { - err := os.Setenv(k, v) - if err != nil { - t.Fatalf("%s: Unable to set environment: %v", tt.name, err) - } - } - - actual, err := pgx.ParseEnvLibpq() - if err != nil { - t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err) - continue - } - - testConnConfigEquals(t, tt.config, actual, tt.name) - } -} - func TestExec(t *testing.T) { t.Parallel() @@ -1863,7 +1065,7 @@ func TestFatalRxError(t *testing.T) { } }() - otherConn, err := pgx.Connect(*defaultConnConfig) + otherConn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } @@ -1889,7 +1091,7 @@ func TestFatalTxError(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - otherConn, err := pgx.Connect(*defaultConnConfig) + otherConn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } diff --git a/example_custom_type_test.go b/example_custom_type_test.go index d3cc9085..11417618 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "fmt" "regexp" "strconv" @@ -72,7 +73,7 @@ func (src *Point) String() string { } func Example_CustomType() { - conn, err := pgx.Connect(*defaultConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return diff --git a/example_json_test.go b/example_json_test.go index 09e27cff..c1670949 100644 --- a/example_json_test.go +++ b/example_json_test.go @@ -1,13 +1,14 @@ package pgx_test import ( + "context" "fmt" "github.com/jackc/pgx" ) func Example_JSON() { - conn, err := pgx.Connect(*defaultConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return diff --git a/helper_test.go b/helper_test.go index 78063107..f967d7a1 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,13 +1,14 @@ package pgx_test import ( + "context" "testing" "github.com/jackc/pgx" ) func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { - conn, err := pgx.Connect(config) + conn, err := pgx.ConnectConfig(context.Background(), &config) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } @@ -15,7 +16,7 @@ func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { } func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.ReplicationConn { - conn, err := pgx.ReplicationConnect(config) + conn, err := pgx.ReplicationConnect(&config) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } diff --git a/large_objects_test.go b/large_objects_test.go index 1d0a4f32..1fdd7627 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "io" "testing" @@ -10,7 +11,7 @@ import ( func TestLargeObjects(t *testing.T) { t.Parallel() - conn, err := pgx.Connect(*defaultConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { t.Fatal(err) } @@ -123,7 +124,7 @@ func TestLargeObjects(t *testing.T) { func TestLargeObjectsMultipleTransactions(t *testing.T) { t.Parallel() - conn, err := pgx.Connect(*defaultConnConfig) + conn, err := pgx.ConnectConfig(context.Background(), defaultConnConfig) if err != nil { t.Fatal(err) } diff --git a/pgconn/config.go b/pgconn/config.go new file mode 100644 index 00000000..515d6356 --- /dev/null +++ b/pgconn/config.go @@ -0,0 +1,421 @@ +package pgconn + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "os/user" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgpassfile" + "github.com/pkg/errors" +) + +// Config is the settings used to establish a connection to a PostgreSQL server. +type Config struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + Fallbacks []*FallbackConfig +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if strings.HasPrefix(host, "/") { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = fmt.Sprintf("%s:%d", host, port) + } + return network, address +} + +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. +// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment +// variables. connString may be a URL or a DSN. It also may be empty to only read from the +// environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// +// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// +// Multiple configs may be returned due to sslmode settings with fallback options (e.g. +// sslmode=prefer). Future implementations may also support multiple hosts +// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// +// ParseConfig currently recognizes the following environment variable and their parameter key word +// equivalents passed via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of +// environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key +// word names. They are usually but not always the environment variable name downcased and without +// the "PG" prefix. +// +// Important TLS Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to +// "prefer" behavior if not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on +// what level of security each sslmode provides. +// +// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger +// security guarantees than it would with libpq. Do not rely on this behavior as it +// may be possible to match libpq in the future. If you need full security use +// "verify-full". +func ParseConfig(connString string) (*Config, error) { + settings := defaultSettings() + addEnvSettings(settings) + + if connString != "" { + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") { + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + err = addURLSettings(settings, url) + if err != nil { + return nil, err + } + } else { + err := addDSNSettings(settings, connString) + if err != nil { + return nil, err + } + } + } + + config := &Config{ + Host: settings["host"], + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + } + + if port, err := parsePort(settings["port"]); err == nil { + config.Port = port + } else { + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + if connectTimeout, present := settings["connect_timeout"]; present { + dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) + if err != nil { + return nil, err + } + config.DialFunc = dialFunc + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + notRuntimeParams := map[string]struct{}{ + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + config.TLSConfig = tlsConfigs[0] + + for _, tlsConfig := range tlsConfigs[1:] { + config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: tlsConfig, + }) + } + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + return config, nil +} + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + } + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} + +func addEnvSettings(settings map[string]string) { + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } +} + +func addURLSettings(settings map[string]string, url *url.URL) error { + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + parts := strings.SplitN(url.Host, ":", 2) + if parts[0] != "" { + settings["host"] = parts[0] + } + if len(parts) == 2 { + settings["port"] = parts[1] + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + for k, v := range url.Query() { + settings[k] = v[0] + } + + return nil +} + +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) + +func addDSNSettings(settings map[string]string, s string) error { + m := dsnRegexp.FindAllStringSubmatch(s, -1) + + for _, b := range m { + settings[b[1]] = b[2] + } + + return nil +} + +type pgTLSArgs struct { + sslMode string + sslRootCert string + sslCert string + sslKey string +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string) ([]*tls.Config, error) { + host := settings["host"] + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + tlsConfig.InsecureSkipVerify = sslrootcert == "" + case "verify-ca", "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.Wrap(err, "unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return nil, errors.Wrap(err, "unable to read cert") + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + if timeout < 0 { + return nil, errors.New("negative timeout") + } + + d := makeDefaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + return d.DialContext, nil +} diff --git a/pgconn/config_test.go b/pgconn/config_test.go new file mode 100644 index 00000000..796876f2 --- /dev/null +++ b/pgconn/config_test.go @@ -0,0 +1,392 @@ +package pgconn_test + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "os" + "os/user" + "testing" + + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "/tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN everything", + connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + } + + for i, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + require.Nil(t, err) + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + require.Nil(t, err) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + tf, err := ioutil.TempFile("", "") + require.Nil(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) + require.Nil(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.Nil(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index c9caef42..37a205dc 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1,20 +1,16 @@ package pgconn import ( + "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "errors" - "fmt" "io" "net" - "os" - "os/user" - "path/filepath" "strconv" "strings" - "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -23,7 +19,7 @@ import ( const batchBufferSize = 4096 // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string @@ -50,60 +46,12 @@ func (pe PgError) Error() string { } // DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(network, addr string) (net.Conn, error) +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") -type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) -} - -func (cc *ConnConfig) NetworkAddress() (network, address string) { - // If host is a valid path, then address is unix socket - if _, err := os.Stat(cc.Host); err == nil { - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } - } else { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - } - - return network, address -} - -func (cc *ConnConfig) assignDefaults() error { - if cc.User == "" { - user, err := user.Current() - if err != nil { - return err - } - cc.User = user.Username - } - - if cc.Port == 0 { - cc.Port = 5432 - } - - if cc.Dial == nil { - defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute} - cc.Dial = defaultDialer.Dial - } - - return nil -} - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { NetConn net.Conn // the underlying TCP or unix domain socket connection @@ -113,7 +61,7 @@ type PgConn struct { TxStatus byte Frontend *pgproto3.Frontend - Config ConnConfig + Config *Config batchBuf []byte batchCount int32 @@ -123,24 +71,72 @@ type PgConn struct { closed bool } -func Connect(cc ConnConfig) (*PgConn, error) { - err := cc.assignDefaults() +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) if err != nil { return nil, err } - pgConn := new(PgConn) - pgConn.Config = cc + return ConnectConfig(ctx, config) +} - pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) +// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { + // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. + if config.Port == 0 { + config.Port = 5432 + } + if config.DialFunc == nil { + config.DialFunc = makeDefaultDialer().DialContext + } + if config.RuntimeParams == nil { + config.RuntimeParams = make(map[string]string) + } + + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + for _, fc := range fallbackConfigs { + pgConn, err = connect(ctx, config, fc) + if err == nil { + return pgConn, nil + } else if err, ok := err.(PgError); ok { + return nil, err + } + } + + return nil, err +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.Config = config + + var err error + network, address := NetworkAddress(config.Host, config.Port) + pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } pgConn.parameterStatuses = make(map[string]string) - if cc.TLSConfig != nil { - if err := pgConn.startTLS(cc.TLSConfig); err != nil { + if config.TLSConfig != nil { + if err := pgConn.startTLS(config.TLSConfig); err != nil { return nil, err } } @@ -156,13 +152,13 @@ func Connect(cc ConnConfig) (*PgConn, error) { } // Copy default run-time params - for k, v := range cc.RuntimeParams { + for k, v := range config.RuntimeParams { startupMsg.Parameters[k] = v } - startupMsg.Parameters["user"] = cc.User - if cc.Database != "" { - startupMsg.Parameters["database"] = cc.Database + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index dbcf2704..f165786e 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1,16 +1,18 @@ package pgconn_test import ( - "github.com/jackc/pgx/pgconn" - + "context" + "os" "testing" + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSimple(t *testing.T) { - pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) pgConn.SendExec("select current_database()") diff --git a/pgpass.go b/pgpass.go deleted file mode 100644 index 34b9bdf5..00000000 --- a/pgpass.go +++ /dev/null @@ -1,106 +0,0 @@ -package pgx - -import ( - "bufio" - "fmt" - "os" - "os/user" - "path/filepath" - "strings" -) - -func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *string { - const ( - backslash = "\r" - colon = "\n" - ) - const ( - host int = iota - port - database - username - pw - ) - if strings.HasPrefix(line, "#") { - return nil - } - line = strings.Replace(line, `\:`, colon, -1) - line = strings.Replace(line, `\\`, backslash, -1) - parts := strings.Split(line, `:`) - if len(parts) != 5 { - return nil - } - for i := range parts { - if parts[i] == `*` { - continue - } - parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1) - switch i { - case host: - if parts[i] != cfgHost { - return nil - } - case port: - if parts[i] != cfgPort { - return nil - } - case database: - if parts[i] != cfgDatabase { - return nil - } - case username: - if parts[i] != cfgUsername { - return nil - } - } - } - return &parts[4] -} - -func pgpass(cfg *ConnConfig) (found bool) { - passfile := os.Getenv("PGPASSFILE") - if passfile == "" { - u, err := user.Current() - if err != nil { - return - } - passfile = filepath.Join(u.HomeDir, ".pgpass") - } - f, err := os.Open(passfile) - if err != nil { - return - } - defer f.Close() - - host := cfg.Host - if _, err := os.Stat(host); err == nil { - host = "localhost" - } - port := fmt.Sprintf(`%v`, cfg.Port) - if port == "0" { - port = "5432" - } - username := cfg.User - if username == "" { - user, err := user.Current() - if err != nil { - return - } - username = user.Username - } - database := cfg.Database - if database == "" { - database = username - } - - scanner := bufio.NewScanner(f) - var pw *string - for scanner.Scan() { - pw = parsepgpass(scanner.Text(), host, port, database, username) - if pw != nil { - cfg.Password = *pw - return true - } - } - return false -} diff --git a/pgpass_test.go b/pgpass_test.go deleted file mode 100644 index 2c63f130..00000000 --- a/pgpass_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package pgx - -import ( - "fmt" - "io/ioutil" - "os" - "os/user" - "strings" - "testing" -) - -func unescape(s string) string { - s = strings.Replace(s, `\:`, `:`, -1) - s = strings.Replace(s, `\\`, `\`, -1) - return s -} - -var passfile = [][]string{ - {"test1", "5432", "larrydb", "larry", "whatstheidea"}, - {"test1", "5432", "moedb", "moe", "imbecile"}, - {"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, - {"test2", "5432", "*", "shemp", "heymoe"}, - {"test2", "5432", "*", "*", `test\\ing\:`}, - {"localhost", "*", "*", "*", "sesam"}, - {"test3", "*", "", "", "swordfish"}, // user will be filled later -} - -func TestPGPass(t *testing.T) { - tf, err := ioutil.TempFile("", "") - if err != nil { - t.Fatal(err) - } - user, err := user.Current() - if err != nil { - t.Fatal(err) - } - passfile[len(passfile)-1][2] = user.Username - passfile[len(passfile)-1][3] = user.Username - - defer tf.Close() - defer os.Remove(tf.Name()) - os.Setenv("PGPASSFILE", tf.Name()) - _, err = fmt.Fprintln(tf, "#some comment\n\n#more comment") - if err != nil { - t.Fatal(err) - } - for _, l := range passfile { - _, err := fmt.Fprintln(tf, strings.Join(l, `:`)) - if err != nil { - t.Fatal(err) - } - } - if err = tf.Close(); err != nil { - t.Fatal(err) - } - for i, l := range passfile { - cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]} - found := pgpass(&cfg) - if !found { - t.Fatalf("Entry %v not found", i) - } - if cfg.Password != unescape(l[4]) { - t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) - } - if l[0] == "localhost" { - // using some existing path as socket - cfg := ConnConfig{Host: tf.Name(), Database: l[2], User: l[3]} - found := pgpass(&cfg) - if !found { - t.Fatalf("Entry %v not found", i) - } - if cfg.Password != unescape(l[4]) { - t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) - } - } - } - cfg := ConnConfig{Host: "test3"} - found := pgpass(&cfg) - if !found { - t.Fatalf("Entry for default user name") - } - if cfg.Password != "swordfish" { - t.Fatalf(`Password mismatch for default user entry, want %s got %s`, "swordfish", cfg.Password) - } - cfg = ConnConfig{Host: "derp", Database: "herp", User: "joe"} - found = pgpass(&cfg) - if found { - t.Fatal("bad found") - } -} diff --git a/pgpassfile/pgpass.go b/pgpassfile/pgpass.go new file mode 100644 index 00000000..cd249bde --- /dev/null +++ b/pgpassfile/pgpass.go @@ -0,0 +1,109 @@ +package pgpassfile + +import ( + "bufio" + "io" + "os" + "regexp" + "strings" +) + +// Entry represents a line in a PG passfile. +type Entry struct { + Hostname string + Port string + Database string + Username string + Password string +} + +// Passfile is the in memory data structure representing a PG passfile. +type Passfile struct { + Entries []*Entry +} + +// ReadPassfile reads the file at path and parses it into a Passfile. +func ReadPassfile(path string) (*Passfile, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return ParsePassfile(f) +} + +// ParsePassfile reads r and parses it into a Passfile. +func ParsePassfile(r io.Reader) (*Passfile, error) { + passfile := &Passfile{} + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + entry := parseLine(scanner.Text()) + if entry != nil { + passfile.Entries = append(passfile.Entries, entry) + } + } + + return passfile, scanner.Err() +} + +// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped +// colon. +var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+") + +// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)") + +// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable +// line. +func parseLine(line string) *Entry { + const ( + tmpBackslash = "\r" + tmpColon = "\n" + ) + + line = strings.TrimSpace(line) + + if strings.HasPrefix(line, "#") { + return nil + } + + line = strings.Replace(line, `\\`, tmpBackslash, -1) + line = strings.Replace(line, `\:`, tmpColon, -1) + + parts := strings.Split(line, ":") + if len(parts) != 5 { + return nil + } + + // Unescape escaped colons and backslashes + for i := range parts { + parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1) + parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1) + } + + return &Entry{ + Hostname: parts[0], + Port: parts[1], + Database: parts[2], + Username: parts[3], + Password: parts[4], + } +} + +// FindPassword finds the password for the provided hostname, port, database, and username. For a +// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no +// match is found. +// +// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information. +func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) { + for _, e := range pf.Entries { + if (e.Hostname == "*" || e.Hostname == hostname) && + (e.Port == "*" || e.Port == port) && + (e.Database == "*" || e.Database == database) && + (e.Username == "*" || e.Username == username) { + return e.Password + } + } + return "" +} diff --git a/pgpassfile/pgpass_test.go b/pgpassfile/pgpass_test.go new file mode 100644 index 00000000..adf7f2af --- /dev/null +++ b/pgpassfile/pgpass_test.go @@ -0,0 +1,52 @@ +package pgpassfile + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func unescape(s string) string { + s = strings.Replace(s, `\:`, `:`, -1) + s = strings.Replace(s, `\\`, `\`, -1) + return s +} + +var passfile = [][]string{ + {"test1", "5432", "larrydb", "larry", "whatstheidea"}, + {"test1", "5432", "moedb", "moe", "imbecile"}, + {"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, + {"test2", "5432", "*", "shemp", "heymoe"}, + {"test2", "5432", "*", "*", `test\\ing\:`}, + {"localhost", "*", "*", "*", "sesam"}, + {"test3", "*", "", "", "swordfish"}, // user will be filled later +} + +func TestParsePassFile(t *testing.T) { + buf := bytes.NewBufferString(`# A comment + test1:5432:larrydb:larry:whatstheidea + test1:5432:moedb:moe:imbecile + test1:5432:curlydb:curly:nyuknyuknyuk + test2:5432:*:shemp:heymoe + test2:5432:*:*:test\\ing\: + localhost:*:*:*:sesam + `) + + passfile, err := ParsePassfile(buf) + require.Nil(t, err) + + assert.Len(t, passfile.Entries, 6) + + assert.Equal(t, "whatstheidea", passfile.FindPassword("test1", "5432", "larrydb", "larry")) + assert.Equal(t, "imbecile", passfile.FindPassword("test1", "5432", "moedb", "moe")) + assert.Equal(t, `test\ing:`, passfile.FindPassword("test2", "5432", "something", "else")) + assert.Equal(t, "sesam", passfile.FindPassword("localhost", "9999", "foo", "bare")) + + assert.Equal(t, "", passfile.FindPassword("wrong", "5432", "larrydb", "larry")) + assert.Equal(t, "", passfile.FindPassword("test1", "wrong", "larrydb", "larry")) + assert.Equal(t, "", passfile.FindPassword("test1", "5432", "wrong", "larry")) + assert.Equal(t, "", passfile.FindPassword("test1", "5432", "larrydb", "wrong")) +} diff --git a/replication.go b/replication.go index 782051fc..25d21b48 100644 --- a/replication.go +++ b/replication.go @@ -158,13 +158,13 @@ func NewStandbyStatus(walPositions ...uint64) (status *StandbyStatus, err error) return } -func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) +func ReplicationConnect(config *ConnConfig) (r *ReplicationConn, err error) { + if config.Config.RuntimeParams == nil { + config.Config.RuntimeParams = make(map[string]string) } - config.RuntimeParams["replication"] = "database" + config.Config.RuntimeParams["replication"] = "database" - c, err := Connect(config) + c, err := ConnectConfig(context.TODO(), config) if err != nil { return } diff --git a/tx_test.go b/tx_test.go index eff5604e..633f2177 100644 --- a/tx_test.go +++ b/tx_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgconn" "github.com/jackc/pgx/pgmock" "github.com/jackc/pgx/pgproto3" ) @@ -254,12 +255,12 @@ func TestConnBeginExContextCancel(t *testing.T) { errChan <- server.ServeOne() }() - mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) if err != nil { t.Fatal(err) } - conn := mustConnect(t, mockConfig) + conn := mustConnect(t, pgx.ConnConfig{Config: *pc}) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() @@ -303,12 +304,12 @@ func TestTxCommitExCancel(t *testing.T) { errChan <- server.ServeOne() }() - mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) if err != nil { t.Fatal(err) } - conn := mustConnect(t, mockConfig) + conn := mustConnect(t, pgx.ConnConfig{Config: *pc}) defer conn.Close() tx, err := conn.Begin() diff --git a/v4.md b/v4.md index 51c9e798..8cc8b752 100644 --- a/v4.md +++ b/v4.md @@ -38,4 +38,6 @@ Minor Potential Changes: ### Incompatible Changes +* Connect method now takes context and connection string. +* ConnectConfig takes context and config object. * `RuntimeParams` `pgx.Conn`. Server reported status can now be queried with the `ParameterStatus` method. The rename aligns with the PostgreSQL protocol and standard libpq naming. Access via a method instead of direct access to the map protects against outside modification.