Replace Query&QueryNext for the new QueryChunks func

pull/2/head
Vinícius Garcia 2020-10-09 15:26:00 -03:00
parent 76e5ad0f0f
commit a3bf34146d
3 changed files with 151 additions and 215 deletions

View File

@ -10,11 +10,24 @@ type ORMProvider interface {
Insert(ctx context.Context, items ...interface{}) error Insert(ctx context.Context, items ...interface{}) error
Delete(ctx context.Context, ids ...interface{}) error Delete(ctx context.Context, ids ...interface{}) error
Update(ctx context.Context, items ...interface{}) error Update(ctx context.Context, items ...interface{}) error
Query(ctx context.Context, query string, params ...interface{}) (Iterator, error) QueryChunks(ctx context.Context, parser ChunkParser) error
QueryNext(ctx context.Context, rawIt Iterator, item interface{}) (done bool, err error)
} }
// Iterator ... type ChunkParser struct {
type Iterator interface { // The Query and Params are used together to build a query with
Close() error // protection from injection, just like when using the Find function.
Query string
Params []interface{}
ChunkSize int
Chunk interface{} // Must be a pointer to a slice of structs
// The closure that will be called right after
// filling the Chunk with ChunkSize items
//
// Each chunk consecutively parsed will overwrite the
// same slice, so don't keep references to it, if you
// need some data to be preserved after all chunks are
// processed copy the items by value.
ForEachChunk func() error
} }

View File

@ -2,7 +2,6 @@ package kissorm
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
@ -16,7 +15,12 @@ type Client struct {
} }
// NewClient instantiates a new client // NewClient instantiates a new client
func NewClient(dbDriver string, connectionString string, maxOpenConns int, tableName string) (Client, error) { func NewClient(
dbDriver string,
connectionString string,
maxOpenConns int,
tableName string,
) (Client, error) {
db, err := gorm.Open(dbDriver, connectionString) db, err := gorm.Open(dbDriver, connectionString)
if err != nil { if err != nil {
return Client{}, err return Client{}, err
@ -58,68 +62,77 @@ func (c Client) Find(
return it.Error return it.Error
} }
type iterator struct { // QueryChunks is meant to perform queries that returns
isClosed bool // many results and should only be used for that purpose.
rows *sql.Rows //
} // It ChunkParser argument will inform the query and its params,
// and the information that will be used to iterate on the results,
// Close ... // namely:
func (i *iterator) Close() error { // (1) The Chunk, which must be a pointer to a slice of structs where
if i.isClosed { // the results of the query will be kept on each iteration.
return nil // (2) The ChunkSize that describes how many rows should be loaded
} // on the Chunk slice before running the iteration callback.
i.isClosed = true // (3) The ForEachChunk function, which is the iteration callback
return i.rows.Close() // and will be called right after the Chunk is filled with rows
} // and/or after the last row is read from the database.
func (c Client) QueryChunks(
var noopCloser = iterator{isClosed: true}
// Query builds an iterator for querying several
// results from the database
func (c Client) Query(
ctx context.Context, ctx context.Context,
query string, parser ChunkParser,
params ...interface{}, ) error {
) (Iterator, error) { it := c.db.Raw(parser.Query, parser.Params...)
it := c.db.Raw(query, params...)
if it.Error != nil { if it.Error != nil {
return &noopCloser, it.Error return it.Error
} }
rows, err := it.Rows() rows, err := it.Rows()
if err != nil { if err != nil {
return &noopCloser, err return err
}
defer rows.Close()
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(parser.Chunk)
if err != nil {
return err
} }
return &iterator{ slice := sliceRef.Elem()
isClosed: false, var idx = 0
rows: rows, for ; rows.Next(); idx++ {
}, nil if slice.Len() <= idx {
} var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
slice = reflect.Append(slice, elemValue)
}
// QueryNext parses the next row of a query err = c.db.ScanRows(rows, slice.Index(idx).Addr().Interface())
// and updates the item argument that must be if err != nil {
// passed by reference. return err
func (c Client) QueryNext( }
ctx context.Context,
rawIt Iterator, if idx == parser.ChunkSize-1 {
item interface{}, idx = 0
) (done bool, err error) { sliceRef.Elem().Set(slice)
it, ok := rawIt.(*iterator) err = parser.ForEachChunk()
if !ok { if err != nil {
return false, fmt.Errorf("invalid iterator received on QueryNext()") return err
}
}
} }
if it.isClosed { // If no rows were found or idx was reset to 0
return false, fmt.Errorf("received closed iterator") // on the last iteration skip this last call to ForEachChunk:
if idx > 0 {
sliceRef.Elem().Set(slice.Slice(0, idx))
err = parser.ForEachChunk()
if err != nil {
return err
}
} }
if !it.rows.Next() { return nil
it.Close()
return true, it.rows.Err()
}
return false, c.db.ScanRows(it.rows, item)
} }
// Insert one or more instances on the database // Insert one or more instances on the database
@ -315,24 +328,64 @@ func FillStructWith(entity interface{}, dbRow map[string]interface{}) error {
// and the second is a slice of maps representing the database rows you want // and the second is a slice of maps representing the database rows you want
// to use to update this struct. // to use to update this struct.
func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error {
slicePtrValue := reflect.ValueOf(entities) sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(entities)
if err != nil {
return err
}
info, found := tagInfoCache[structType]
if !found {
info = getTagNames(structType)
tagInfoCache[structType] = info
}
slice := sliceRef.Elem()
for idx, row := range dbRows {
if slice.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
slice = reflect.Append(slice, elemValue)
}
err := FillStructWith(slice.Index(idx).Addr().Interface(), row)
if err != nil {
return err
}
}
sliceRef.Elem().Set(slice)
return nil
}
func decodeAsSliceOfStructs(slice interface{}) (
sliceRef reflect.Value,
structType reflect.Type,
isSliceOfPtrs bool,
err error,
) {
slicePtrValue := reflect.ValueOf(slice)
slicePtrType := slicePtrValue.Type() slicePtrType := slicePtrValue.Type()
if slicePtrType.Kind() != reflect.Ptr { if slicePtrType.Kind() != reflect.Ptr {
return fmt.Errorf( err = fmt.Errorf(
"FillListWith: expected input to be a pointer to struct but got %T", "FillListWith: expected input to be a pointer to struct but got %T",
entities, slice,
) )
return
} }
t := slicePtrType.Elem() t := slicePtrType.Elem()
v := slicePtrValue.Elem()
if t.Kind() != reflect.Slice { if t.Kind() != reflect.Slice {
return fmt.Errorf( err = fmt.Errorf(
"FillListWith: expected input kind to be a slice but got %T", "FillListWith: expected input kind to be a slice but got %T",
entities, slice,
) )
return
} }
elemType := t.Elem() elemType := t.Elem()
@ -342,36 +395,13 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
elemType = elemType.Elem() elemType = elemType.Elem()
} }
info, found := tagInfoCache[elemType]
if !found {
info = getTagNames(elemType)
tagInfoCache[elemType] = info
}
if elemType.Kind() != reflect.Struct { if elemType.Kind() != reflect.Struct {
return fmt.Errorf( err = fmt.Errorf(
"FillListWith: expected input to be a slice of structs but got %T", "FillListWith: expected input to be a slice of structs but got %T",
entities, slice,
) )
return
} }
for idx, row := range dbRows { return slicePtrValue, elemType, isPtr, nil
if v.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(elemType)
if !isPtr {
elemValue = elemValue.Elem()
}
v = reflect.Append(v, elemValue)
}
err := FillStructWith(v.Index(idx).Addr().Interface(), row)
if err != nil {
return err
}
}
slicePtrValue.Elem().Set(v)
return nil
} }

View File

@ -305,8 +305,8 @@ func TestStructToMap(t *testing.T) {
}) })
} }
func TestQuery(t *testing.T) { func TestQueryChunks(t *testing.T) {
t.Run("should execute query one correctly", func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable() err := createTable()
if err != nil { if err != nil {
t.Fatal("could not create test table!") t.Fatal("could not create test table!")
@ -323,136 +323,29 @@ func TestQuery(t *testing.T) {
_ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User1"})
it, err := c.Query(ctx, `select * from users where name = ?;`, "User1") var length int
assert.Equal(t, nil, err) var u User
var users []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name = ?;`,
Params: []interface{}{"User1"},
u := User{} ChunkSize: 100,
_, err = c.QueryNext(ctx, it, &u) Chunk: &users,
it.Close() ForEachChunk: func() error {
length = len(users)
if length > 0 {
u = users[0]
}
return nil
},
})
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 1, length)
assert.NotEqual(t, 0, u.ID) assert.NotEqual(t, 0, u.ID)
assert.Equal(t, "User1", u.Name) assert.Equal(t, "User1", u.Name)
}) })
t.Run("should execute query many correctly", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
it, err := c.Query(ctx, `select * from users where name in (?,?);`, "User1", "User3")
assert.Equal(t, nil, err)
// var results []User
u := User{}
u2 := User{}
u3 := User{}
done, err := c.QueryNext(ctx, it, &u)
assert.Equal(t, false, done)
assert.Equal(t, nil, err)
done, err = c.QueryNext(ctx, it, &u2)
assert.Equal(t, false, done)
assert.Equal(t, nil, err)
done, err = c.QueryNext(ctx, it, &u3)
assert.Equal(t, true, done)
assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID)
assert.Equal(t, "User1", u.Name)
assert.NotEqual(t, 0, u2.ID)
assert.Equal(t, "User3", u2.Name)
})
t.Run("should return error for an invalid iterator", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
u := User{}
_, err = c.QueryNext(ctx, Iterator(nil), &u)
assert.NotEqual(t, nil, err)
})
t.Run("should return noop closer when syntax error occurs", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
it, err := c.Query(ctx, `select * users`)
assert.NotEqual(t, nil, err)
assert.Equal(t, &noopCloser, it)
})
t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
it, err := c.Query(ctx, `select * from users`)
assert.Equal(t, nil, err)
err = it.Close()
assert.Equal(t, nil, err)
u := User{}
_, err = c.QueryNext(ctx, it, &u)
assert.NotEqual(t, nil, err)
})
}
func TestIterator(t *testing.T) {
t.Run("should return no errors if it's closed multiple times", func(t *testing.T) {
it := iterator{isClosed: true}
err := it.Close()
assert.Equal(t, nil, err)
assert.Equal(t, true, it.isClosed)
})
} }
func TestFillSliceWith(t *testing.T) { func TestFillSliceWith(t *testing.T) {