diff --git a/contracts.go b/contracts.go index ceebfe3..006ee77 100644 --- a/contracts.go +++ b/contracts.go @@ -30,14 +30,9 @@ type ChunkParser struct { Params []interface{} ChunkSize int - Chunk interface{} // Must be a pointer to a slice of structs - // The closure that will be called right after - // filling the Chunk with ChunkSize records - // - // Each chunk consecutively parsed will overwrite the - // same slice, so don't keep references to it, if you - // need some data to be preserved after all chunks are - // processed copy the records by value. - ForEachChunk func() error + // This attribute must be a func(chunk []) error, + // where the actual Record should be a struct + // representing the rows you are expecting to receive. + ForEachChunk interface{} } diff --git a/kiss_orm.go b/kiss_orm.go index 7dc0d86..15561cf 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -139,28 +139,33 @@ func (c Client) QueryChunks( } defer rows.Close() - sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(parser.Chunk) + fnValue := reflect.ValueOf(parser.ForEachChunk) + chunkType, err := parseInputFunc(parser.ForEachChunk) if err != nil { return err } - slice := sliceRef.Elem() - if slice.Len() > parser.ChunkSize { - slice = slice.Slice(0, parser.ChunkSize) + chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize) + + structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(chunkType) + if err != nil { + return err } var idx = 0 for rows.Next() { - if slice.Len() <= idx { + // Allocate new slice elements + // only if they are not already allocated: + if chunk.Len() <= idx { var elemValue reflect.Value elemValue = reflect.New(structType) if !isSliceOfPtrs { elemValue = elemValue.Elem() } - slice = reflect.Append(slice, elemValue) + chunk = reflect.Append(chunk, elemValue) } - err = c.db.ScanRows(rows, slice.Index(idx).Addr().Interface()) + err = c.db.ScanRows(rows, chunk.Index(idx).Addr().Interface()) if err != nil { return err } @@ -171,8 +176,7 @@ func (c Client) QueryChunks( } idx = 0 - sliceRef.Elem().Set(slice) - err = parser.ForEachChunk() + err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error) if err != nil { if err == ErrAbortIteration { return nil @@ -184,8 +188,9 @@ func (c Client) QueryChunks( // If no rows were found or idx was reset to 0 // on the last iteration skip this last call to ForEachChunk: if idx > 0 { - sliceRef.Elem().Set(slice.Slice(0, idx)) - err = parser.ForEachChunk() + chunk = chunk.Slice(0, idx) + + err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error) if err != nil { if err == ErrAbortIteration { return nil @@ -435,9 +440,18 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error { // and the second is a slice of maps representing the database rows you want // to use to update this struct. func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { - sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(entities) + sliceRef := reflect.ValueOf(entities) + sliceType := sliceRef.Type() + if sliceType.Kind() != reflect.Ptr { + return fmt.Errorf( + "FillSliceWith: expected input to be a pointer to struct but got %v", + sliceType, + ) + } + + structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(sliceType.Elem()) if err != nil { - return err + return fmt.Errorf("FillSliceWith: %s", err.Error()) } info, found := tagInfoCache[structType] @@ -468,34 +482,20 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error return nil } -func decodeAsSliceOfStructs(slice interface{}) ( - sliceRef reflect.Value, +func decodeAsSliceOfStructs(slice reflect.Type) ( structType reflect.Type, isSliceOfPtrs bool, err error, ) { - slicePtrValue := reflect.ValueOf(slice) - slicePtrType := slicePtrValue.Type() - - if slicePtrType.Kind() != reflect.Ptr { + if slice.Kind() != reflect.Slice { err = fmt.Errorf( - "FillListWith: expected input to be a pointer to struct but got %T", + "expected input kind to be a slice but got %v", slice, ) return } - t := slicePtrType.Elem() - - if t.Kind() != reflect.Slice { - err = fmt.Errorf( - "FillListWith: expected input kind to be a slice but got %T", - slice, - ) - return - } - - elemType := t.Elem() + elemType := slice.Elem() isPtr := elemType.Kind() == reflect.Ptr if isPtr { @@ -504,11 +504,39 @@ func decodeAsSliceOfStructs(slice interface{}) ( if elemType.Kind() != reflect.Struct { err = fmt.Errorf( - "FillListWith: expected input to be a slice of structs but got %T", + "expected input to be a slice of structs but got %v", slice, ) return } - return slicePtrValue, elemType, isPtr, nil + return elemType, isPtr, nil +} + +var errType = reflect.TypeOf(new(error)).Elem() + +func parseInputFunc(fn interface{}) (reflect.Type, error) { + t := reflect.TypeOf(fn) + + if t.Kind() != reflect.Func { + return nil, fmt.Errorf("the ForEachChunk callback must be a function") + } + if t.NumIn() != 1 { + return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument") + } + + if t.NumOut() != 1 { + return nil, fmt.Errorf("the ForEachChunk callback must have a single return value") + } + + if t.Out(0) != errType { + return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error") + } + + argsType := t.In(0) + if argsType.Kind() != reflect.Slice { + return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs") + } + + return argsType, nil } diff --git a/kiss_orm_test.go b/kiss_orm_test.go index d5dbd11..86e56f4 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -565,14 +565,12 @@ func TestQueryChunks(t *testing.T) { var length int var u User - var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name = ?;`, Params: []interface{}{"User1"}, ChunkSize: 100, - Chunk: &users, - ForEachChunk: func() error { + ForEachChunk: func(users []User) error { length = len(users) if length > 0 { u = users[0] @@ -605,20 +603,23 @@ func TestQueryChunks(t *testing.T) { _ = c.Insert(ctx, &User{Name: "User1"}) _ = c.Insert(ctx, &User{Name: "User2"}) + var lengths []int var users []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ? order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, - Chunk: &users, - ForEachChunk: func() error { + ForEachChunk: func(buffer []User) error { + users = append(users, buffer...) + lengths = append(lengths, len(buffer)) return nil }, }) assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) + assert.Equal(t, 1, len(lengths)) + assert.Equal(t, 2, lengths[0]) assert.NotEqual(t, 0, users[0].ID) assert.Equal(t, "User1", users[0].Name) assert.NotEqual(t, 0, users[1].ID) @@ -645,14 +646,12 @@ func TestQueryChunks(t *testing.T) { var lengths []int var users []User - var buffer []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ? order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 1, - Chunk: &buffer, - ForEachChunk: func() error { + ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return nil @@ -689,14 +688,12 @@ func TestQueryChunks(t *testing.T) { var lengths []int var users []User - var buffer []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ? order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, - Chunk: &buffer, - ForEachChunk: func() error { + ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return nil @@ -735,14 +732,12 @@ func TestQueryChunks(t *testing.T) { var lengths []int var users []User - var buffer []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ? order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, - Chunk: &buffer, - ForEachChunk: func() error { + ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...) return ErrAbortIteration @@ -780,14 +775,12 @@ func TestQueryChunks(t *testing.T) { returnVals := []error{nil, ErrAbortIteration} var lengths []int var users []User - var buffer []User err = c.QueryChunks(ctx, ChunkParser{ Query: `select * from users where name like ? order by name asc;`, Params: []interface{}{"User%"}, ChunkSize: 2, - Chunk: &buffer, - ForEachChunk: func() error { + ForEachChunk: func(buffer []User) error { lengths = append(lengths, len(buffer)) users = append(users, buffer...)