Merge pull request #768 from georgysavva/fix-exposed-config

Fix exposed config
pull/783/head
Jack Christensen 2020-06-06 07:05:13 -05:00 committed by GitHub
commit 73d0ac206c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 9 deletions

17
conn.go
View File

@ -47,6 +47,16 @@ type ConnConfig struct {
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the tls.Config:
// according to the tls.Config docs it must not be modified after creation.
func (cc *ConnConfig) Copy() *ConnConfig {
newConfig := new(ConnConfig)
*newConfig = *cc
newConfig.Config = *newConfig.Config.Copy()
return newConfig
}
func (cc *ConnConfig) ConnString() string { return cc.connString }
// BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection.
@ -174,6 +184,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
if !config.createdByParseConfig {
panic("config must be created by ParseConfig")
}
originalConfig := config
// This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting
// other connections with the same config. See https://github.com/jackc/pgx/issues/618.
@ -183,7 +194,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
}
c = &Conn{
config: config,
config: originalConfig,
connInfo: pgtype.NewConnInfo(),
logLevel: config.LogLevel,
logger: config.Logger,
@ -424,8 +435,8 @@ func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache }
// ConnInfo returns the connection info used for this connection.
func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo }
// Config returns config that was used to establish this connection.
func (c *Conn) Config() *ConnConfig { return c.config }
// Config returns a copy of config that was used to establish this connection.
func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.

View File

@ -51,7 +51,7 @@ func TestConnect(t *testing.T) {
t.Fatalf("Unable to establish connection: %v", err)
}
assert.Equal(t, connString, conn.Config().ConnString())
assertConfigsEqual(t, config, conn.Config(), "Conn.Config() returns original config")
var currentDB string
err = conn.QueryRow(context.Background(), "select current_database()").Scan(&currentDB)
@ -116,6 +116,27 @@ func TestConfigContainsConnStr(t *testing.T) {
assert.Equal(t, connStr, config.ConnString())
}
func TestConfigCopyReturnsEqualConfig(t *testing.T) {
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
original, err := pgx.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assertConfigsEqual(t, original, copied, t.Name())
}
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
connString := os.Getenv("PGX_TEST_DATABASE")
original, err := pgx.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assert.NotPanics(t, func() {
_, err = pgx.ConnectConfig(context.Background(), copied)
})
assert.NoError(t, err)
}
func TestParseConfigExtractsStatementCacheOptions(t *testing.T) {
t.Parallel()

3
go.mod
View File

@ -5,7 +5,7 @@ go 1.12
require (
github.com/cockroachdb/apd v1.1.0
github.com/gofrs/uuid v3.2.0+incompatible
github.com/jackc/pgconn v1.5.0
github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853
github.com/jackc/pgio v1.0.0
github.com/jackc/pgproto3/v2 v2.0.1
github.com/jackc/pgtype v1.3.1-0.20200513130519-238967ec4e4c
@ -14,7 +14,6 @@ require (
github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.5.1
go.uber.org/multierr v1.5.0 // indirect
go.uber.org/zap v1.10.0
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec

2
go.sum
View File

@ -28,6 +28,8 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU
github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk=
github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o=
github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI=
github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 h1:LRlrfJW9S99uiOCY8F/qLvX1yEY1TVAaCBHFb79yHBQ=
github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA=

View File

@ -2,6 +2,7 @@ package pgx_test
import (
"context"
"github.com/stretchr/testify/assert"
"os"
"testing"
@ -114,3 +115,52 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) {
t.Error("Wrong values returned")
}
}
func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
if !assert.NotNil(t, expected) {
return
}
if !assert.NotNil(t, actual) {
return
}
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
if expected.TLSConfig != nil {
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
}
}
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
for i := range expected.Fallbacks {
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
if expected.Fallbacks[i].TLSConfig != nil {
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
}
}
}
}
}

View File

@ -5,6 +5,8 @@ import (
"testing"
"time"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
@ -125,3 +127,79 @@ func testCopyFrom(t *testing.T, db interface {
assert.NoError(t, rows.Err())
assert.Equal(t, inputRows, outputRows)
}
func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) {
if !assert.NotNil(t, expected) {
return
}
if !assert.NotNil(t, actual) {
return
}
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)
assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName)
assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)
}
func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
if !assert.NotNil(t, expected) {
return
}
if !assert.NotNil(t, actual) {
return
}
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
if expected.TLSConfig != nil {
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
}
}
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
for i := range expected.Fallbacks {
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
if expected.Fallbacks[i].TLSConfig != nil {
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
}
}
}
}
}

View File

@ -121,6 +121,16 @@ type Config struct {
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the tls.Config:
// according to the tls.Config docs it must not be modified after creation.
func (c *Config) Copy() *Config {
newConfig := new(Config)
*newConfig = *c
newConfig.ConnConfig = c.ConnConfig.Copy()
return newConfig
}
func (c *Config) ConnString() string { return c.ConnConfig.ConnString() }
// Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial
@ -373,8 +383,8 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
return conns
}
// Config returns config that was used to initialize this pool.
func (p *Pool) Config() *Config { return p.config }
// Config returns a copy of config that was used to initialize this pool.
func (p *Pool) Config() *Config { return p.config.Copy() }
func (p *Pool) Stat() *Stat {
return &Stat{s: p.p.Stat()}

View File

@ -28,7 +28,7 @@ func TestConnectConfig(t *testing.T) {
require.NoError(t, err)
pool, err := pgxpool.ConnectConfig(context.Background(), config)
require.NoError(t, err)
assert.Equal(t, config, pool.Config())
assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config")
pool.Close()
}
@ -78,6 +78,28 @@ func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) {
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.ConnectConfig(context.Background(), config) })
}
func TestConfigCopyReturnsEqualConfig(t *testing.T) {
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
original, err := pgxpool.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assertConfigsEqual(t, original, copied, t.Name())
}
func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
connString := os.Getenv("PGX_TEST_DATABASE")
original, err := pgxpool.ParseConfig(connString)
require.NoError(t, err)
copied := original.Copy()
assert.NotPanics(t, func() {
_, err = pgxpool.ConnectConfig(context.Background(), copied)
})
assert.NoError(t, err)
}
func TestPoolAcquireAndConnRelease(t *testing.T) {
t.Parallel()