Merge branch 'james-lawrence-implement-opendb'

pull/436/head
Jack Christensen 2018-07-14 09:58:35 -05:00
commit 3cbe92ebb5
5 changed files with 115 additions and 11 deletions

64
stdlib/opendb.go Normal file
View File

@ -0,0 +1,64 @@
// +build go1.10
package stdlib
import (
"context"
"database/sql"
"database/sql/driver"
"github.com/jackc/pgx"
)
// OptionOpenDB options for configuring the driver when opening a new db pool.
type OptionOpenDB func(*connector)
// OptionAfterConnect provide a callback for after connect.
func OptionAfterConnect(ac func(*pgx.Conn) error) OptionOpenDB {
return func(dc *connector) {
dc.AfterConnect = ac
}
}
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
c := connector{
ConnConfig: config,
AfterConnect: func(*pgx.Conn) error { return nil }, // noop after connect by default
driver: pgxDriver,
}
for _, opt := range opts {
opt(&c)
}
return sql.OpenDB(c)
}
type connector struct {
pgx.ConnConfig
AfterConnect func(*pgx.Conn) error // function to call on every new connection
driver *Driver
}
// Connect implement driver.Connector interface
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
var (
err error
conn *pgx.Conn
)
if conn, err = pgx.Connect(c.ConnConfig); err != nil {
return nil, err
}
if err = c.AfterConnect(conn); err != nil {
return nil, err
}
return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil
}
// Driver implement driver.Connector interface
func (c connector) Driver() driver.Driver {
return c.driver
}

View File

@ -132,8 +132,11 @@ type Driver struct {
}
func (d *Driver) Open(name string) (driver.Conn, error) {
var connConfig pgx.ConnConfig
var afterConnect func(*pgx.Conn) error
var (
connConfig pgx.ConnConfig
afterConnect func(*pgx.Conn) error
)
if len(name) >= 9 && name[0] == 0 {
idBuf := []byte(name)[1:9]
id := int64(binary.BigEndian.Uint64(idBuf))

View File

@ -17,15 +17,6 @@ import (
"github.com/jackc/pgx/stdlib"
)
func openDB(t *testing.T) *sql.DB {
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
return db
}
func closeDB(t *testing.T, db *sql.DB) {
err := db.Close()
if err != nil {
@ -82,6 +73,14 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) {
}
}
func TestSQLOpen(t *testing.T) {
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
closeDB(t, db)
}
func TestNormalLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)

View File

@ -0,0 +1,20 @@
// +build go1.10
package stdlib_test
import (
"database/sql"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/stdlib"
)
func openDB(t *testing.T) *sql.DB {
config, err := pgx.ParseConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
}
return stdlib.OpenDB(config)
}

18
stdlib/stdlibutil_test.go Normal file
View File

@ -0,0 +1,18 @@
// +build !go1.10
package stdlib_test
import (
"database/sql"
"testing"
)
// this file contains utility functions for tests that differ between versions.
func openDB(t *testing.T) *sql.DB {
db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
return db
}