From 9d717717d404195b186d6eb45d627c90273f38d4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= <vingarcia00@gmail.com>
Date: Thu, 14 Jan 2021 23:56:15 -0300
Subject: [PATCH] Add initial version of Transaction() function

---
 contracts.go |  1 +
 kiss_orm.go  | 34 ++++++++++++++++++++++++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/contracts.go b/contracts.go
index f854f7c..d90d673 100644
--- a/contracts.go
+++ b/contracts.go
@@ -22,6 +22,7 @@ type ORMProvider interface {
 	QueryChunks(ctx context.Context, parser ChunkParser) error
 
 	Exec(ctx context.Context, query string, params ...interface{}) error
+	Transaction(ctx context.Context, fn func(ORMProvider) error) (err error)
 }
 
 // ChunkParser stores the arguments of the QueryChunks function
diff --git a/kiss_orm.go b/kiss_orm.go
index afd9918..1cf5def 100644
--- a/kiss_orm.go
+++ b/kiss_orm.go
@@ -503,6 +503,40 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error
 	return err
 }
 
+// Transaction just runs an SQL command on the database returning no rows.
+func (c DB) Transaction(ctx context.Context, fn func(ORMProvider) error) (err error) {
+	switch db := c.db.(type) {
+	case *sql.Tx:
+		return fn(c)
+	case *sql.DB:
+		var tx *sql.Tx
+		tx, err = db.BeginTx(ctx, nil)
+		if err != nil {
+			return err
+		}
+		defer func() {
+			if r := recover(); r != nil {
+				_ = tx.Rollback()
+				panic(r)
+			}
+		}()
+
+		ormCopy := c
+		ormCopy.db = tx
+
+		err = fn(ormCopy)
+		if err != nil {
+			_ = tx.Rollback()
+			return err
+		}
+
+		return tx.Commit()
+
+	default:
+		return fmt.Errorf("unexpected error on kissorm: db has an invalid type")
+	}
+}
+
 // This cache is kept as a pkg variable
 // because the total number of types on a program
 // should be finite. So keeping a single cache here