diff --git a/conn.go b/conn.go index 7c6f261e..0330ddcc 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ import ( "io/ioutil" "net" "net/url" + "os" "os/user" "path/filepath" "strconv" @@ -34,8 +35,7 @@ const ( // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Socket string // path to unix domain socket directory (e.g. /private/tmp) - Host string // url (e.g. localhost) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Port uint16 // default: 5432 Database string User string // default: OS user name @@ -118,9 +118,9 @@ func (e ProtocolError) Error() string { var NotificationTimeoutError = errors.New("Notification Timeout") var DeadConnError = errors.New("Connection is dead") -// Connect establishes a connection with a PostgreSQL server using config. One -// of config.Socket or config.Host must be specified. config.User -// will default to the OS user name. Other config fields are optional. +// Connect establishes a connection with a PostgreSQL server using config. +// config.Host must be specified. config.User will default to the OS user name. +// Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { c = new(Conn) @@ -150,9 +150,11 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.logger.Debug("Using default connection config", "MsgBufSize", c.config.MsgBufSize) } - if c.config.Socket != "" { + // See if host is a valid path, if yes connect with a socket + _, err = os.Stat(c.config.Host) + if err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred - socket := c.config.Socket + socket := c.config.Host if !strings.Contains(socket, "/.s.PGSQL.") { socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) } @@ -163,7 +165,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) return nil, err } - } else if c.config.Host != "" { + } else { c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port)) c.conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)) if err != nil { diff --git a/conn_config_test.go.example b/conn_config_test.go.example index e94c5111..802eabb3 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -15,7 +15,7 @@ var noPasswordConnConfig *pgx.ConnConfig = nil var invalidUserConnConfig *pgx.ConnConfig = nil // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Socket: "/private/tmp", User: "pgx_none", Database: "pgx_test"} +// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} // var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"} diff --git a/conn_test.go b/conn_test.go index 7e753b5e..66156710 100644 --- a/conn_test.go +++ b/conn_test.go @@ -73,7 +73,7 @@ func TestConnectWithUnixSocketFile(t *testing.T) { } connParams := *unixSocketConnConfig - connParams.Socket = connParams.Socket + "/.s.PGSQL.5432" + connParams.Host = connParams.Host + "/.s.PGSQL.5432" conn, err := pgx.Connect(connParams) if err != nil { t.Fatalf("Unable to establish connection: %v", err)