From a3bf34146dbfee9cbbdf5882c880704e0f6215d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Fri, 9 Oct 2020 15:26:00 -0300 Subject: [PATCH] Replace Query&QueryNext for the new QueryChunks func --- contracts.go | 23 ++++-- kiss_orm.go | 198 +++++++++++++++++++++++++++-------------------- kiss_orm_test.go | 145 +++++----------------------------- 3 files changed, 151 insertions(+), 215 deletions(-) diff --git a/contracts.go b/contracts.go index 8ccae90..95aa456 100644 --- a/contracts.go +++ b/contracts.go @@ -10,11 +10,24 @@ type ORMProvider interface { Insert(ctx context.Context, items ...interface{}) error Delete(ctx context.Context, ids ...interface{}) error Update(ctx context.Context, items ...interface{}) error - Query(ctx context.Context, query string, params ...interface{}) (Iterator, error) - QueryNext(ctx context.Context, rawIt Iterator, item interface{}) (done bool, err error) + QueryChunks(ctx context.Context, parser ChunkParser) error } -// Iterator ... -type Iterator interface { - Close() error +type ChunkParser struct { + // The Query and Params are used together to build a query with + // protection from injection, just like when using the Find function. + Query string + Params []interface{} + + ChunkSize int + Chunk interface{} // Must be a pointer to a slice of structs + + // The closure that will be called right after + // filling the Chunk with ChunkSize items + // + // Each chunk consecutively parsed will overwrite the + // same slice, so don't keep references to it, if you + // need some data to be preserved after all chunks are + // processed copy the items by value. + ForEachChunk func() error } diff --git a/kiss_orm.go b/kiss_orm.go index 3f78445..57fa2f0 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -2,7 +2,6 @@ package kissorm import ( "context" - "database/sql" "fmt" "reflect" @@ -16,7 +15,12 @@ type Client struct { } // NewClient instantiates a new client -func NewClient(dbDriver string, connectionString string, maxOpenConns int, tableName string) (Client, error) { +func NewClient( + dbDriver string, + connectionString string, + maxOpenConns int, + tableName string, +) (Client, error) { db, err := gorm.Open(dbDriver, connectionString) if err != nil { return Client{}, err @@ -58,68 +62,77 @@ func (c Client) Find( return it.Error } -type iterator struct { - isClosed bool - rows *sql.Rows -} - -// Close ... -func (i *iterator) Close() error { - if i.isClosed { - return nil - } - i.isClosed = true - return i.rows.Close() -} - -var noopCloser = iterator{isClosed: true} - -// Query builds an iterator for querying several -// results from the database -func (c Client) Query( +// QueryChunks is meant to perform queries that returns +// many results and should only be used for that purpose. +// +// It ChunkParser argument will inform the query and its params, +// and the information that will be used to iterate on the results, +// namely: +// (1) The Chunk, which must be a pointer to a slice of structs where +// the results of the query will be kept on each iteration. +// (2) The ChunkSize that describes how many rows should be loaded +// on the Chunk slice before running the iteration callback. +// (3) The ForEachChunk function, which is the iteration callback +// and will be called right after the Chunk is filled with rows +// and/or after the last row is read from the database. +func (c Client) QueryChunks( ctx context.Context, - query string, - params ...interface{}, -) (Iterator, error) { - it := c.db.Raw(query, params...) + parser ChunkParser, +) error { + it := c.db.Raw(parser.Query, parser.Params...) if it.Error != nil { - return &noopCloser, it.Error + return it.Error } rows, err := it.Rows() if err != nil { - return &noopCloser, err + return err + } + defer rows.Close() + + sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(parser.Chunk) + if err != nil { + return err } - return &iterator{ - isClosed: false, - rows: rows, - }, nil -} + slice := sliceRef.Elem() + var idx = 0 + for ; rows.Next(); idx++ { + if slice.Len() <= idx { + var elemValue reflect.Value + elemValue = reflect.New(structType) + if !isSliceOfPtrs { + elemValue = elemValue.Elem() + } + slice = reflect.Append(slice, elemValue) + } -// QueryNext parses the next row of a query -// and updates the item argument that must be -// passed by reference. -func (c Client) QueryNext( - ctx context.Context, - rawIt Iterator, - item interface{}, -) (done bool, err error) { - it, ok := rawIt.(*iterator) - if !ok { - return false, fmt.Errorf("invalid iterator received on QueryNext()") + err = c.db.ScanRows(rows, slice.Index(idx).Addr().Interface()) + if err != nil { + return err + } + + if idx == parser.ChunkSize-1 { + idx = 0 + sliceRef.Elem().Set(slice) + err = parser.ForEachChunk() + if err != nil { + return err + } + } } - if it.isClosed { - return false, fmt.Errorf("received closed iterator") + // If no rows were found or idx was reset to 0 + // on the last iteration skip this last call to ForEachChunk: + if idx > 0 { + sliceRef.Elem().Set(slice.Slice(0, idx)) + err = parser.ForEachChunk() + if err != nil { + return err + } } - if !it.rows.Next() { - it.Close() - return true, it.rows.Err() - } - - return false, c.db.ScanRows(it.rows, item) + return nil } // Insert one or more instances on the database @@ -315,24 +328,64 @@ func FillStructWith(entity interface{}, dbRow map[string]interface{}) error { // and the second is a slice of maps representing the database rows you want // to use to update this struct. func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { - slicePtrValue := reflect.ValueOf(entities) + sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(entities) + if err != nil { + return err + } + + info, found := tagInfoCache[structType] + if !found { + info = getTagNames(structType) + tagInfoCache[structType] = info + } + + slice := sliceRef.Elem() + for idx, row := range dbRows { + if slice.Len() <= idx { + var elemValue reflect.Value + elemValue = reflect.New(structType) + if !isSliceOfPtrs { + elemValue = elemValue.Elem() + } + slice = reflect.Append(slice, elemValue) + } + + err := FillStructWith(slice.Index(idx).Addr().Interface(), row) + if err != nil { + return err + } + } + + sliceRef.Elem().Set(slice) + + return nil +} + +func decodeAsSliceOfStructs(slice interface{}) ( + sliceRef reflect.Value, + structType reflect.Type, + isSliceOfPtrs bool, + err error, +) { + slicePtrValue := reflect.ValueOf(slice) slicePtrType := slicePtrValue.Type() if slicePtrType.Kind() != reflect.Ptr { - return fmt.Errorf( + err = fmt.Errorf( "FillListWith: expected input to be a pointer to struct but got %T", - entities, + slice, ) + return } t := slicePtrType.Elem() - v := slicePtrValue.Elem() if t.Kind() != reflect.Slice { - return fmt.Errorf( + err = fmt.Errorf( "FillListWith: expected input kind to be a slice but got %T", - entities, + slice, ) + return } elemType := t.Elem() @@ -342,36 +395,13 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error elemType = elemType.Elem() } - info, found := tagInfoCache[elemType] - if !found { - info = getTagNames(elemType) - tagInfoCache[elemType] = info - } - if elemType.Kind() != reflect.Struct { - return fmt.Errorf( + err = fmt.Errorf( "FillListWith: expected input to be a slice of structs but got %T", - entities, + slice, ) + return } - for idx, row := range dbRows { - if v.Len() <= idx { - var elemValue reflect.Value - elemValue = reflect.New(elemType) - if !isPtr { - elemValue = elemValue.Elem() - } - v = reflect.Append(v, elemValue) - } - - err := FillStructWith(v.Index(idx).Addr().Interface(), row) - if err != nil { - return err - } - } - - slicePtrValue.Elem().Set(v) - - return nil + return slicePtrValue, elemType, isPtr, nil } diff --git a/kiss_orm_test.go b/kiss_orm_test.go index 9b8e89f..29ec778 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -305,8 +305,8 @@ func TestStructToMap(t *testing.T) { }) } -func TestQuery(t *testing.T) { - t.Run("should execute query one correctly", func(t *testing.T) { +func TestQueryChunks(t *testing.T) { + t.Run("should query a single row correctly", func(t *testing.T) { err := createTable() if err != nil { t.Fatal("could not create test table!") @@ -323,136 +323,29 @@ func TestQuery(t *testing.T) { _ = c.Insert(ctx, &User{Name: "User1"}) - it, err := c.Query(ctx, `select * from users where name = ?;`, "User1") - assert.Equal(t, nil, err) + var length int + var u User + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name = ?;`, + Params: []interface{}{"User1"}, - u := User{} - _, err = c.QueryNext(ctx, it, &u) - it.Close() + ChunkSize: 100, + Chunk: &users, + ForEachChunk: func() error { + length = len(users) + if length > 0 { + u = users[0] + } + return nil + }, + }) assert.Equal(t, nil, err) + assert.Equal(t, 1, length) assert.NotEqual(t, 0, u.ID) assert.Equal(t, "User1", u.Name) }) - - t.Run("should execute query many correctly", func(t *testing.T) { - err := createTable() - if err != nil { - t.Fatal("could not create test table!") - } - - db := connectDB(t) - defer db.Close() - - ctx := context.Background() - c := Client{ - db: db, - tableName: "users", - } - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - it, err := c.Query(ctx, `select * from users where name in (?,?);`, "User1", "User3") - assert.Equal(t, nil, err) - - // var results []User - u := User{} - u2 := User{} - u3 := User{} - done, err := c.QueryNext(ctx, it, &u) - assert.Equal(t, false, done) - assert.Equal(t, nil, err) - - done, err = c.QueryNext(ctx, it, &u2) - assert.Equal(t, false, done) - assert.Equal(t, nil, err) - - done, err = c.QueryNext(ctx, it, &u3) - assert.Equal(t, true, done) - assert.Equal(t, nil, err) - - assert.NotEqual(t, 0, u.ID) - assert.Equal(t, "User1", u.Name) - assert.NotEqual(t, 0, u2.ID) - assert.Equal(t, "User3", u2.Name) - }) - - t.Run("should return error for an invalid iterator", func(t *testing.T) { - err := createTable() - if err != nil { - t.Fatal("could not create test table!") - } - - db := connectDB(t) - defer db.Close() - - ctx := context.Background() - c := Client{ - db: db, - tableName: "users", - } - - u := User{} - _, err = c.QueryNext(ctx, Iterator(nil), &u) - - assert.NotEqual(t, nil, err) - }) - - t.Run("should return noop closer when syntax error occurs", func(t *testing.T) { - err := createTable() - if err != nil { - t.Fatal("could not create test table!") - } - - db := connectDB(t) - defer db.Close() - - ctx := context.Background() - c := Client{ - db: db, - tableName: "users", - } - - it, err := c.Query(ctx, `select * users`) - assert.NotEqual(t, nil, err) - assert.Equal(t, &noopCloser, it) - }) - - t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) { - err := createTable() - if err != nil { - t.Fatal("could not create test table!") - } - - db := connectDB(t) - defer db.Close() - - ctx := context.Background() - c := Client{ - db: db, - tableName: "users", - } - - it, err := c.Query(ctx, `select * from users`) - assert.Equal(t, nil, err) - err = it.Close() - assert.Equal(t, nil, err) - u := User{} - _, err = c.QueryNext(ctx, it, &u) - - assert.NotEqual(t, nil, err) - }) -} - -func TestIterator(t *testing.T) { - t.Run("should return no errors if it's closed multiple times", func(t *testing.T) { - it := iterator{isClosed: true} - err := it.Close() - assert.Equal(t, nil, err) - assert.Equal(t, true, it.isClosed) - }) } func TestFillSliceWith(t *testing.T) {