From 5822e23de424c01ed56becca24adf6ab0cb586e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Wed, 14 Oct 2020 15:18:17 -0300 Subject: [PATCH] Add tests to QueryChunks function This commit also fixes some bugs on this function and adds a feature: Now you can return an kissorm.AbortIteration error to abort the iteration and stop processing chunks. This does not causes the call to QueryChunks to return an error, since this is an expected error, thus, it is just ignored. --- contracts.go | 5 +- kiss_orm.go | 27 ++++-- kiss_orm_test.go | 225 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 249 insertions(+), 8 deletions(-) diff --git a/contracts.go b/contracts.go index 0723f96..814a294 100644 --- a/contracts.go +++ b/contracts.go @@ -5,9 +5,12 @@ import ( "fmt" ) -// EntityNotFound ... +// EntityNotFoundErr ... var EntityNotFoundErr error = fmt.Errorf("kissorm: the query returned no results") +// AbortIteration ... +var AbortIteration error = fmt.Errorf("kissorm: abort iteration, should only be used inside QueryChunks function") + // ORMProvider describes the public behavior of this ORM type ORMProvider interface { Insert(ctx context.Context, records ...interface{}) error diff --git a/kiss_orm.go b/kiss_orm.go index f2ab40e..234801a 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -144,8 +144,12 @@ func (c Client) QueryChunks( } slice := sliceRef.Elem() + if slice.Len() > parser.ChunkSize { + slice = slice.Slice(0, parser.ChunkSize) + } + var idx = 0 - for ; rows.Next(); idx++ { + for rows.Next() { if slice.Len() <= idx { var elemValue reflect.Value elemValue = reflect.New(structType) @@ -160,13 +164,19 @@ func (c Client) QueryChunks( return err } - if idx == parser.ChunkSize-1 { - idx = 0 - sliceRef.Elem().Set(slice) - err = parser.ForEachChunk() - if err != nil { - return err + if idx < parser.ChunkSize-1 { + idx++ + continue + } + + idx = 0 + sliceRef.Elem().Set(slice) + err = parser.ForEachChunk() + if err != nil { + if err == AbortIteration { + return nil } + return err } } @@ -176,6 +186,9 @@ func (c Client) QueryChunks( sliceRef.Elem().Set(slice.Slice(0, idx)) err = parser.ForEachChunk() if err != nil { + if err == AbortIteration { + return nil + } return err } } diff --git a/kiss_orm_test.go b/kiss_orm_test.go index 8b84d38..261f661 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -445,6 +445,225 @@ func TestQueryChunks(t *testing.T) { assert.NotEqual(t, 0, u.ID) assert.Equal(t, "User1", u.Name) }) + + t.Run("should query one chunk correctly", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &users, + ForEachChunk: func() error { + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + }) + + t.Run("should query chunks of 1 correctly", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 1, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{1, 1}, lengths) + }) + + t.Run("should load partially filled chunks correctly", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, 0, users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should abort the first iteration when the callback returns an AbortIteration", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return AbortIteration + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{2}, lengths) + }) + + t.Run("should abort the last iteration when the callback returns an AbortIteration", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + returnVals := []error{nil, AbortIteration} + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + + return shiftErrSlice(&returnVals) + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, 0, users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) } func TestFillSliceWith(t *testing.T) { @@ -490,3 +709,9 @@ func connectDB(t *testing.T) *gorm.DB { } return db } + +func shiftErrSlice(errs *[]error) error { + err := (*errs)[0] + *errs = (*errs)[1:] + return err +}