mirror of https://github.com/VinGarcia/ksql.git
Remove gorm dependency from Insert()
Although this implementation was meant to work with sqlite and postgres it was not yet tested with postgres.pull/2/head
parent
4030768f22
commit
a7b3c12b95
115
kiss_orm.go
115
kiss_orm.go
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
// Client ...
|
||||
type Client struct {
|
||||
driver string
|
||||
tableName string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
@ -34,6 +35,7 @@ func NewClient(
|
|||
db.DB().SetMaxOpenConns(maxOpenConns)
|
||||
|
||||
return Client{
|
||||
driver: dbDriver,
|
||||
db: db,
|
||||
tableName: tableName,
|
||||
}, nil
|
||||
|
@ -252,20 +254,78 @@ func (c Client) Insert(
|
|||
ctx context.Context,
|
||||
records ...interface{},
|
||||
) error {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
r := c.db.Table(c.tableName).Create(record)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
query, params, err := buildInsertQuery(c.tableName, record, "id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.driver {
|
||||
case "postgres":
|
||||
err = c.insertOnPostgres(ctx, record, query, params)
|
||||
default:
|
||||
err = c.insertWithLastInsertID(ctx, record, query, params)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Client) insertOnPostgres(
|
||||
ctx context.Context,
|
||||
record interface{},
|
||||
query string,
|
||||
params []interface{},
|
||||
) error {
|
||||
query = query + " RETURNING id"
|
||||
|
||||
rows, err := c.db.DB().QueryContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
err := fmt.Errorf("unexpected error retrieving the id from the database")
|
||||
if rows.Err() != nil {
|
||||
err = rows.Err()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
info := getTagInfoWithCache(tagInfoCache, t.Elem())
|
||||
|
||||
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
|
||||
return rows.Scan(fieldAddr)
|
||||
}
|
||||
|
||||
func (c Client) insertWithLastInsertID(
|
||||
ctx context.Context,
|
||||
record interface{},
|
||||
query string,
|
||||
params []interface{},
|
||||
) error {
|
||||
result, err := c.db.DB().ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
info := getTagInfoWithCache(tagInfoCache, t.Elem())
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
|
||||
fieldAddr.Elem().Set(reflect.ValueOf(id).Convert(fieldAddr.Elem().Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes one or more instances from the database by id
|
||||
func (c Client) Delete(
|
||||
ctx context.Context,
|
||||
|
@ -304,6 +364,47 @@ func (c Client) Update(
|
|||
return nil
|
||||
}
|
||||
|
||||
func buildInsertQuery(
|
||||
tableName string,
|
||||
record interface{},
|
||||
idFieldNames ...string,
|
||||
) (query string, params []interface{}, err error) {
|
||||
recordMap, err := StructToMap(record)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
numAttrs := len(recordMap)
|
||||
params = make([]interface{}, numAttrs)
|
||||
|
||||
for _, fieldName := range idFieldNames {
|
||||
// Remove any ID field that was not set:
|
||||
if reflect.ValueOf(recordMap[fieldName]).IsZero() {
|
||||
delete(recordMap, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
columnNames := []string{}
|
||||
for col := range recordMap {
|
||||
columnNames = append(columnNames, col)
|
||||
}
|
||||
|
||||
var valuesQuery []string
|
||||
for i, col := range columnNames {
|
||||
params[i] = recordMap[col]
|
||||
valuesQuery = append(valuesQuery, "?")
|
||||
}
|
||||
|
||||
query = fmt.Sprintf(
|
||||
"INSERT INTO `%s` (`%s`) VALUES (%s)",
|
||||
tableName,
|
||||
strings.Join(columnNames, "`, `"),
|
||||
strings.Join(valuesQuery, ", "),
|
||||
)
|
||||
|
||||
return query, params, nil
|
||||
}
|
||||
|
||||
func buildUpdateQuery(
|
||||
tableName string,
|
||||
record interface{},
|
||||
|
|
|
@ -2,7 +2,6 @@ package kissorm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -230,6 +229,7 @@ func TestInsert(t *testing.T) {
|
|||
|
||||
err := c.Insert(ctx, &u)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, 0, u.ID)
|
||||
|
||||
result := User{}
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
|
@ -1009,9 +1009,7 @@ func TestScanRows(t *testing.T) {
|
|||
assert.Equal(t, nil, err)
|
||||
|
||||
var u map[string]interface{}
|
||||
fmt.Println("before scan")
|
||||
err = scanRows(rows, &u)
|
||||
fmt.Println("after scan")
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue