diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 52577bb0..3d583dca 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -11,6 +11,7 @@ import ( "io" "math" "net" + "strconv" "strings" "sync" "time" @@ -44,7 +45,8 @@ type Notification struct { // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) -// LookupFunc is a function that can be used to lookup IPs addrs from host. +// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be +// returned in order to override the connection string's port. type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. @@ -196,11 +198,24 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba } for _, ip := range ips { - configs = append(configs, &FallbackConfig{ - Host: ip, - Port: fb.Port, - TLSConfig: fb.TLSConfig, - }) + splitIP, splitPort, err := net.SplitHostPort(ip) + if err == nil { + port, err := strconv.ParseUint(splitPort, 10, 16) + if err != nil { + return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + } + configs = append(configs, &FallbackConfig{ + Host: splitIP, + Port: uint16(port), + TLSConfig: fb.TLSConfig, + }) + } else { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } } } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 22f0c26f..4350452c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -237,6 +237,40 @@ func TestConnectCustomLookup(t *testing.T) { closeConn(t, conn) } +func TestConnectCustomLookupWithPort(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + origPort := config.Port + // Chnage the config an invalid port so it will fail if used + config.Port = 0 + + looked := false + config.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + looked = true + addrs, err := net.LookupHost(host) + if err != nil { + return nil, err + } + for i := range addrs { + addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10)) + } + return addrs, nil + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel()