diff --git a/config.go b/config.go index 859672ea..e141a2f8 100644 --- a/config.go +++ b/config.go @@ -100,10 +100,29 @@ type FallbackConfig struct { TLSConfig *tls.Config // nil disables TLS } +// isAbsolutePath checks if the provided value is an absolute path either +// beginning with a forward slash (as on Linux-based systems) or with a capital +// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). +func isAbsolutePath(path string) bool { + isWindowsPath := func(p string) bool { + if len(p) < 3 { + return false + } + drive := p[0] + colon := p[1] + backslash := p[2] + if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { + return true + } + return false + } + return strings.HasPrefix(path, "/") || isWindowsPath(path) +} + // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with // net.Dial. func NetworkAddress(host string, port uint16) (network, address string) { - if strings.HasPrefix(host, "/") { + if isAbsolutePath(host) { network = "unix" address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { diff --git a/config_test.go b/config_test.go index da28782d..a28db3d6 100644 --- a/config_test.go +++ b/config_test.go @@ -231,6 +231,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url unix domain socket host on windows", + connString: "postgres:///foo?host=C:\\tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "C:\\tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "database url dbname", connString: "postgres://localhost/?dbname=foo&sslmode=disable", @@ -703,6 +715,55 @@ func TestConfigCopyCanBeUsedToConnect(t *testing.T) { assert.NoError(t, err) } +func TestNetworkAddress(t *testing.T) { + tests := []struct { + name string + host string + wantNet string + }{ + { + name: "Default Unix socket address", + host: "/var/run/postgresql", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (standard drive name)", + host: "C:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (first drive name)", + host: "A:\\tmp", + wantNet: "unix", + }, + { + name: "Windows Unix socket address (last drive name)", + host: "Z:\\tmp", + wantNet: "unix", + }, + { + name: "Assume TCP for unknown formats", + host: "a/tmp", + wantNet: "tcp", + }, + { + name: "loopback interface", + host: "localhost", + wantNet: "tcp", + }, + { + name: "IP address", + host: "127.0.0.1", + wantNet: "tcp", + }, + } + for i, tt := range tests { + gotNet, _ := pgconn.NetworkAddress(tt.host, 5432) + + assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name) + } +} + func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return diff --git a/pgconn.go b/pgconn.go index f1304d08..ef5b76fd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -187,7 +187,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba for _, fb := range fallbacks { // skip resolve for unix sockets - if strings.HasPrefix(fb.Host, "/") { + if isAbsolutePath(fb.Host) { configs = append(configs, &FallbackConfig{ Host: fb.Host, Port: fb.Port,