mirror of https://github.com/jackc/pgx.git
fix(stdlib): lock openFromConnPoolCount while using
Locks the `openFromConnPoolCount` counter while formatting the driver name and incrementing to avoid a data race of multiple goroutines modifying the counter and registering the same name. `sql.Register` panics if a driver name has already been registered.v3-numeric-wip
parent
120da8df8f
commit
1a99c0e5c4
|
@ -50,12 +50,16 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
var openFromConnPoolCount int
|
||||
var (
|
||||
openFromConnPoolCountMu sync.Mutex
|
||||
openFromConnPoolCount int
|
||||
)
|
||||
|
||||
// oids that map to intrinsic database/sql types. These will be allowed to be
|
||||
// binary, anything else will be forced to text format
|
||||
|
@ -120,8 +124,12 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
|||
// pool connection size must be at least 2.
|
||||
func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) {
|
||||
d := &Driver{Pool: pool}
|
||||
|
||||
openFromConnPoolCountMu.Lock()
|
||||
name := fmt.Sprintf("pgx-%d", openFromConnPoolCount)
|
||||
openFromConnPoolCount++
|
||||
openFromConnPoolCountMu.Unlock()
|
||||
|
||||
sql.Register(name, d)
|
||||
db, err := sql.Open(name, "")
|
||||
if err != nil {
|
||||
|
|
|
@ -3,6 +3,7 @@ package stdlib_test
|
|||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
|
@ -164,6 +165,43 @@ func TestOpenFromConnPool(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOpenFromConnPoolRace(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
connConfig := pgx.ConnConfig{
|
||||
Host: "127.0.0.1",
|
||||
User: "pgx_md5",
|
||||
Password: "secret",
|
||||
Database: "pgx_test",
|
||||
}
|
||||
|
||||
config := pgx.ConnPoolConfig{ConnConfig: connConfig}
|
||||
pool, err := pgx.NewConnPool(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
db, err := stdlib.OpenFromConnPool(pool)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
// Can get pgx.ConnPool from driver
|
||||
driver := db.Driver().(*stdlib.Driver)
|
||||
if driver.Pool == nil {
|
||||
t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestStmtExec(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
|
Loading…
Reference in New Issue