Improve QueryChunks signature to be easier to use

The changes made on this commit were designed by
Raí Tamarindo (raitamarindo@gmail.com) on a previous meeting.
pull/2/head
Vinícius Garcia 2020-10-28 08:55:55 -03:00
parent c1f645216c
commit 6978474d41
3 changed files with 76 additions and 60 deletions

View File

@ -30,14 +30,9 @@ type ChunkParser struct {
Params []interface{}
ChunkSize int
Chunk interface{} // Must be a pointer to a slice of structs
// The closure that will be called right after
// filling the Chunk with ChunkSize records
//
// Each chunk consecutively parsed will overwrite the
// same slice, so don't keep references to it, if you
// need some data to be preserved after all chunks are
// processed copy the records by value.
ForEachChunk func() error
// This attribute must be a func(chunk []<Record>) error,
// where the actual Record should be a struct
// representing the rows you are expecting to receive.
ForEachChunk interface{}
}

View File

@ -139,28 +139,33 @@ func (c Client) QueryChunks(
}
defer rows.Close()
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(parser.Chunk)
fnValue := reflect.ValueOf(parser.ForEachChunk)
chunkType, err := parseInputFunc(parser.ForEachChunk)
if err != nil {
return err
}
slice := sliceRef.Elem()
if slice.Len() > parser.ChunkSize {
slice = slice.Slice(0, parser.ChunkSize)
chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(chunkType)
if err != nil {
return err
}
var idx = 0
for rows.Next() {
if slice.Len() <= idx {
// Allocate new slice elements
// only if they are not already allocated:
if chunk.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
slice = reflect.Append(slice, elemValue)
chunk = reflect.Append(chunk, elemValue)
}
err = c.db.ScanRows(rows, slice.Index(idx).Addr().Interface())
err = c.db.ScanRows(rows, chunk.Index(idx).Addr().Interface())
if err != nil {
return err
}
@ -171,8 +176,7 @@ func (c Client) QueryChunks(
}
idx = 0
sliceRef.Elem().Set(slice)
err = parser.ForEachChunk()
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
if err != nil {
if err == ErrAbortIteration {
return nil
@ -184,8 +188,9 @@ func (c Client) QueryChunks(
// If no rows were found or idx was reset to 0
// on the last iteration skip this last call to ForEachChunk:
if idx > 0 {
sliceRef.Elem().Set(slice.Slice(0, idx))
err = parser.ForEachChunk()
chunk = chunk.Slice(0, idx)
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
if err != nil {
if err == ErrAbortIteration {
return nil
@ -435,9 +440,18 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error {
// and the second is a slice of maps representing the database rows you want
// to use to update this struct.
func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error {
sliceRef, structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(entities)
sliceRef := reflect.ValueOf(entities)
sliceType := sliceRef.Type()
if sliceType.Kind() != reflect.Ptr {
return fmt.Errorf(
"FillSliceWith: expected input to be a pointer to struct but got %v",
sliceType,
)
}
structType, isSliceOfPtrs, err := decodeAsSliceOfStructs(sliceType.Elem())
if err != nil {
return err
return fmt.Errorf("FillSliceWith: %s", err.Error())
}
info, found := tagInfoCache[structType]
@ -468,34 +482,20 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error
return nil
}
func decodeAsSliceOfStructs(slice interface{}) (
sliceRef reflect.Value,
func decodeAsSliceOfStructs(slice reflect.Type) (
structType reflect.Type,
isSliceOfPtrs bool,
err error,
) {
slicePtrValue := reflect.ValueOf(slice)
slicePtrType := slicePtrValue.Type()
if slicePtrType.Kind() != reflect.Ptr {
if slice.Kind() != reflect.Slice {
err = fmt.Errorf(
"FillListWith: expected input to be a pointer to struct but got %T",
"expected input kind to be a slice but got %v",
slice,
)
return
}
t := slicePtrType.Elem()
if t.Kind() != reflect.Slice {
err = fmt.Errorf(
"FillListWith: expected input kind to be a slice but got %T",
slice,
)
return
}
elemType := t.Elem()
elemType := slice.Elem()
isPtr := elemType.Kind() == reflect.Ptr
if isPtr {
@ -504,11 +504,39 @@ func decodeAsSliceOfStructs(slice interface{}) (
if elemType.Kind() != reflect.Struct {
err = fmt.Errorf(
"FillListWith: expected input to be a slice of structs but got %T",
"expected input to be a slice of structs but got %v",
slice,
)
return
}
return slicePtrValue, elemType, isPtr, nil
return elemType, isPtr, nil
}
var errType = reflect.TypeOf(new(error)).Elem()
func parseInputFunc(fn interface{}) (reflect.Type, error) {
t := reflect.TypeOf(fn)
if t.Kind() != reflect.Func {
return nil, fmt.Errorf("the ForEachChunk callback must be a function")
}
if t.NumIn() != 1 {
return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument")
}
if t.NumOut() != 1 {
return nil, fmt.Errorf("the ForEachChunk callback must have a single return value")
}
if t.Out(0) != errType {
return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error")
}
argsType := t.In(0)
if argsType.Kind() != reflect.Slice {
return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs")
}
return argsType, nil
}

View File

@ -565,14 +565,12 @@ func TestQueryChunks(t *testing.T) {
var length int
var u User
var users []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name = ?;`,
Params: []interface{}{"User1"},
ChunkSize: 100,
Chunk: &users,
ForEachChunk: func() error {
ForEachChunk: func(users []User) error {
length = len(users)
if length > 0 {
u = users[0]
@ -605,20 +603,23 @@ func TestQueryChunks(t *testing.T) {
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
var lengths []int
var users []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name like ? order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
Chunk: &users,
ForEachChunk: func() error {
ForEachChunk: func(buffer []User) error {
users = append(users, buffer...)
lengths = append(lengths, len(buffer))
return nil
},
})
assert.Equal(t, nil, err)
assert.Equal(t, 2, len(users))
assert.Equal(t, 1, len(lengths))
assert.Equal(t, 2, lengths[0])
assert.NotEqual(t, 0, users[0].ID)
assert.Equal(t, "User1", users[0].Name)
assert.NotEqual(t, 0, users[1].ID)
@ -645,14 +646,12 @@ func TestQueryChunks(t *testing.T) {
var lengths []int
var users []User
var buffer []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name like ? order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 1,
Chunk: &buffer,
ForEachChunk: func() error {
ForEachChunk: func(buffer []User) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return nil
@ -689,14 +688,12 @@ func TestQueryChunks(t *testing.T) {
var lengths []int
var users []User
var buffer []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name like ? order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
Chunk: &buffer,
ForEachChunk: func() error {
ForEachChunk: func(buffer []User) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return nil
@ -735,14 +732,12 @@ func TestQueryChunks(t *testing.T) {
var lengths []int
var users []User
var buffer []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name like ? order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
Chunk: &buffer,
ForEachChunk: func() error {
ForEachChunk: func(buffer []User) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)
return ErrAbortIteration
@ -780,14 +775,12 @@ func TestQueryChunks(t *testing.T) {
returnVals := []error{nil, ErrAbortIteration}
var lengths []int
var users []User
var buffer []User
err = c.QueryChunks(ctx, ChunkParser{
Query: `select * from users where name like ? order by name asc;`,
Params: []interface{}{"User%"},
ChunkSize: 2,
Chunk: &buffer,
ForEachChunk: func() error {
ForEachChunk: func(buffer []User) error {
lengths = append(lengths, len(buffer))
users = append(users, buffer...)