diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 591f32e..eea4862 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,12 @@ jobs: tests: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.20' + + - uses: actions/checkout@v3 - name: Pull Postgres run: docker pull postgres:14.0 - name: Pull MariaDB diff --git a/contracts.go b/contracts.go index f1557f2..fa2a6e0 100644 --- a/contracts.go +++ b/contracts.go @@ -24,6 +24,10 @@ var ErrRecordNotFound error = fmt.Errorf("ksql: the query returned no results: % // (2) If the attribute is using a modifier that contains the SkipUpdates flag. var ErrNoValuesToUpdate error = fmt.Errorf("ksql: the input struct contains no values to update") +// ErrRecordMissingIDs is returned by the Update or Delete functions if an input record does +// not have all of the IDs described on the input table. +var ErrRecordMissingIDs error = fmt.Errorf("ksql: missing required ID fields") + // ErrAbortIteration should be used inside the QueryChunks function to inform QueryChunks it should stop querying, // close the connection and return with no errors. var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function") diff --git a/ksql.go b/ksql.go index 5972ed5..0230c1d 100644 --- a/ksql.go +++ b/ksql.go @@ -845,18 +845,18 @@ func buildUpdateQuery( numAttrs := len(recordMap) args = make([]interface{}, numAttrs) - numNonIDArgs := numAttrs - len(idFieldNames) - whereArgs := args[numNonIDArgs:] - - if numNonIDArgs == 0 { - return "", nil, ErrNoValuesToUpdate - } err = validateIfAllIdsArePresent(idFieldNames, recordMap) if err != nil { return "", nil, err } + numNonIDArgs := numAttrs - len(idFieldNames) + whereArgs := args[numNonIDArgs:] + if numNonIDArgs == 0 { + return "", nil, ErrNoValuesToUpdate + } + whereQuery := make([]string, len(idFieldNames)) for i, fieldName := range idFieldNames { whereArgs[i] = recordMap[fieldName] @@ -912,11 +912,11 @@ func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{}) for _, idName := range idNames { id, found := idMap[idName] if !found { - return fmt.Errorf("missing required id field `%s` on input record", idName) + return fmt.Errorf("missing required id field `%s` on input record: %w", idName, ErrRecordMissingIDs) } if id == nil || reflect.ValueOf(id).IsZero() { - return fmt.Errorf("invalid value '%v' received for id column: '%s'", id, idName) + return fmt.Errorf("invalid value '%v' received for id column: '%s': %w", id, idName, ErrRecordMissingIDs) } } diff --git a/test_adapters.go b/test_adapters.go index e1c4f01..6e37572 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -1807,7 +1807,7 @@ func PatchTest( ID: 1, Name: "some name", }) - tt.AssertEqual(t, err, ErrNoValuesToUpdate) + tt.AssertErrContains(t, err, "struct", "no values to update") }) t.Run("should report error if the id is missing", func(t *testing.T) { @@ -1839,6 +1839,40 @@ func PatchTest( }) }) + t.Run("should report error if the struct has no id", func(t *testing.T) { + t.Run("with a single primary key", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, dialect) + + err := c.Update(ctx, usersTable, &struct { + // Missing ID + Name string `ksql:"name"` + }{ + Name: "Jane", + }) + tt.AssertErrContains(t, err, "missing", "ID fields", "id") + }) + + t.Run("with composite keys", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, dialect) + + err := c.Update(ctx, NewTable("user_permissions", "id", "user_id", "perm_id"), &struct { + ID int `ksql:"id"` + // Missing UserID + PermID int `ksql:"perm_id"` + }{ + ID: 1, + PermID: 42, + }) + tt.AssertErrContains(t, err, "missing", "ID fields", "user_id") + }) + }) + t.Run("should report error context.Canceled the context has been canceled", func(t *testing.T) { db, closer := newDBAdapter(t) defer closer.Close()