diff --git a/stdlib/opendb.go b/stdlib/opendb.go deleted file mode 100644 index ad4e979b..00000000 --- a/stdlib/opendb.go +++ /dev/null @@ -1,64 +0,0 @@ -// +build go1.10 - -package stdlib - -import ( - "context" - "database/sql" - "database/sql/driver" - - "github.com/jackc/pgx/v4" -) - -// OptionOpenDB options for configuring the driver when opening a new db pool. -type OptionOpenDB func(*connector) - -// OptionAfterConnect provide a callback for after connect. -func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { - return func(dc *connector) { - dc.AfterConnect = ac - } -} - -func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { - c := connector{ - ConnConfig: config, - AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default - driver: pgxDriver, - } - - for _, opt := range opts { - opt(&c) - } - - return sql.OpenDB(c) -} - -type connector struct { - pgx.ConnConfig - AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection - driver *Driver -} - -// Connect implement driver.Connector interface -func (c connector) Connect(ctx context.Context) (driver.Conn, error) { - var ( - err error - conn *pgx.Conn - ) - - if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil { - return nil, err - } - - if err = c.AfterConnect(ctx, conn); err != nil { - return nil, err - } - - return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil -} - -// Driver implement driver.Connector interface -func (c connector) Driver() driver.Driver { - return c.driver -} diff --git a/stdlib/sql.go b/stdlib/sql.go index db8eefc0..f69196b8 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -114,6 +114,59 @@ var ( fakeTxConns map[*pgx.Conn]*sql.Tx ) +// OptionOpenDB options for configuring the driver when opening a new db pool. +type OptionOpenDB func(*connector) + +// OptionAfterConnect provide a callback for after connect. +func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.AfterConnect = ac + } +} + +func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { + c := connector{ + ConnConfig: config, + AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return sql.OpenDB(c) +} + +type connector struct { + pgx.ConnConfig + AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection + driver *Driver +} + +// Connect implement driver.Connector interface +func (c connector) Connect(ctx context.Context) (driver.Conn, error) { + var ( + err error + conn *pgx.Conn + ) + + if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil { + return nil, err + } + + if err = c.AfterConnect(ctx, conn); err != nil { + return nil, err + } + + return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil +} + +// Driver implement driver.Connector interface +func (c connector) Driver() driver.Driver { + return c.driver +} + // GetDefaultDriver returns the driver initialized in the init function // and used when the pgx driver is registered. func GetDefaultDriver() driver.Driver { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index d387f25b..4629f8d5 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -19,6 +19,15 @@ import ( "github.com/stretchr/testify/require" ) +func openDB(t testing.TB) *sql.DB { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatalf("pgx.ParseConnectionString failed: %v", err) + } + + return stdlib.OpenDB(*config) +} + func closeDB(t testing.TB, db *sql.DB) { err := db.Close() if err != nil { diff --git a/stdlib/stdlibutil110_test.go b/stdlib/stdlibutil110_test.go deleted file mode 100644 index 6eec51e0..00000000 --- a/stdlib/stdlibutil110_test.go +++ /dev/null @@ -1,21 +0,0 @@ -// +build go1.10 - -package stdlib_test - -import ( - "database/sql" - "os" - "testing" - - "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/stdlib" -) - -func openDB(t testing.TB) *sql.DB { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatalf("pgx.ParseConnectionString failed: %v", err) - } - - return stdlib.OpenDB(*config) -} diff --git a/stdlib/stdlibutil_test.go b/stdlib/stdlibutil_test.go deleted file mode 100644 index 58b3ff5b..00000000 --- a/stdlib/stdlibutil_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build !go1.10 - -package stdlib_test - -import ( - "database/sql" - "os" - "testing" -) - -// this file contains utility functions for tests that differ between versions. -func openDB(t *testing.T) *sql.DB { - db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } - - return db -}