Replace Query&QueryNext for the new QueryChunks func

pull/2/head
Vinícius Garcia 2020-10-09 15:26:00 -03:00
parent 76e5ad0f0f
commit a3bf34146d
3 changed files with 151 additions and 215 deletions

View File

@ -10,11 +10,24 @@ type ORMProvider interface {
Insert(ctx context.Context, items ...interface{}) error
Delete(ctx context.Context, ids ...interface{}) error
Update(ctx context.Context, items ...interface{}) error
Query(ctx context.Context, query string, params ...interface{}) (Iterator, error)
QueryNext(ctx context.Context, rawIt Iterator, item interface{}) (done bool, err error)
QueryChunks(ctx context.Context, parser ChunkParser) error
}
// Iterator ...
type Iterator interface {
Close() error
type ChunkParser struct {
// The Query and Params are used together to build a query with
// protection from injection, just like when using the Find function.
Query string
Params []interface{}
ChunkSize int
Chunk interface{} // Must be a pointer to a slice of structs
// The closure that will be called right after
// filling the Chunk with ChunkSize items
//
// Each chunk consecutively parsed will overwrite the
// same slice, so don't keep references to it, if you
// need some data to be preserved after all chunks are
// processed copy the items by value.
ForEachChunk func() error
}

View File

@ -2,7 +2,6 @@ package kissorm
import (
"context"
"database/sql"
"fmt"
"reflect"
@ -16,7 +15,12 @@ type Client struct {
}
// NewClient instantiates a new client
func NewClient(dbDriver string, connectionString string, maxOpenConns int, tableName string) (Client, error) {
func NewClient(
dbDriver string,
connectionString string,
maxOpenConns int,
tableName string,
) (Client, error) {
db, err := gorm.Open(dbDriver, connectionString)
if err != nil {
return Client{}, err
@ -58,68 +62,77 @@ func (c Client) Find(
return it.Error
}
type iterator struct {
isClosed bool
rows *sql.Rows
}
// Close ...
func (i *iterator) Close() error {
if i.isClosed {
return nil
}
i.isClosed = true
return i.rows.Close()
}
var noopCloser = iterator{isClosed: true}
// Query builds an iterator for querying several
// results from the database
func (c Client) Query(
// QueryChunks is meant to perform queries that returns
// many results and should only be used for that purpose.
//
// It ChunkParser argument will inform the query and its params,
// and the information that will be used to iterate on the results,
// namely:
// (1) The Chunk, which must be a pointer to a slice of structs where
// the results of the query will be kept on each iteration.
// (2) The ChunkSize that describes how many rows should be loaded
// on the Chunk slice before running the iteration callback.
// (3) The ForEachChunk function, which is the iteration callback
// and will be called right after the Chunk is filled with rows
// and/or after the last row is read from the database.
func (c Client) QueryChunks(
ctx context.Context,
query string,
params ...interface{},
) (Iterator, error) {
it := c.db.Raw(query, params...)
parser ChunkParser,
) error {
it := c.db.Raw(parser.Query, parser.Params...)
if it.Error != nil {
return &noopCloser, it.Error
return it.Error
}
rows, err := it.Rows()
if err != nil {
return &noopCloser, err
return err
}
defer rows.Close()
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(parser.Chunk)
if err != nil {
return err
}
return &iterator{
isClosed: false,
rows: rows,
}, nil
}
slice := sliceRef.Elem()
var idx = 0
for ; rows.Next(); idx++ {
if slice.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
slice = reflect.Append(slice, elemValue)
}
// QueryNext parses the next row of a query
// and updates the item argument that must be
// passed by reference.
func (c Client) QueryNext(
ctx context.Context,
rawIt Iterator,
item interface{},
) (done bool, err error) {
it, ok := rawIt.(*iterator)
if !ok {
return false, fmt.Errorf("invalid iterator received on QueryNext()")
err = c.db.ScanRows(rows, slice.Index(idx).Addr().Interface())
if err != nil {
return err
}
if idx == parser.ChunkSize-1 {
idx = 0
sliceRef.Elem().Set(slice)
err = parser.ForEachChunk()
if err != nil {
return err
}
}
}
if it.isClosed {
return false, fmt.Errorf("received closed iterator")
// If no rows were found or idx was reset to 0
// on the last iteration skip this last call to ForEachChunk:
if idx > 0 {
sliceRef.Elem().Set(slice.Slice(0, idx))
err = parser.ForEachChunk()
if err != nil {
return err
}
}
if !it.rows.Next() {
it.Close()
return true, it.rows.Err()
}
return false, c.db.ScanRows(it.rows, item)
return nil
}
// Insert one or more instances on the database
@ -315,24 +328,64 @@ func FillStructWith(entity interface{}, dbRow map[string]interface{}) error {
// and the second is a slice of maps representing the database rows you want
// to use to update this struct.
func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error {
slicePtrValue := reflect.ValueOf(entities)
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(entities)
if err != nil {
return err
}
info, found := tagInfoCache[structType]
if !found {
info = getTagNames(structType)
tagInfoCache[structType] = info
}
slice := sliceRef.Elem()
for idx, row := range dbRows {
if slice.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
slice = reflect.Append(slice, elemValue)
}
err := FillStructWith(slice.Index(idx).Addr().Interface(), row)
if err != nil {
return err
}
}
sliceRef.Elem().Set(slice)
return nil
}
func decodeAsSliceOfStructs(slice interface{}) (
sliceRef reflect.Value,
structType reflect.Type,
isSliceOfPtrs bool,
err error,
) {
slicePtrValue := reflect.ValueOf(slice)
slicePtrType := slicePtrValue.Type()
if slicePtrType.Kind() != reflect.Ptr {
return fmt.Errorf(
err = fmt.Errorf(
"FillListWith: expected input to be a pointer to struct but got %T",
entities,
slice,
)
return
}
t := slicePtrType.Elem()
v := slicePtrValue.Elem()
if t.Kind() != reflect.Slice {
return fmt.Errorf(
err = fmt.Errorf(
"FillListWith: expected input kind to be a slice but got %T",
entities,
slice,
)
return
}
elemType := t.Elem()
@ -342,36 +395,13 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
elemType = elemType.Elem()
}
info, found := tagInfoCache[elemType]
if !found {
info = getTagNames(elemType)
tagInfoCache[elemType] = info
}
if elemType.Kind() != reflect.Struct {
return fmt.Errorf(
err = fmt.Errorf(
"FillListWith: expected input to be a slice of structs but got %T",
entities,
slice,
)
return
}
for idx, row := range dbRows {
if v.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(elemType)
if !isPtr {
elemValue = elemValue.Elem()
}
v = reflect.Append(v, elemValue)
}
err := FillStructWith(v.Index(idx).Addr().Interface(), row)
if err != nil {
return err
}
}
slicePtrValue.Elem().Set(v)
return nil
return slicePtrValue, elemType, isPtr, nil
}

View File

@ -305,8 +305,8 @@ func TestStructToMap(t *testing.T) {
})
}
func TestQuery(t *testing.T) {
t.Run("should execute query one correctly", func(t *testing.T) {
func TestQueryChunks(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
@ -323,136 +323,29 @@ func TestQuery(t *testing.T) {
_ = c.Insert(ctx, &User{Name: "User1"})
it, err := c.Query(ctx, `select * from users where name = ?;`, "User1")
assert.Equal(t, nil, err)
var length int
var u User
var users []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name = ?;`,
Params: []interface{}{"User1"},
u := User{}
_, err = c.QueryNext(ctx, it, &u)
it.Close()
ChunkSize: 100,
Chunk: &users,
ForEachChunk: func() error {
length = len(users)
if length > 0 {
u = users[0]
}
return nil
},
})
assert.Equal(t, nil, err)
assert.Equal(t, 1, length)
assert.NotEqual(t, 0, u.ID)
assert.Equal(t, "User1", u.Name)
})
t.Run("should execute query many correctly", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
it, err := c.Query(ctx, `select * from users where name in (?,?);`, "User1", "User3")
assert.Equal(t, nil, err)
// var results []User
u := User{}
u2 := User{}
u3 := User{}
done, err := c.QueryNext(ctx, it, &u)
assert.Equal(t, false, done)
assert.Equal(t, nil, err)
done, err = c.QueryNext(ctx, it, &u2)
assert.Equal(t, false, done)
assert.Equal(t, nil, err)
done, err = c.QueryNext(ctx, it, &u3)
assert.Equal(t, true, done)
assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID)
assert.Equal(t, "User1", u.Name)
assert.NotEqual(t, 0, u2.ID)
assert.Equal(t, "User3", u2.Name)
})
t.Run("should return error for an invalid iterator", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
u := User{}
_, err = c.QueryNext(ctx, Iterator(nil), &u)
assert.NotEqual(t, nil, err)
})
t.Run("should return noop closer when syntax error occurs", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
it, err := c.Query(ctx, `select * users`)
assert.NotEqual(t, nil, err)
assert.Equal(t, &noopCloser, it)
})
t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
it, err := c.Query(ctx, `select * from users`)
assert.Equal(t, nil, err)
err = it.Close()
assert.Equal(t, nil, err)
u := User{}
_, err = c.QueryNext(ctx, it, &u)
assert.NotEqual(t, nil, err)
})
}
func TestIterator(t *testing.T) {
t.Run("should return no errors if it's closed multiple times", func(t *testing.T) {
it := iterator{isClosed: true}
err := it.Close()
assert.Equal(t, nil, err)
assert.Equal(t, true, it.isClosed)
})
}
func TestFillSliceWith(t *testing.T) {