diff --git a/stdlib/sql.go b/stdlib/sql.go index c43450f6..68ae2697 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -8,12 +8,16 @@ import ( "io" ) +var openFromConnPoolCount int + func init() { d := &Driver{} sql.Register("pgx", d) } -type Driver struct{} +type Driver struct { + Pool *pgx.ConnPool +} func (d *Driver) Open(name string) (driver.Conn, error) { connConfig, err := pgx.ParseURI(name) @@ -30,6 +34,29 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return c, nil } +// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB +// with pool as the backend. This enables full control over the connection +// process and configuration while maintaining compatibility with the +// database/sql interface. In addition, by calling Driver() on the returned +// *sql.DB and typecasting to *stdlib.Driver a reference to the pgx.ConnPool can +// be reaquired later. This allows fast paths targeting pgx to be used while +// still maintaining compatibility with other databases and drivers. +func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { + d := &Driver{Pool: pool} + name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) + openFromConnPoolCount++ + sql.Register(name, d) + db, err := sql.Open(name, "") + if err != nil { + return nil, err + } + + db.SetMaxIdleConns(0) + db.SetMaxOpenConns(pool.MaxConnectionCount()) + + return db, nil +} + type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 9ae4b808..4378f3af 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -2,7 +2,8 @@ package stdlib_test import ( "database/sql" - _ "github.com/jackc/pgx/stdlib" + "github.com/jackc/pgx" + "github.com/jackc/pgx/stdlib" "testing" ) @@ -85,6 +86,43 @@ func TestNormalLifeCycle(t *testing.T) { } } +func TestSqlOpenDoesNotHavePool(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + driver := db.Driver().(*stdlib.Driver) + if driver.Pool != nil { + t.Fatal("Did not expect driver opened through database/sql to have Pool, but it did") + } +} + +func TestOpenFromConnPool(t *testing.T) { + connConfig := pgx.ConnConfig{ + Host: "localhost", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + driver := db.Driver().(*stdlib.Driver) + if driver.Pool == nil { + t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") + } +} + func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db)