diff --git a/pgconn/config.go b/pgconn/config.go index 515d6356..d2001dc5 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -55,21 +55,23 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. -// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment -// variables. connString may be a URL or a DSN. It also may be empty to only read from the -// environment. If a password is not supplied it will attempt to read the .pgpass file. +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. +// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the +// .pgpass file. // -// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// Example DSN: "user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca" // -// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// Example URL: "postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca" // -// Multiple configs may be returned due to sslmode settings with fallback options (e.g. -// sslmode=prefer). Future implementations may also support multiple hosts -// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // -// ParseConfig currently recognizes the following environment variable and their parameter key word -// equivalents passed via database URL or DSN: +// Example URL: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb" +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or DSN: // // PGHOST // PGPORT @@ -84,20 +86,18 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGAPPNAME // PGCONNECT_TIMEOUT // -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of -// environment variables. +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key -// word names. They are usually but not always the environment variable name downcased and without -// the "PG" prefix. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. // // Important TLS Security Notes: // -// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to -// "prefer" behavior if not set. +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. // -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on -// what level of security each sslmode provides. +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. // // "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger // security guarantees than it would with libpq. Do not rely on this behavior as it @@ -110,12 +110,7 @@ func ParseConfig(connString string) (*Config, error) { if connString != "" { // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") { - url, err := url.Parse(connString) - if err != nil { - return nil, err - } - - err = addURLSettings(settings, url) + err := addURLSettings(settings, connString) if err != nil { return nil, err } @@ -128,19 +123,12 @@ func ParseConfig(connString string) (*Config, error) { } config := &Config{ - Host: settings["host"], Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), } - if port, err := parsePort(settings["port"]); err == nil { - config.Port = port - } else { - return nil, fmt.Errorf("invalid port: %v", settings["port"]) - } - if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { @@ -173,28 +161,50 @@ func ParseConfig(connString string) (*Config, error) { config.RuntimeParams[k] = v } - var tlsConfigs []*tls.Config + fallbacks := []*FallbackConfig{} - // Ignore TLS settings if Unix domain socket like libpq - if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { - tlsConfigs = append(tlsConfigs, nil) - } else { - var err error - tlsConfigs, err = configTLS(settings) + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) } } - config.TLSConfig = tlsConfigs[0] - - for _, tlsConfig := range tlsConfigs[1:] { - config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: tlsConfig, - }) - } + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -272,7 +282,12 @@ func addEnvSettings(settings map[string]string) { } } -func addURLSettings(settings map[string]string, url *url.URL) error { +func addURLSettings(settings map[string]string, connString string) error { + url, err := url.Parse(connString) + if err != nil { + return err + } + if url.User != nil { settings["user"] = url.User.Username() if password, present := url.User.Password(); present { @@ -280,12 +295,23 @@ func addURLSettings(settings map[string]string, url *url.URL) error { } } - parts := strings.SplitN(url.Host, ":", 2) - if parts[0] != "" { - settings["host"] = parts[0] + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + parts := strings.SplitN(host, ":", 2) + if parts[0] != "" { + hosts = append(hosts, parts[0]) + } + if len(parts) == 2 { + ports = append(ports, parts[1]) + } } - if len(parts) == 2 { - settings["port"] = parts[1] + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") } database := strings.TrimLeft(url.Path, "/") diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 796876f2..566a44f0 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -230,6 +230,150 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "URL multiple hosts", + connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "URL multiple hosts and ports", + connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts one port", + connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts multiple ports", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "multiple hosts and fallback tsl", + connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "foo", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, } for i, tt := range tests { @@ -243,6 +387,13 @@ func TestParseConfig(t *testing.T) { } func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) @@ -257,12 +408,12 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName } } - if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { for i := range expected.Fallbacks { assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) - if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { if expected.Fallbacks[i].TLSConfig != nil { assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 37a205dc..09860eb2 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -127,7 +127,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.Config = config var err error - network, address := NetworkAddress(config.Host, config.Port) + network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 9e16e925..d53bbc09 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -157,6 +157,37 @@ func TestConnectWithRuntimeParams(t *testing.T) { assert.Nil(t, err) } +func TestConnectWithFallback(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here + + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + closeConn(t, conn) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) diff --git a/v4.md b/v4.md index 65c1a2cc..326668a8 100644 --- a/v4.md +++ b/v4.md @@ -42,3 +42,7 @@ Test configuration now done with environment variables instead of `.gitignore`'d * Connect method now takes context and connection string. * ConnectConfig takes context and config object. * `RuntimeParams` `pgx.Conn`. Server reported status can now be queried with the `ParameterStatus` method. The rename aligns with the PostgreSQL protocol and standard libpq naming. Access via a method instead of direct access to the map protects against outside modification. + +## New Features + +* Specifying multiple hosts for connecting to HA systems.