diff --git a/stdlib/sql.go b/stdlib/sql.go index affa93b6..e3d46cab 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -50,12 +50,16 @@ import ( "errors" "fmt" "io" + "sync" "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" ) -var openFromConnPoolCount int +var ( + openFromConnPoolCountMu sync.Mutex + openFromConnPoolCount int +) // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format @@ -120,8 +124,12 @@ func (d *Driver) Open(name string) (driver.Conn, error) { // pool connection size must be at least 2. func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { d := &Driver{Pool: pool} + + openFromConnPoolCountMu.Lock() name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) openFromConnPoolCount++ + openFromConnPoolCountMu.Unlock() + sql.Register(name, d) db, err := sql.Open(name, "") if err != nil { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index c8062c61..641ba9fe 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -3,6 +3,7 @@ package stdlib_test import ( "bytes" "database/sql" + "sync" "testing" "github.com/jackc/pgx" @@ -164,6 +165,43 @@ func TestOpenFromConnPool(t *testing.T) { } } +func TestOpenFromConnPoolRace(t *testing.T) { + wg := &sync.WaitGroup{} + connConfig := pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + // Can get pgx.ConnPool from driver + driver := db.Driver().(*stdlib.Driver) + if driver.Pool == nil { + t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") + } + }() + } + + wg.Wait() +} + func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db)