mirror of https://github.com/jackc/pgx.git
Parse connect_timeout into Dial func
Instead of adding Timeout field which could conflict with custom Dial func.pull/388/head
parent
9281f057ae
commit
2c07b03087
22
conn.go
22
conn.go
|
@ -72,7 +72,6 @@ type ConnConfig struct {
|
|||
Logger Logger
|
||||
LogLevel int
|
||||
Dial DialFunc
|
||||
Timeout time.Duration
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
||||
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
||||
|
@ -224,6 +223,10 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
return connect(config, minimalConnInfo)
|
||||
}
|
||||
|
||||
func defaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
}
|
||||
|
||||
func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
|
||||
c = new(Conn)
|
||||
|
||||
|
@ -260,7 +263,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
|||
|
||||
network, address := c.config.networkAddress()
|
||||
if c.config.Dial == nil {
|
||||
c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
|
||||
d := defaultDialer()
|
||||
c.config.Dial = d.Dial
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
|
@ -692,7 +696,9 @@ func ParseURI(uri string) (ConnConfig, error) {
|
|||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
cp.Timeout = time.Duration(timeout) * time.Second
|
||||
d := defaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
cp.Dial = d.Dial
|
||||
}
|
||||
|
||||
err = configSSL(url.Query().Get("sslmode"), &cp)
|
||||
|
@ -761,11 +767,13 @@ func ParseDSN(s string) (ConnConfig, error) {
|
|||
case "sslmode":
|
||||
sslmode = b[2]
|
||||
case "connect_timeout":
|
||||
t, err := strconv.ParseInt(b[2], 10, 64)
|
||||
timeout, err := strconv.ParseInt(b[2], 10, 64)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
cp.Timeout = time.Duration(t) * time.Second
|
||||
d := defaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
cp.Dial = d.Dial
|
||||
default:
|
||||
cp.RuntimeParams[b[1]] = b[2]
|
||||
}
|
||||
|
@ -841,7 +849,9 @@ func ParseEnvLibpq() (ConnConfig, error) {
|
|||
|
||||
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
|
||||
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
|
||||
cc.Timeout = time.Duration(timeout) * time.Second
|
||||
d := defaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
cc.Dial = d.Dial
|
||||
} else {
|
||||
return cc, err
|
||||
}
|
||||
|
|
144
conn_test.go
144
conn_test.go
|
@ -576,7 +576,7 @@ func TestParseDSN(t *testing.T) {
|
|||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
|
@ -585,15 +585,13 @@ func TestParseDSN(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
connParams, err := pgx.ParseDSN(tt.url)
|
||||
actual, err := pgx.ParseDSN(tt.url)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
||||
}
|
||||
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -721,7 +719,7 @@ func TestParseConnectionString(t *testing.T) {
|
|||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
|
@ -819,16 +817,80 @@ func TestParseConnectionString(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
connParams, err := pgx.ParseConnectionString(tt.url)
|
||||
actual, err := pgx.ParseConnectionString(tt.url)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
||||
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
|
||||
func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) {
|
||||
if actual.Host != expected.Host {
|
||||
t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host)
|
||||
}
|
||||
if actual.Port != expected.Port {
|
||||
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
|
||||
}
|
||||
if actual.Port != expected.Port {
|
||||
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
|
||||
}
|
||||
if actual.User != expected.User {
|
||||
t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User)
|
||||
}
|
||||
if actual.Password != expected.Password {
|
||||
t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password)
|
||||
}
|
||||
// Cannot test value of underlying Dialer stuct but can at least test if Dial func is set.
|
||||
if (actual.Dial != nil) != (expected.Dial != nil) {
|
||||
t.Errorf("%s: expected Dial mismatch", testName)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) {
|
||||
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams)
|
||||
}
|
||||
|
||||
tlsTests := []struct {
|
||||
name string
|
||||
expected *tls.Config
|
||||
actual *tls.Config
|
||||
}{
|
||||
{
|
||||
name: "TLSConfig",
|
||||
expected: expected.TLSConfig,
|
||||
actual: actual.TLSConfig,
|
||||
},
|
||||
{
|
||||
name: "FallbackTLSConfig",
|
||||
expected: expected.FallbackTLSConfig,
|
||||
actual: actual.FallbackTLSConfig,
|
||||
},
|
||||
}
|
||||
for _, tlsTest := range tlsTests {
|
||||
name := tlsTest.name
|
||||
expected := tlsTest.expected
|
||||
actual := tlsTest.actual
|
||||
|
||||
if expected == nil && actual != nil {
|
||||
t.Errorf("%s / %s: expected nil, but it was set", testName, name)
|
||||
} else if expected != nil && actual == nil {
|
||||
t.Errorf("%s / %s: expected to be set, but got nil", testName, name)
|
||||
} else if expected != nil && actual != nil {
|
||||
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
|
||||
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
if actual.ServerName != expected.ServerName {
|
||||
t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if actual.UseFallbackTLS != expected.UseFallbackTLS {
|
||||
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEnvLibpq(t *testing.T) {
|
||||
|
@ -881,7 +943,7 @@ func TestParseEnvLibpq(t *testing.T) {
|
|||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
Timeout: 10 * time.Second,
|
||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
|
@ -997,71 +1059,13 @@ func TestParseEnvLibpq(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
config, err := pgx.ParseEnvLibpq()
|
||||
actual, err := pgx.ParseEnvLibpq()
|
||||
if err != nil {
|
||||
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if config.Host != tt.config.Host {
|
||||
t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
|
||||
}
|
||||
if config.Port != tt.config.Port {
|
||||
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
|
||||
}
|
||||
if config.Port != tt.config.Port {
|
||||
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
|
||||
}
|
||||
if config.User != tt.config.User {
|
||||
t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
|
||||
}
|
||||
if config.Password != tt.config.Password {
|
||||
t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) {
|
||||
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams)
|
||||
}
|
||||
|
||||
tlsTests := []struct {
|
||||
name string
|
||||
expected *tls.Config
|
||||
actual *tls.Config
|
||||
}{
|
||||
{
|
||||
name: "TLSConfig",
|
||||
expected: tt.config.TLSConfig,
|
||||
actual: config.TLSConfig,
|
||||
},
|
||||
{
|
||||
name: "FallbackTLSConfig",
|
||||
expected: tt.config.FallbackTLSConfig,
|
||||
actual: config.FallbackTLSConfig,
|
||||
},
|
||||
}
|
||||
for _, tlsTest := range tlsTests {
|
||||
name := tlsTest.name
|
||||
expected := tlsTest.expected
|
||||
actual := tlsTest.actual
|
||||
|
||||
if expected == nil && actual != nil {
|
||||
t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
|
||||
} else if expected != nil && actual == nil {
|
||||
t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
|
||||
} else if expected != nil && actual != nil {
|
||||
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
|
||||
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
if actual.ServerName != expected.ServerName {
|
||||
t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.UseFallbackTLS != tt.config.UseFallbackTLS {
|
||||
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
|
||||
}
|
||||
testConnConfigEquals(t, tt.config, actual, tt.name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue