From 23efe488698ec36ebe171d5f7080cda3146c92cb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= <vingarcia00@gmail.com>
Date: Fri, 31 Dec 2021 01:04:34 -0300
Subject: [PATCH] Add final tests to the Mock struct

---
 mocks_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 56 insertions(+), 1 deletion(-)

diff --git a/mocks_test.go b/mocks_test.go
index f2e1556..0488d3c 100644
--- a/mocks_test.go
+++ b/mocks_test.go
@@ -146,7 +146,7 @@ func TestMock(t *testing.T) {
 
 	t.Run("should call the user provided behavior correctly", func(t *testing.T) {
 		t.Run("Insert", func(t *testing.T) {
-			ctx := context.WithValue(context.Background(), "key", "value")
+			ctx := context.Background()
 			var capturedArgs struct {
 				ctx    context.Context
 				table  ksql.Table
@@ -372,4 +372,59 @@ func TestMock(t *testing.T) {
 			tt.AssertEqual(t, executed, false)
 		})
 	})
+
+	t.Run("SetFallbackDatabase", func(t *testing.T) {
+		testMock := ksql.Mock{}
+		dbMock := ksql.Mock{
+			InsertFn: func(ctx context.Context, table ksql.Table, record interface{}) error {
+				return fmt.Errorf("called from InsertFn")
+			},
+			UpdateFn: func(ctx context.Context, table ksql.Table, record interface{}) error {
+				return fmt.Errorf("called from UpdateFn")
+			},
+			DeleteFn: func(ctx context.Context, table ksql.Table, record interface{}) error {
+				return fmt.Errorf("called from DeleteFn")
+			},
+			QueryFn: func(ctx context.Context, records interface{}, query string, params ...interface{}) error {
+				return fmt.Errorf("called from QueryFn")
+			},
+			QueryOneFn: func(ctx context.Context, record interface{}, query string, params ...interface{}) error {
+				return fmt.Errorf("called from QueryOneFn")
+			},
+			QueryChunksFn: func(ctx context.Context, parser ksql.ChunkParser) error {
+				return fmt.Errorf("called from QueryChunksFn")
+			},
+			ExecFn: func(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) {
+				return 0, fmt.Errorf("called from ExecFn")
+			},
+			TransactionFn: func(ctx context.Context, fn func(db ksql.Provider) error) error {
+				return fmt.Errorf("called from TransactionFn")
+			},
+		}
+
+		ctx := context.Background()
+		testMock = testMock.SetFallbackDatabase(dbMock)
+
+		var user User
+		err := testMock.Insert(ctx, UsersTable, &user)
+		tt.AssertErrContains(t, err, "called from InsertFn")
+		err = testMock.Update(ctx, UsersTable, &user)
+		tt.AssertErrContains(t, err, "called from UpdateFn")
+		err = testMock.Delete(ctx, UsersTable, &user)
+		tt.AssertErrContains(t, err, "called from DeleteFn")
+
+		var users []User
+		err = testMock.Query(ctx, &users, "fake-query")
+		tt.AssertErrContains(t, err, "called from QueryFn")
+		err = testMock.QueryOne(ctx, &user, "fake-query")
+		tt.AssertErrContains(t, err, "called from QueryOneFn")
+		err = testMock.QueryChunks(ctx, ksql.ChunkParser{})
+		tt.AssertErrContains(t, err, "called from QueryChunksFn")
+		_, err = testMock.Exec(ctx, "fake-query")
+		tt.AssertErrContains(t, err, "called from ExecFn")
+		err = testMock.Transaction(ctx, func(db ksql.Provider) error {
+			return nil
+		})
+		tt.AssertErrContains(t, err, "called from TransactionFn")
+	})
 }