From 752e6bb0a1c23574197c340510ed26c8c3a4a8f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Mon, 25 Jul 2022 23:47:06 -0300 Subject: [PATCH] Add some tests for the Patch function with composite keys --- ksql.go | 33 +++++++++++++++++++++------------ test_adapters.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/ksql.go b/ksql.go index f1a9f23..4ed4121 100644 --- a/ksql.go +++ b/ksql.go @@ -618,18 +618,7 @@ func normalizeIDsAsMap(idNames []string, idOrMap interface{}) (idMap map[string] } } - for _, idName := range idNames { - id, found := idMap[idName] - if !found { - return nil, fmt.Errorf("missing required id field `%s` on input record", idName) - } - - if id == nil || reflect.ValueOf(id).IsZero() { - return nil, fmt.Errorf("invalid value '%v' received for id column: '%s'", id, idName) - } - } - - return idMap, nil + return idMap, validateIfAllIdsArePresent(idNames, idMap) } // Update updates the given instances on the database by id. @@ -803,6 +792,11 @@ func buildUpdateQuery( numNonIDArgs := numAttrs - len(idFieldNames) whereArgs := args[numNonIDArgs:] + err = validateIfAllIdsArePresent(idFieldNames, recordMap) + if err != nil { + return "", nil, err + } + whereQuery := make([]string, len(idFieldNames)) for i, fieldName := range idFieldNames { whereArgs[i] = recordMap[fieldName] @@ -847,6 +841,21 @@ func buildUpdateQuery( return query, args, nil } +func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{}) error { + for _, idName := range idNames { + id, found := idMap[idName] + if !found { + return fmt.Errorf("missing required id field `%s` on input record", idName) + } + + if id == nil || reflect.ValueOf(id).IsZero() { + return fmt.Errorf("invalid value '%v' received for id column: '%s'", id, idName) + } + } + + return nil +} + // Exec just runs an SQL command on the database returning no rows. func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) { return c.db.ExecContext(ctx, query, params...) diff --git a/test_adapters.go b/test_adapters.go index 57bd64d..9db6dd5 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -1619,6 +1619,37 @@ func PatchTest( err := c.Update(ctx, usersTable, u) tt.AssertNotEqual(t, err, nil) }) + + t.Run("should report error if the id is missing", func(t *testing.T) { + t.Run("with a single primary key", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err := c.Update(ctx, usersTable, &user{ + // Missing ID + Name: "Jane", + }) + tt.AssertErrContains(t, err, "invalid value", "0", "'id'") + }) + + t.Run("with composite keys", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err := c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &userPermission{ + ID: 1, + // Missing UserID + PermID: 42, + }) + tt.AssertErrContains(t, err, "invalid value", "0", "'user_id'") + }) + }) }) }