Add tests and simplify the Delete function with composite keys

pull/16/head
Vinícius Garcia 2022-02-11 23:59:29 -03:00
parent d1e97489ef
commit 07c6065a5b
2 changed files with 61 additions and 54 deletions

60
ksql.go
View File

@ -558,11 +558,8 @@ func (c DB) Delete(
var query string
var params []interface{}
if len(table.idColumns) == 1 {
query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps)
} else {
query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps)
}
query, params = buildDeleteQuery(c.dialect, table, idMaps[0])
fmt.Println("query:", query, "params:", params)
result, err := c.db.ExecContext(ctx, query, params...)
if err != nil {
@ -997,52 +994,23 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
return scanArgs
}
func buildSingleKeyDeleteQuery(
func buildDeleteQuery(
dialect Dialect,
table string,
idName string,
idMaps []map[string]interface{},
table Table,
idMap map[string]interface{},
) (query string, params []interface{}) {
values := []string{}
for i, m := range idMaps {
values = append(values, dialect.Placeholder(i))
params = append(params, m[idName])
whereQuery := []string{}
for i, idName := range table.idColumns {
whereQuery = append(whereQuery, fmt.Sprintf(
"%s = %s", dialect.Escape(idName), dialect.Placeholder(i),
))
params = append(params, idMap[idName])
}
return fmt.Sprintf(
"DELETE FROM %s WHERE %s IN (%s)",
dialect.Escape(table),
dialect.Escape(idName),
strings.Join(values, ","),
), params
}
func buildCompositeKeyDeleteQuery(
dialect Dialect,
table string,
idNames []string,
idMaps []map[string]interface{},
) (query string, params []interface{}) {
escapedNames := []string{}
for _, name := range idNames {
escapedNames = append(escapedNames, dialect.Escape(name))
}
values := []string{}
for _, m := range idMaps {
tuple := []string{}
for _, name := range idNames {
params = append(params, m[name])
tuple = append(tuple, dialect.Placeholder(len(values)))
}
values = append(values, "("+strings.Join(tuple, ",")+")")
}
return fmt.Sprintf(
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
dialect.Escape(table),
strings.Join(escapedNames, ","),
strings.Join(values, ","),
"DELETE FROM %s WHERE %s",
dialect.Escape(table.name),
strings.Join(whereQuery, " AND "),
), params
}

View File

@ -47,9 +47,9 @@ type Post struct {
Title string `ksql:"title"`
}
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id")
var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
type UserPermissions struct {
type UserPermission struct {
ID int `ksql:"id"`
UserID int `ksql:"user_id"`
PostID int `ksql:"post_id"`
@ -787,7 +787,8 @@ func TestInsert(t *testing.T) {
ctx := context.Background()
c := newTestDB(db, config.driver)
err = c.Insert(ctx, UserPermissionsTable, &UserPermissions{
table := NewTable("user_permissions", "id", "user_id", "post_id")
err = c.Insert(ctx, table, &UserPermission{
UserID: 1,
PostID: 42,
})
@ -807,11 +808,14 @@ func TestInsert(t *testing.T) {
ctx := context.Background()
c := newTestDB(db, config.driver)
permission := UserPermissions{
// Table defined with 3 values, but we'll provide only 2,
// the third will be generated for the purposes of this test:
table := NewTable("user_permissions", "id", "user_id", "post_id")
permission := UserPermission{
UserID: 2,
PostID: 42,
}
err = c.Insert(ctx, UserPermissionsTable, &permission)
err = c.Insert(ctx, table, &permission)
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
@ -1005,7 +1009,7 @@ func TestDelete(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should delete one id correctly", func(t *testing.T) {
t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) {
tests := []struct {
desc string
deletionKeyForUser func(u User) interface{}
@ -1081,6 +1085,41 @@ func TestDelete(t *testing.T) {
}
})
t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) {
t.Run("using structs", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
// This permission should not be deleted, we'll use the id to check it:
p0 := UserPermission{
UserID: 1,
PostID: 44,
}
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, p0.ID, 0)
p1 := UserPermission{
UserID: 1,
PostID: 42,
}
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1)
tt.AssertNoErr(t, err)
err = c.Delete(ctx, UserPermissionsTable, p1)
tt.AssertNoErr(t, err)
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, len(userPerms), 1)
tt.AssertEqual(t, userPerms[0].UserID, 1)
tt.AssertEqual(t, userPerms[0].PostID, 44)
})
})
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
@ -2321,7 +2360,7 @@ func getUserByName(db DBAdapter, driver string, result *User, name string) error
return json.Unmarshal(rawAddr, &result.Address)
}
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermissions, _ error) {
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermission, _ error) {
dialect := supportedDialects[driver]
rows, err := db.QueryContext(context.TODO(),
@ -2334,7 +2373,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
defer rows.Close()
for rows.Next() {
var userPerm UserPermissions
var userPerm UserPermission
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
if err != nil {
return nil, err