diff --git a/kiss_orm.go b/kiss_orm.go index 15019cf..fcd671c 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "strings" "github.com/jinzhu/gorm" ) @@ -241,20 +242,62 @@ func (c Client) Update( records ...interface{}, ) error { for _, record := range records { - m, err := StructToMap(record) + query, params, err := buildUpdateQuery(c.tableName, record, "id") if err != nil { return err } - delete(m, "id") - r := c.db.Table(c.tableName).Model(record).Updates(m) - if r.Error != nil { - return r.Error + + _, err = c.db.DB().Exec(query, params...) + if err != nil { + return err } } return nil } +func buildUpdateQuery( + tableName string, + record interface{}, + idFieldNames ...string, +) (query string, args []interface{}, err error) { + recordMap, err := StructToMap(record) + if err != nil { + return "", nil, err + } + numAttrs := len(recordMap) + numIDs := len(idFieldNames) + args = make([]interface{}, numAttrs+numIDs) + whereArgs := args[numAttrs-1:] + + var whereQuery []string + for i, fieldName := range idFieldNames { + whereArgs[i] = recordMap[fieldName] + whereQuery = append(whereQuery, fmt.Sprintf("`%s` = ?", fieldName)) + delete(recordMap, fieldName) + } + + keys := []string{} + for key := range recordMap { + keys = append(keys, key) + } + + var setQuery []string + for i, k := range keys { + args[i] = recordMap[k] + setQuery = append(setQuery, fmt.Sprintf("`%s` = ?", k)) + } + + query = fmt.Sprintf( + "UPDATE `%s` SET %s WHERE %s", + tableName, + strings.Join(setQuery, ", "), + strings.Join(whereQuery, ", "), + ) + + return query, args, nil +} + // This cache is kept as a pkg variable // because the total number of types on a program // should be finite. So keeping a single cache here