diff --git a/tx.go b/tx.go index 5b345c2..ed693be 100644 --- a/tx.go +++ b/tx.go @@ -29,15 +29,16 @@ type txid uint64 // are using them. A long running read transaction can cause the database to // quickly grow. type Tx struct { - writable bool - managed bool - db *DB - meta *meta - buckets *buckets - nodes map[pgid]*node - pages map[pgid]*page - pending []*node - stats TxStats + writable bool + managed bool + db *DB + meta *meta + buckets *buckets + nodes map[pgid]*node + pages map[pgid]*page + pending []*node + stats TxStats + commitHandlers []func() } // init initializes the transaction. @@ -175,6 +176,11 @@ func (t *Tx) DeleteBucket(name string) error { return nil } +// OnCommit adds a handler function to be executed after the transaction successfully commits. +func (t *Tx) OnCommit(fn func()) { + t.commitHandlers = append(t.commitHandlers, fn) +} + // Commit writes all changes to disk and updates the meta page. // Returns an error if a disk write error occurs. func (t *Tx) Commit() error { @@ -185,7 +191,6 @@ func (t *Tx) Commit() error { } else if !t.writable { return ErrTxNotWritable } - defer t.close() // TODO(benbjohnson): Use vectorized I/O to write out dirty pages. @@ -197,6 +202,7 @@ func (t *Tx) Commit() error { // spill data onto dirty pages. startTime = time.Now() if err := t.spill(); err != nil { + t.close() return err } t.stats.SpillTime += time.Since(startTime) @@ -204,6 +210,7 @@ func (t *Tx) Commit() error { // Spill buckets page. p, err := t.allocate((t.buckets.size() / t.db.pageSize) + 1) if err != nil { + t.close() return err } t.buckets.write(p) @@ -217,6 +224,7 @@ func (t *Tx) Commit() error { t.db.freelist.free(t.id(), t.page(t.meta.freelist)) p, err = t.allocate((t.db.freelist.size() / t.db.pageSize) + 1) if err != nil { + t.close() return err } t.db.freelist.write(p) @@ -225,15 +233,25 @@ func (t *Tx) Commit() error { // Write dirty pages to disk. startTime = time.Now() if err := t.write(); err != nil { + t.close() return err } // Write meta to disk. if err := t.writeMeta(); err != nil { + t.close() return err } t.stats.WriteTime += time.Since(startTime) + // Finalize the transaction. + t.close() + + // Execute commit handlers now that the locks have been removed. + for _, fn := range t.commitHandlers { + fn() + } + return nil } @@ -250,13 +268,13 @@ func (t *Tx) Rollback() error { func (t *Tx) close() { if t.writable { - t.db.rwlock.Unlock() - // Merge statistics. t.db.metalock.Lock() t.db.stats.TxStats.add(&t.stats) t.db.metalock.Unlock() + // Remove writer lock. + t.db.rwlock.Unlock() } else { t.db.removeTx(t) } diff --git a/tx_test.go b/tx_test.go index 4a6c329..3372dff 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,6 +1,7 @@ package bolt import ( + "errors" "fmt" "math/rand" "os" @@ -484,6 +485,33 @@ func TestTxCursorIterateReverse(t *testing.T) { fmt.Fprint(os.Stderr, "\n") } +// Ensure that Tx commit handlers are called after a transaction successfully commits. +func TestTx_OnCommit(t *testing.T) { + var x int + withOpenDB(func(db *DB, path string) { + db.Update(func(tx *Tx) error { + tx.OnCommit(func() { x += 1 }) + tx.OnCommit(func() { x += 2 }) + return tx.CreateBucket("widgets") + }) + }) + assert.Equal(t, 3, x) +} + +// Ensure that Tx commit handlers are NOT called after a transaction rolls back. +func TestTx_OnCommit_Rollback(t *testing.T) { + var x int + withOpenDB(func(db *DB, path string) { + db.Update(func(tx *Tx) error { + tx.OnCommit(func() { x += 1 }) + tx.OnCommit(func() { x += 2 }) + tx.CreateBucket("widgets") + return errors.New("rollback this commit") + }) + }) + assert.Equal(t, 0, x) +} + // Benchmark the performance iterating over a cursor. func BenchmarkTxCursor(b *testing.B) { indexes := rand.Perm(b.N)