diff --git a/contracts.go b/contracts.go index 0723f96..814a294 100644 --- a/contracts.go +++ b/contracts.go @@ -5,9 +5,12 @@ import ( "fmt" ) -// EntityNotFound ... +// EntityNotFoundErr ... var EntityNotFoundErr error = fmt.Errorf("kissorm: the query returned no results") +// AbortIteration ... +var AbortIteration error = fmt.Errorf("kissorm: abort iteration, should only be used inside QueryChunks function") + // ORMProvider describes the public behavior of this ORM type ORMProvider interface { Insert(ctx context.Context, records ...interface{}) error diff --git a/kiss_orm.go b/kiss_orm.go index f2ab40e..234801a 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -144,8 +144,12 @@ func (c Client) QueryChunks( } slice := sliceRef.Elem() + if slice.Len() > parser.ChunkSize { + slice = slice.Slice(0, parser.ChunkSize) + } + var idx = 0 - for ; rows.Next(); idx++ { + for rows.Next() { if slice.Len() <= idx { var elemValue reflect.Value elemValue = reflect.New(structType) @@ -160,13 +164,19 @@ func (c Client) QueryChunks( return err } - if idx == parser.ChunkSize-1 { - idx = 0 - sliceRef.Elem().Set(slice) - err = parser.ForEachChunk() - if err != nil { - return err + if idx < parser.ChunkSize-1 { + idx++ + continue + } + + idx = 0 + sliceRef.Elem().Set(slice) + err = parser.ForEachChunk() + if err != nil { + if err == AbortIteration { + return nil } + return err } } @@ -176,6 +186,9 @@ func (c Client) QueryChunks( sliceRef.Elem().Set(slice.Slice(0, idx)) err = parser.ForEachChunk() if err != nil { + if err == AbortIteration { + return nil + } return err } } diff --git a/kiss_orm_test.go b/kiss_orm_test.go index 8b84d38..261f661 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -445,6 +445,225 @@ func TestQueryChunks(t *testing.T) { assert.NotEqual(t, 0, u.ID) assert.Equal(t, "User1", u.Name) }) + + t.Run("should query one chunk correctly", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &users, + ForEachChunk: func() error { + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + }) + + t.Run("should query chunks of 1 correctly", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 1, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{1, 1}, lengths) + }) + + t.Run("should load partially filled chunks correctly", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, 0, users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should abort the first iteration when the callback returns an AbortIteration", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return AbortIteration + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{2}, lengths) + }) + + t.Run("should abort the last iteration when the callback returns an AbortIteration", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + returnVals := []error{nil, AbortIteration} + var lengths []int + var users []User + var buffer []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: `select * from users where name like ? order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + Chunk: &buffer, + ForEachChunk: func() error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + + return shiftErrSlice(&returnVals) + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, 0, users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, 0, users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, 0, users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) } func TestFillSliceWith(t *testing.T) { @@ -490,3 +709,9 @@ func connectDB(t *testing.T) *gorm.DB { } return db } + +func shiftErrSlice(errs *[]error) error { + err := (*errs)[0] + *errs = (*errs)[1:] + return err +}