Add tests and simplify the Delete function with composite keys

pull/16/head
Vinícius Garcia 2022-02-11 23:59:29 -03:00
parent d1e97489ef
commit 07c6065a5b
2 changed files with 61 additions and 54 deletions

60
ksql.go
View File

@ -558,11 +558,8 @@ func (c DB) Delete(
var query string var query string
var params []interface{} var params []interface{}
if len(table.idColumns) == 1 { query, params = buildDeleteQuery(c.dialect, table, idMaps[0])
query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps) fmt.Println("query:", query, "params:", params)
} else {
query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps)
}
result, err := c.db.ExecContext(ctx, query, params...) result, err := c.db.ExecContext(ctx, query, params...)
if err != nil { if err != nil {
@ -997,52 +994,23 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
return scanArgs return scanArgs
} }
func buildSingleKeyDeleteQuery( func buildDeleteQuery(
dialect Dialect, dialect Dialect,
table string, table Table,
idName string, idMap map[string]interface{},
idMaps []map[string]interface{},
) (query string, params []interface{}) { ) (query string, params []interface{}) {
values := []string{} whereQuery := []string{}
for i, m := range idMaps { for i, idName := range table.idColumns {
values = append(values, dialect.Placeholder(i)) whereQuery = append(whereQuery, fmt.Sprintf(
params = append(params, m[idName]) "%s = %s", dialect.Escape(idName), dialect.Placeholder(i),
))
params = append(params, idMap[idName])
} }
return fmt.Sprintf( return fmt.Sprintf(
"DELETE FROM %s WHERE %s IN (%s)", "DELETE FROM %s WHERE %s",
dialect.Escape(table), dialect.Escape(table.name),
dialect.Escape(idName), strings.Join(whereQuery, " AND "),
strings.Join(values, ","),
), params
}
func buildCompositeKeyDeleteQuery(
dialect Dialect,
table string,
idNames []string,
idMaps []map[string]interface{},
) (query string, params []interface{}) {
escapedNames := []string{}
for _, name := range idNames {
escapedNames = append(escapedNames, dialect.Escape(name))
}
values := []string{}
for _, m := range idMaps {
tuple := []string{}
for _, name := range idNames {
params = append(params, m[name])
tuple = append(tuple, dialect.Placeholder(len(values)))
}
values = append(values, "("+strings.Join(tuple, ",")+")")
}
return fmt.Sprintf(
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
dialect.Escape(table),
strings.Join(escapedNames, ","),
strings.Join(values, ","),
), params ), params
} }

View File

@ -47,9 +47,9 @@ type Post struct {
Title string `ksql:"title"` Title string `ksql:"title"`
} }
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id") var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
type UserPermissions struct { type UserPermission struct {
ID int `ksql:"id"` ID int `ksql:"id"`
UserID int `ksql:"user_id"` UserID int `ksql:"user_id"`
PostID int `ksql:"post_id"` PostID int `ksql:"post_id"`
@ -787,7 +787,8 @@ func TestInsert(t *testing.T) {
ctx := context.Background() ctx := context.Background()
c := newTestDB(db, config.driver) c := newTestDB(db, config.driver)
err = c.Insert(ctx, UserPermissionsTable, &UserPermissions{ table := NewTable("user_permissions", "id", "user_id", "post_id")
err = c.Insert(ctx, table, &UserPermission{
UserID: 1, UserID: 1,
PostID: 42, PostID: 42,
}) })
@ -807,11 +808,14 @@ func TestInsert(t *testing.T) {
ctx := context.Background() ctx := context.Background()
c := newTestDB(db, config.driver) c := newTestDB(db, config.driver)
permission := UserPermissions{ // Table defined with 3 values, but we'll provide only 2,
// the third will be generated for the purposes of this test:
table := NewTable("user_permissions", "id", "user_id", "post_id")
permission := UserPermission{
UserID: 2, UserID: 2,
PostID: 42, PostID: 42,
} }
err = c.Insert(ctx, UserPermissionsTable, &permission) err = c.Insert(ctx, table, &permission)
tt.AssertNoErr(t, err) tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, config.driver, 2) userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
@ -1005,7 +1009,7 @@ func TestDelete(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error()) t.Fatal("could not create test table!, reason:", err.Error())
} }
t.Run("should delete one id correctly", func(t *testing.T) { t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) {
tests := []struct { tests := []struct {
desc string desc string
deletionKeyForUser func(u User) interface{} deletionKeyForUser func(u User) interface{}
@ -1081,6 +1085,41 @@ func TestDelete(t *testing.T) {
} }
}) })
t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) {
t.Run("using structs", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
// This permission should not be deleted, we'll use the id to check it:
p0 := UserPermission{
UserID: 1,
PostID: 44,
}
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, p0.ID, 0)
p1 := UserPermission{
UserID: 1,
PostID: 42,
}
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1)
tt.AssertNoErr(t, err)
err = c.Delete(ctx, UserPermissionsTable, p1)
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 1)
tt.AssertEqual(t, userPerms[0].PostID, 44)
})
})
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) { t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
db, closer := connectDB(t, config) db, closer := connectDB(t, config)
defer closer.Close() defer closer.Close()
@ -2321,7 +2360,7 @@ func getUserByName(db DBAdapter, driver string, result *User, name string) error
return json.Unmarshal(rawAddr, &result.Address) return json.Unmarshal(rawAddr, &result.Address)
} }
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermissions, _ error) { func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermission, _ error) {
dialect := supportedDialects[driver] dialect := supportedDialects[driver]
rows, err := db.QueryContext(context.TODO(), rows, err := db.QueryContext(context.TODO(),
@ -2334,7 +2373,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var userPerm UserPermissions var userPerm UserPermission
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID) err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
if err != nil { if err != nil {
return nil, err return nil, err