diff --git a/contracts.go b/contracts.go index f854f7c..d90d673 100644 --- a/contracts.go +++ b/contracts.go @@ -22,6 +22,7 @@ type ORMProvider interface { QueryChunks(ctx context.Context, parser ChunkParser) error Exec(ctx context.Context, query string, params ...interface{}) error + Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) } // ChunkParser stores the arguments of the QueryChunks function diff --git a/kiss_orm.go b/kiss_orm.go index afd9918..1cf5def 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -503,6 +503,40 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error return err } +// Transaction just runs an SQL command on the database returning no rows. +func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) { + switch db := c.db.(type) { + case *sql.Tx: + return fn(c) + case *sql.DB: + var tx *sql.Tx + tx, err = db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if r := recover(); r != nil { + _ = tx.Rollback() + panic(r) + } + }() + + ormCopy := c + ormCopy.db = tx + + err = fn(ormCopy) + if err != nil { + _ = tx.Rollback() + return err + } + + return tx.Commit() + + default: + return fmt.Errorf("unexpected error on kissorm: db has an invalid type") + } +} + // This cache is kept as a pkg variable // because the total number of types on a program // should be finite. So keeping a single cache here