mirror of https://github.com/jackc/pgx.git
Merge remote-tracking branch 'pgconn/master' into v5-dev
commit
14b5053209
|
@ -11,6 +11,7 @@ import (
|
|||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -44,7 +45,8 @@ type Notification struct {
|
|||
// DialFunc is a function that can be used to connect to a PostgreSQL server.
|
||||
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// LookupFunc is a function that can be used to lookup IPs addrs from host.
|
||||
// LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be
|
||||
// returned in order to override the connection string's port.
|
||||
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
|
||||
|
||||
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
|
||||
|
@ -196,11 +198,24 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
|
|||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
configs = append(configs, &FallbackConfig{
|
||||
Host: ip,
|
||||
Port: fb.Port,
|
||||
TLSConfig: fb.TLSConfig,
|
||||
})
|
||||
splitIP, splitPort, err := net.SplitHostPort(ip)
|
||||
if err == nil {
|
||||
port, err := strconv.ParseUint(splitPort, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
|
||||
}
|
||||
configs = append(configs, &FallbackConfig{
|
||||
Host: splitIP,
|
||||
Port: uint16(port),
|
||||
TLSConfig: fb.TLSConfig,
|
||||
})
|
||||
} else {
|
||||
configs = append(configs, &FallbackConfig{
|
||||
Host: ip,
|
||||
Port: fb.Port,
|
||||
TLSConfig: fb.TLSConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -237,6 +237,40 @@ func TestConnectCustomLookup(t *testing.T) {
|
|||
closeConn(t, conn)
|
||||
}
|
||||
|
||||
func TestConnectCustomLookupWithPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
||||
if connString == "" {
|
||||
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
||||
}
|
||||
|
||||
config, err := pgconn.ParseConfig(connString)
|
||||
require.NoError(t, err)
|
||||
|
||||
origPort := config.Port
|
||||
// Chnage the config an invalid port so it will fail if used
|
||||
config.Port = 0
|
||||
|
||||
looked := false
|
||||
config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
|
||||
looked = true
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range addrs {
|
||||
addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.True(t, looked)
|
||||
closeConn(t, conn)
|
||||
}
|
||||
|
||||
func TestConnectWithRuntimeParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue