diff --git a/stdlib/sql.go b/stdlib/sql.go index 3dd92cbd..8ff3fd49 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -266,7 +266,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, err } - return wrapTx{tx: tx}, nil + return wrapTx{ctx: ctx, tx: tx}, nil } func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { @@ -539,11 +539,14 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} { return args } -type wrapTx struct{ tx pgx.Tx } +type wrapTx struct { + ctx context.Context + tx pgx.Tx +} -func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) } +func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } -func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) } +func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } type fakeTx struct{}