diff --git a/internal_mocks.go b/internal_mocks.go index e77bd36..e022a92 100644 --- a/internal_mocks.go +++ b/internal_mocks.go @@ -12,6 +12,20 @@ func (b mockTxBeginner) BeginTx(ctx context.Context) (Tx, error) { return b.BeginTxFn(ctx) } +// mockDBAdapter mocks the ksql.DBAdapter interface +type mockDBAdapter struct { + ExecContextFn func(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContextFn func(ctx context.Context, query string, args ...interface{}) (Rows, error) +} + +func (m mockDBAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return m.ExecContextFn(ctx, query, args...) +} + +func (m mockDBAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return m.QueryContextFn(ctx, query, args...) +} + // mockTx mocks the ksql.Tx interface type mockTx struct { DBAdapter diff --git a/test_adapters.go b/test_adapters.go index e42ce92..01581a3 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -2370,6 +2370,26 @@ func TransactionTest( }) tt.AssertErrContains(t, err, "KSQL", "fakeErrMsg") }) + + t.Run("should report error if DBAdapter can't create transactions", func(t *testing.T) { + err := createTables(driver, connStr) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + c.db = mockDBAdapter{} + + err = c.Transaction(ctx, func(db Provider) error { + return nil + }) + tt.AssertErrContains(t, err, "KSQL", "can't start transaction", "DBAdapter", "TxBeginner") + }) }) }