mirror of https://github.com/VinGarcia/ksql.git
Add tests and simplify the Delete function with composite keys
parent
d1e97489ef
commit
07c6065a5b
60
ksql.go
60
ksql.go
|
@ -558,11 +558,8 @@ func (c DB) Delete(
|
|||
|
||||
var query string
|
||||
var params []interface{}
|
||||
if len(table.idColumns) == 1 {
|
||||
query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps)
|
||||
} else {
|
||||
query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps)
|
||||
}
|
||||
query, params = buildDeleteQuery(c.dialect, table, idMaps[0])
|
||||
fmt.Println("query:", query, "params:", params)
|
||||
|
||||
result, err := c.db.ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
|
@ -997,52 +994,23 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
|
|||
return scanArgs
|
||||
}
|
||||
|
||||
func buildSingleKeyDeleteQuery(
|
||||
func buildDeleteQuery(
|
||||
dialect Dialect,
|
||||
table string,
|
||||
idName string,
|
||||
idMaps []map[string]interface{},
|
||||
table Table,
|
||||
idMap map[string]interface{},
|
||||
) (query string, params []interface{}) {
|
||||
values := []string{}
|
||||
for i, m := range idMaps {
|
||||
values = append(values, dialect.Placeholder(i))
|
||||
params = append(params, m[idName])
|
||||
whereQuery := []string{}
|
||||
for i, idName := range table.idColumns {
|
||||
whereQuery = append(whereQuery, fmt.Sprintf(
|
||||
"%s = %s", dialect.Escape(idName), dialect.Placeholder(i),
|
||||
))
|
||||
params = append(params, idMap[idName])
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE %s IN (%s)",
|
||||
dialect.Escape(table),
|
||||
dialect.Escape(idName),
|
||||
strings.Join(values, ","),
|
||||
), params
|
||||
}
|
||||
|
||||
func buildCompositeKeyDeleteQuery(
|
||||
dialect Dialect,
|
||||
table string,
|
||||
idNames []string,
|
||||
idMaps []map[string]interface{},
|
||||
) (query string, params []interface{}) {
|
||||
escapedNames := []string{}
|
||||
for _, name := range idNames {
|
||||
escapedNames = append(escapedNames, dialect.Escape(name))
|
||||
}
|
||||
|
||||
values := []string{}
|
||||
for _, m := range idMaps {
|
||||
tuple := []string{}
|
||||
for _, name := range idNames {
|
||||
params = append(params, m[name])
|
||||
tuple = append(tuple, dialect.Placeholder(len(values)))
|
||||
}
|
||||
values = append(values, "("+strings.Join(tuple, ",")+")")
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
|
||||
dialect.Escape(table),
|
||||
strings.Join(escapedNames, ","),
|
||||
strings.Join(values, ","),
|
||||
"DELETE FROM %s WHERE %s",
|
||||
dialect.Escape(table.name),
|
||||
strings.Join(whereQuery, " AND "),
|
||||
), params
|
||||
}
|
||||
|
||||
|
|
55
ksql_test.go
55
ksql_test.go
|
@ -47,9 +47,9 @@ type Post struct {
|
|||
Title string `ksql:"title"`
|
||||
}
|
||||
|
||||
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id")
|
||||
var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
|
||||
|
||||
type UserPermissions struct {
|
||||
type UserPermission struct {
|
||||
ID int `ksql:"id"`
|
||||
UserID int `ksql:"user_id"`
|
||||
PostID int `ksql:"post_id"`
|
||||
|
@ -787,7 +787,8 @@ func TestInsert(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
|
||||
err = c.Insert(ctx, UserPermissionsTable, &UserPermissions{
|
||||
table := NewTable("user_permissions", "id", "user_id", "post_id")
|
||||
err = c.Insert(ctx, table, &UserPermission{
|
||||
UserID: 1,
|
||||
PostID: 42,
|
||||
})
|
||||
|
@ -807,11 +808,14 @@ func TestInsert(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
|
||||
permission := UserPermissions{
|
||||
// Table defined with 3 values, but we'll provide only 2,
|
||||
// the third will be generated for the purposes of this test:
|
||||
table := NewTable("user_permissions", "id", "user_id", "post_id")
|
||||
permission := UserPermission{
|
||||
UserID: 2,
|
||||
PostID: 42,
|
||||
}
|
||||
err = c.Insert(ctx, UserPermissionsTable, &permission)
|
||||
err = c.Insert(ctx, table, &permission)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
|
||||
|
@ -1005,7 +1009,7 @@ func TestDelete(t *testing.T) {
|
|||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
||||
t.Run("should delete one id correctly", func(t *testing.T) {
|
||||
t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
deletionKeyForUser func(u User) interface{}
|
||||
|
@ -1081,6 +1085,41 @@ func TestDelete(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) {
|
||||
t.Run("using structs", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
|
||||
// This permission should not be deleted, we'll use the id to check it:
|
||||
p0 := UserPermission{
|
||||
UserID: 1,
|
||||
PostID: 44,
|
||||
}
|
||||
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, p0.ID, 0)
|
||||
|
||||
p1 := UserPermission{
|
||||
UserID: 1,
|
||||
PostID: 42,
|
||||
}
|
||||
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
err = c.Delete(ctx, UserPermissionsTable, p1)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertEqual(t, len(userPerms), 1)
|
||||
tt.AssertEqual(t, userPerms[0].UserID, 1)
|
||||
tt.AssertEqual(t, userPerms[0].PostID, 44)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
@ -2321,7 +2360,7 @@ func getUserByName(db DBAdapter, driver string, result *User, name string) error
|
|||
return json.Unmarshal(rawAddr, &result.Address)
|
||||
}
|
||||
|
||||
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermissions, _ error) {
|
||||
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermission, _ error) {
|
||||
dialect := supportedDialects[driver]
|
||||
|
||||
rows, err := db.QueryContext(context.TODO(),
|
||||
|
@ -2334,7 +2373,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
|||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var userPerm UserPermissions
|
||||
var userPerm UserPermission
|
||||
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Reference in New Issue