improve iterator.Close to prevent accidental panics

pull/2/head
Breno Almeida 2020-09-28 18:26:17 -03:00
parent fabffed6d1
commit 614cfde4b7
2 changed files with 89 additions and 7 deletions

View File

@ -74,6 +74,22 @@ func (c Client) Find(
return it.Error
}
type iterator struct {
isClosed bool
rows *sql.Rows
}
// Close ...
func (i *iterator) Close() error {
if i.isClosed {
return nil
}
i.isClosed = true
return i.rows.Close()
}
var noopCloser = iterator{isClosed: true}
// Query build an iterator for querying several
// results from the database
func (c Client) Query(
@ -83,10 +99,18 @@ func (c Client) Query(
) (Iterator, error) {
it := c.db.Raw(query, params...)
if it.Error != nil {
return nil, it.Error
return &noopCloser, it.Error
}
return it.Rows()
rows, err := it.Rows()
if err != nil {
return &noopCloser, err
}
return &iterator{
isClosed: false,
rows: rows,
}, nil
}
// QueryNext parses the next row of a query
@ -97,17 +121,21 @@ func (c Client) QueryNext(
rawIt Iterator,
item interface{},
) (done bool, err error) {
rows, ok := rawIt.(*sql.Rows)
it, ok := rawIt.(*iterator)
if !ok {
return false, fmt.Errorf("invalid iterator received on QueryNext()")
}
if !rows.Next() {
rows.Close()
return true, rows.Err()
if it.isClosed {
return false, fmt.Errorf("received closed iterator")
}
return false, c.db.ScanRows(rows, item)
if !it.rows.Next() {
it.Close()
return true, it.rows.Err()
}
return false, c.db.ScanRows(it.rows, item)
}
// GetByID recovers a single entity from the database by the ID field.

View File

@ -413,6 +413,60 @@ func TestQuery(t *testing.T) {
assert.NotEqual(t, nil, err)
})
t.Run("should return noop closer when syntax error occurs", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
it, err := c.Query(ctx, `select * users`)
assert.NotEqual(t, nil, err)
assert.Equal(t, &noopCloser, it)
})
t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
it, err := c.Query(ctx, `select * from users`)
assert.Equal(t, nil, err)
err = it.Close()
assert.Equal(t, nil, err)
u := User{}
_, err = c.QueryNext(ctx, it, &u)
assert.NotEqual(t, nil, err)
})
}
func TestIterator(t *testing.T) {
t.Run("should return no errors if it's closed multiple times", func(t *testing.T) {
it := iterator{isClosed: true}
err := it.Close()
assert.Equal(t, nil, err)
assert.Equal(t, true, it.isClosed)
})
}
func createTable() error {