Add TransactionCtx method

add-transactionCtx-method
Vinícius Garcia 2022-07-31 12:18:19 -03:00
parent 9161634e7b
commit 46ebf63c0f
1 changed files with 79 additions and 10 deletions

89
ksql.go
View File

@ -161,7 +161,8 @@ func (c DB) Query(
query = selectPrefix + query
}
rows, err := c.db.QueryContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
rows, err := db.QueryContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("error running query: %s", err)
}
@ -251,7 +252,8 @@ func (c DB) QueryOne(
query = selectPrefix + query
}
rows, err := c.db.QueryContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
rows, err := db.QueryContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("error running query: %s", err)
}
@ -324,7 +326,8 @@ func (c DB) QueryChunks(
parser.Query = selectPrefix + parser.Query
}
rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...)
db := c.getAdapterFromCtx(ctx)
rows, err := db.QueryContext(ctx, parser.Query, parser.Params...)
if err != nil {
return err
}
@ -448,7 +451,8 @@ func (c DB) insertReturningIDs(
scanValues []interface{},
idNames []string,
) error {
rows, err := c.db.QueryContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
rows, err := db.QueryContext(ctx, query, params...)
if err != nil {
return err
}
@ -481,7 +485,8 @@ func (c DB) insertWithLastInsertID(
params []interface{},
idName string,
) error {
result, err := c.db.ExecContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
result, err := db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
@ -514,7 +519,8 @@ func (c DB) insertWithNoIDRetrieval(
query string,
params []interface{},
) error {
_, err := c.db.ExecContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
_, err := db.ExecContext(ctx, query, params...)
return err
}
@ -569,7 +575,8 @@ func (c DB) Delete(
var params []interface{}
query, params = buildDeleteQuery(c.dialect, table, idMap)
result, err := c.db.ExecContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
result, err := db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
@ -662,7 +669,8 @@ func (c DB) Patch(
return err
}
result, err := c.db.ExecContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
result, err := db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
@ -858,10 +866,19 @@ func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{})
// Exec just runs an SQL command on the database returning no rows.
func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) {
return c.db.ExecContext(ctx, query, params...)
db := c.getAdapterFromCtx(ctx)
return db.ExecContext(ctx, query, params...)
}
// Transaction just runs an SQL command on the database returning no rows.
// Transaction encapsulates several queries into a single transaction.
// All these queries should be made inside the input callback `fn`
// and they should use the input ksql.Provider.
//
// If the callback returns any errors the transaction will be rolled back,
// otherwise the transaction will me committed.
//
// If it happens that a second transaction is started inside a transaction
// callback the same transaction will be reused with no errors.
func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
switch txBeginner := c.db.(type) {
case Tx:
@ -904,6 +921,58 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
}
}
type txCtxKey struct{}
func (c DB) getAdapterFromCtx(ctx context.Context) DBAdapter {
if tx, ok := ctx.Value(txCtxKey{}).(Tx); ok {
return tx
}
return c.db
}
// TransactionCtx just runs an SQL command on the database returning no rows.
func (c DB) TransactionCtx(ctx context.Context, fn func(ctx context.Context) error) error {
if _, ok := ctx.Value(txCtxKey{}).(Tx); ok {
return fn(ctx)
}
txBeginner, ok := c.db.(TxBeginner)
if !ok {
return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBeginner interface")
}
tx, err := txBeginner.BeginTx(ctx)
if err != nil {
return err
}
defer func() {
if r := recover(); r != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {
r = errors.Wrap(rollbackErr,
fmt.Sprintf("unable to rollback after panic with value: %v", r),
)
}
panic(r)
}
}()
ctx = context.WithValue(ctx, txCtxKey{}, tx)
err = fn(ctx)
if err != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {
err = errors.Wrap(rollbackErr,
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
)
}
return err
}
return tx.Commit(ctx)
}
// Close implements the io.Closer interface
func (c DB) Close() error {
closer, ok := c.db.(io.Closer)