From 1bec450326120060e476f0e7dcab78312569ad35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Peignier?= Date: Sat, 16 Dec 2017 19:10:22 -0800 Subject: [PATCH] Handle timeout parameters --- conn.go | 29 +++++++++++++++++++++++++++-- conn_test.go | 50 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 68 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index d86c4025..f3e1657d 100644 --- a/conn.go +++ b/conn.go @@ -72,6 +72,7 @@ type ConnConfig struct { Logger Logger LogLevel int Dial DialFunc + Timeout time.Duration RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) OnNotice NoticeHandler // Callback function called when a notice response is received. } @@ -247,7 +248,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) network, address := c.config.networkAddress() if c.config.Dial == nil { - c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial + c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial } if c.shouldLog(LogLevelInfo) { @@ -655,13 +656,22 @@ func ParseURI(uri string) (ConnConfig, error) { } cp.Database = strings.TrimLeft(url.Path, "/") + if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { + timeout, err := strconv.ParseInt(pgtimeout, 10, 64) + if err != nil { + return cp, err + } + cp.Timeout = time.Duration(timeout) * time.Second + } + err = configSSL(url.Query().Get("sslmode"), &cp) if err != nil { return cp, err } ignoreKeys := map[string]struct{}{ - "sslmode": {}, + "sslmode": {}, + "connect_timeout": {}, } cp.RuntimeParams = make(map[string]string) @@ -719,6 +729,12 @@ func ParseDSN(s string) (ConnConfig, error) { cp.Database = b[2] case "sslmode": sslmode = b[2] + case "connect_timeout": + t, err := strconv.ParseInt(b[2], 10, 64) + if err != nil { + return cp, err + } + cp.Timeout = time.Duration(t) * time.Second default: cp.RuntimeParams[b[1]] = b[2] } @@ -756,6 +772,7 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGPASSWORD // PGSSLMODE // PGAPPNAME +// PGCONNECT_TIMEOUT // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -791,6 +808,14 @@ func ParseEnvLibpq() (ConnConfig, error) { cc.User = os.Getenv("PGUSER") cc.Password = os.Getenv("PGPASSWORD") + if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" { + if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil { + cc.Timeout = time.Duration(timeout) * time.Second + } else { + return cc, err + } + } + sslmode := os.Getenv("PGSSLMODE") err := configSSL(sslmode, &cc) diff --git a/conn_test.go b/conn_test.go index e117a953..f7a39fdf 100644 --- a/conn_test.go +++ b/conn_test.go @@ -542,6 +542,21 @@ func TestParseDSN(t *testing.T) { }, }, }, + { + url: "user=jack host=localhost dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Timeout: 10 * time.Second, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests { @@ -672,6 +687,21 @@ func TestParseConnectionString(t *testing.T) { }, }, }, + { + url: "postgres://jack@localhost/mydb?connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Timeout: 10 * time.Second, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", connParams: pgx.ConnConfig{ @@ -777,7 +807,7 @@ func TestParseConnectionString(t *testing.T) { } func TestParseEnvLibpq(t *testing.T) { - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"} + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} savedEnv := make(map[string]string) for _, n := range pgEnvvars { @@ -810,11 +840,12 @@ func TestParseEnvLibpq(t *testing.T) { { name: "Normal PG vars", envvars: map[string]string{ - "PGHOST": "123.123.123.123", - "PGPORT": "7777", - "PGDATABASE": "foo", - "PGUSER": "bar", - "PGPASSWORD": "baz", + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", }, config: pgx.ConnConfig{ Host: "123.123.123.123", @@ -825,6 +856,7 @@ func TestParseEnvLibpq(t *testing.T) { TLSConfig: &tls.Config{InsecureSkipVerify: true}, UseFallbackTLS: true, FallbackTLSConfig: nil, + Timeout: 10 * time.Second, RuntimeParams: map[string]string{}, }, }, @@ -1988,9 +2020,9 @@ func TestConnInitConnInfo(t *testing.T) { // spot check that the standard postgres type names aren't qualified nameOIDs := map[string]pgtype.OID{ "_int8": pgtype.Int8ArrayOID, - "int8": pgtype.Int8OID, - "json": pgtype.JSONOID, - "text": pgtype.TextOID, + "int8": pgtype.Int8OID, + "json": pgtype.JSONOID, + "text": pgtype.TextOID, } for name, oid := range nameOIDs { dtByName, ok := conn.ConnInfo.DataTypeForName(name)