Remove var args from Insert and Update, so they actually perform atomic operations in the database

pull/2/head
Vinícius Garcia 2021-03-12 11:01:36 -03:00
parent df7db29464
commit 35c1f42317
5 changed files with 51 additions and 108 deletions

View File

@ -16,9 +16,9 @@ var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be
// SQLProvider describes the public behavior of this ORM
type SQLProvider interface {
Insert(ctx context.Context, records ...interface{}) error
Delete(ctx context.Context, ids ...interface{}) error
Update(ctx context.Context, records ...interface{}) error
Insert(ctx context.Context, record interface{}) error
Update(ctx context.Context, record interface{}) error
Delete(ctx context.Context, idsOrRecords ...interface{}) error
Query(ctx context.Context, records interface{}, query string, params ...interface{}) error
QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error

View File

@ -36,10 +36,10 @@ func (m *MockSQLProvider) EXPECT() *MockSQLProviderMockRecorder {
}
// Delete mocks base method.
func (m *MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error {
func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{ctx}
for _, a := range ids {
for _, a := range idsOrRecords {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Delete", varargs...)
@ -48,9 +48,9 @@ func (m *MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error
}
// Delete indicates an expected call of Delete.
func (mr *MockSQLProviderMockRecorder) Delete(ctx interface{}, ids ...interface{}) *gomock.Call {
func (mr *MockSQLProviderMockRecorder) Delete(ctx interface{}, idsOrRecords ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx}, ids...)
varargs := append([]interface{}{ctx}, idsOrRecords...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSQLProvider)(nil).Delete), varargs...)
}
@ -74,22 +74,17 @@ func (mr *MockSQLProviderMockRecorder) Exec(ctx, query interface{}, params ...in
}
// Insert mocks base method.
func (m *MockSQLProvider) Insert(ctx context.Context, records ...interface{}) error {
func (m *MockSQLProvider) Insert(ctx context.Context, record interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{ctx}
for _, a := range records {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Insert", varargs...)
ret := m.ctrl.Call(m, "Insert", ctx, record)
ret0, _ := ret[0].(error)
return ret0
}
// Insert indicates an expected call of Insert.
func (mr *MockSQLProviderMockRecorder) Insert(ctx interface{}, records ...interface{}) *gomock.Call {
func (mr *MockSQLProviderMockRecorder) Insert(ctx, record interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx}, records...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), varargs...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, record)
}
// Query mocks base method.
@ -159,20 +154,15 @@ func (mr *MockSQLProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.
}
// Update mocks base method.
func (m *MockSQLProvider) Update(ctx context.Context, records ...interface{}) error {
func (m *MockSQLProvider) Update(ctx context.Context, record interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{ctx}
for _, a := range records {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Update", varargs...)
ret := m.ctrl.Call(m, "Update", ctx, record)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockSQLProviderMockRecorder) Update(ctx interface{}, records ...interface{}) *gomock.Call {
func (mr *MockSQLProviderMockRecorder) Update(ctx, record interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx}, records...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), varargs...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, record)
}

61
ksql.go
View File

@ -330,37 +330,31 @@ func (c DB) QueryChunks(
// the ID is automatically updated after insertion is completed.
func (c DB) Insert(
ctx context.Context,
records ...interface{},
record interface{},
) error {
if c.tableName == "" {
return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method")
}
for _, record := range records {
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
switch c.insertMethod {
case insertWithReturning:
err = c.insertWithReturningID(ctx, record, query, params, c.idCols)
case insertWithLastInsertID:
err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
case insertWithNoIDRetrieval:
err = c.insertWithNoIDRetrieval(ctx, record, query, params)
default:
// Unsupported drivers should be detected on the New() function,
// So we don't expect the code to ever get into this default case.
err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
}
if err != nil {
return err
}
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
return nil
switch c.insertMethod {
case insertWithReturning:
err = c.insertWithReturningID(ctx, record, query, params, c.idCols)
case insertWithLastInsertID:
err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
case insertWithNoIDRetrieval:
err = c.insertWithNoIDRetrieval(ctx, record, query, params)
default:
// Unsupported drivers should be detected on the New() function,
// So we don't expect the code to ever get into this default case.
err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
}
return err
}
func (c DB) insertWithReturningID(
@ -551,25 +545,20 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter
// Partial updates are supported, i.e. it will ignore nil pointer attributes
func (c DB) Update(
ctx context.Context,
records ...interface{},
record interface{},
) error {
if c.tableName == "" {
return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Update method")
}
for _, record := range records {
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
_, err = c.db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
return nil
_, err = c.db.ExecContext(ctx, query, params...)
return err
}
func buildInsertQuery(

View File

@ -326,17 +326,6 @@ func TestInsert(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should ignore empty lists of users", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
err = c.Insert(ctx)
assert.Equal(t, nil, err)
})
t.Run("should insert one user correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
@ -414,10 +403,10 @@ func TestInsert(t *testing.T) {
ctx := context.Background()
c := newTestDB(db, driver, "users")
err = c.Insert(ctx, "foo", "bar")
err = c.Insert(ctx, "foo")
assert.NotEqual(t, nil, err)
err = c.Insert(ctx, nullable.String("foo"), nullable.String("bar"))
err = c.Insert(ctx, nullable.String("foo"))
assert.NotEqual(t, nil, err)
err = c.Insert(ctx, map[string]interface{}{
@ -607,31 +596,6 @@ func TestUpdate(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should ignore empty lists of ids", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
u := User{
Name: "Thay",
}
err := c.Insert(ctx, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
// Empty update, should do nothing:
err = c.Update(ctx)
assert.Equal(t, nil, err)
result := User{}
err = getUserByID(c.db, c.dialect, &result, u.ID)
assert.Equal(t, nil, err)
assert.Equal(t, "Thay", result.Name)
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()

View File

@ -6,9 +6,9 @@ var _ SQLProvider = MockSQLProvider{}
// MockSQLProvider ...
type MockSQLProvider struct {
InsertFn func(ctx context.Context, records ...interface{}) error
InsertFn func(ctx context.Context, record interface{}) error
UpdateFn func(ctx context.Context, record interface{}) error
DeleteFn func(ctx context.Context, ids ...interface{}) error
UpdateFn func(ctx context.Context, records ...interface{}) error
QueryFn func(ctx context.Context, records interface{}, query string, params ...interface{}) error
QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error
@ -19,8 +19,13 @@ type MockSQLProvider struct {
}
// Insert ...
func (m MockSQLProvider) Insert(ctx context.Context, records ...interface{}) error {
return m.InsertFn(ctx, records...)
func (m MockSQLProvider) Insert(ctx context.Context, record interface{}) error {
return m.InsertFn(ctx, record)
}
// Update ...
func (m MockSQLProvider) Update(ctx context.Context, record interface{}) error {
return m.UpdateFn(ctx, record)
}
// Delete ...
@ -28,11 +33,6 @@ func (m MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error {
return m.DeleteFn(ctx, ids...)
}
// Update ...
func (m MockSQLProvider) Update(ctx context.Context, records ...interface{}) error {
return m.UpdateFn(ctx, records...)
}
// Query ...
func (m MockSQLProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error {
return m.QueryFn(ctx, records, query, params...)