diff --git a/contracts.go b/contracts.go index eeddd62..5ed26d5 100644 --- a/contracts.go +++ b/contracts.go @@ -30,7 +30,7 @@ type Provider interface { QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunks(ctx context.Context, parser ChunkParser) error - Exec(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) + Exec(ctx context.Context, query string, params ...interface{}) (Result, error) Transaction(ctx context.Context, fn func(Provider) error) error } diff --git a/examples/example_service/mocks.go b/examples/example_service/mocks.go index 308570b..b340c28 100644 --- a/examples/example_service/mocks.go +++ b/examples/example_service/mocks.go @@ -50,14 +50,14 @@ func (mr *MockProviderMockRecorder) Delete(ctx, table, idOrRecord interface{}) * } // Exec mocks base method. -func (m *MockProvider) Exec(ctx context.Context, query string, params ...interface{}) (int64, error) { +func (m *MockProvider) Exec(ctx context.Context, query string, params ...interface{}) (ksql.Result, error) { m.ctrl.T.Helper() varargs := []interface{}{ctx, query} for _, a := range params { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Exec", varargs...) - ret0, _ := ret[0].(int64) + ret0, _ := ret[0].(ksql.Result) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/ksql.go b/ksql.go index b40c098..66260b5 100644 --- a/ksql.go +++ b/ksql.go @@ -839,10 +839,8 @@ func buildUpdateQuery( } // Exec just runs an SQL command on the database returning no rows. -func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) { - result, err := c.db.ExecContext(ctx, query, params...) - rowsAffected, _ = result.RowsAffected() - return rowsAffected, err +func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) { + return c.db.ExecContext(ctx, query, params...) } // Transaction just runs an SQL command on the database returning no rows. diff --git a/mocks.go b/mocks.go index 998891f..7b17f37 100644 --- a/mocks.go +++ b/mocks.go @@ -59,10 +59,22 @@ type Mock struct { QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunksFn func(ctx context.Context, parser ChunkParser) error - ExecFn func(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) + ExecFn func(ctx context.Context, query string, params ...interface{}) (Result, error) TransactionFn func(ctx context.Context, fn func(db Provider) error) error } +// MockResult implements the Result interface returned by the Exec function +// +// Use the constructor `NewMockResult(42, 42)` for a simpler instantiation of this mock. +// +// But if you want one of the functions to return an error you'll need +// to specify the desired behavior by overwriting one of the attributes +// of the struct. +type MockResult struct { + LastInsertIdFn func() (int64, error) + RowsAffectedFn func() (int64, error) +} + // SetFallbackDatabase will set all the Fn attributes to use // the function from the input database. // @@ -197,7 +209,7 @@ func (m Mock) QueryChunks(ctx context.Context, parser ChunkParser) error { // Exec mocks the behavior of the Exec method. // If ExecFn is set it will just call it returning the same return values. // If ExecFn is unset it will panic with an appropriate error message. -func (m Mock) Exec(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) { +func (m Mock) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) { if m.ExecFn == nil { panic(fmt.Errorf("ksql.Mock.Exec(ctx, %s, %v) called but the ksql.Mock.ExecFn() is not set", query, params)) } @@ -214,3 +226,27 @@ func (m Mock) Transaction(ctx context.Context, fn func(db Provider) error) error } return m.TransactionFn(ctx, fn) } + +// NewMockResult returns a simple implementation of the Result interface. +func NewMockResult(lastInsertID int64, rowsAffected int64) Result { + return MockResult{ + LastInsertIdFn: func() (int64, error) { return lastInsertID, nil }, + RowsAffectedFn: func() (int64, error) { return rowsAffected, nil }, + } +} + +// LastInsertId implements the Result interface +func (m MockResult) LastInsertId() (int64, error) { + if m.LastInsertIdFn == nil { + panic(fmt.Errorf("ksql.MockResult.LastInsertId() called but ksql.MockResult.LastInsertIdFn is not set")) + } + return m.LastInsertIdFn() +} + +// RowsAffected implements the Result interface +func (m MockResult) RowsAffected() (int64, error) { + if m.RowsAffectedFn == nil { + panic(fmt.Errorf("ksql.MockResult.RowsAffected() called but ksql.MockResult.RowsAffectedFn is not set")) + } + return m.RowsAffectedFn() +} diff --git a/mocks_test.go b/mocks_test.go index b50a5c5..197449f 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -383,17 +383,22 @@ func TestMock(t *testing.T) { params []interface{} } mock := ksql.Mock{ - ExecFn: func(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) { + ExecFn: func(ctx context.Context, query string, params ...interface{}) (ksql.Result, error) { capturedArgs.ctx = ctx capturedArgs.query = query capturedArgs.params = params - return 42, fmt.Errorf("fake-error") + return ksql.NewMockResult(42, 42), fmt.Errorf("fake-error") }, } - rowsAffected, err := mock.Exec(ctx, "INSERT INTO users_permissions(user_id, permission_id) VALUES (?, ?)", 4242, 4) + r, err := mock.Exec(ctx, "INSERT INTO users_permissions(user_id, permission_id) VALUES (?, ?)", 4242, 4) tt.AssertErrContains(t, err, "fake-error") + rowsAffected, err := r.RowsAffected() + tt.AssertNoErr(t, err) tt.AssertEqual(t, rowsAffected, int64(42)) + lastInsertID, err := r.LastInsertId() + tt.AssertNoErr(t, err) + tt.AssertEqual(t, lastInsertID, int64(42)) tt.AssertEqual(t, capturedArgs.ctx, ctx) tt.AssertEqual(t, capturedArgs.query, "INSERT INTO users_permissions(user_id, permission_id) VALUES (?, ?)") tt.AssertEqual(t, capturedArgs.params, []interface{}{4242, 4}) @@ -441,8 +446,8 @@ func TestMock(t *testing.T) { QueryChunksFn: func(ctx context.Context, parser ksql.ChunkParser) error { return fmt.Errorf("called from QueryChunksFn") }, - ExecFn: func(ctx context.Context, query string, params ...interface{}) (rowsAffected int64, _ error) { - return 0, fmt.Errorf("called from ExecFn") + ExecFn: func(ctx context.Context, query string, params ...interface{}) (ksql.Result, error) { + return nil, fmt.Errorf("called from ExecFn") }, TransactionFn: func(ctx context.Context, fn func(db ksql.Provider) error) error { return fmt.Errorf("called from TransactionFn")