mirror of https://github.com/jackc/pgx.git
Support using a custom dialer
For example I may want to use a dialer which retries transient network errors (e.g. DNS issues). Signed-off-by: Lewis Marshall <lewis@lmars.net>pull/80/head
parent
d46a762159
commit
784d12cbbc
39
conn.go
39
conn.go
|
@ -20,6 +20,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
type DialFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ConnConfig contains all the options used to establish a connection.
|
||||
type ConnConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
|
@ -29,6 +31,7 @@ type ConnConfig struct {
|
|||
Password string
|
||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||
Logger Logger
|
||||
Dial DialFunc
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
|
@ -122,30 +125,26 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
c.logger.Debug("Using default connection config", "Port", c.config.Port)
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
_, err = os.Stat(c.config.Host)
|
||||
if err == nil {
|
||||
if _, err := os.Stat(c.config.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
socket := c.config.Host
|
||||
if !strings.Contains(socket, "/.s.PGSQL.") {
|
||||
socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||
}
|
||||
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket))
|
||||
c.conn, err = net.Dial("unix", socket)
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port))
|
||||
d := net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
c.conn, err = d.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port))
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
network = "unix"
|
||||
address = c.config.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||
}
|
||||
}
|
||||
if c.config.Dial == nil {
|
||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||
}
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
c.conn, err = c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if c != nil && err != nil {
|
||||
c.conn.Close()
|
||||
|
|
|
@ -14,6 +14,7 @@ var plainPasswordConnConfig *pgx.ConnConfig = nil
|
|||
var noPasswordConnConfig *pgx.ConnConfig = nil
|
||||
var invalidUserConnConfig *pgx.ConnConfig = nil
|
||||
var tlsConnConfig *pgx.ConnConfig = nil
|
||||
var customDialerConnConfig *pgx.ConnConfig = nil
|
||||
|
||||
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
||||
|
@ -22,3 +23,4 @@ var tlsConnConfig *pgx.ConnConfig = nil
|
|||
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
|
||||
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
|
|
@ -12,3 +12,4 @@ var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password
|
|||
var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
||||
var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||
var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
|
34
conn_test.go
34
conn_test.go
|
@ -3,6 +3,8 @@ package pgx_test
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/jackc/pgx"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -196,6 +198,34 @@ func TestConnectWithConnectionRefused(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectCustomDialer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if customDialerConnConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
dialled := false
|
||||
conf := *customDialerConnConfig
|
||||
conf.Dial = func(network, address string) (net.Conn, error) {
|
||||
dialled = true
|
||||
return net.Dial(network, address)
|
||||
}
|
||||
|
||||
conn, err := pgx.Connect(conf)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %s", err)
|
||||
}
|
||||
if !dialled {
|
||||
t.Fatal("Connect did not use custom dialer")
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Unable to close connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseURI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -249,7 +279,7 @@ func TestParseURI(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
|
||||
if connParams != tt.connParams {
|
||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
||||
}
|
||||
}
|
||||
|
@ -298,7 +328,7 @@ func TestParseDSN(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
|
||||
if connParams != tt.connParams {
|
||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue