mirror of https://github.com/VinGarcia/ksql.git
improve iterator.Close to prevent accidental panics
parent
fabffed6d1
commit
614cfde4b7
42
kiss_orm.go
42
kiss_orm.go
|
@ -74,6 +74,22 @@ func (c Client) Find(
|
|||
return it.Error
|
||||
}
|
||||
|
||||
type iterator struct {
|
||||
isClosed bool
|
||||
rows *sql.Rows
|
||||
}
|
||||
|
||||
// Close ...
|
||||
func (i *iterator) Close() error {
|
||||
if i.isClosed {
|
||||
return nil
|
||||
}
|
||||
i.isClosed = true
|
||||
return i.rows.Close()
|
||||
}
|
||||
|
||||
var noopCloser = iterator{isClosed: true}
|
||||
|
||||
// Query build an iterator for querying several
|
||||
// results from the database
|
||||
func (c Client) Query(
|
||||
|
@ -83,10 +99,18 @@ func (c Client) Query(
|
|||
) (Iterator, error) {
|
||||
it := c.db.Raw(query, params...)
|
||||
if it.Error != nil {
|
||||
return nil, it.Error
|
||||
return &noopCloser, it.Error
|
||||
}
|
||||
|
||||
return it.Rows()
|
||||
rows, err := it.Rows()
|
||||
if err != nil {
|
||||
return &noopCloser, err
|
||||
}
|
||||
|
||||
return &iterator{
|
||||
isClosed: false,
|
||||
rows: rows,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryNext parses the next row of a query
|
||||
|
@ -97,17 +121,21 @@ func (c Client) QueryNext(
|
|||
rawIt Iterator,
|
||||
item interface{},
|
||||
) (done bool, err error) {
|
||||
rows, ok := rawIt.(*sql.Rows)
|
||||
it, ok := rawIt.(*iterator)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid iterator received on QueryNext()")
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
rows.Close()
|
||||
return true, rows.Err()
|
||||
if it.isClosed {
|
||||
return false, fmt.Errorf("received closed iterator")
|
||||
}
|
||||
|
||||
return false, c.db.ScanRows(rows, item)
|
||||
if !it.rows.Next() {
|
||||
it.Close()
|
||||
return true, it.rows.Err()
|
||||
}
|
||||
|
||||
return false, c.db.ScanRows(it.rows, item)
|
||||
}
|
||||
|
||||
// GetByID recovers a single entity from the database by the ID field.
|
||||
|
|
|
@ -413,6 +413,60 @@ func TestQuery(t *testing.T) {
|
|||
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
|
||||
t.Run("should return noop closer when syntax error occurs", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
}
|
||||
|
||||
db := connectDB(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := Client{
|
||||
db: db,
|
||||
tableName: "users",
|
||||
}
|
||||
|
||||
it, err := c.Query(ctx, `select * users`)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, &noopCloser, it)
|
||||
})
|
||||
|
||||
t.Run("should return error if queryNext receives a closed iterator", func(t *testing.T) {
|
||||
err := createTable()
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!")
|
||||
}
|
||||
|
||||
db := connectDB(t)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := Client{
|
||||
db: db,
|
||||
tableName: "users",
|
||||
}
|
||||
|
||||
it, err := c.Query(ctx, `select * from users`)
|
||||
assert.Equal(t, nil, err)
|
||||
err = it.Close()
|
||||
assert.Equal(t, nil, err)
|
||||
u := User{}
|
||||
_, err = c.QueryNext(ctx, it, &u)
|
||||
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIterator(t *testing.T) {
|
||||
t.Run("should return no errors if it's closed multiple times", func(t *testing.T) {
|
||||
it := iterator{isClosed: true}
|
||||
err := it.Close()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, it.isClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func createTable() error {
|
||||
|
|
Loading…
Reference in New Issue