diff --git a/mocks_test.go b/mocks_test.go index f2e1556..0488d3c 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -146,7 +146,7 @@ func TestMock(t *testing.T) { t.Run("should call the user provided behavior correctly", func(t *testing.T) { t.Run("Insert", func(t *testing.T) { - ctx := context.WithValue(context.Background(), "key", "value") + ctx := context.Background() var capturedArgs struct { ctx context.Context table ksql.Table @@ -372,4 +372,59 @@ func TestMock(t *testing.T) { tt.AssertEqual(t, executed, false) }) }) + + t.Run("SetFallbackDatabase", func(t *testing.T) { + testMock := ksql.Mock{} + dbMock := ksql.Mock{ + InsertFn: func(ctx context.Context, table ksql.Table, record interface{}) error { + return fmt.Errorf("called from InsertFn") + }, + UpdateFn: func(ctx context.Context, table ksql.Table, record interface{}) error { + return fmt.Errorf("called from UpdateFn") + }, + DeleteFn: func(ctx context.Context, table ksql.Table, record interface{}) error { + return fmt.Errorf("called from DeleteFn") + }, + QueryFn: func(ctx context.Context, records interface{}, query string, params ...interface{}) error { + return fmt.Errorf("called from QueryFn") + }, + QueryOneFn: func(ctx context.Context, record interface{}, query string, params ...interface{}) error { + return fmt.Errorf("called from QueryOneFn") + }, + QueryChunksFn: func(ctx context.Context, parser ksql.ChunkParser) error { + return fmt.Errorf("called from QueryChunksFn") + }, + ExecFn: func(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) { + return 0, fmt.Errorf("called from ExecFn") + }, + TransactionFn: func(ctx context.Context, fn func(db ksql.Provider) error) error { + return fmt.Errorf("called from TransactionFn") + }, + } + + ctx := context.Background() + testMock = testMock.SetFallbackDatabase(dbMock) + + var user User + err := testMock.Insert(ctx, UsersTable, &user) + tt.AssertErrContains(t, err, "called from InsertFn") + err = testMock.Update(ctx, UsersTable, &user) + tt.AssertErrContains(t, err, "called from UpdateFn") + err = testMock.Delete(ctx, UsersTable, &user) + tt.AssertErrContains(t, err, "called from DeleteFn") + + var users []User + err = testMock.Query(ctx, &users, "fake-query") + tt.AssertErrContains(t, err, "called from QueryFn") + err = testMock.QueryOne(ctx, &user, "fake-query") + tt.AssertErrContains(t, err, "called from QueryOneFn") + err = testMock.QueryChunks(ctx, ksql.ChunkParser{}) + tt.AssertErrContains(t, err, "called from QueryChunksFn") + _, err = testMock.Exec(ctx, "fake-query") + tt.AssertErrContains(t, err, "called from ExecFn") + err = testMock.Transaction(ctx, func(db ksql.Provider) error { + return nil + }) + tt.AssertErrContains(t, err, "called from TransactionFn") + }) }