From 0425eb11239b80acb71176856d443a76762f1c00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Mon, 25 Jul 2022 22:49:51 -0300 Subject: [PATCH] Add test for Patch with composite keys --- ksql.go | 2 +- test_adapters.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/ksql.go b/ksql.go index e53e54e..f1a9f23 100644 --- a/ksql.go +++ b/ksql.go @@ -841,7 +841,7 @@ func buildUpdateQuery( "UPDATE %s SET %s WHERE %s", dialect.Escape(tableName), strings.Join(setQuery, ", "), - strings.Join(whereQuery, ", "), + strings.Join(whereQuery, " AND "), ) return query, args, nil diff --git a/test_adapters.go b/test_adapters.go index 820713f..04c7fb3 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -44,9 +44,10 @@ type post struct { var userPermissionsTable = NewTable("user_permissions", "user_id", "perm_id") type userPermission struct { - ID int `ksql:"id"` - UserID int `ksql:"user_id"` - PermID int `ksql:"perm_id"` + ID int `ksql:"id"` + UserID int `ksql:"user_id"` + PermID int `ksql:"perm_id"` + Type string `ksql:"type"` } // RunTestsForAdapter will run all necessary tests for making sure @@ -1462,6 +1463,43 @@ func PatchTest( tt.AssertEqual(t, result.Name, "Thayane") }) + t.Run("should update tables with composite keys correctly", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err = createUserPermission(db, c.dialect, userPermission{ + UserID: 42, + PermID: 43, + Type: "existingFakeType", + }) + tt.AssertNoErr(t, err) + + existingPerm, err := getUserPermissionBySecondaryKeys(db, c.dialect, 42, 43) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, existingPerm.ID, 0) + tt.AssertEqual(t, existingPerm.Type, "existingFakeType") + + err = c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &userPermission{ + ID: existingPerm.ID, + UserID: 42, + PermID: 43, + Type: "newFakeType", + }) + tt.AssertNoErr(t, err) + + newPerm, err := getUserPermissionBySecondaryKeys(db, c.dialect, 42, 43) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, newPerm, userPermission{ + ID: existingPerm.ID, + UserID: 42, + PermID: 43, + Type: "newFakeType", + }) + }) + t.Run("should ignore null pointers on partial updates", func(t *testing.T) { db, closer := newDBAdapter(t) defer closer.Close() @@ -2387,6 +2425,7 @@ func createTables(driver string, connStr string) error { id INTEGER PRIMARY KEY, user_id INTEGER, perm_id INTEGER, + type TEXT, UNIQUE (user_id, perm_id) )`) case "postgres": @@ -2394,6 +2433,7 @@ func createTables(driver string, connStr string) error { id serial PRIMARY KEY, user_id INT, perm_id INT, + type VARCHAR(50), UNIQUE (user_id, perm_id) )`) case "mysql": @@ -2401,6 +2441,7 @@ func createTables(driver string, connStr string) error { id INT AUTO_INCREMENT PRIMARY KEY, user_id INT, perm_id INT, + type VARCHAR(50), UNIQUE KEY (user_id, perm_id) )`) case "sqlserver": @@ -2408,6 +2449,7 @@ func createTables(driver string, connStr string) error { id INT IDENTITY(1,1) PRIMARY KEY, user_id INT, perm_id INT, + type VARCHAR(50), CONSTRAINT unique_1 UNIQUE (user_id, perm_id) )`) } @@ -2488,6 +2530,42 @@ func getUserByName(db DBAdapter, driver string, result *user, name string) error return json.Unmarshal(rawAddr, &result.Address) } +func createUserPermission(db DBAdapter, dialect Dialect, userPerm userPermission) error { + _, err := db.ExecContext(context.TODO(), + `INSERT INTO user_permissions (user_id, perm_id, type) + VALUES (`+dialect.Placeholder(0)+`, `+dialect.Placeholder(1)+`, `+dialect.Placeholder(2)+`)`, + userPerm.UserID, userPerm.PermID, userPerm.Type, + ) + + return err +} + +func getUserPermissionBySecondaryKeys(db DBAdapter, dialect Dialect, userID int, permID int) (userPermission, error) { + rows, err := db.QueryContext(context.TODO(), + `SELECT id, user_id, perm_id, type FROM user_permissions WHERE user_id=`+dialect.Placeholder(0)+` AND perm_id=`+dialect.Placeholder(1), + userID, permID, + ) + if err != nil { + return userPermission{}, err + } + defer rows.Close() + + if rows.Next() == false { + if rows.Err() != nil { + return userPermission{}, rows.Err() + } + return userPermission{}, sql.ErrNoRows + } + + var result userPermission + err = rows.Scan(&result.ID, &result.UserID, &result.PermID, &result.Type) + if err != nil { + return userPermission{}, err + } + + return result, nil +} + func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results []userPermission, _ error) { dialect := supportedDialects[driver]