Add func UpdateStructWith() for facilitating mocking this ORM

pull/2/head
Vinícius Garcia 2020-09-22 21:07:03 -03:00
parent 7b30856248
commit d77fd7d679
2 changed files with 94 additions and 21 deletions

View File

@ -1,4 +1,4 @@
package gpostgres package kissorm
import ( import (
"context" "context"
@ -8,19 +8,27 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
// ORMProvider describes the public behavior of this ORM
type ORMProvider interface {
Find(ctx context.Context, item interface{}, query string, params ...interface{}) error
GetByID(ctx context.Context, item interface{}, id interface{}) error
Insert(ctx context.Context, items ...interface{}) error
Delete(ctx context.Context, ids ...interface{}) error
Update(ctx context.Context, intems ...interface{}) error
}
// Client ... // Client ...
type Client struct { type Client struct {
tableName string tableName string
db *gorm.DB db *gorm.DB
} }
// NewClient ... // NewClient instantiates a new client
func NewClient(dbDriver string, connectionString string, maxOpenConns int) (Client, error) { func NewClient(dbDriver string, connectionString string, maxOpenConns int, tableName string) (Client, error) {
db, err := gorm.Open(dbDriver, connectionString) db, err := gorm.Open(dbDriver, connectionString)
if err != nil { if err != nil {
return Client{}, err return Client{}, err
} }
if err = db.DB().Ping(); err != nil { if err = db.DB().Ping(); err != nil {
return Client{}, err return Client{}, err
} }
@ -28,16 +36,17 @@ func NewClient(dbDriver string, connectionString string, maxOpenConns int) (Clie
db.DB().SetMaxOpenConns(maxOpenConns) db.DB().SetMaxOpenConns(maxOpenConns)
return Client{ return Client{
db: db, db: db,
tableName: tableName,
}, nil }, nil
} }
// ChangeTable returns a new Client configured to use a new table // ChangeTable creates a new client configured to query on a different table
func (c Client) ChangeTable(ctx context.Context, tableName string) (*Client, error) { func (c Client) ChangeTable(ctx context.Context, tableName string) ORMProvider {
return &Client{ return &Client{
db: c.db, db: c.db,
tableName: tableName, tableName: tableName,
}, nil }
} }
// Find one instance from the database, the input struct // Find one instance from the database, the input struct
@ -115,7 +124,7 @@ func (c Client) Update(
return err return err
} }
r := c.db.Table(c.tableName).Updates(m) r := c.db.Table(c.tableName).Model(item).Updates(m)
if r.Error != nil { if r.Error != nil {
return r.Error return r.Error
} }
@ -128,10 +137,15 @@ func (c Client) Update(
// because the total number of types on a program // because the total number of types on a program
// should be finite. So keeping a single cache here // should be finite. So keeping a single cache here
// works fine. // works fine.
var tagNamesCache = map[reflect.Type]map[int]string{} var tagInfoCache = map[reflect.Type]StructInfo{}
// structToMap converts any type to a map based on the type StructInfo struct {
// tag named `gorm`, i.e. `gorm:"map_key_name"` Names map[int]string
Index map[string]int
}
// structToMap converts any struct type to a map based on
// the tag named `gorm`, i.e. `gorm:"map_key_name"`
// //
// This function is efficient in the fact that it caches // This function is efficient in the fact that it caches
// the slower steps of the reflection required to do perform // the slower steps of the reflection required to do perform
@ -148,14 +162,17 @@ func structToMap(obj interface{}) (map[string]interface{}, error) {
return nil, fmt.Errorf("input must be a struct or struct pointer") return nil, fmt.Errorf("input must be a struct or struct pointer")
} }
names, found := tagNamesCache[t] info, found := tagInfoCache[t]
if !found { if !found {
names = getTagNames(t) info = getTagNames(t)
tagNamesCache[t] = names tagInfoCache[t] = info
} }
m := map[string]interface{}{} m := map[string]interface{}{}
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
if info.Names[i] == "id" {
continue
}
field := v.Field(i) field := v.Field(i)
ft := field.Type() ft := field.Type()
if ft.Kind() == reflect.Ptr { if ft.Kind() == reflect.Ptr {
@ -166,7 +183,7 @@ func structToMap(obj interface{}) (map[string]interface{}, error) {
field = field.Elem() field = field.Elem()
} }
m[names[i]] = field.Interface() m[info.Names[i]] = field.Interface()
} }
return m, nil return m, nil
@ -177,15 +194,71 @@ func structToMap(obj interface{}) (map[string]interface{}, error) {
// //
// This should save several calls to `Field(i).Tag.Get("foo")` // This should save several calls to `Field(i).Tag.Get("foo")`
// which improves performance by a lot. // which improves performance by a lot.
func getTagNames(t reflect.Type) map[int]string { func getTagNames(t reflect.Type) StructInfo {
resp := map[int]string{} info := StructInfo{
Names: map[int]string{},
Index: map[string]int{},
}
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Tag.Get("gorm") name := t.Field(i).Tag.Get("gorm")
if name == "" { if name == "" {
continue continue
} }
resp[i] = name info.Names[i] = name
info.Index[name] = i
} }
return resp return info
}
// UpdateStructWith is meant to be used on unit tests to mock
// the response from the database.
//
// The first argument is any struct you are passing to a kissorm func,
// and the second is a map representing a database row you want
// to use to update this struct.
func UpdateStructWith(entity interface{}, db_row map[string]interface{}) error {
v := reflect.ValueOf(entity)
t := v.Type()
if t.Kind() != reflect.Ptr {
return fmt.Errorf(
"UpdateStructWith: expected input to be a pointer to struct but got %T",
entity,
)
}
t = t.Elem()
v = v.Elem()
if t.Kind() != reflect.Struct {
return fmt.Errorf(
"UpdateStructWith: expected input to be a kind of struct but got %T",
entity,
)
}
info, found := tagInfoCache[t]
if !found {
info = getTagNames(t)
tagInfoCache[t] = info
}
for colName, attr := range db_row {
attrValue := reflect.ValueOf(attr)
field := v.Field(info.Index[colName])
fieldType := t.Field(info.Index[colName]).Type
if !attrValue.Type().ConvertibleTo(fieldType) {
return fmt.Errorf(
"UpdateStructWith: cannot convert atribute %s of type %v to type %T",
colName,
fieldType,
entity,
)
}
field.Set(attrValue.Convert(fieldType))
}
return nil
} }

View File

@ -1,4 +1,4 @@
package gpostgres package kissorm
import ( import (
"context" "context"