Support using a custom dialer

For example I may want to use a dialer which retries transient network
errors (e.g. DNS issues).

Signed-off-by: Lewis Marshall <lewis@lmars.net>
pull/80/head
Lewis Marshall 2015-04-18 22:38:15 +01:00
parent d46a762159
commit 784d12cbbc
4 changed files with 54 additions and 22 deletions

39
conn.go
View File

@ -20,6 +20,8 @@ import (
"time"
)
type DialFunc func(network, addr string) (net.Conn, error)
// ConnConfig contains all the options used to establish a connection.
type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
@ -29,6 +31,7 @@ type ConnConfig struct {
Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
Logger Logger
Dial DialFunc
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@ -122,30 +125,26 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.logger.Debug("Using default connection config", "Port", c.config.Port)
}
network := "tcp"
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
// See if host is a valid path, if yes connect with a socket
_, err = os.Stat(c.config.Host)
if err == nil {
if _, err := os.Stat(c.config.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
socket := c.config.Host
if !strings.Contains(socket, "/.s.PGSQL.") {
socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
}
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket))
c.conn, err = net.Dial("unix", socket)
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
}
} else {
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port))
d := net.Dialer{KeepAlive: 5 * time.Minute}
c.conn, err = d.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port))
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
network = "unix"
address = c.config.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
}
}
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
}
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
c.conn, err = c.config.Dial(network, address)
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
}
defer func() {
if c != nil && err != nil {
c.conn.Close()

View File

@ -14,6 +14,7 @@ var plainPasswordConnConfig *pgx.ConnConfig = nil
var noPasswordConnConfig *pgx.ConnConfig = nil
var invalidUserConnConfig *pgx.ConnConfig = nil
var tlsConnConfig *pgx.ConnConfig = nil
var customDialerConnConfig *pgx.ConnConfig = nil
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
@ -22,3 +23,4 @@ var tlsConnConfig *pgx.ConnConfig = nil
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}

View File

@ -12,3 +12,4 @@ var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password
var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}

View File

@ -3,6 +3,8 @@ package pgx_test
import (
"fmt"
"github.com/jackc/pgx"
"net"
"reflect"
"strconv"
"strings"
"sync"
@ -196,6 +198,34 @@ func TestConnectWithConnectionRefused(t *testing.T) {
}
}
func TestConnectCustomDialer(t *testing.T) {
t.Parallel()
if customDialerConnConfig == nil {
return
}
dialled := false
conf := *customDialerConnConfig
conf.Dial = func(network, address string) (net.Conn, error) {
dialled = true
return net.Dial(network, address)
}
conn, err := pgx.Connect(conf)
if err != nil {
t.Fatalf("Unable to establish connection: %s", err)
}
if !dialled {
t.Fatal("Connect did not use custom dialer")
}
err = conn.Close()
if err != nil {
t.Fatal("Unable to close connection")
}
}
func TestParseURI(t *testing.T) {
t.Parallel()
@ -249,7 +279,7 @@ func TestParseURI(t *testing.T) {
continue
}
if connParams != tt.connParams {
if !reflect.DeepEqual(connParams, tt.connParams) {
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
}
}
@ -298,7 +328,7 @@ func TestParseDSN(t *testing.T) {
continue
}
if connParams != tt.connParams {
if !reflect.DeepEqual(connParams, tt.connParams) {
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
}
}