mirror of https://github.com/VinGarcia/ksql.git
Add TransactionCtx method
parent
9161634e7b
commit
46ebf63c0f
89
ksql.go
89
ksql.go
|
@ -161,7 +161,8 @@ func (c DB) Query(
|
|||
query = selectPrefix + query
|
||||
}
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
rows, err := db.QueryContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running query: %s", err)
|
||||
}
|
||||
|
@ -251,7 +252,8 @@ func (c DB) QueryOne(
|
|||
query = selectPrefix + query
|
||||
}
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
rows, err := db.QueryContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running query: %s", err)
|
||||
}
|
||||
|
@ -324,7 +326,8 @@ func (c DB) QueryChunks(
|
|||
parser.Query = selectPrefix + parser.Query
|
||||
}
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
rows, err := db.QueryContext(ctx, parser.Query, parser.Params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -448,7 +451,8 @@ func (c DB) insertReturningIDs(
|
|||
scanValues []interface{},
|
||||
idNames []string,
|
||||
) error {
|
||||
rows, err := c.db.QueryContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
rows, err := db.QueryContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -481,7 +485,8 @@ func (c DB) insertWithLastInsertID(
|
|||
params []interface{},
|
||||
idName string,
|
||||
) error {
|
||||
result, err := c.db.ExecContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
result, err := db.ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -514,7 +519,8 @@ func (c DB) insertWithNoIDRetrieval(
|
|||
query string,
|
||||
params []interface{},
|
||||
) error {
|
||||
_, err := c.db.ExecContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
_, err := db.ExecContext(ctx, query, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -569,7 +575,8 @@ func (c DB) Delete(
|
|||
var params []interface{}
|
||||
query, params = buildDeleteQuery(c.dialect, table, idMap)
|
||||
|
||||
result, err := c.db.ExecContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
result, err := db.ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -662,7 +669,8 @@ func (c DB) Patch(
|
|||
return err
|
||||
}
|
||||
|
||||
result, err := c.db.ExecContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
result, err := db.ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -858,10 +866,19 @@ func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{})
|
|||
|
||||
// Exec just runs an SQL command on the database returning no rows.
|
||||
func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) {
|
||||
return c.db.ExecContext(ctx, query, params...)
|
||||
db := c.getAdapterFromCtx(ctx)
|
||||
return db.ExecContext(ctx, query, params...)
|
||||
}
|
||||
|
||||
// Transaction just runs an SQL command on the database returning no rows.
|
||||
// Transaction encapsulates several queries into a single transaction.
|
||||
// All these queries should be made inside the input callback `fn`
|
||||
// and they should use the input ksql.Provider.
|
||||
//
|
||||
// If the callback returns any errors the transaction will be rolled back,
|
||||
// otherwise the transaction will me committed.
|
||||
//
|
||||
// If it happens that a second transaction is started inside a transaction
|
||||
// callback the same transaction will be reused with no errors.
|
||||
func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
||||
switch txBeginner := c.db.(type) {
|
||||
case Tx:
|
||||
|
@ -904,6 +921,58 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
|||
}
|
||||
}
|
||||
|
||||
type txCtxKey struct{}
|
||||
|
||||
func (c DB) getAdapterFromCtx(ctx context.Context) DBAdapter {
|
||||
if tx, ok := ctx.Value(txCtxKey{}).(Tx); ok {
|
||||
return tx
|
||||
}
|
||||
|
||||
return c.db
|
||||
}
|
||||
|
||||
// TransactionCtx just runs an SQL command on the database returning no rows.
|
||||
func (c DB) TransactionCtx(ctx context.Context, fn func(ctx context.Context) error) error {
|
||||
if _, ok := ctx.Value(txCtxKey{}).(Tx); ok {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
txBeginner, ok := c.db.(TxBeginner)
|
||||
if !ok {
|
||||
return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBeginner interface")
|
||||
}
|
||||
|
||||
tx, err := txBeginner.BeginTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rollbackErr := tx.Rollback(ctx)
|
||||
if rollbackErr != nil {
|
||||
r = errors.Wrap(rollbackErr,
|
||||
fmt.Sprintf("unable to rollback after panic with value: %v", r),
|
||||
)
|
||||
}
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx = context.WithValue(ctx, txCtxKey{}, tx)
|
||||
err = fn(ctx)
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback(ctx)
|
||||
if rollbackErr != nil {
|
||||
err = errors.Wrap(rollbackErr,
|
||||
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface
|
||||
func (c DB) Close() error {
|
||||
closer, ok := c.db.(io.Closer)
|
||||
|
|
Loading…
Reference in New Issue