diff --git a/contracts.go b/contracts.go index 3de2e8e..4195b1b 100644 --- a/contracts.go +++ b/contracts.go @@ -8,6 +8,7 @@ import ( // ErrRecordNotFound ... var ErrRecordNotFound error = fmt.Errorf("ksql: the query returned no results: %w", sql.ErrNoRows) +var ErrNoValuesToUpdate error = fmt.Errorf("ksql: the input struct contains no values to update") // ErrAbortIteration ... var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function") diff --git a/internal/modifiers/global_modifiers.go b/internal/modifiers/global_modifiers.go index c27fb68..3fb7722 100644 --- a/internal/modifiers/global_modifiers.go +++ b/internal/modifiers/global_modifiers.go @@ -9,10 +9,21 @@ import ( var modifiers sync.Map func init() { - // These are the builtin modifiers + // These are the builtin modifiers: + + // This one is useful for serializing/desserializing structs: modifiers.Store("json", jsonModifier) + + // This next two are useful for the UpdatedAt and Created fields respectively: + // They only work on time.Time attributes and will set the attribute to time.Now(). modifiers.Store("timeNowUTC", timeNowUTCModifier) modifiers.Store("timeNowUTC/skipUpdates", timeNowUTCSkipUpdatesModifier) + + // These are mostly example modifiers and they are also used + // to test the feature of skipping updates, inserts and queries. + modifiers.Store("skipUpdates", skipUpdatesModifier) + modifiers.Store("skipInserts", skipInsertsModifier) + modifiers.Store("skipQueries", skipQueriesModifier) } // RegisterAttrModifier allow users to add custom modifiers on startup diff --git a/internal/modifiers/skip_modifiers.go b/internal/modifiers/skip_modifiers.go new file mode 100644 index 0000000..150b027 --- /dev/null +++ b/internal/modifiers/skip_modifiers.go @@ -0,0 +1,13 @@ +package modifiers + +var skipInsertsModifier = AttrModifier{ + SkipOnInsert: true, +} + +var skipUpdatesModifier = AttrModifier{ + SkipOnUpdate: true, +} + +var skipQueriesModifier = AttrModifier{ + SkipOnQuery: true, +} diff --git a/ksql.go b/ksql.go index 7aec5f2..a346b25 100644 --- a/ksql.go +++ b/ksql.go @@ -258,8 +258,8 @@ func (c DB) QueryOne( defer rows.Close() if !rows.Next() { - if rows.Err() != nil { - return rows.Err() + if err := rows.Err(); err != nil { + return err } return ErrRecordNotFound } @@ -709,11 +709,15 @@ func buildInsertQuery( columnNames := []string{} for col := range recordMap { + if info.ByName(col).Modifier.SkipOnInsert { + continue + } + columnNames = append(columnNames, col) } - params = make([]interface{}, len(recordMap)) - valuesQuery := make([]string, len(recordMap)) + params = make([]interface{}, len(columnNames)) + valuesQuery := make([]string, len(columnNames)) for i, col := range columnNames { recordValue := recordMap[col] params[i] = recordValue @@ -770,6 +774,16 @@ func buildInsertQuery( } } + if len(columnNames) == 0 && dialect.DriverName() != "mysql" { + query = fmt.Sprintf( + "INSERT INTO %s%s DEFAULT VALUES%s", + dialect.Escape(table.name), + outputQuery, + returningQuery, + ) + return query, params, scanValues, nil + } + // Note that the outputQuery and the returningQuery depend // on the selected driver, thus, they might be empty strings. query = fmt.Sprintf( @@ -807,6 +821,10 @@ func buildUpdateQuery( numNonIDArgs := numAttrs - len(idFieldNames) whereArgs := args[numNonIDArgs:] + if numNonIDArgs == 0 { + return "", nil, ErrNoValuesToUpdate + } + err = validateIfAllIdsArePresent(idFieldNames, recordMap) if err != nil { return "", nil, err diff --git a/test_adapters.go b/test_adapters.go index ca52fe3..beadd82 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -896,6 +896,33 @@ func InsertTest( tt.AssertNoErr(t, err) tt.AssertEqual(t, inserted.Age, 5455) }) + + t.Run("should work and retrieve the ID for structs with no attributes", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + type tsUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipInserts"` + } + u := tsUser{ + Name: "Letícia", + } + err := c.Insert(ctx, usersTable, &u) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, 0) + + var untaggedUser struct { + ID uint `ksql:"id"` + Name *string `ksql:"name"` + } + err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, untaggedUser.Name, (*string)(nil)) + }) }) t.Run("composite key tables", func(t *testing.T) { @@ -1714,6 +1741,24 @@ func PatchTest( tt.AssertNotEqual(t, err, nil) }) + t.Run("should report error if the struct has no fields to update", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err = c.Update(ctx, usersTable, struct { + ID uint `ksql:"id"` // ID fields are not updated + Name string `ksql:"name,skipUpdates"` // the skipUpdate modifier should rule this one out + Age *int `ksql:"age"` // Age is a nil pointer so it would not be updated + }{ + ID: 1, + Name: "some name", + }) + tt.AssertEqual(t, err, ErrNoValuesToUpdate) + }) + t.Run("should report error if the id is missing", func(t *testing.T) { t.Run("with a single primary key", func(t *testing.T) { db, closer := newDBAdapter(t) @@ -2749,6 +2794,199 @@ func ModifiersTest( tt.AssertEqual(t, taggedUser.CreatedAt, tt.ParseTime(t, "2000-08-05T14:00:00Z")) }) }) + + t.Run("skipInserts modifier", func(t *testing.T) { + t.Run("should ignore the field during insertions", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type tsUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipInserts"` + Age int `ksql:"age"` + } + u := tsUser{ + Name: "Letícia", + Age: 22, + } + err := c.Insert(ctx, usersTable, &u) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, 0) + + var untaggedUser struct { + ID uint `ksql:"id"` + Name *string `ksql:"name"` + Age int `ksql:"age"` + } + err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, untaggedUser.Name, (*string)(nil)) + tt.AssertEqual(t, untaggedUser.Age, 22) + }) + + t.Run("should have no effect on updates", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type userWithNoTags struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + Age int `ksql:"age"` + } + untaggedUser := userWithNoTags{ + Name: "Laurinha Ribeiro", + Age: 11, + } + err := c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, untaggedUser.ID, 0) + + type taggedUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipInserts"` + Age int `ksql:"age"` + } + u := taggedUser{ + ID: untaggedUser.ID, + Name: "Laura Ribeiro", + Age: 12, + } + err = c.Patch(ctx, usersTable, u) + tt.AssertNoErr(t, err) + + var untaggedUser2 userWithNoTags + err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro") + tt.AssertEqual(t, untaggedUser2.Age, 12) + }) + + t.Run("should not alter the value on queries", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type userWithNoTags struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + } + untaggedUser := userWithNoTags{ + Name: "Marta Ribeiro", + } + err := c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, untaggedUser.ID, 0) + + var taggedUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipInserts"` + } + err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID) + tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro") + }) + }) + + t.Run("skipUpdates modifier", func(t *testing.T) { + t.Run("should set the field on insertion", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type tsUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipUpdates"` + } + u := tsUser{ + Name: "Letícia", + } + err := c.Insert(ctx, usersTable, &u) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, 0) + + var untaggedUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + } + err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), u.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, untaggedUser.Name, "Letícia") + }) + + t.Run("should be ignored on updates", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type userWithNoTags struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + Age int `ksql:"age"` + } + untaggedUser := userWithNoTags{ + Name: "Laurinha Ribeiro", + Age: 11, + } + err := c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, untaggedUser.ID, 0) + + type taggedUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipUpdates"` + Age int `ksql:"age"` + } + u := taggedUser{ + ID: untaggedUser.ID, + Name: "Laura Ribeiro", + Age: 12, + } + err = c.Patch(ctx, usersTable, u) + tt.AssertNoErr(t, err) + + var untaggedUser2 userWithNoTags + err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, untaggedUser2.Name, "Laurinha Ribeiro") + tt.AssertEqual(t, untaggedUser2.Age, 12) + }) + + t.Run("should not alter the value on queries", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type userWithNoTags struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + } + untaggedUser := userWithNoTags{ + Name: "Marta Ribeiro", + } + err := c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, untaggedUser.ID, 0) + + var taggedUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name,skipUpdates"` + } + err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID) + tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro") + }) + }) + }) }