diff --git a/kiss_orm.go b/kiss_orm.go index b7963d1..8a2efcc 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -363,7 +363,7 @@ func (c Client) Update( records ...interface{}, ) error { for _, record := range records { - query, params, err := buildUpdateQuery(c.tableName, record, "id") + query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, "id") if err != nil { return err } @@ -424,6 +424,7 @@ func buildInsertQuery( } func buildUpdateQuery( + dialect dialect, tableName string, record interface{}, idFieldNames ...string, @@ -433,14 +434,19 @@ func buildUpdateQuery( return "", nil, err } numAttrs := len(recordMap) - numIDs := len(idFieldNames) - args = make([]interface{}, numAttrs+numIDs) - whereArgs := args[numAttrs-len(idFieldNames):] + args = make([]interface{}, numAttrs) + numNonIDArgs := numAttrs - len(idFieldNames) + whereArgs := args[numNonIDArgs:] - var whereQuery []string + whereQuery := make([]string, len(idFieldNames)) for i, fieldName := range idFieldNames { whereArgs[i] = recordMap[fieldName] - whereQuery = append(whereQuery, fmt.Sprintf("`%s` = ?", fieldName)) + whereQuery[i] = fmt.Sprintf( + "%s = %s", + dialect.Escape(fieldName), + dialect.Placeholder(i+numNonIDArgs), + ) + delete(recordMap, fieldName) } @@ -452,12 +458,16 @@ func buildUpdateQuery( var setQuery []string for i, k := range keys { args[i] = recordMap[k] - setQuery = append(setQuery, fmt.Sprintf("`%s` = ?", k)) + setQuery = append(setQuery, fmt.Sprintf( + "%s = %s", + dialect.Escape(k), + dialect.Placeholder(i), + )) } query = fmt.Sprintf( - "UPDATE `%s` SET %s WHERE %s", - tableName, + "UPDATE %s SET %s WHERE %s", + dialect.Escape(tableName), strings.Join(setQuery, ", "), strings.Join(whereQuery, ", "), ) diff --git a/kiss_orm_test.go b/kiss_orm_test.go index d3b3616..50e32d7 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -364,178 +364,182 @@ func TestDelete(t *testing.T) { } func TestUpdate(t *testing.T) { - err := createTable("sqlite3") - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) + for _, driver := range []string{"sqlite3", "postgres"} { + t.Run(driver, func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + t.Run("should ignore empty lists of ids", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestClient(db, driver, "users") + + u := User{ + Name: "Thay", + } + err := c.Insert(ctx, &u) + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + + // Empty update, should do nothing: + err = c.Update(ctx) + assert.Equal(t, nil, err) + + result := User{} + it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) + it.Scan(&result) + it.Close() + assert.Equal(t, nil, err) + + assert.Equal(t, "Thay", result.Name) + }) + + t.Run("should update one user correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestClient(db, driver, "users") + + u := User{ + Name: "Letícia", + } + r := c.db.Table(c.tableName).Create(&u) + assert.Equal(t, nil, r.Error) + assert.NotEqual(t, uint(0), u.ID) + + err = c.Update(ctx, User{ + ID: u.ID, + Name: "Thayane", + }) + assert.Equal(t, nil, err) + + var result User + it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) + it.Scan(&result) + assert.Equal(t, nil, it.Error) + assert.Equal(t, "Thayane", result.Name) + }) + + t.Run("should update one user correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestClient(db, driver, "users") + + u := User{ + Name: "Letícia", + } + r := c.db.Table(c.tableName).Create(&u) + assert.Equal(t, nil, r.Error) + assert.NotEqual(t, uint(0), u.ID) + + err = c.Update(ctx, User{ + ID: u.ID, + Name: "Thayane", + }) + assert.Equal(t, nil, err) + + var result User + it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) + it.Scan(&result) + assert.Equal(t, nil, it.Error) + assert.Equal(t, "Thayane", result.Name) + }) + + t.Run("should ignore null pointers on partial updates", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestClient(db, driver, "users") + + type partialUser struct { + ID uint `gorm:"id"` + Name string `gorm:"name"` + Age *int `gorm:"age"` + } + u := partialUser{ + Name: "Letícia", + Age: nullable.Int(22), + } + r := c.db.Table(c.tableName).Create(&u) + assert.Equal(t, nil, r.Error) + assert.NotEqual(t, uint(0), u.ID) + + err = c.Update(ctx, partialUser{ + ID: u.ID, + // Should be updated because it is not null, just empty: + Name: "", + // Should not be updated because it is null: + Age: nil, + }) + assert.Equal(t, nil, err) + + var result User + it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) + it.Scan(&result) + assert.Equal(t, nil, it.Error) + assert.Equal(t, "", result.Name) + assert.Equal(t, 22, result.Age) + }) + + t.Run("should update valid pointers on partial updates", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestClient(db, driver, "users") + + type partialUser struct { + ID uint `gorm:"id"` + Name string `gorm:"name"` + Age *int `gorm:"age"` + } + u := partialUser{ + Name: "Letícia", + Age: nullable.Int(22), + } + r := c.db.Table(c.tableName).Create(&u) + assert.Equal(t, nil, r.Error) + assert.NotEqual(t, uint(0), u.ID) + + // Should update all fields: + err = c.Update(ctx, partialUser{ + ID: u.ID, + Name: "Thay", + Age: nullable.Int(42), + }) + assert.Equal(t, nil, err) + + var result User + it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) + it.Scan(&result) + assert.Equal(t, nil, it.Error) + assert.Equal(t, "Thay", result.Name) + assert.Equal(t, 42, result.Age) + }) + + t.Run("should report database errors correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestClient(db, driver, "non_existing_table") + + err = c.Update(ctx, User{ + ID: 1, + Name: "Thayane", + }) + assert.NotEqual(t, nil, err) + }) + }) } - - t.Run("should ignore empty lists of ids", func(t *testing.T) { - db := connectDB(t, "sqlite3") - defer db.Close() - - ctx := context.Background() - c := newTestClient(db, "sqlite3", "users") - - u := User{ - Name: "Thay", - } - err := c.Insert(ctx, &u) - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - - // Empty update, should do nothing: - err = c.Update(ctx) - assert.Equal(t, nil, err) - - result := User{} - it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) - it.Scan(&result) - it.Close() - assert.Equal(t, nil, err) - - assert.Equal(t, "Thay", result.Name) - }) - - t.Run("should update one user correctly", func(t *testing.T) { - db := connectDB(t, "sqlite3") - defer db.Close() - - ctx := context.Background() - c := newTestClient(db, "sqlite3", "users") - - u := User{ - Name: "Letícia", - } - r := c.db.Table(c.tableName).Create(&u) - assert.Equal(t, nil, r.Error) - assert.NotEqual(t, uint(0), u.ID) - - err = c.Update(ctx, User{ - ID: u.ID, - Name: "Thayane", - }) - assert.Equal(t, nil, err) - - var result User - it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) - it.Scan(&result) - assert.Equal(t, nil, it.Error) - assert.Equal(t, "Thayane", result.Name) - }) - - t.Run("should update one user correctly", func(t *testing.T) { - db := connectDB(t, "sqlite3") - defer db.Close() - - ctx := context.Background() - c := newTestClient(db, "sqlite3", "users") - - u := User{ - Name: "Letícia", - } - r := c.db.Table(c.tableName).Create(&u) - assert.Equal(t, nil, r.Error) - assert.NotEqual(t, uint(0), u.ID) - - err = c.Update(ctx, User{ - ID: u.ID, - Name: "Thayane", - }) - assert.Equal(t, nil, err) - - var result User - it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) - it.Scan(&result) - assert.Equal(t, nil, it.Error) - assert.Equal(t, "Thayane", result.Name) - }) - - t.Run("should ignore null pointers on partial updates", func(t *testing.T) { - db := connectDB(t, "sqlite3") - defer db.Close() - - ctx := context.Background() - c := newTestClient(db, "sqlite3", "users") - - type partialUser struct { - ID uint `gorm:"id"` - Name string `gorm:"name"` - Age *int `gorm:"age"` - } - u := partialUser{ - Name: "Letícia", - Age: nullable.Int(22), - } - r := c.db.Table(c.tableName).Create(&u) - assert.Equal(t, nil, r.Error) - assert.NotEqual(t, uint(0), u.ID) - - err = c.Update(ctx, partialUser{ - ID: u.ID, - // Should be updated because it is not null, just empty: - Name: "", - // Should not be updated because it is null: - Age: nil, - }) - assert.Equal(t, nil, err) - - var result User - it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) - it.Scan(&result) - assert.Equal(t, nil, it.Error) - assert.Equal(t, "", result.Name) - assert.Equal(t, 22, result.Age) - }) - - t.Run("should update valid pointers on partial updates", func(t *testing.T) { - db := connectDB(t, "sqlite3") - defer db.Close() - - ctx := context.Background() - c := newTestClient(db, "sqlite3", "users") - - type partialUser struct { - ID uint `gorm:"id"` - Name string `gorm:"name"` - Age *int `gorm:"age"` - } - u := partialUser{ - Name: "Letícia", - Age: nullable.Int(22), - } - r := c.db.Table(c.tableName).Create(&u) - assert.Equal(t, nil, r.Error) - assert.NotEqual(t, uint(0), u.ID) - - // Should update all fields: - err = c.Update(ctx, partialUser{ - ID: u.ID, - Name: "Thay", - Age: nullable.Int(42), - }) - assert.Equal(t, nil, err) - - var result User - it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) - it.Scan(&result) - assert.Equal(t, nil, it.Error) - assert.Equal(t, "Thay", result.Name) - assert.Equal(t, 42, result.Age) - }) - - t.Run("should report database errors correctly", func(t *testing.T) { - db := connectDB(t, "sqlite3") - defer db.Close() - - ctx := context.Background() - c := newTestClient(db, "sqlite3", "non_existing_table") - - err = c.Update(ctx, User{ - ID: 1, - Name: "Thayane", - }) - assert.NotEqual(t, nil, err) - }) } func TestStructToMap(t *testing.T) {