diff --git a/ksql.go b/ksql.go index 9a1a7ba..2c52a49 100644 --- a/ksql.go +++ b/ksql.go @@ -558,11 +558,8 @@ func (c DB) Delete( var query string var params []interface{} - if len(table.idColumns) == 1 { - query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps) - } else { - query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps) - } + query, params = buildDeleteQuery(c.dialect, table, idMaps[0]) + fmt.Println("query:", query, "params:", params) result, err := c.db.ExecContext(ctx, query, params...) if err != nil { @@ -997,52 +994,23 @@ func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info return scanArgs } -func buildSingleKeyDeleteQuery( +func buildDeleteQuery( dialect Dialect, - table string, - idName string, - idMaps []map[string]interface{}, + table Table, + idMap map[string]interface{}, ) (query string, params []interface{}) { - values := []string{} - for i, m := range idMaps { - values = append(values, dialect.Placeholder(i)) - params = append(params, m[idName]) + whereQuery := []string{} + for i, idName := range table.idColumns { + whereQuery = append(whereQuery, fmt.Sprintf( + "%s = %s", dialect.Escape(idName), dialect.Placeholder(i), + )) + params = append(params, idMap[idName]) } return fmt.Sprintf( - "DELETE FROM %s WHERE %s IN (%s)", - dialect.Escape(table), - dialect.Escape(idName), - strings.Join(values, ","), - ), params -} - -func buildCompositeKeyDeleteQuery( - dialect Dialect, - table string, - idNames []string, - idMaps []map[string]interface{}, -) (query string, params []interface{}) { - escapedNames := []string{} - for _, name := range idNames { - escapedNames = append(escapedNames, dialect.Escape(name)) - } - - values := []string{} - for _, m := range idMaps { - tuple := []string{} - for _, name := range idNames { - params = append(params, m[name]) - tuple = append(tuple, dialect.Placeholder(len(values))) - } - values = append(values, "("+strings.Join(tuple, ",")+")") - } - - return fmt.Sprintf( - "DELETE FROM %s WHERE (%s) IN (VALUES %s)", - dialect.Escape(table), - strings.Join(escapedNames, ","), - strings.Join(values, ","), + "DELETE FROM %s WHERE %s", + dialect.Escape(table.name), + strings.Join(whereQuery, " AND "), ), params } diff --git a/ksql_test.go b/ksql_test.go index bc9ec82..03ee07a 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -47,9 +47,9 @@ type Post struct { Title string `ksql:"title"` } -var UserPermissionsTable = NewTable("user_permissions", "id", "user_id", "post_id") +var UserPermissionsTable = NewTable("user_permissions", "user_id", "post_id") -type UserPermissions struct { +type UserPermission struct { ID int `ksql:"id"` UserID int `ksql:"user_id"` PostID int `ksql:"post_id"` @@ -787,7 +787,8 @@ func TestInsert(t *testing.T) { ctx := context.Background() c := newTestDB(db, config.driver) - err = c.Insert(ctx, UserPermissionsTable, &UserPermissions{ + table := NewTable("user_permissions", "id", "user_id", "post_id") + err = c.Insert(ctx, table, &UserPermission{ UserID: 1, PostID: 42, }) @@ -807,11 +808,14 @@ func TestInsert(t *testing.T) { ctx := context.Background() c := newTestDB(db, config.driver) - permission := UserPermissions{ + // Table defined with 3 values, but we'll provide only 2, + // the third will be generated for the purposes of this test: + table := NewTable("user_permissions", "id", "user_id", "post_id") + permission := UserPermission{ UserID: 2, PostID: 42, } - err = c.Insert(ctx, UserPermissionsTable, &permission) + err = c.Insert(ctx, table, &permission) tt.AssertNoErr(t, err) userPerms, err := getUserPermissionsByUser(db, config.driver, 2) @@ -1005,7 +1009,7 @@ func TestDelete(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } - t.Run("should delete one id correctly", func(t *testing.T) { + t.Run("should delete from tables with a single primary key correctly", func(t *testing.T) { tests := []struct { desc string deletionKeyForUser func(u User) interface{} @@ -1081,6 +1085,41 @@ func TestDelete(t *testing.T) { } }) + t.Run("should delete from tables with composite primary keys correctly", func(t *testing.T) { + t.Run("using structs", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + // This permission should not be deleted, we'll use the id to check it: + p0 := UserPermission{ + UserID: 1, + PostID: 44, + } + err = c.Insert(ctx, NewTable("user_permissions", "id"), &p0) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, p0.ID, 0) + + p1 := UserPermission{ + UserID: 1, + PostID: 42, + } + err = c.Insert(ctx, NewTable("user_permissions", "id"), &p1) + tt.AssertNoErr(t, err) + + err = c.Delete(ctx, UserPermissionsTable, p1) + tt.AssertNoErr(t, err) + + userPerms, err := getUserPermissionsByUser(db, config.driver, 1) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, len(userPerms), 1) + tt.AssertEqual(t, userPerms[0].UserID, 1) + tt.AssertEqual(t, userPerms[0].PostID, 44) + }) + }) + t.Run("should return ErrRecordNotFound if no rows were deleted", func(t *testing.T) { db, closer := connectDB(t, config) defer closer.Close() @@ -2321,7 +2360,7 @@ func getUserByName(db DBAdapter, driver string, result *User, name string) error return json.Unmarshal(rawAddr, &result.Address) } -func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermissions, _ error) { +func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []UserPermission, _ error) { dialect := supportedDialects[driver] rows, err := db.QueryContext(context.TODO(), @@ -2334,7 +2373,7 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results defer rows.Close() for rows.Next() { - var userPerm UserPermissions + var userPerm UserPermission err := rows.Scan(&userPerm.ID, &userPerm.UserID, &userPerm.PostID) if err != nil { return nil, err