diff --git a/ksql.go b/ksql.go index f5f53c6..9ddee20 100644 --- a/ksql.go +++ b/ksql.go @@ -122,6 +122,16 @@ func (c DB) Query( query string, params ...interface{}, ) error { + // Check if we the user wants to use the lazy chunked approach: + if chunksPtr, ok := records.(*Chunks); ok { + *chunksPtr = chunks{ + db: c, + query: query, + params: params, + } + return nil + } + slicePtr := reflect.ValueOf(records) slicePtrType := slicePtr.Type() if slicePtrType.Kind() != reflect.Ptr { @@ -271,6 +281,28 @@ func (c DB) QueryOne( return rows.Close() } +type Chunks interface { + ForEach(ctx context.Context, chunkSize int, fn interface{}) error +} + +// chunks stores a query to be executed lazily afterwards. +// the chunks ForEach function can be called any number of times and will +// repeat the query each time it is called. +type chunks struct { + db Provider + query string + params []interface{} +} + +func (c chunks) ForEach(ctx context.Context, chunkSize int, fn interface{}) error { + return c.db.QueryChunks(ctx, ChunkParser{ + ChunkSize: chunkSize, + Query: c.query, + Params: c.params, + ForEachChunk: fn, + }) +} + // QueryChunks is meant to perform queries that returns // more results than would normally fit on memory, // for others cases the Query and QueryOne functions are indicated. diff --git a/test_adapters.go b/test_adapters.go index fa3b4ae..d6a7bc1 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -70,7 +70,22 @@ func RunTestsForAdapter( InsertTest(t, driver, connStr, newDBAdapter) DeleteTest(t, driver, connStr, newDBAdapter) UpdateTest(t, driver, connStr, newDBAdapter) - QueryChunksTest(t, driver, connStr, newDBAdapter) + + // We are keeping this callback to simplify how we are testing both ways of querying chunks. + // In the future we plan on deprecating and eventually deleting one of them: + QueryChunksTest(t, driver, connStr, newDBAdapter, func(db Provider, ctx context.Context, parser ChunkParser) error { + return db.QueryChunks(ctx, parser) + }) + QueryChunksTest(t, driver, connStr, newDBAdapter, func(db Provider, ctx context.Context, parser ChunkParser) error { + var chunks Chunks + err := db.Query(ctx, &chunks, parser.Query, parser.Params...) + if err != nil { + return err + } + + return chunks.ForEach(ctx, parser.ChunkSize, parser.ForEachChunk) + }) + TransactionTest(t, driver, connStr, newDBAdapter) ScanRowsTest(t, driver, connStr, newDBAdapter) }) @@ -1531,6 +1546,7 @@ func QueryChunksTest( driver string, connStr string, newDBAdapter func(t *testing.T) (DBAdapter, io.Closer), + queryChunks func(Provider, context.Context, ChunkParser) error, ) { t.Run("QueryChunks", func(t *testing.T) { variations := []struct { @@ -1567,7 +1583,7 @@ func QueryChunksTest( var length int var u user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `FROM users WHERE name = ` + c.dialect.Placeholder(0), Params: []interface{}{"User1"}, @@ -1605,7 +1621,7 @@ func QueryChunksTest( var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1647,7 +1663,7 @@ func QueryChunksTest( var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1690,7 +1706,7 @@ func QueryChunksTest( var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1713,7 +1729,6 @@ func QueryChunksTest( assert.Equal(t, []int{2, 1}, lengths) }) - // xxx t.Run("should query joined tables correctly", func(t *testing.T) { // This test only makes sense with no query prefix if variation.queryPrefix != "" { @@ -1747,7 +1762,7 @@ func QueryChunksTest( var lengths []int var users []user var posts []post - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: fmt.Sprint( `FROM users u JOIN posts p ON p.user_id = u.id`, ` WHERE u.name like `, c.dialect.Placeholder(0), @@ -1803,7 +1818,7 @@ func QueryChunksTest( var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1843,7 +1858,7 @@ func QueryChunksTest( returnVals := []error{nil, ErrAbortIteration} var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1885,7 +1900,7 @@ func QueryChunksTest( var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1925,7 +1940,7 @@ func QueryChunksTest( returnVals := []error{nil, errors.New("fake error msg")} var lengths []int var users []user - err = c.QueryChunks(ctx, ChunkParser{ + err = queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, Params: []interface{}{"User%"}, @@ -1982,7 +1997,7 @@ func QueryChunksTest( } for _, fn := range funcs { - err := c.QueryChunks(ctx, ChunkParser{ + err := queryChunks(c, ctx, ChunkParser{ Query: variation.queryPrefix + `FROM users`, Params: []interface{}{}, @@ -1999,7 +2014,7 @@ func QueryChunksTest( ctx := context.Background() c := newTestDB(db, driver) - err := c.QueryChunks(ctx, ChunkParser{ + err := queryChunks(c, ctx, ChunkParser{ Query: `SELECT * FROM not a valid query`, Params: []interface{}{}, @@ -2018,7 +2033,7 @@ func QueryChunksTest( ctx := context.Background() c := newTestDB(db, driver) - err := c.QueryChunks(ctx, ChunkParser{ + err := queryChunks(c, ctx, ChunkParser{ Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`, Params: []interface{}{},