diff --git a/stdlib/sql.go b/stdlib/sql.go index e4c53ea7..e4565227 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -81,17 +81,10 @@ var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver -type ctxKey int - -var ctxKeyFakeTx ctxKey = 0 - -var ErrNotPgx = errors.New("not pgx *sql.DB") - func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } - fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ @@ -111,11 +104,6 @@ func init() { } } -var ( - fakeTxMutex sync.Mutex - fakeTxConns map[*pgx.Conn]*sql.Tx -) - // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) @@ -367,11 +355,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, driver.ErrBadConn } - if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { - *pconn = c.conn - return fakeTx{}, nil - } - var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: @@ -786,55 +769,3 @@ type wrapTx struct { func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } - -type fakeTx struct{} - -func (fakeTx) Commit() error { return nil } - -func (fakeTx) Rollback() error { return nil } - -// AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn. -// -// In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method. -func AcquireConn(db *sql.DB) (*pgx.Conn, error) { - var conn *pgx.Conn - ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - if conn == nil { - tx.Rollback() - return nil, ErrNotPgx - } - - fakeTxMutex.Lock() - fakeTxConns[conn] = tx - fakeTxMutex.Unlock() - - return conn, nil -} - -// ReleaseConn releases a *pgx.Conn acquired with AcquireConn. -func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { - var tx *sql.Tx - var ok bool - - if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - conn.Close(ctx) - } - - fakeTxMutex.Lock() - tx, ok = fakeTxConns[conn] - if ok { - delete(fakeTxConns, conn) - fakeTxMutex.Unlock() - } else { - fakeTxMutex.Unlock() - return fmt.Errorf("can't release conn that is not acquired") - } - - return tx.Rollback() -} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 30cea7d6..faa4a0cb 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -640,42 +640,6 @@ func TestBeginTxContextCancel(t *testing.T) { }) } -func TestAcquireConn(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - var conns []*pgx.Conn - - for i := 1; i < 6; i++ { - conn, err := stdlib.AcquireConn(db) - if err != nil { - t.Errorf("%d. AcquireConn failed: %v", i, err) - continue - } - - var n int32 - err = conn.QueryRow(context.Background(), "select 1").Scan(&n) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) - } - if n != 1 { - t.Errorf("%d. n => %d, want %d", i, n, 1) - } - - stats := db.Stats() - if stats.OpenConnections != i { - t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) - } - - conns = append(conns, conn) - } - - for i, conn := range conns { - if err := stdlib.ReleaseConn(db, conn); err != nil { - t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) - } - } - }) -} - func TestConnRaw(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { conn, err := db.Conn(context.Background()) @@ -691,38 +655,6 @@ func TestConnRaw(t *testing.T) { }) } -// https://github.com/jackc/pgx/issues/673 -func TestReleaseConnWithTxInProgress(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - skipCockroachDB(t, db, "Server does not support backend PID") - - c1, err := stdlib.AcquireConn(db) - require.NoError(t, err) - - _, err = c1.Exec(context.Background(), "begin") - require.NoError(t, err) - - c1PID := c1.PgConn().PID() - - err = stdlib.ReleaseConn(db, c1) - require.NoError(t, err) - - c2, err := stdlib.AcquireConn(db) - require.NoError(t, err) - - c2PID := c2.PgConn().PID() - - err = stdlib.ReleaseConn(db, c2) - require.NoError(t, err) - - require.NotEqual(t, c1PID, c2PID) - - // Releasing a conn with a tx in progress should close the connection - stats := db.Stats() - require.Equal(t, 1, stats.OpenConnections) - }) -} - func TestConnPingContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { err := db.PingContext(context.Background()) @@ -746,23 +678,6 @@ func TestConnExecContextSuccess(t *testing.T) { }) } -func TestConnExecContextFailureRetry(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - // We get a connection, immediately close it, and then get it back; - // DB.Conn along with Conn.ResetSession does the retry for us. - { - conn, err := stdlib.AcquireConn(db) - require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - _, err = conn.ExecContext(context.Background(), "select 1") - require.NoError(t, err) - }) -} - func TestConnQueryContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") @@ -777,24 +692,6 @@ func TestConnQueryContextSuccess(t *testing.T) { }) } -func TestConnQueryContextFailureRetry(t *testing.T) { - testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { - // We get a connection, immediately close it, and then get it back; - // DB.Conn along with Conn.ResetSession does the retry for us. - { - conn, err := stdlib.AcquireConn(db) - require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - - _, err = conn.QueryContext(context.Background(), "select 1") - require.NoError(t, err) - }) -} - func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select 42::bigint")