From 784d12cbbcfc11c637d8f7931dd87a3f236e5c4d Mon Sep 17 00:00:00 2001 From: Lewis Marshall Date: Sat, 18 Apr 2015 22:38:15 +0100 Subject: [PATCH] Support using a custom dialer For example I may want to use a dialer which retries transient network errors (e.g. DNS issues). Signed-off-by: Lewis Marshall --- conn.go | 39 ++++++++++++++++++------------------- conn_config_test.go.example | 2 ++ conn_config_test.go.travis | 1 + conn_test.go | 34 ++++++++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 22 deletions(-) diff --git a/conn.go b/conn.go index dd205f78..f51bdb6d 100644 --- a/conn.go +++ b/conn.go @@ -20,6 +20,8 @@ import ( "time" ) +type DialFunc func(network, addr string) (net.Conn, error) + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) @@ -29,6 +31,7 @@ type ConnConfig struct { Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS Logger Logger + Dial DialFunc } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -122,30 +125,26 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.logger.Debug("Using default connection config", "Port", c.config.Port) } + network := "tcp" + address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) // See if host is a valid path, if yes connect with a socket - _, err = os.Stat(c.config.Host) - if err == nil { + if _, err := os.Stat(c.config.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred - socket := c.config.Host - if !strings.Contains(socket, "/.s.PGSQL.") { - socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) - } - - c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket)) - c.conn, err = net.Dial("unix", socket) - if err != nil { - c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) - return nil, err - } - } else { - c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port)) - d := net.Dialer{KeepAlive: 5 * time.Minute} - c.conn, err = d.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)) - if err != nil { - c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) - return nil, err + network = "unix" + address = c.config.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) } } + if c.config.Dial == nil { + c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial + } + c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + c.conn, err = c.config.Dial(network, address) + if err != nil { + c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) + return nil, err + } defer func() { if c != nil && err != nil { c.conn.Close() diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 70dcce93..358e0247 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -14,6 +14,7 @@ var plainPasswordConnConfig *pgx.ConnConfig = nil var noPasswordConnConfig *pgx.ConnConfig = nil var invalidUserConnConfig *pgx.ConnConfig = nil var tlsConnConfig *pgx.ConnConfig = nil +var customDialerConnConfig *pgx.ConnConfig = nil // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} @@ -22,3 +23,4 @@ var tlsConnConfig *pgx.ConnConfig = nil // var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"} // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} // var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} +// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index 39da315d..2b2691de 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -12,3 +12,4 @@ var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} +var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_test.go b/conn_test.go index 5340be62..72d8174b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,8 @@ package pgx_test import ( "fmt" "github.com/jackc/pgx" + "net" + "reflect" "strconv" "strings" "sync" @@ -196,6 +198,34 @@ func TestConnectWithConnectionRefused(t *testing.T) { } } +func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + + if customDialerConnConfig == nil { + return + } + + dialled := false + conf := *customDialerConnConfig + conf.Dial = func(network, address string) (net.Conn, error) { + dialled = true + return net.Dial(network, address) + } + + conn, err := pgx.Connect(conf) + if err != nil { + t.Fatalf("Unable to establish connection: %s", err) + } + if !dialled { + t.Fatal("Connect did not use custom dialer") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestParseURI(t *testing.T) { t.Parallel() @@ -249,7 +279,7 @@ func TestParseURI(t *testing.T) { continue } - if connParams != tt.connParams { + if !reflect.DeepEqual(connParams, tt.connParams) { t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) } } @@ -298,7 +328,7 @@ func TestParseDSN(t *testing.T) { continue } - if connParams != tt.connParams { + if !reflect.DeepEqual(connParams, tt.connParams) { t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) } }