Remove gorm dependency from Insert()

Although this implementation was meant to work with sqlite and
postgres it was not yet tested with postgres.
pull/2/head
Vinícius Garcia 2020-12-27 14:26:27 -03:00
parent 4030768f22
commit a7b3c12b95
2 changed files with 109 additions and 10 deletions

View File

@ -12,6 +12,7 @@ import (
// Client ...
type Client struct {
driver string
tableName string
db *gorm.DB
}
@ -34,6 +35,7 @@ func NewClient(
db.DB().SetMaxOpenConns(maxOpenConns)
return Client{
driver: dbDriver,
db: db,
tableName: tableName,
}, nil
@ -252,20 +254,78 @@ func (c Client) Insert(
ctx context.Context,
records ...interface{},
) error {
if len(records) == 0 {
return nil
}
for _, record := range records {
r := c.db.Table(c.tableName).Create(record)
if r.Error != nil {
return r.Error
query, params, err := buildInsertQuery(c.tableName, record, "id")
if err != nil {
return err
}
switch c.driver {
case "postgres":
err = c.insertOnPostgres(ctx, record, query, params)
default:
err = c.insertWithLastInsertID(ctx, record, query, params)
}
}
return nil
}
func (c Client) insertOnPostgres(
ctx context.Context,
record interface{},
query string,
params []interface{},
) error {
query = query + " RETURNING id"
rows, err := c.db.DB().QueryContext(ctx, query, params...)
if err != nil {
return err
}
if !rows.Next() {
err := fmt.Errorf("unexpected error retrieving the id from the database")
if rows.Err() != nil {
err = rows.Err()
}
return err
}
v := reflect.ValueOf(record)
t := v.Type()
info := getTagInfoWithCache(tagInfoCache, t.Elem())
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
return rows.Scan(fieldAddr)
}
func (c Client) insertWithLastInsertID(
ctx context.Context,
record interface{},
query string,
params []interface{},
) error {
result, err := c.db.DB().ExecContext(ctx, query, params...)
if err != nil {
return err
}
v := reflect.ValueOf(record)
t := v.Type()
info := getTagInfoWithCache(tagInfoCache, t.Elem())
id, err := result.LastInsertId()
if err != nil {
return err
}
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
fieldAddr.Elem().Set(reflect.ValueOf(id).Convert(fieldAddr.Elem().Type()))
return nil
}
// Delete deletes one or more instances from the database by id
func (c Client) Delete(
ctx context.Context,
@ -304,6 +364,47 @@ func (c Client) Update(
return nil
}
func buildInsertQuery(
tableName string,
record interface{},
idFieldNames ...string,
) (query string, params []interface{}, err error) {
recordMap, err := StructToMap(record)
if err != nil {
return "", nil, err
}
numAttrs := len(recordMap)
params = make([]interface{}, numAttrs)
for _, fieldName := range idFieldNames {
// Remove any ID field that was not set:
if reflect.ValueOf(recordMap[fieldName]).IsZero() {
delete(recordMap, fieldName)
}
}
columnNames := []string{}
for col := range recordMap {
columnNames = append(columnNames, col)
}
var valuesQuery []string
for i, col := range columnNames {
params[i] = recordMap[col]
valuesQuery = append(valuesQuery, "?")
}
query = fmt.Sprintf(
"INSERT INTO `%s` (`%s`) VALUES (%s)",
tableName,
strings.Join(columnNames, "`, `"),
strings.Join(valuesQuery, ", "),
)
return query, params, nil
}
func buildUpdateQuery(
tableName string,
record interface{},

View File

@ -2,7 +2,6 @@ package kissorm
import (
"context"
"fmt"
"testing"
"time"
@ -230,6 +229,7 @@ func TestInsert(t *testing.T) {
err := c.Insert(ctx, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID)
result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
@ -1009,9 +1009,7 @@ func TestScanRows(t *testing.T) {
assert.Equal(t, nil, err)
var u map[string]interface{}
fmt.Println("before scan")
err = scanRows(rows, &u)
fmt.Println("after scan")
assert.NotEqual(t, nil, err)
})
}