diff --git a/tx.go b/tx.go index f9607f70..81fcfa26 100644 --- a/tx.go +++ b/tx.go @@ -144,11 +144,16 @@ func (tx *Tx) CommitEx(ctx context.Context) error { // defer tx.Rollback() is safe even if tx.Commit() will be called first in a // non-error condition. func (tx *Tx) Rollback() error { + ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) + return tx.RollbackEx(ctx) +} + +// RollbackEx is the context version of Rollback +func (tx *Tx) RollbackEx(ctx context.Context) error { if tx.status != TxStatusInProgress { return ErrTxClosed } - ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) _, tx.err = tx.conn.ExecEx(ctx, "rollback", nil) if tx.err == nil { tx.status = TxStatusRollbackSuccess @@ -167,11 +172,16 @@ func (tx *Tx) Rollback() error { // Exec delegates to the underlying *Conn func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + return tx.ExecEx(context.Background(), sql, nil, arguments...) +} + +// ExecEx delegates to the underlying *Conn +func (tx *Tx) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { if tx.status != TxStatusInProgress { return CommandTag(""), ErrTxClosed } - return tx.conn.Exec(sql, arguments...) + return tx.conn.ExecEx(ctx, sql, options, arguments...) } // Prepare delegates to the underlying *Conn @@ -190,13 +200,18 @@ func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOp // Query delegates to the underlying *Conn func (tx *Tx) Query(sql string, args ...interface{}) (*Rows, error) { + return tx.QueryEx(context.Background(), sql, nil, args...) +} + +// QueryEx delegates to the underlying *Conn +func (tx *Tx) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { if tx.status != TxStatusInProgress { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed return &Rows{closed: true, err: err}, err } - return tx.conn.Query(sql, args...) + return tx.conn.QueryEx(ctx, sql, options, args...) } // QueryRow delegates to the underlying *Conn @@ -205,6 +220,12 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +// QueryRowEx delegates to the underlying *Conn +func (tx *Tx) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { + rows, _ := tx.QueryEx(ctx, sql, options, args...) + return (*Row)(rows) +} + // CopyFrom delegates to the underlying *Conn func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { if tx.status != TxStatusInProgress {