From 8f45498f58c5a38d904760e5bfa1dc59af5a38d7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= <vingarcia00@gmail.com>
Date: Sun, 2 Apr 2023 11:29:33 -0300
Subject: [PATCH] Improve error message for structs missing required ID fields

---
 contracts.go     |  4 ++++
 ksql.go          | 16 ++++++++--------
 test_adapters.go | 36 +++++++++++++++++++++++++++++++++++-
 3 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/contracts.go b/contracts.go
index f1557f2..fa2a6e0 100644
--- a/contracts.go
+++ b/contracts.go
@@ -24,6 +24,10 @@ var ErrRecordNotFound error = fmt.Errorf("ksql: the query returned no results: %
 // (2) If the attribute is using a modifier that contains the SkipUpdates flag.
 var ErrNoValuesToUpdate error = fmt.Errorf("ksql: the input struct contains no values to update")
 
+// ErrRecordMissingIDs is returned by the Update or Delete functions if an input record does
+// not have all of the IDs described on the input table.
+var ErrRecordMissingIDs error = fmt.Errorf("ksql: missing required ID fields")
+
 // ErrAbortIteration should be used inside the QueryChunks function to inform QueryChunks it should stop querying,
 // close the connection and return with no errors.
 var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function")
diff --git a/ksql.go b/ksql.go
index 5972ed5..0230c1d 100644
--- a/ksql.go
+++ b/ksql.go
@@ -845,18 +845,18 @@ func buildUpdateQuery(
 
 	numAttrs := len(recordMap)
 	args = make([]interface{}, numAttrs)
-	numNonIDArgs := numAttrs - len(idFieldNames)
-	whereArgs := args[numNonIDArgs:]
-
-	if numNonIDArgs == 0 {
-		return "", nil, ErrNoValuesToUpdate
-	}
 
 	err = validateIfAllIdsArePresent(idFieldNames, recordMap)
 	if err != nil {
 		return "", nil, err
 	}
 
+	numNonIDArgs := numAttrs - len(idFieldNames)
+	whereArgs := args[numNonIDArgs:]
+	if numNonIDArgs == 0 {
+		return "", nil, ErrNoValuesToUpdate
+	}
+
 	whereQuery := make([]string, len(idFieldNames))
 	for i, fieldName := range idFieldNames {
 		whereArgs[i] = recordMap[fieldName]
@@ -912,11 +912,11 @@ func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{})
 	for _, idName := range idNames {
 		id, found := idMap[idName]
 		if !found {
-			return fmt.Errorf("missing required id field `%s` on input record", idName)
+			return fmt.Errorf("missing required id field `%s` on input record: %w", idName, ErrRecordMissingIDs)
 		}
 
 		if id == nil || reflect.ValueOf(id).IsZero() {
-			return fmt.Errorf("invalid value '%v' received for id column: '%s'", id, idName)
+			return fmt.Errorf("invalid value '%v' received for id column: '%s': %w", id, idName, ErrRecordMissingIDs)
 		}
 	}
 
diff --git a/test_adapters.go b/test_adapters.go
index bc6f5a5..00ff51b 100644
--- a/test_adapters.go
+++ b/test_adapters.go
@@ -1805,7 +1805,7 @@ func PatchTest(
 				ID:   1,
 				Name: "some name",
 			})
-			tt.AssertEqual(t, err, ErrNoValuesToUpdate)
+			tt.AssertErrContains(t, err, "struct", "no values to update")
 		})
 
 		t.Run("should report error if the id is missing", func(t *testing.T) {
@@ -1837,6 +1837,40 @@ func PatchTest(
 			})
 		})
 
+		t.Run("should report error if the struct has no id", func(t *testing.T) {
+			t.Run("with a single primary key", func(t *testing.T) {
+				db, closer := newDBAdapter(t)
+				defer closer.Close()
+
+				c := newTestDB(db, dialect)
+
+				err := c.Update(ctx, usersTable, &struct {
+					// Missing ID
+					Name string `ksql:"name"`
+				}{
+					Name: "Jane",
+				})
+				tt.AssertErrContains(t, err, "missing", "ID fields", "id")
+			})
+
+			t.Run("with composite keys", func(t *testing.T) {
+				db, closer := newDBAdapter(t)
+				defer closer.Close()
+
+				c := newTestDB(db, dialect)
+
+				err := c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &struct {
+					ID int `ksql:"id"`
+					// Missing UserID
+					PermID int `ksql:"perm_id"`
+				}{
+					ID:     1,
+					PermID: 42,
+				})
+				tt.AssertErrContains(t, err, "missing", "ID fields", "user_id")
+			})
+		})
+
 		t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) {
 			db, closer := newDBAdapter(t)
 			defer closer.Close()