diff --git a/postgres.go b/postgres.go index 7fd2d8b..7480c57 100644 --- a/postgres.go +++ b/postgres.go @@ -3,6 +3,7 @@ package gpostgres import ( "context" "fmt" + "reflect" "github.com/jinzhu/gorm" ) @@ -80,3 +81,86 @@ func (c Client) Delete( return nil } + +// Update updates the given instances on the database by id. +// +// Partial updates are supported, i.e. it will ignore nil pointer attributes +func (c Client) Update( + ctx context.Context, + items ...interface{}, +) error { + for _, item := range items { + m, err := structToMap(item) + if err != nil { + return err + } + + r := c.db.Table(c.tableName).Updates(m) + if r.Error != nil { + return r.Error + } + } + + return nil +} + +// This cache is kept as a pkg variable +// because the total number of types on a program +// should be finite. So keeping a single cache here +// works fine. +var tagNamesCache = map[reflect.Type]map[int]string{} + +// structToMap converts any type to a map based on the +// tag named `sql`, i.e. `sql:"map_key_name"` +// +// This function is efficient in the fact that it caches +// the slower steps of the reflection required to do perform +// this task. +func structToMap(obj interface{}) (map[string]interface{}, error) { + v := reflect.ValueOf(obj) + t := v.Type() + + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("input must be a struct or struct pointer") + } + + names, found := tagNamesCache[t] + if !found { + names = getTagNames(t) + tagNamesCache[t] = names + } + + m := map[string]interface{}{} + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + ft := field.Type() + if ft.Kind() == reflect.Ptr && field.IsNil() { + continue + } + m[names[i]] = field.Interface() + } + + return m, nil +} + +// This function collects only the names +// that will be used from the input type. +// +// This should save several calls to `Field(i).Tag.Get("foo")` +// which improves performance by a lot. +func getTagNames(t reflect.Type) map[int]string { + resp := map[int]string{} + for i := 0; i < t.NumField(); i++ { + name := t.Field(i).Tag.Get("sql") + if name == "" { + continue + } + resp[i] = name + } + + return resp +}