mirror of https://github.com/jackc/pgx.git
Merge branch 'timeout' of https://github.com/cyberdelia/pgx into cyberdelia-timeout
commit
9281f057ae
29
conn.go
29
conn.go
|
@ -72,6 +72,7 @@ type ConnConfig struct {
|
||||||
Logger Logger
|
Logger Logger
|
||||||
LogLevel int
|
LogLevel int
|
||||||
Dial DialFunc
|
Dial DialFunc
|
||||||
|
Timeout time.Duration
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
||||||
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
||||||
|
@ -259,7 +260,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||||
|
|
||||||
network, address := c.config.networkAddress()
|
network, address := c.config.networkAddress()
|
||||||
if c.config.Dial == nil {
|
if c.config.Dial == nil {
|
||||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.shouldLog(LogLevelInfo) {
|
if c.shouldLog(LogLevelInfo) {
|
||||||
|
@ -686,13 +687,22 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||||
}
|
}
|
||||||
cp.Database = strings.TrimLeft(url.Path, "/")
|
cp.Database = strings.TrimLeft(url.Path, "/")
|
||||||
|
|
||||||
|
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
|
||||||
|
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
|
cp.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
err = configSSL(url.Query().Get("sslmode"), &cp)
|
err = configSSL(url.Query().Get("sslmode"), &cp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cp, err
|
return cp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ignoreKeys := map[string]struct{}{
|
ignoreKeys := map[string]struct{}{
|
||||||
"sslmode": {},
|
"sslmode": {},
|
||||||
|
"connect_timeout": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
cp.RuntimeParams = make(map[string]string)
|
cp.RuntimeParams = make(map[string]string)
|
||||||
|
@ -750,6 +760,12 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||||
cp.Database = b[2]
|
cp.Database = b[2]
|
||||||
case "sslmode":
|
case "sslmode":
|
||||||
sslmode = b[2]
|
sslmode = b[2]
|
||||||
|
case "connect_timeout":
|
||||||
|
t, err := strconv.ParseInt(b[2], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
|
cp.Timeout = time.Duration(t) * time.Second
|
||||||
default:
|
default:
|
||||||
cp.RuntimeParams[b[1]] = b[2]
|
cp.RuntimeParams[b[1]] = b[2]
|
||||||
}
|
}
|
||||||
|
@ -787,6 +803,7 @@ func ParseConnectionString(s string) (ConnConfig, error) {
|
||||||
// PGPASSWORD
|
// PGPASSWORD
|
||||||
// PGSSLMODE
|
// PGSSLMODE
|
||||||
// PGAPPNAME
|
// PGAPPNAME
|
||||||
|
// PGCONNECT_TIMEOUT
|
||||||
//
|
//
|
||||||
// Important TLS Security Notes:
|
// Important TLS Security Notes:
|
||||||
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
|
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
|
||||||
|
@ -822,6 +839,14 @@ func ParseEnvLibpq() (ConnConfig, error) {
|
||||||
cc.User = os.Getenv("PGUSER")
|
cc.User = os.Getenv("PGUSER")
|
||||||
cc.Password = os.Getenv("PGPASSWORD")
|
cc.Password = os.Getenv("PGPASSWORD")
|
||||||
|
|
||||||
|
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
|
||||||
|
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
|
||||||
|
cc.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
} else {
|
||||||
|
return cc, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sslmode := os.Getenv("PGSSLMODE")
|
sslmode := os.Getenv("PGSSLMODE")
|
||||||
|
|
||||||
err := configSSL(sslmode, &cc)
|
err := configSSL(sslmode, &cc)
|
||||||
|
|
44
conn_test.go
44
conn_test.go
|
@ -567,6 +567,21 @@ func TestParseDSN(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
url: "user=jack host=localhost dbname=mydb connect_timeout=10",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
@ -697,6 +712,21 @@ func TestParseConnectionString(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
url: "postgres://jack@localhost/mydb?connect_timeout=10",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
|
||||||
connParams: pgx.ConnConfig{
|
connParams: pgx.ConnConfig{
|
||||||
|
@ -802,7 +832,7 @@ func TestParseConnectionString(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseEnvLibpq(t *testing.T) {
|
func TestParseEnvLibpq(t *testing.T) {
|
||||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"}
|
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
||||||
|
|
||||||
savedEnv := make(map[string]string)
|
savedEnv := make(map[string]string)
|
||||||
for _, n := range pgEnvvars {
|
for _, n := range pgEnvvars {
|
||||||
|
@ -835,11 +865,12 @@ func TestParseEnvLibpq(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "Normal PG vars",
|
name: "Normal PG vars",
|
||||||
envvars: map[string]string{
|
envvars: map[string]string{
|
||||||
"PGHOST": "123.123.123.123",
|
"PGHOST": "123.123.123.123",
|
||||||
"PGPORT": "7777",
|
"PGPORT": "7777",
|
||||||
"PGDATABASE": "foo",
|
"PGDATABASE": "foo",
|
||||||
"PGUSER": "bar",
|
"PGUSER": "bar",
|
||||||
"PGPASSWORD": "baz",
|
"PGPASSWORD": "baz",
|
||||||
|
"PGCONNECT_TIMEOUT": "10",
|
||||||
},
|
},
|
||||||
config: pgx.ConnConfig{
|
config: pgx.ConnConfig{
|
||||||
Host: "123.123.123.123",
|
Host: "123.123.123.123",
|
||||||
|
@ -850,6 +881,7 @@ func TestParseEnvLibpq(t *testing.T) {
|
||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
UseFallbackTLS: true,
|
UseFallbackTLS: true,
|
||||||
FallbackTLSConfig: nil,
|
FallbackTLSConfig: nil,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue