From 46ebf63c0f481ca39dfde36ad356fdd03cf804e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 31 Jul 2022 12:18:19 -0300 Subject: [PATCH] Add TransactionCtx method --- ksql.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/ksql.go b/ksql.go index 4ed4121..fa2c03f 100644 --- a/ksql.go +++ b/ksql.go @@ -161,7 +161,8 @@ func (c DB) Query( query = selectPrefix + query } - rows, err := c.db.QueryContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + rows, err := db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %s", err) } @@ -251,7 +252,8 @@ func (c DB) QueryOne( query = selectPrefix + query } - rows, err := c.db.QueryContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + rows, err := db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %s", err) } @@ -324,7 +326,8 @@ func (c DB) QueryChunks( parser.Query = selectPrefix + parser.Query } - rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...) + db := c.getAdapterFromCtx(ctx) + rows, err := db.QueryContext(ctx, parser.Query, parser.Params...) if err != nil { return err } @@ -448,7 +451,8 @@ func (c DB) insertReturningIDs( scanValues []interface{}, idNames []string, ) error { - rows, err := c.db.QueryContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + rows, err := db.QueryContext(ctx, query, params...) if err != nil { return err } @@ -481,7 +485,8 @@ func (c DB) insertWithLastInsertID( params []interface{}, idName string, ) error { - result, err := c.db.ExecContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + result, err := db.ExecContext(ctx, query, params...) if err != nil { return err } @@ -514,7 +519,8 @@ func (c DB) insertWithNoIDRetrieval( query string, params []interface{}, ) error { - _, err := c.db.ExecContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + _, err := db.ExecContext(ctx, query, params...) return err } @@ -569,7 +575,8 @@ func (c DB) Delete( var params []interface{} query, params = buildDeleteQuery(c.dialect, table, idMap) - result, err := c.db.ExecContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + result, err := db.ExecContext(ctx, query, params...) if err != nil { return err } @@ -662,7 +669,8 @@ func (c DB) Patch( return err } - result, err := c.db.ExecContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + result, err := db.ExecContext(ctx, query, params...) if err != nil { return err } @@ -858,10 +866,19 @@ func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{}) // Exec just runs an SQL command on the database returning no rows. func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) { - return c.db.ExecContext(ctx, query, params...) + db := c.getAdapterFromCtx(ctx) + return db.ExecContext(ctx, query, params...) } -// Transaction just runs an SQL command on the database returning no rows. +// Transaction encapsulates several queries into a single transaction. +// All these queries should be made inside the input callback `fn` +// and they should use the input ksql.Provider. +// +// If the callback returns any errors the transaction will be rolled back, +// otherwise the transaction will me committed. +// +// If it happens that a second transaction is started inside a transaction +// callback the same transaction will be reused with no errors. func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { switch txBeginner := c.db.(type) { case Tx: @@ -904,6 +921,58 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { } } +type txCtxKey struct{} + +func (c DB) getAdapterFromCtx(ctx context.Context) DBAdapter { + if tx, ok := ctx.Value(txCtxKey{}).(Tx); ok { + return tx + } + + return c.db +} + +// TransactionCtx just runs an SQL command on the database returning no rows. +func (c DB) TransactionCtx(ctx context.Context, fn func(ctx context.Context) error) error { + if _, ok := ctx.Value(txCtxKey{}).(Tx); ok { + return fn(ctx) + } + + txBeginner, ok := c.db.(TxBeginner) + if !ok { + return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBeginner interface") + } + + tx, err := txBeginner.BeginTx(ctx) + if err != nil { + return err + } + defer func() { + if r := recover(); r != nil { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil { + r = errors.Wrap(rollbackErr, + fmt.Sprintf("unable to rollback after panic with value: %v", r), + ) + } + panic(r) + } + }() + + ctx = context.WithValue(ctx, txCtxKey{}, tx) + err = fn(ctx) + if err != nil { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil { + err = errors.Wrap(rollbackErr, + fmt.Sprintf("unable to rollback after error: %s", err.Error()), + ) + } + return err + } + + return tx.Commit(ctx) +} + // Close implements the io.Closer interface func (c DB) Close() error { closer, ok := c.db.(io.Closer)