Improve reflection checks on Insert() function

pull/2/head
Vinícius Garcia 2021-01-17 19:25:21 -03:00
parent 5d083e35f0
commit 0b97dbcff7
2 changed files with 32 additions and 2 deletions

View File

@ -197,7 +197,7 @@ This library has a few helper functions for helping your tests:
- `kissorm.StructToMap(struct interface{}) (map[string]interface{}, error)` - `kissorm.StructToMap(struct interface{}) (map[string]interface{}, error)`
If you want to see examples (we have examples for all the public functions) just If you want to see examples (we have examples for all the public functions) just
read the example tests available on the our [example service](./examples/example_service) read the example tests available on our [example service](./examples/example_service)
### TODO List ### TODO List

View File

@ -327,6 +327,9 @@ func (c DB) insertOnPostgres(
v := reflect.ValueOf(record) v := reflect.ValueOf(record)
t := v.Type() t := v.Type()
if err = assertStructPtr(t); err != nil {
return errors.Wrap(err, "can't write id field")
}
info := getCachedTagInfo(tagInfoCache, t.Elem()) info := getCachedTagInfo(tagInfoCache, t.Elem())
fieldAddr := v.Elem().Field(info.Index["id"]).Addr() fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
@ -351,6 +354,10 @@ func (c DB) insertWithLastInsertID(
v := reflect.ValueOf(record) v := reflect.ValueOf(record)
t := v.Type() t := v.Type()
if err = assertStructPtr(t); err != nil {
return errors.Wrap(err, "can't write id field")
}
info := getCachedTagInfo(tagInfoCache, t.Elem()) info := getCachedTagInfo(tagInfoCache, t.Elem())
id, err := result.LastInsertId() id, err := result.LastInsertId()
@ -358,8 +365,31 @@ func (c DB) insertWithLastInsertID(
return err return err
} }
vID := reflect.ValueOf(id)
tID := vID.Type()
fieldAddr := v.Elem().Field(info.Index["id"]).Addr() fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
fieldAddr.Elem().Set(reflect.ValueOf(id).Convert(fieldAddr.Elem().Type())) fieldType := fieldAddr.Type().Elem()
if !tID.ConvertibleTo(fieldType) {
return fmt.Errorf(
"Can't convert last insert id of type int64 into field `%s` of type %s",
"id",
fieldType,
)
}
fieldAddr.Elem().Set(vID.Convert(fieldType))
return nil
}
func assertStructPtr(t reflect.Type) error {
if t.Kind() != reflect.Ptr {
return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
}
if t.Elem().Kind() != reflect.Struct {
return fmt.Errorf("expected a Kind of Ptr to Struct but got: %s", t)
}
return nil return nil
} }