diff --git a/kiss_orm.go b/kiss_orm.go index a619cbb..a4f0f5d 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -12,6 +12,7 @@ import ( // Client ... type Client struct { + driver string tableName string db *gorm.DB } @@ -34,6 +35,7 @@ func NewClient( db.DB().SetMaxOpenConns(maxOpenConns) return Client{ + driver: dbDriver, db: db, tableName: tableName, }, nil @@ -252,20 +254,78 @@ func (c Client) Insert( ctx context.Context, records ...interface{}, ) error { - if len(records) == 0 { - return nil - } - for _, record := range records { - r := c.db.Table(c.tableName).Create(record) - if r.Error != nil { - return r.Error + query, params, err := buildInsertQuery(c.tableName, record, "id") + if err != nil { + return err + } + + switch c.driver { + case "postgres": + err = c.insertOnPostgres(ctx, record, query, params) + default: + err = c.insertWithLastInsertID(ctx, record, query, params) } } return nil } +func (c Client) insertOnPostgres( + ctx context.Context, + record interface{}, + query string, + params []interface{}, +) error { + query = query + " RETURNING id" + + rows, err := c.db.DB().QueryContext(ctx, query, params...) + if err != nil { + return err + } + + if !rows.Next() { + err := fmt.Errorf("unexpected error retrieving the id from the database") + if rows.Err() != nil { + err = rows.Err() + } + + return err + } + + v := reflect.ValueOf(record) + t := v.Type() + info := getTagInfoWithCache(tagInfoCache, t.Elem()) + + fieldAddr := v.Elem().Field(info.Index["id"]).Addr() + return rows.Scan(fieldAddr) +} + +func (c Client) insertWithLastInsertID( + ctx context.Context, + record interface{}, + query string, + params []interface{}, +) error { + result, err := c.db.DB().ExecContext(ctx, query, params...) + if err != nil { + return err + } + + v := reflect.ValueOf(record) + t := v.Type() + info := getTagInfoWithCache(tagInfoCache, t.Elem()) + + id, err := result.LastInsertId() + if err != nil { + return err + } + + fieldAddr := v.Elem().Field(info.Index["id"]).Addr() + fieldAddr.Elem().Set(reflect.ValueOf(id).Convert(fieldAddr.Elem().Type())) + return nil +} + // Delete deletes one or more instances from the database by id func (c Client) Delete( ctx context.Context, @@ -304,6 +364,47 @@ func (c Client) Update( return nil } +func buildInsertQuery( + tableName string, + record interface{}, + idFieldNames ...string, +) (query string, params []interface{}, err error) { + recordMap, err := StructToMap(record) + if err != nil { + return "", nil, err + } + + numAttrs := len(recordMap) + params = make([]interface{}, numAttrs) + + for _, fieldName := range idFieldNames { + // Remove any ID field that was not set: + if reflect.ValueOf(recordMap[fieldName]).IsZero() { + delete(recordMap, fieldName) + } + } + + columnNames := []string{} + for col := range recordMap { + columnNames = append(columnNames, col) + } + + var valuesQuery []string + for i, col := range columnNames { + params[i] = recordMap[col] + valuesQuery = append(valuesQuery, "?") + } + + query = fmt.Sprintf( + "INSERT INTO `%s` (`%s`) VALUES (%s)", + tableName, + strings.Join(columnNames, "`, `"), + strings.Join(valuesQuery, ", "), + ) + + return query, params, nil +} + func buildUpdateQuery( tableName string, record interface{}, diff --git a/kiss_orm_test.go b/kiss_orm_test.go index f965edb..299d67d 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -2,7 +2,6 @@ package kissorm import ( "context" - "fmt" "testing" "time" @@ -230,6 +229,7 @@ func TestInsert(t *testing.T) { err := c.Insert(ctx, &u) assert.Equal(t, nil, err) + assert.NotEqual(t, 0, u.ID) result := User{} it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) @@ -1009,9 +1009,7 @@ func TestScanRows(t *testing.T) { assert.Equal(t, nil, err) var u map[string]interface{} - fmt.Println("before scan") err = scanRows(rows, &u) - fmt.Println("after scan") assert.NotEqual(t, nil, err) }) }