mirror of https://github.com/jackc/pgx.git
Support using a custom dialer
For example I may want to use a dialer which retries transient network errors (e.g. DNS issues). Signed-off-by: Lewis Marshall <lewis@lmars.net>pull/80/head
parent
d46a762159
commit
784d12cbbc
33
conn.go
33
conn.go
|
@ -20,6 +20,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type DialFunc func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// ConnConfig contains all the options used to establish a connection.
|
// ConnConfig contains all the options used to establish a connection.
|
||||||
type ConnConfig struct {
|
type ConnConfig struct {
|
||||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
@ -29,6 +31,7 @@ type ConnConfig struct {
|
||||||
Password string
|
Password string
|
||||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||||
Logger Logger
|
Logger Logger
|
||||||
|
Dial DialFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
|
@ -122,30 +125,26 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
||||||
c.logger.Debug("Using default connection config", "Port", c.config.Port)
|
c.logger.Debug("Using default connection config", "Port", c.config.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
network := "tcp"
|
||||||
|
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||||
// See if host is a valid path, if yes connect with a socket
|
// See if host is a valid path, if yes connect with a socket
|
||||||
_, err = os.Stat(c.config.Host)
|
if _, err := os.Stat(c.config.Host); err == nil {
|
||||||
if err == nil {
|
|
||||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||||
socket := c.config.Host
|
network = "unix"
|
||||||
if !strings.Contains(socket, "/.s.PGSQL.") {
|
address = c.config.Host
|
||||||
socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||||
|
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket))
|
if c.config.Dial == nil {
|
||||||
c.conn, err = net.Dial("unix", socket)
|
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||||
|
}
|
||||||
|
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||||
|
c.conn, err = c.config.Dial(network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port))
|
|
||||||
d := net.Dialer{KeepAlive: 5 * time.Minute}
|
|
||||||
c.conn, err = d.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port))
|
|
||||||
if err != nil {
|
|
||||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if c != nil && err != nil {
|
if c != nil && err != nil {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
|
|
@ -14,6 +14,7 @@ var plainPasswordConnConfig *pgx.ConnConfig = nil
|
||||||
var noPasswordConnConfig *pgx.ConnConfig = nil
|
var noPasswordConnConfig *pgx.ConnConfig = nil
|
||||||
var invalidUserConnConfig *pgx.ConnConfig = nil
|
var invalidUserConnConfig *pgx.ConnConfig = nil
|
||||||
var tlsConnConfig *pgx.ConnConfig = nil
|
var tlsConnConfig *pgx.ConnConfig = nil
|
||||||
|
var customDialerConnConfig *pgx.ConnConfig = nil
|
||||||
|
|
||||||
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||||
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
||||||
|
@ -22,3 +23,4 @@ var tlsConnConfig *pgx.ConnConfig = nil
|
||||||
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
|
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
|
||||||
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||||
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||||
|
|
|
@ -12,3 +12,4 @@ var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password
|
||||||
var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
||||||
var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||||
var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||||
|
|
34
conn_test.go
34
conn_test.go
|
@ -3,6 +3,8 @@ package pgx_test
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -196,6 +198,34 @@ func TestConnectWithConnectionRefused(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectCustomDialer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if customDialerConnConfig == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dialled := false
|
||||||
|
conf := *customDialerConnConfig
|
||||||
|
conf.Dial = func(network, address string) (net.Conn, error) {
|
||||||
|
dialled = true
|
||||||
|
return net.Dial(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pgx.Connect(conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to establish connection: %s", err)
|
||||||
|
}
|
||||||
|
if !dialled {
|
||||||
|
t.Fatal("Connect did not use custom dialer")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unable to close connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseURI(t *testing.T) {
|
func TestParseURI(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -249,7 +279,7 @@ func TestParseURI(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if connParams != tt.connParams {
|
if !reflect.DeepEqual(connParams, tt.connParams) {
|
||||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -298,7 +328,7 @@ func TestParseDSN(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if connParams != tt.connParams {
|
if !reflect.DeepEqual(connParams, tt.connParams) {
|
||||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue