mirror of https://github.com/VinGarcia/ksql.git
Replace Query&QueryNext for the new QueryChunks func
parent
76e5ad0f0f
commit
a3bf34146d
23
contracts.go
23
contracts.go
|
@ -10,11 +10,24 @@ type ORMProvider interface {
|
|||
Insert(ctx context.Context, items ...interface{}) error
|
||||
Delete(ctx context.Context, ids ...interface{}) error
|
||||
Update(ctx context.Context, items ...interface{}) error
|
||||
Query(ctx context.Context, query string, params ...interface{}) (Iterator, error)
|
||||
QueryNext(ctx context.Context, rawIt Iterator, item interface{}) (done bool, err error)
|
||||
QueryChunks(ctx context.Context, parser ChunkParser) error
|
||||
}
|
||||
|
||||
// Iterator ...
|
||||
type Iterator interface {
|
||||
Close() error
|
||||
type ChunkParser struct {
|
||||
// The Query and Params are used together to build a query with
|
||||
// protection from injection, just like when using the Find function.
|
||||
Query string
|
||||
Params []interface{}
|
||||
|
||||
ChunkSize int
|
||||
Chunk interface{} // Must be a pointer to a slice of structs
|
||||
|
||||
// The closure that will be called right after
|
||||
// filling the Chunk with ChunkSize items
|
||||
//
|
||||
// Each chunk consecutively parsed will overwrite the
|
||||
// same slice, so don't keep references to it, if you
|
||||
// need some data to be preserved after all chunks are
|
||||
// processed copy the items by value.
|
||||
ForEachChunk func() error
|
||||
}
|
||||
|
|
198
kiss_orm.go
198
kiss_orm.go
|
@ -2,7 +2,6 @@ package kissorm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
|
@ -16,7 +15,12 @@ type Client struct {
|
|||
}
|
||||
|
||||
// NewClient instantiates a new client
|
||||
func NewClient(dbDriver string, connectionString string, maxOpenConns int, tableName string) (Client, error) {
|
||||
func NewClient(
|
||||
dbDriver string,
|
||||
connectionString string,
|
||||
maxOpenConns int,
|
||||
tableName string,
|
||||
) (Client, error) {
|
||||
db, err := gorm.Open(dbDriver, connectionString)
|
||||
if err != nil {
|
||||
return Client{}, err
|
||||
|
@ -58,68 +62,77 @@ func (c Client) Find(
|
|||
return it.Error
|
||||
}
|
||||
|
||||
type iterator struct {
|
||||
isClosed bool
|
||||
rows *sql.Rows
|
||||
}
|
||||
|
||||
// Close ...
|
||||
func (i *iterator) Close() error {
|
||||
if i.isClosed {
|
||||
return nil
|
||||
}
|
||||
i.isClosed = true
|
||||
return i.rows.Close()
|
||||
}
|
||||
|
||||
var noopCloser = iterator{isClosed: true}
|
||||
|
||||
// Query builds an iterator for querying several
|
||||
// results from the database
|
||||
func (c Client) Query(
|
||||
// QueryChunks is meant to perform queries that returns
|
||||
// many results and should only be used for that purpose.
|
||||
//
|
||||
// It ChunkParser argument will inform the query and its params,
|
||||
// and the information that will be used to iterate on the results,
|
||||
// namely:
|
||||
// (1) The Chunk, which must be a pointer to a slice of structs where
|
||||
// the results of the query will be kept on each iteration.
|
||||
// (2) The ChunkSize that describes how many rows should be loaded
|
||||
// on the Chunk slice before running the iteration callback.
|
||||
// (3) The ForEachChunk function, which is the iteration callback
|
||||
// and will be called right after the Chunk is filled with rows
|
||||
// and/or after the last row is read from the database.
|
||||
func (c Client) QueryChunks(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
params ...interface{},
|
||||
) (Iterator, error) {
|
||||
it := c.db.Raw(query, params...)
|
||||
parser ChunkParser,
|
||||
) error {
|
||||
it := c.db.Raw(parser.Query, parser.Params...)
|
||||
if it.Error != nil {
|
||||
return &noopCloser, it.Error
|
||||
return it.Error
|
||||
}
|
||||
|
||||
rows, err := it.Rows()
|
||||
if err != nil {
|
||||
return &noopCloser, err
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(parser.Chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return &iterator{
|
||||
isClosed: false,
|
||||
rows: rows,
|
||||
}, nil
|
||||
}
|
||||
slice := sliceRef.Elem()
|
||||
var idx = 0
|
||||
for ; rows.Next(); idx++ {
|
||||
if slice.Len() <= idx {
|
||||
var elemValue reflect.Value
|
||||
elemValue = reflect.New(structType)
|
||||
if !isSliceOfPtrs {
|
||||
elemValue = elemValue.Elem()
|
||||
}
|
||||
slice = reflect.Append(slice, elemValue)
|
||||
}
|
||||
|
||||
// QueryNext parses the next row of a query
|
||||
// and updates the item argument that must be
|
||||
// passed by reference.
|
||||
func (c Client) QueryNext(
|
||||
ctx context.Context,
|
||||
rawIt Iterator,
|
||||
item interface{},
|
||||
) (done bool, err error) {
|
||||
it, ok := rawIt.(*iterator)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid iterator received on QueryNext()")
|
||||
err = c.db.ScanRows(rows, slice.Index(idx).Addr().Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if idx == parser.ChunkSize-1 {
|
||||
idx = 0
|
||||
sliceRef.Elem().Set(slice)
|
||||
err = parser.ForEachChunk()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if it.isClosed {
|
||||
return false, fmt.Errorf("received closed iterator")
|
||||
// If no rows were found or idx was reset to 0
|
||||
// on the last iteration skip this last call to ForEachChunk:
|
||||
if idx > 0 {
|
||||
sliceRef.Elem().Set(slice.Slice(0, idx))
|
||||
err = parser.ForEachChunk()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !it.rows.Next() {
|
||||
it.Close()
|
||||
return true, it.rows.Err()
|
||||
}
|
||||
|
||||
return false, c.db.ScanRows(it.rows, item)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert one or more instances on the database
|
||||
|
@ -315,24 +328,64 @@ func FillStructWith(entity interface{}, dbRow map[string]interface{}) error {
|
|||
// and the second is a slice of maps representing the database rows you want
|
||||
// to use to update this struct.
|
||||
func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error {
|
||||
slicePtrValue := reflect.ValueOf(entities)
|
||||
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(entities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, found := tagInfoCache[structType]
|
||||
if !found {
|
||||
info = getTagNames(structType)
|
||||
tagInfoCache[structType] = info
|
||||
}
|
||||
|
||||
slice := sliceRef.Elem()
|
||||
for idx, row := range dbRows {
|
||||
if slice.Len() <= idx {
|
||||
var elemValue reflect.Value
|
||||
elemValue = reflect.New(structType)
|
||||
if !isSliceOfPtrs {
|
||||
elemValue = elemValue.Elem()
|
||||
}
|
||||
slice = reflect.Append(slice, elemValue)
|
||||
}
|
||||
|
||||
err := FillStructWith(slice.Index(idx).Addr().Interface(), row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sliceRef.Elem().Set(slice)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeAsSliceOfStructs(slice interface{}) (
|
||||
sliceRef reflect.Value,
|
||||
structType reflect.Type,
|
||||
isSliceOfPtrs bool,
|
||||
err error,
|
||||
) {
|
||||
slicePtrValue := reflect.ValueOf(slice)
|
||||
slicePtrType := slicePtrValue.Type()
|
||||
|
||||
if slicePtrType.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf(
|
||||
err = fmt.Errorf(
|
||||
"FillListWith: expected input to be a pointer to struct but got %T",
|
||||
entities,
|
||||
slice,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
t := slicePtrType.Elem()
|
||||
v := slicePtrValue.Elem()
|
||||
|
||||
if t.Kind() != reflect.Slice {
|
||||
return fmt.Errorf(
|
||||
err = fmt.Errorf(
|
||||
"FillListWith: expected input kind to be a slice but got %T",
|
||||
entities,
|
||||
slice,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
elemType := t.Elem()
|
||||
|
@ -342,36 +395,13 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
|
|||
elemType = elemType.Elem()
|
||||
}
|
||||
|
||||
info, found := tagInfoCache[elemType]
|
||||
if !found {
|
||||
info = getTagNames(elemType)
|
||||
tagInfoCache[elemType] = info
|
||||
}
|
||||
|
||||
if elemType.Kind() != reflect.Struct {
|
||||
return fmt.Errorf(
|
||||
err = fmt.Errorf(
|
||||
"FillListWith: expected input to be a slice of structs but got %T",
|
||||
entities,
|
||||
slice,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
for idx, row := range dbRows {
|
||||
if v.Len() <= idx {
|
||||
var elemValue reflect.Value
|
||||
elemValue = reflect.New(elemType)
|
||||
if !isPtr {
|
||||
elemValue = elemValue.Elem()
|
||||
}
|
||||
v = reflect.Append(v, elemValue)
|
||||
}
|
||||
|
||||
err := FillStructWith(v.Index(idx).Addr().Interface(), row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slicePtrValue.Elem().Set(v)
|
||||
|
||||
return nil
|
||||
return slicePtrValue, elemType, isPtr, nil
|
||||
}
|
||||
|
|
145
kiss_orm_test.go
145
kiss_orm_test.go
|
@ -305,8 +305,8 @@ func TestStructToMap(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
t.Run("should execute query one correctly", func(t *testing.T) {
|
||||
func TestQueryChunks(t *testing.T) {
|
||||
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
|
@ -323,136 +323,29 @@ func TestQuery(t *testing.T) {
|
|||
|
||||
_ = c.Insert(ctx, &User{Name: "User1"})
|
||||
|
||||
it, err := c.Query(ctx, `select * from users where name = ?;`, "User1")
|
||||
assert.Equal(t, nil, err)
|
||||
var length int
|
||||
var u User
|
||||
var users []User
|
||||
err = c.QueryChunks(ctx, ChunkParser{
|
||||
Query: `select * from users where name = ?;`,
|
||||
Params: []interface{}{"User1"},
|
||||
|
||||
u := User{}
|
||||
_, err = c.QueryNext(ctx, it, &u)
|
||||
it.Close()
|
||||
ChunkSize: 100,
|
||||
Chunk: &users,
|
||||
ForEachChunk: func() error {
|
||||
length = len(users)
|
||||
if length > 0 {
|
||||
u = users[0]
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, length)
|
||||
assert.NotEqual(t, 0, u.ID)
|
||||
assert.Equal(t, "User1", u.Name)
|
||||
})
|
||||
|
||||
t.Run("should execute query many correctly", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
}
|
||||
|
||||
db := connectDB(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := Client{
|
||||
db: db,
|
||||
tableName: "users",
|
||||
}
|
||||
|
||||
_ = c.Insert(ctx, &User{Name: "User1"})
|
||||
_ = c.Insert(ctx, &User{Name: "User2"})
|
||||
_ = c.Insert(ctx, &User{Name: "User3"})
|
||||
|
||||
it, err := c.Query(ctx, `select * from users where name in (?,?);`, "User1", "User3")
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// var results []User
|
||||
u := User{}
|
||||
u2 := User{}
|
||||
u3 := User{}
|
||||
done, err := c.QueryNext(ctx, it, &u)
|
||||
assert.Equal(t, false, done)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
done, err = c.QueryNext(ctx, it, &u2)
|
||||
assert.Equal(t, false, done)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
done, err = c.QueryNext(ctx, it, &u3)
|
||||
assert.Equal(t, true, done)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.NotEqual(t, 0, u.ID)
|
||||
assert.Equal(t, "User1", u.Name)
|
||||
assert.NotEqual(t, 0, u2.ID)
|
||||
assert.Equal(t, "User3", u2.Name)
|
||||
})
|
||||
|
||||
t.Run("should return error for an invalid iterator", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
}
|
||||
|
||||
db := connectDB(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := Client{
|
||||
db: db,
|
||||
tableName: "users",
|
||||
}
|
||||
|
||||
u := User{}
|
||||
_, err = c.QueryNext(ctx, Iterator(nil), &u)
|
||||
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
|
||||
t.Run("should return noop closer when syntax error occurs", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
}
|
||||
|
||||
db := connectDB(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := Client{
|
||||
db: db,
|
||||
tableName: "users",
|
||||
}
|
||||
|
||||
it, err := c.Query(ctx, `select * users`)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, &noopCloser, it)
|
||||
})
|
||||
|
||||
t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
}
|
||||
|
||||
db := connectDB(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := Client{
|
||||
db: db,
|
||||
tableName: "users",
|
||||
}
|
||||
|
||||
it, err := c.Query(ctx, `select * from users`)
|
||||
assert.Equal(t, nil, err)
|
||||
err = it.Close()
|
||||
assert.Equal(t, nil, err)
|
||||
u := User{}
|
||||
_, err = c.QueryNext(ctx, it, &u)
|
||||
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
t.Run("should return no errors if it's closed multiple times", func(t *testing.T) {
|
||||
it := iterator{isClosed: true}
|
||||
err := it.Close()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, it.isClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFillSliceWith(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue