Parse connect_timeout into Dial func

Instead of adding Timeout field which could conflict with custom Dial
func.
pull/388/head
Jack Christensen 2018-01-13 18:02:13 -06:00
parent 9281f057ae
commit 2c07b03087
2 changed files with 90 additions and 76 deletions

22
conn.go
View File

@ -72,7 +72,6 @@ type ConnConfig struct {
Logger Logger
LogLevel int
Dial DialFunc
Timeout time.Duration
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
OnNotice NoticeHandler // Callback function called when a notice response is received.
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
@ -224,6 +223,10 @@ func Connect(config ConnConfig) (c *Conn, err error) {
return connect(config, minimalConnInfo)
}
func defaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
c = new(Conn)
@ -260,7 +263,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
network, address := c.config.networkAddress()
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
d := defaultDialer()
c.config.Dial = d.Dial
}
if c.shouldLog(LogLevelInfo) {
@ -692,7 +696,9 @@ func ParseURI(uri string) (ConnConfig, error) {
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(timeout) * time.Second
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
}
err = configSSL(url.Query().Get("sslmode"), &cp)
@ -761,11 +767,13 @@ func ParseDSN(s string) (ConnConfig, error) {
case "sslmode":
sslmode = b[2]
case "connect_timeout":
t, err := strconv.ParseInt(b[2], 10, 64)
timeout, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(t) * time.Second
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
default:
cp.RuntimeParams[b[1]] = b[2]
}
@ -841,7 +849,9 @@ func ParseEnvLibpq() (ConnConfig, error) {
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
cc.Timeout = time.Duration(timeout) * time.Second
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cc.Dial = d.Dial
} else {
return cc, err
}

View File

@ -576,7 +576,7 @@ func TestParseDSN(t *testing.T) {
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
@ -585,15 +585,13 @@ func TestParseDSN(t *testing.T) {
}
for i, tt := range tests {
connParams, err := pgx.ParseDSN(tt.url)
actual, err := pgx.ParseDSN(tt.url)
if err != nil {
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
continue
}
if !reflect.DeepEqual(connParams, tt.connParams) {
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
}
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
}
}
@ -721,7 +719,7 @@ func TestParseConnectionString(t *testing.T) {
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
@ -819,16 +817,80 @@ func TestParseConnectionString(t *testing.T) {
}
for i, tt := range tests {
connParams, err := pgx.ParseConnectionString(tt.url)
actual, err := pgx.ParseConnectionString(tt.url)
if err != nil {
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
continue
}
if !reflect.DeepEqual(connParams, tt.connParams) {
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
}
}
func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) {
if actual.Host != expected.Host {
t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host)
}
if actual.Port != expected.Port {
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
}
if actual.Port != expected.Port {
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
}
if actual.User != expected.User {
t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User)
}
if actual.Password != expected.Password {
t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password)
}
// Cannot test value of underlying Dialer stuct but can at least test if Dial func is set.
if (actual.Dial != nil) != (expected.Dial != nil) {
t.Errorf("%s: expected Dial mismatch", testName)
}
if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) {
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams)
}
tlsTests := []struct {
name string
expected *tls.Config
actual *tls.Config
}{
{
name: "TLSConfig",
expected: expected.TLSConfig,
actual: actual.TLSConfig,
},
{
name: "FallbackTLSConfig",
expected: expected.FallbackTLSConfig,
actual: actual.FallbackTLSConfig,
},
}
for _, tlsTest := range tlsTests {
name := tlsTest.name
expected := tlsTest.expected
actual := tlsTest.actual
if expected == nil && actual != nil {
t.Errorf("%s / %s: expected nil, but it was set", testName, name)
} else if expected != nil && actual == nil {
t.Errorf("%s / %s: expected to be set, but got nil", testName, name)
} else if expected != nil && actual != nil {
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
}
if actual.ServerName != expected.ServerName {
t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName)
}
}
}
if actual.UseFallbackTLS != expected.UseFallbackTLS {
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS)
}
}
func TestParseEnvLibpq(t *testing.T) {
@ -881,7 +943,7 @@ func TestParseEnvLibpq(t *testing.T) {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
Timeout: 10 * time.Second,
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
RuntimeParams: map[string]string{},
},
},
@ -997,71 +1059,13 @@ func TestParseEnvLibpq(t *testing.T) {
}
}
config, err := pgx.ParseEnvLibpq()
actual, err := pgx.ParseEnvLibpq()
if err != nil {
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
continue
}
if config.Host != tt.config.Host {
t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
}
if config.Port != tt.config.Port {
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
}
if config.Port != tt.config.Port {
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
}
if config.User != tt.config.User {
t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
}
if config.Password != tt.config.Password {
t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
}
if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) {
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams)
}
tlsTests := []struct {
name string
expected *tls.Config
actual *tls.Config
}{
{
name: "TLSConfig",
expected: tt.config.TLSConfig,
actual: config.TLSConfig,
},
{
name: "FallbackTLSConfig",
expected: tt.config.FallbackTLSConfig,
actual: config.FallbackTLSConfig,
},
}
for _, tlsTest := range tlsTests {
name := tlsTest.name
expected := tlsTest.expected
actual := tlsTest.actual
if expected == nil && actual != nil {
t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
} else if expected != nil && actual == nil {
t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
} else if expected != nil && actual != nil {
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
}
if actual.ServerName != expected.ServerName {
t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
}
}
}
if config.UseFallbackTLS != tt.config.UseFallbackTLS {
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
}
testConnConfigEquals(t, tt.config, actual, tt.name)
}
}