mirror of https://github.com/VinGarcia/ksql.git
546 lines
12 KiB
Go
546 lines
12 KiB
Go
package kissorm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/jinzhu/gorm"
|
|
)
|
|
|
|
// Client ...
|
|
type Client struct {
|
|
tableName string
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewClient instantiates a new client
|
|
func NewClient(
|
|
dbDriver string,
|
|
connectionString string,
|
|
maxOpenConns int,
|
|
tableName string,
|
|
) (Client, error) {
|
|
db, err := gorm.Open(dbDriver, connectionString)
|
|
if err != nil {
|
|
return Client{}, err
|
|
}
|
|
if err = db.DB().Ping(); err != nil {
|
|
return Client{}, err
|
|
}
|
|
|
|
db.DB().SetMaxOpenConns(maxOpenConns)
|
|
|
|
return Client{
|
|
db: db,
|
|
tableName: tableName,
|
|
}, nil
|
|
}
|
|
|
|
// ChangeTable creates a new client configured to query on a different table
|
|
func (c Client) ChangeTable(ctx context.Context, tableName string) ORMProvider {
|
|
return &Client{
|
|
db: c.db,
|
|
tableName: tableName,
|
|
}
|
|
}
|
|
|
|
// Query queries several rows from the database,
|
|
// the input should be a slice of structs passed
|
|
// by reference and it will be filled with all the results.
|
|
//
|
|
// Note: it is very important to make sure the query will
|
|
// return a small number of results, otherwise you risk
|
|
// of overloading the available memory.
|
|
func (c Client) Query(
|
|
ctx context.Context,
|
|
records interface{},
|
|
query string,
|
|
params ...interface{},
|
|
) error {
|
|
t := reflect.TypeOf(records)
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records)
|
|
}
|
|
t = t.Elem()
|
|
if t.Kind() != reflect.Slice {
|
|
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records)
|
|
}
|
|
if t.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records)
|
|
}
|
|
|
|
it := c.db.Raw(query, params...)
|
|
if it.Error != nil {
|
|
return it.Error
|
|
}
|
|
it = it.Scan(records)
|
|
return it.Error
|
|
}
|
|
|
|
// QueryOne queries one instance from the database,
|
|
// the input struct must be passed by reference
|
|
// and the query should return only one result.
|
|
//
|
|
// QueryOne returns a ErrRecordNotFound if
|
|
// the query returns no results.
|
|
func (c Client) QueryOne(
|
|
ctx context.Context,
|
|
record interface{},
|
|
query string,
|
|
params ...interface{},
|
|
) error {
|
|
t := reflect.TypeOf(record)
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record)
|
|
}
|
|
t = t.Elem()
|
|
if t.Kind() != reflect.Struct {
|
|
return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
it := c.db.Raw(query, params...)
|
|
if it.Error != nil {
|
|
return it.Error
|
|
}
|
|
it = it.Scan(record)
|
|
if it.Error != nil && it.Error.Error() == "record not found" {
|
|
return ErrRecordNotFound
|
|
}
|
|
return it.Error
|
|
}
|
|
|
|
// QueryChunks is meant to perform queries that returns
|
|
// more results than would normally fit on memory,
|
|
// for others cases the Query and QueryOne functions are indicated.
|
|
//
|
|
// The ChunkParser argument has 4 attributes:
|
|
// (1) The Query;
|
|
// (2) The query args;
|
|
// (3) The chunk size;
|
|
// (4) A callback function called ForEachChunk, that will be called
|
|
// to process each chunk loaded from the database.
|
|
//
|
|
// Note that the signature of the ForEachChunk callback can be
|
|
// any function that receives a slice of structs or a slice of
|
|
// pointers to struct as its only argument and that reflection
|
|
// will be used to instantiate this argument and to fill it
|
|
// with the database rows.
|
|
func (c Client) QueryChunks(
|
|
ctx context.Context,
|
|
parser ChunkParser,
|
|
) error {
|
|
it := c.db.Raw(parser.Query, parser.Params...)
|
|
if it.Error != nil {
|
|
return it.Error
|
|
}
|
|
|
|
rows, err := it.Rows()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
fnValue := reflect.ValueOf(parser.ForEachChunk)
|
|
chunkType, err := parseInputFunc(parser.ForEachChunk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)
|
|
|
|
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(chunkType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var idx = 0
|
|
for rows.Next() {
|
|
// Allocate new slice elements
|
|
// only if they are not already allocated:
|
|
if chunk.Len() <= idx {
|
|
var elemValue reflect.Value
|
|
elemValue = reflect.New(structType)
|
|
if !isSliceOfPtrs {
|
|
elemValue = elemValue.Elem()
|
|
}
|
|
chunk = reflect.Append(chunk, elemValue)
|
|
}
|
|
|
|
err = c.db.ScanRows(rows, chunk.Index(idx).Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if idx < parser.ChunkSize-1 {
|
|
idx++
|
|
continue
|
|
}
|
|
|
|
idx = 0
|
|
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
|
|
if err != nil {
|
|
if err == ErrAbortIteration {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
// If no rows were found or idx was reset to 0
|
|
// on the last iteration skip this last call to ForEachChunk:
|
|
if idx > 0 {
|
|
chunk = chunk.Slice(0, idx)
|
|
|
|
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
|
|
if err != nil {
|
|
if err == ErrAbortIteration {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Insert one or more instances on the database
|
|
//
|
|
// If the original instances have been passed by reference
|
|
// the ID is automatically updated after insertion is completed.
|
|
func (c Client) Insert(
|
|
ctx context.Context,
|
|
records ...interface{},
|
|
) error {
|
|
if len(records) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for _, record := range records {
|
|
r := c.db.Table(c.tableName).Create(record)
|
|
if r.Error != nil {
|
|
return r.Error
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete deletes one or more instances from the database by id
|
|
func (c Client) Delete(
|
|
ctx context.Context,
|
|
ids ...interface{},
|
|
) error {
|
|
for _, id := range ids {
|
|
r := c.db.Table(c.tableName).Delete(id)
|
|
if r.Error != nil {
|
|
return r.Error
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Update updates the given instances on the database by id.
|
|
//
|
|
// Partial updates are supported, i.e. it will ignore nil pointer attributes
|
|
func (c Client) Update(
|
|
ctx context.Context,
|
|
records ...interface{},
|
|
) error {
|
|
for _, record := range records {
|
|
query, params, err := buildUpdateQuery(c.tableName, record, "id")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = c.db.DB().ExecContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildUpdateQuery(
|
|
tableName string,
|
|
record interface{},
|
|
idFieldNames ...string,
|
|
) (query string, args []interface{}, err error) {
|
|
recordMap, err := StructToMap(record)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
numAttrs := len(recordMap)
|
|
numIDs := len(idFieldNames)
|
|
args = make([]interface{}, numAttrs+numIDs)
|
|
whereArgs := args[numAttrs-len(idFieldNames):]
|
|
|
|
var whereQuery []string
|
|
for i, fieldName := range idFieldNames {
|
|
whereArgs[i] = recordMap[fieldName]
|
|
whereQuery = append(whereQuery, fmt.Sprintf("`%s` = ?", fieldName))
|
|
delete(recordMap, fieldName)
|
|
}
|
|
|
|
keys := []string{}
|
|
for key := range recordMap {
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
var setQuery []string
|
|
for i, k := range keys {
|
|
args[i] = recordMap[k]
|
|
setQuery = append(setQuery, fmt.Sprintf("`%s` = ?", k))
|
|
}
|
|
|
|
query = fmt.Sprintf(
|
|
"UPDATE `%s` SET %s WHERE %s",
|
|
tableName,
|
|
strings.Join(setQuery, ", "),
|
|
strings.Join(whereQuery, ", "),
|
|
)
|
|
|
|
return query, args, nil
|
|
}
|
|
|
|
// This cache is kept as a pkg variable
|
|
// because the total number of types on a program
|
|
// should be finite. So keeping a single cache here
|
|
// works fine.
|
|
var tagInfoCache = map[reflect.Type]structInfo{}
|
|
|
|
type structInfo struct {
|
|
Names map[int]string
|
|
Index map[string]int
|
|
}
|
|
|
|
// StructToMap converts any struct type to a map based on
|
|
// the tag named `gorm`, i.e. `gorm:"map_key_name"`
|
|
//
|
|
// Valid pointers are dereferenced and copied to the map,
|
|
// null pointers are ignored.
|
|
//
|
|
// This function is efficient in the fact that it caches
|
|
// the slower steps of the reflection required to perform
|
|
// this task.
|
|
func StructToMap(obj interface{}) (map[string]interface{}, error) {
|
|
v := reflect.ValueOf(obj)
|
|
t := v.Type()
|
|
|
|
if t.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
t = t.Elem()
|
|
}
|
|
if t.Kind() != reflect.Struct {
|
|
return nil, fmt.Errorf("input must be a struct or struct pointer")
|
|
}
|
|
|
|
info, found := tagInfoCache[t]
|
|
if !found {
|
|
info = getTagNames(t)
|
|
tagInfoCache[t] = info
|
|
}
|
|
|
|
m := map[string]interface{}{}
|
|
for i := 0; i < v.NumField(); i++ {
|
|
field := v.Field(i)
|
|
ft := field.Type()
|
|
if ft.Kind() == reflect.Ptr {
|
|
if field.IsNil() {
|
|
continue
|
|
}
|
|
|
|
field = field.Elem()
|
|
}
|
|
|
|
m[info.Names[i]] = field.Interface()
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// This function collects only the names
|
|
// that will be used from the input type.
|
|
//
|
|
// This should save several calls to `Field(i).Tag.Get("foo")`
|
|
// which improves performance by a lot.
|
|
func getTagNames(t reflect.Type) structInfo {
|
|
info := structInfo{
|
|
Names: map[int]string{},
|
|
Index: map[string]int{},
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
name := t.Field(i).Tag.Get("gorm")
|
|
if name == "" {
|
|
continue
|
|
}
|
|
info.Names[i] = name
|
|
info.Index[name] = i
|
|
}
|
|
|
|
return info
|
|
}
|
|
|
|
// FillStructWith is meant to be used on unit tests to mock
|
|
// the response from the database.
|
|
//
|
|
// The first argument is any struct you are passing to a kissorm func,
|
|
// and the second is a map representing a database row you want
|
|
// to use to update this struct.
|
|
func FillStructWith(record interface{}, dbRow map[string]interface{}) error {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf(
|
|
"FillStructWith: expected input to be a pointer to struct but got %T",
|
|
record,
|
|
)
|
|
}
|
|
|
|
t = t.Elem()
|
|
v = v.Elem()
|
|
|
|
if t.Kind() != reflect.Struct {
|
|
return fmt.Errorf(
|
|
"FillStructWith: expected input kind to be a struct but got %T",
|
|
record,
|
|
)
|
|
}
|
|
|
|
info, found := tagInfoCache[t]
|
|
if !found {
|
|
info = getTagNames(t)
|
|
tagInfoCache[t] = info
|
|
}
|
|
|
|
for colName, attr := range dbRow {
|
|
attrValue := reflect.ValueOf(attr)
|
|
field := v.Field(info.Index[colName])
|
|
fieldType := t.Field(info.Index[colName]).Type
|
|
|
|
if !attrValue.Type().ConvertibleTo(fieldType) {
|
|
return fmt.Errorf(
|
|
"FillStructWith: cannot convert atribute %s of type %v to type %T",
|
|
colName,
|
|
fieldType,
|
|
record,
|
|
)
|
|
}
|
|
field.Set(attrValue.Convert(fieldType))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FillSliceWith is meant to be used on unit tests to mock
|
|
// the response from the database.
|
|
//
|
|
// The first argument is any slice of structs you are passing to a kissorm func,
|
|
// and the second is a slice of maps representing the database rows you want
|
|
// to use to update this struct.
|
|
func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error {
|
|
sliceRef := reflect.ValueOf(entities)
|
|
sliceType := sliceRef.Type()
|
|
if sliceType.Kind() != reflect.Ptr {
|
|
return fmt.Errorf(
|
|
"FillSliceWith: expected input to be a pointer to struct but got %v",
|
|
sliceType,
|
|
)
|
|
}
|
|
|
|
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(sliceType.Elem())
|
|
if err != nil {
|
|
return fmt.Errorf("FillSliceWith: %s", err.Error())
|
|
}
|
|
|
|
info, found := tagInfoCache[structType]
|
|
if !found {
|
|
info = getTagNames(structType)
|
|
tagInfoCache[structType] = info
|
|
}
|
|
|
|
slice := sliceRef.Elem()
|
|
for idx, row := range dbRows {
|
|
if slice.Len() <= idx {
|
|
var elemValue reflect.Value
|
|
elemValue = reflect.New(structType)
|
|
if !isSliceOfPtrs {
|
|
elemValue = elemValue.Elem()
|
|
}
|
|
slice = reflect.Append(slice, elemValue)
|
|
}
|
|
|
|
err := FillStructWith(slice.Index(idx).Addr().Interface(), row)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
sliceRef.Elem().Set(slice)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeAsSliceOfStructs(slice reflect.Type) (
|
|
structType reflect.Type,
|
|
isSliceOfPtrs bool,
|
|
err error,
|
|
) {
|
|
if slice.Kind() != reflect.Slice {
|
|
err = fmt.Errorf(
|
|
"expected input kind to be a slice but got %v",
|
|
slice,
|
|
)
|
|
return
|
|
}
|
|
|
|
elemType := slice.Elem()
|
|
isPtr := elemType.Kind() == reflect.Ptr
|
|
|
|
if isPtr {
|
|
elemType = elemType.Elem()
|
|
}
|
|
|
|
if elemType.Kind() != reflect.Struct {
|
|
err = fmt.Errorf(
|
|
"expected input to be a slice of structs but got %v",
|
|
slice,
|
|
)
|
|
return
|
|
}
|
|
|
|
return elemType, isPtr, nil
|
|
}
|
|
|
|
var errType = reflect.TypeOf(new(error)).Elem()
|
|
|
|
func parseInputFunc(fn interface{}) (reflect.Type, error) {
|
|
t := reflect.TypeOf(fn)
|
|
|
|
if t.Kind() != reflect.Func {
|
|
return nil, fmt.Errorf("the ForEachChunk callback must be a function")
|
|
}
|
|
if t.NumIn() != 1 {
|
|
return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument")
|
|
}
|
|
|
|
if t.NumOut() != 1 {
|
|
return nil, fmt.Errorf("the ForEachChunk callback must have a single return value")
|
|
}
|
|
|
|
if t.Out(0) != errType {
|
|
return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error")
|
|
}
|
|
|
|
argsType := t.In(0)
|
|
if argsType.Kind() != reflect.Slice {
|
|
return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs")
|
|
}
|
|
|
|
return argsType, nil
|
|
}
|