mirror of https://github.com/VinGarcia/ksql.git
Add tests and simplify the Delete function with composite keys
parent
d1e97489ef
commit
07c6065a5b
60
ksql.go
60
ksql.go
|
@ -558,11 +558,8 @@ func (c DB) Delete(
|
||||||
|
|
||||||
var query string
|
var query string
|
||||||
var params []interface{}
|
var params []interface{}
|
||||||
if len(table.idColumns) == 1 {
|
query, params = buildDeleteQuery(c.dialect, table, idMaps[0])
|
||||||
query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps)
|
fmt.Println("query:", query, "params:", params)
|
||||||
} else {
|
|
||||||
query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := c.db.ExecContext(ctx, query, params...)
|
result, err := c.db.ExecContext(ctx, query, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -997,52 +994,23 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info
|
||||||
return scanArgs
|
return scanArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildSingleKeyDeleteQuery(
|
func buildDeleteQuery(
|
||||||
dialect Dialect,
|
dialect Dialect,
|
||||||
table string,
|
table Table,
|
||||||
idName string,
|
idMap map[string]interface{},
|
||||||
idMaps []map[string]interface{},
|
|
||||||
) (query string, params []interface{}) {
|
) (query string, params []interface{}) {
|
||||||
values := []string{}
|
whereQuery := []string{}
|
||||||
for i, m := range idMaps {
|
for i, idName := range table.idColumns {
|
||||||
values = append(values, dialect.Placeholder(i))
|
whereQuery = append(whereQuery, fmt.Sprintf(
|
||||||
params = append(params, m[idName])
|
"%s = %s", dialect.Escape(idName), dialect.Placeholder(i),
|
||||||
|
))
|
||||||
|
params = append(params, idMap[idName])
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"DELETE FROM %s WHERE %s IN (%s)",
|
"DELETE FROM %s WHERE %s",
|
||||||
dialect.Escape(table),
|
dialect.Escape(table.name),
|
||||||
dialect.Escape(idName),
|
strings.Join(whereQuery, " AND "),
|
||||||
strings.Join(values, ","),
|
|
||||||
), params
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildCompositeKeyDeleteQuery(
|
|
||||||
dialect Dialect,
|
|
||||||
table string,
|
|
||||||
idNames []string,
|
|
||||||
idMaps []map[string]interface{},
|
|
||||||
) (query string, params []interface{}) {
|
|
||||||
escapedNames := []string{}
|
|
||||||
for _, name := range idNames {
|
|
||||||
escapedNames = append(escapedNames, dialect.Escape(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
values := []string{}
|
|
||||||
for _, m := range idMaps {
|
|
||||||
tuple := []string{}
|
|
||||||
for _, name := range idNames {
|
|
||||||
params = append(params, m[name])
|
|
||||||
tuple = append(tuple, dialect.Placeholder(len(values)))
|
|
||||||
}
|
|
||||||
values = append(values, "("+strings.Join(tuple, ",")+")")
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
|
|
||||||
dialect.Escape(table),
|
|
||||||
strings.Join(escapedNames, ","),
|
|
||||||
strings.Join(values, ","),
|
|
||||||
), params
|
), params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
55
ksql_test.go
55
ksql_test.go
|
@ -47,9 +47,9 @@ type Post struct {
|
||||||
Title string `ksql:"title"`
|
Title string `ksql:"title"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id")
|
var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id")
|
||||||
|
|
||||||
type UserPermissions struct {
|
type UserPermission struct {
|
||||||
ID int `ksql:"id"`
|
ID int `ksql:"id"`
|
||||||
UserID int `ksql:"user_id"`
|
UserID int `ksql:"user_id"`
|
||||||
PostID int `ksql:"post_id"`
|
PostID int `ksql:"post_id"`
|
||||||
|
@ -787,7 +787,8 @@ func TestInsert(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
c := newTestDB(db, config.driver)
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
err = c.Insert(ctx, UserPermissionsTable, &UserPermissions{
|
table := NewTable("user_permissions", "id", "user_id", "post_id")
|
||||||
|
err = c.Insert(ctx, table, &UserPermission{
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
PostID: 42,
|
PostID: 42,
|
||||||
})
|
})
|
||||||
|
@ -807,11 +808,14 @@ func TestInsert(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
c := newTestDB(db, config.driver)
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
permission := UserPermissions{
|
// Table defined with 3 values, but we'll provide only 2,
|
||||||
|
// the third will be generated for the purposes of this test:
|
||||||
|
table := NewTable("user_permissions", "id", "user_id", "post_id")
|
||||||
|
permission := UserPermission{
|
||||||
UserID: 2,
|
UserID: 2,
|
||||||
PostID: 42,
|
PostID: 42,
|
||||||
}
|
}
|
||||||
err = c.Insert(ctx, UserPermissionsTable, &permission)
|
err = c.Insert(ctx, table, &permission)
|
||||||
tt.AssertNoErr(t, err)
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
|
userPerms, err := getUserPermissionsByUser(db, config.driver, 2)
|
||||||
|
@ -1005,7 +1009,7 @@ func TestDelete(t *testing.T) {
|
||||||
t.Fatal("could not create test table!, reason:", err.Error())
|
t.Fatal("could not create test table!, reason:", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("should delete one id correctly", func(t *testing.T) {
|
t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
desc string
|
desc string
|
||||||
deletionKeyForUser func(u User) interface{}
|
deletionKeyForUser func(u User) interface{}
|
||||||
|
@ -1081,6 +1085,41 @@ func TestDelete(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) {
|
||||||
|
t.Run("using structs", func(t *testing.T) {
|
||||||
|
db, closer := connectDB(t, config)
|
||||||
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
|
// This permission should not be deleted, we'll use the id to check it:
|
||||||
|
p0 := UserPermission{
|
||||||
|
UserID: 1,
|
||||||
|
PostID: 44,
|
||||||
|
}
|
||||||
|
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
tt.AssertNotEqual(t, p0.ID, 0)
|
||||||
|
|
||||||
|
p1 := UserPermission{
|
||||||
|
UserID: 1,
|
||||||
|
PostID: 42,
|
||||||
|
}
|
||||||
|
err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
err = c.Delete(ctx, UserPermissionsTable, p1)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
userPerms, err := getUserPermissionsByUser(db, config.driver, 1)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
tt.AssertEqual(t, len(userPerms), 1)
|
||||||
|
tt.AssertEqual(t, userPerms[0].UserID, 1)
|
||||||
|
tt.AssertEqual(t, userPerms[0].PostID, 44)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
|
t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) {
|
||||||
db, closer := connectDB(t, config)
|
db, closer := connectDB(t, config)
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
|
@ -2321,7 +2360,7 @@ func getUserByName(db DBAdapter, driver string, result *User, name string) error
|
||||||
return json.Unmarshal(rawAddr, &result.Address)
|
return json.Unmarshal(rawAddr, &result.Address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermissions, _ error) {
|
func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermission, _ error) {
|
||||||
dialect := supportedDialects[driver]
|
dialect := supportedDialects[driver]
|
||||||
|
|
||||||
rows, err := db.QueryContext(context.TODO(),
|
rows, err := db.QueryContext(context.TODO(),
|
||||||
|
@ -2334,7 +2373,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var userPerm UserPermissions
|
var userPerm UserPermission
|
||||||
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
|
err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Reference in New Issue