From 608451a215bad64d389e28cc912266b15a217693 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Mon, 1 Jun 2020 19:38:12 +0300 Subject: [PATCH 1/2] Store original config in Conn before updating it. --- conn.go | 3 ++- conn_test.go | 2 +- go.mod | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index 0c4cfbe2..000c0f24 100644 --- a/conn.go +++ b/conn.go @@ -174,6 +174,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { if !config.createdByParseConfig { panic("config must be created by ParseConfig") } + originalConfig := config // This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting // other connections with the same config. See https://github.com/jackc/pgx/issues/618. @@ -183,7 +184,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { } c = &Conn{ - config: config, + config: originalConfig, connInfo: pgtype.NewConnInfo(), logLevel: config.LogLevel, logger: config.Logger, diff --git a/conn_test.go b/conn_test.go index 72022b21..8a52dbcb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -51,7 +51,7 @@ func TestConnect(t *testing.T) { t.Fatalf("Unable to establish connection: %v", err) } - assert.Equal(t, connString, conn.Config().ConnString()) + assert.Equal(t, config, conn.Config()) var currentDB string err = conn.QueryRow(context.Background(), "select current_database()").Scan(¤tDB) diff --git a/go.mod b/go.mod index b5967ad4..3e45c3e0 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.5.1 - go.uber.org/multierr v1.5.0 // indirect go.uber.org/zap v1.10.0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec From e29ce9f6d6e3ef1d48eb0484167335d4291e10bb Mon Sep 17 00:00:00 2001 From: georgysavva Date: Tue, 2 Jun 2020 13:35:05 +0300 Subject: [PATCH 2/2] Add Config.Copy() in pgx and pgxpool packages. Conn.Config() and Pool.Config() return copy of the original config. --- conn.go | 14 ++++++-- conn_test.go | 23 ++++++++++++- go.mod | 2 +- go.sum | 2 ++ helper_test.go | 50 +++++++++++++++++++++++++++ pgxpool/common_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++ pgxpool/pool.go | 14 ++++++-- pgxpool/pool_test.go | 24 ++++++++++++- 8 files changed, 200 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 000c0f24..6b9a2b20 100644 --- a/conn.go +++ b/conn.go @@ -47,6 +47,16 @@ type ConnConfig struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (cc *ConnConfig) Copy() *ConnConfig { + newConfig := new(ConnConfig) + *newConfig = *cc + newConfig.Config = *newConfig.Config.Copy() + return newConfig +} + func (cc *ConnConfig) ConnString() string { return cc.connString } // BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. @@ -425,8 +435,8 @@ func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } // ConnInfo returns the connection info used for this connection. func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } -// Config returns config that was used to establish this connection. -func (c *Conn) Config() *ConnConfig { return c.config } +// Config returns a copy of config that was used to establish this connection. +func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. diff --git a/conn_test.go b/conn_test.go index 8a52dbcb..ba4eda95 100644 --- a/conn_test.go +++ b/conn_test.go @@ -51,7 +51,7 @@ func TestConnect(t *testing.T) { t.Fatalf("Unable to establish connection: %v", err) } - assert.Equal(t, config, conn.Config()) + assertConfigsEqual(t, config, conn.Config(), "Conn.Config() returns original config") var currentDB string err = conn.QueryRow(context.Background(), "select current_database()").Scan(¤tDB) @@ -116,6 +116,27 @@ func TestConfigContainsConnStr(t *testing.T) { assert.Equal(t, connStr, config.ConnString()) } +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgx.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, t.Name()) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_DATABASE") + original, err := pgx.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgx.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index 3e45c3e0..a390a967 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/cockroachdb/apd v1.1.0 github.com/gofrs/uuid v3.2.0+incompatible - github.com/jackc/pgconn v1.5.0 + github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgtype v1.3.1-0.20200513130519-238967ec4e4c diff --git a/go.sum b/go.sum index 1f52822e..6da235de 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o= github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 h1:LRlrfJW9S99uiOCY8F/qLvX1yEY1TVAaCBHFb79yHBQ= +github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= diff --git a/helper_test.go b/helper_test.go index fde4cbfa..6c532035 100644 --- a/helper_test.go +++ b/helper_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "context" + "github.com/stretchr/testify/assert" "os" "testing" @@ -114,3 +115,52 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } + +func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) + assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) + assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) + + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index 839796a9..68e13a77 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" @@ -125,3 +127,79 @@ func testCopyFrom(t *testing.T, db interface { assert.NoError(t, rows.Err()) assert.Equal(t, inputRows, outputRows) } + +func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName) + assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName) + + assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName) + assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName) + assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) + assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) + assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) + assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName) + + assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) +} + +func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + + assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) + assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) + + assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) + + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 12856244..fb7d3017 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -121,6 +121,16 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConfig := new(Config) + *newConfig = *c + newConfig.ConnConfig = c.ConnConfig.Copy() + return newConfig +} + func (c *Config) ConnString() string { return c.ConnConfig.ConnString() } // Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial @@ -373,8 +383,8 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { return conns } -// Config returns config that was used to initialize this pool. -func (p *Pool) Config() *Config { return p.config } +// Config returns a copy of config that was used to initialize this pool. +func (p *Pool) Config() *Config { return p.config.Copy() } func (p *Pool) Stat() *Stat { return &Stat{s: p.p.Stat()} diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index ebf6bd26..17be648f 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -28,7 +28,7 @@ func TestConnectConfig(t *testing.T) { require.NoError(t, err) pool, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) - assert.Equal(t, config, pool.Config()) + assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config") pool.Close() } @@ -78,6 +78,28 @@ func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.ConnectConfig(context.Background(), config) }) } +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgxpool.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + + assertConfigsEqual(t, original, copied, t.Name()) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_DATABASE") + original, err := pgxpool.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgxpool.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + func TestPoolAcquireAndConnRelease(t *testing.T) { t.Parallel()