Merge remote-tracking branch 'pgconn/master' into v5-dev

query-exec-mode
Jack Christensen 2021-12-18 08:20:53 -06:00
commit 14b5053209
2 changed files with 55 additions and 6 deletions

View File

@ -11,6 +11,7 @@ import (
"io"
"math"
"net"
"strconv"
"strings"
"sync"
"time"
@ -44,7 +45,8 @@ type Notification struct {
// DialFunc is a function that can be used to connect to a PostgreSQL server.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// LookupFunc is a function that can be used to lookup IPs addrs from host.
// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be
// returned in order to override the connection string's port.
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
@ -196,11 +198,24 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
}
for _, ip := range ips {
configs = append(configs, &FallbackConfig{
Host: ip,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
})
splitIP, splitPort, err := net.SplitHostPort(ip)
if err == nil {
port, err := strconv.ParseUint(splitPort, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
}
configs = append(configs, &FallbackConfig{
Host: splitIP,
Port: uint16(port),
TLSConfig: fb.TLSConfig,
})
} else {
configs = append(configs, &FallbackConfig{
Host: ip,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
})
}
}
}

View File

@ -237,6 +237,40 @@ func TestConnectCustomLookup(t *testing.T) {
closeConn(t, conn)
}
func TestConnectCustomLookupWithPort(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
}
config, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
origPort := config.Port
// Chnage the config an invalid port so it will fail if used
config.Port = 0
looked := false
config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
looked = true
addrs, err := net.LookupHost(host)
if err != nil {
return nil, err
}
for i := range addrs {
addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
}
return addrs, nil
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.True(t, looked)
closeConn(t, conn)
}
func TestConnectWithRuntimeParams(t *testing.T) {
t.Parallel()