diff --git a/kiss_orm.go b/kiss_orm.go index fcc1df4..3c60bcf 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -74,6 +74,22 @@ func (c Client) Find( return it.Error } +type iterator struct { + isClosed bool + rows *sql.Rows +} + +// Close ... +func (i *iterator) Close() error { + if i.isClosed { + return nil + } + i.isClosed = true + return i.rows.Close() +} + +var noopCloser = iterator{isClosed: true} + // Query build an iterator for querying several // results from the database func (c Client) Query( @@ -83,10 +99,18 @@ func (c Client) Query( ) (Iterator, error) { it := c.db.Raw(query, params...) if it.Error != nil { - return nil, it.Error + return &noopCloser, it.Error } - return it.Rows() + rows, err := it.Rows() + if err != nil { + return &noopCloser, err + } + + return &iterator{ + isClosed: false, + rows: rows, + }, nil } // QueryNext parses the next row of a query @@ -97,17 +121,21 @@ func (c Client) QueryNext( rawIt Iterator, item interface{}, ) (done bool, err error) { - rows, ok := rawIt.(*sql.Rows) + it, ok := rawIt.(*iterator) if !ok { return false, fmt.Errorf("invalid iterator received on QueryNext()") } - if !rows.Next() { - rows.Close() - return true, rows.Err() + if it.isClosed { + return false, fmt.Errorf("received closed iterator") } - return false, c.db.ScanRows(rows, item) + if !it.rows.Next() { + it.Close() + return true, it.rows.Err() + } + + return false, c.db.ScanRows(it.rows, item) } // GetByID recovers a single entity from the database by the ID field. diff --git a/kiss_orm_test.go b/kiss_orm_test.go index b9bd2d3..ef2f71d 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -413,6 +413,60 @@ func TestQuery(t *testing.T) { assert.NotEqual(t, nil, err) }) + + t.Run("should return noop closer when syntax error occurs", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + it, err := c.Query(ctx, `select * users`) + assert.NotEqual(t, nil, err) + assert.Equal(t, &noopCloser, it) + }) + + t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + + it, err := c.Query(ctx, `select * from users`) + assert.Equal(t, nil, err) + err = it.Close() + assert.Equal(t, nil, err) + u := User{} + _, err = c.QueryNext(ctx, it, &u) + + assert.NotEqual(t, nil, err) + }) +} + +func TestIterator(t *testing.T) { + t.Run("should return no errors if it's closed multiple times", func(t *testing.T) { + it := iterator{isClosed: true} + err := it.Close() + assert.Equal(t, nil, err) + assert.Equal(t, true, it.isClosed) + }) } func createTable() error {