From 1a99c0e5c478b5c14472cfef7924caff4947e2d0 Mon Sep 17 00:00:00 2001 From: Terin Stock Date: Mon, 20 Mar 2017 13:24:44 -0700 Subject: [PATCH] fix(stdlib): lock openFromConnPoolCount while using Locks the `openFromConnPoolCount` counter while formatting the driver name and incrementing to avoid a data race of multiple goroutines modifying the counter and registering the same name. `sql.Register` panics if a driver name has already been registered. --- stdlib/sql.go | 10 +++++++++- stdlib/sql_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index affa93b6..e3d46cab 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -50,12 +50,16 @@ import ( "errors" "fmt" "io" + "sync" "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" ) -var openFromConnPoolCount int +var ( + openFromConnPoolCountMu sync.Mutex + openFromConnPoolCount int +) // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format @@ -120,8 +124,12 @@ func (d *Driver) Open(name string) (driver.Conn, error) { // pool connection size must be at least 2. func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { d := &Driver{Pool: pool} + + openFromConnPoolCountMu.Lock() name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) openFromConnPoolCount++ + openFromConnPoolCountMu.Unlock() + sql.Register(name, d) db, err := sql.Open(name, "") if err != nil { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index c8062c61..641ba9fe 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -3,6 +3,7 @@ package stdlib_test import ( "bytes" "database/sql" + "sync" "testing" "github.com/jackc/pgx" @@ -164,6 +165,43 @@ func TestOpenFromConnPool(t *testing.T) { } } +func TestOpenFromConnPoolRace(t *testing.T) { + wg := &sync.WaitGroup{} + connConfig := pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + // Can get pgx.ConnPool from driver + driver := db.Driver().(*stdlib.Driver) + if driver.Pool == nil { + t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") + } + }() + } + + wg.Wait() +} + func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db)