From e929eba364732aa3a7eb76e2c9c38b7d23b8a980 Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Wed, 20 May 2015 16:10:07 -0600 Subject: [PATCH] Wait for pending tx on close. This commit fixes the DB.Close() function so that it waits for any open transactions to finish before closing. --- db.go | 7 +++++++ db_test.go | 45 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index c415c1b..f5ca05e 100644 --- a/db.go +++ b/db.go @@ -351,8 +351,15 @@ func (db *DB) init() error { // Close releases all database resources. // All transactions must be closed before closing the database. func (db *DB) Close() error { + db.rwlock.Lock() + defer db.rwlock.Unlock() + db.metalock.Lock() defer db.metalock.Unlock() + + db.mmaplock.RLock() + defer db.mmaplock.RUnlock() + return db.close() } diff --git a/db_test.go b/db_test.go index 6af6423..dddf22b 100644 --- a/db_test.go +++ b/db_test.go @@ -324,6 +324,49 @@ func TestDB_BeginRW_Closed(t *testing.T) { assert(t, tx == nil, "") } +func TestDB_Close_PendingTx_RW(t *testing.T) { testDB_Close_PendingTx(t, true) } +func TestDB_Close_PendingTx_RO(t *testing.T) { testDB_Close_PendingTx(t, false) } + +// Ensure that a database cannot close while transactions are open. +func testDB_Close_PendingTx(t *testing.T, writable bool) { + db := NewTestDB() + defer db.Close() + + // Start transaction. + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + // Open update in separate goroutine. + done := make(chan struct{}) + go func() { + db.Close() + close(done) + }() + + // Ensure database hasn't closed. + time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Fatal("database closed too early") + default: + } + + // Commit transaction. + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // Ensure database closed now. + time.Sleep(100 * time.Millisecond) + select { + case <-done: + default: + t.Fatal("database did not close") + } +} + // Ensure a database can provide a transactional block. func TestDB_Update(t *testing.T) { db := NewTestDB() @@ -748,7 +791,7 @@ func (db *TestDB) PrintStats() { // MustCheck runs a consistency check on the database and panics if any errors are found. func (db *TestDB) MustCheck() { - db.View(func(tx *bolt.Tx) error { + db.Update(func(tx *bolt.Tx) error { // Collect all the errors. var errors []error for err := range tx.Check() {