From 19a9154d237e9058211fedc3ba1185ea58a174d1 Mon Sep 17 00:00:00 2001 From: James Lawrence Date: Sun, 25 Feb 2018 14:17:38 -0500 Subject: [PATCH] implement driver.Connector --- stdlib/opendb.go | 64 ++++++++++++++++++++++++++++++++++++ stdlib/sql.go | 7 ++-- stdlib/sql_test.go | 17 +++++----- stdlib/stdlibutil110_test.go | 20 +++++++++++ stdlib/stdlibutil_test.go | 18 ++++++++++ 5 files changed, 115 insertions(+), 11 deletions(-) create mode 100644 stdlib/opendb.go create mode 100644 stdlib/stdlibutil110_test.go create mode 100644 stdlib/stdlibutil_test.go diff --git a/stdlib/opendb.go b/stdlib/opendb.go new file mode 100644 index 00000000..cb3703ab --- /dev/null +++ b/stdlib/opendb.go @@ -0,0 +1,64 @@ +// +build go1.10 + +package stdlib + +import ( + "context" + "database/sql" + "database/sql/driver" + + "github.com/jackc/pgx" +) + +// OptionOpenDB options for configuring the driver when opening a new db pool. +type OptionOpenDB func(*connector) + +// OptionAfterConnect provide a callback for after connect. +func OptionAfterConnect(ac func(*pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.AfterConnect = ac + } +} + +func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { + c := connector{ + ConnConfig: config, + AfterConnect: func(*pgx.Conn) error { return nil }, // noop after connect by default + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return sql.OpenDB(c) +} + +type connector struct { + pgx.ConnConfig + AfterConnect func(*pgx.Conn) error // function to call on every new connection + driver *Driver +} + +// Connect implement driver.Connector interface +func (c connector) Connect(ctx context.Context) (driver.Conn, error) { + var ( + err error + conn *pgx.Conn + ) + + if conn, err = pgx.Connect(c.ConnConfig); err != nil { + return nil, err + } + + if err = c.AfterConnect(conn); err != nil { + return nil, err + } + + return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil +} + +// Driver implement driver.Connector interface +func (c connector) Driver() driver.Driver { + return c.driver +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 2d4930ee..aefe7dbf 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -130,8 +130,11 @@ type Driver struct { } func (d *Driver) Open(name string) (driver.Conn, error) { - var connConfig pgx.ConnConfig - var afterConnect func(*pgx.Conn) error + var ( + connConfig pgx.ConnConfig + afterConnect func(*pgx.Conn) error + ) + if len(name) >= 9 && name[0] == 0 { idBuf := []byte(name)[1:9] id := int64(binary.BigEndian.Uint64(idBuf)) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index a4a99971..d7c627ef 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -16,15 +16,6 @@ import ( "github.com/jackc/pgx/stdlib" ) -func openDB(t *testing.T) *sql.DB { - db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test") - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - - return db -} - func closeDB(t *testing.T, db *sql.DB) { err := db.Close() if err != nil { @@ -81,6 +72,14 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) { } } +func TestSQLOpen(t *testing.T) { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + closeDB(t, db) +} + func TestNormalLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) diff --git a/stdlib/stdlibutil110_test.go b/stdlib/stdlibutil110_test.go new file mode 100644 index 00000000..c83b645b --- /dev/null +++ b/stdlib/stdlibutil110_test.go @@ -0,0 +1,20 @@ +// +build go1.10 + +package stdlib_test + +import ( + "database/sql" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/stdlib" +) + +func openDB(t *testing.T) *sql.DB { + config, err := pgx.ParseConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test") + if err != nil { + t.Fatalf("pgx.ParseConnectionString failed: %v", err) + } + + return stdlib.OpenDB(config) +} diff --git a/stdlib/stdlibutil_test.go b/stdlib/stdlibutil_test.go new file mode 100644 index 00000000..31ea391c --- /dev/null +++ b/stdlib/stdlibutil_test.go @@ -0,0 +1,18 @@ +// +build !go1.10 + +package stdlib_test + +import ( + "database/sql" + "testing" +) + +// this file contains utility functions for tests that differ between versions. +func openDB(t *testing.T) *sql.DB { + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + + return db +}