From 394e42e3ebd70b10488aa791894db28e9410077c Mon Sep 17 00:00:00 2001
From: Ben Johnson <benbjohnson@yahoo.com>
Date: Fri, 4 Apr 2014 07:51:01 -0600
Subject: [PATCH] Add Tx.OnCommit() handler.

This commit adds the ability to execute a function after a transaction has
successfully committed.
---
 tx.go      | 42 ++++++++++++++++++++++++++++++------------
 tx_test.go | 28 ++++++++++++++++++++++++++++
 2 files changed, 58 insertions(+), 12 deletions(-)

diff --git a/tx.go b/tx.go
index 5b345c2..ed693be 100644
--- a/tx.go
+++ b/tx.go
@@ -29,15 +29,16 @@ type txid uint64
 // are using them. A long running read transaction can cause the database to
 // quickly grow.
 type Tx struct {
-	writable bool
-	managed  bool
-	db       *DB
-	meta     *meta
-	buckets  *buckets
-	nodes    map[pgid]*node
-	pages    map[pgid]*page
-	pending  []*node
-	stats    TxStats
+	writable       bool
+	managed        bool
+	db             *DB
+	meta           *meta
+	buckets        *buckets
+	nodes          map[pgid]*node
+	pages          map[pgid]*page
+	pending        []*node
+	stats          TxStats
+	commitHandlers []func()
 }
 
 // init initializes the transaction.
@@ -175,6 +176,11 @@ func (t *Tx) DeleteBucket(name string) error {
 	return nil
 }
 
+// OnCommit adds a handler function to be executed after the transaction successfully commits.
+func (t *Tx) OnCommit(fn func()) {
+	t.commitHandlers = append(t.commitHandlers, fn)
+}
+
 // Commit writes all changes to disk and updates the meta page.
 // Returns an error if a disk write error occurs.
 func (t *Tx) Commit() error {
@@ -185,7 +191,6 @@ func (t *Tx) Commit() error {
 	} else if !t.writable {
 		return ErrTxNotWritable
 	}
-	defer t.close()
 
 	// TODO(benbjohnson): Use vectorized I/O to write out dirty pages.
 
@@ -197,6 +202,7 @@ func (t *Tx) Commit() error {
 	// spill data onto dirty pages.
 	startTime = time.Now()
 	if err := t.spill(); err != nil {
+		t.close()
 		return err
 	}
 	t.stats.SpillTime += time.Since(startTime)
@@ -204,6 +210,7 @@ func (t *Tx) Commit() error {
 	// Spill buckets page.
 	p, err := t.allocate((t.buckets.size() / t.db.pageSize) + 1)
 	if err != nil {
+		t.close()
 		return err
 	}
 	t.buckets.write(p)
@@ -217,6 +224,7 @@ func (t *Tx) Commit() error {
 	t.db.freelist.free(t.id(), t.page(t.meta.freelist))
 	p, err = t.allocate((t.db.freelist.size() / t.db.pageSize) + 1)
 	if err != nil {
+		t.close()
 		return err
 	}
 	t.db.freelist.write(p)
@@ -225,15 +233,25 @@ func (t *Tx) Commit() error {
 	// Write dirty pages to disk.
 	startTime = time.Now()
 	if err := t.write(); err != nil {
+		t.close()
 		return err
 	}
 
 	// Write meta to disk.
 	if err := t.writeMeta(); err != nil {
+		t.close()
 		return err
 	}
 	t.stats.WriteTime += time.Since(startTime)
 
+	// Finalize the transaction.
+	t.close()
+
+	// Execute commit handlers now that the locks have been removed.
+	for _, fn := range t.commitHandlers {
+		fn()
+	}
+
 	return nil
 }
 
@@ -250,13 +268,13 @@ func (t *Tx) Rollback() error {
 
 func (t *Tx) close() {
 	if t.writable {
-		t.db.rwlock.Unlock()
-
 		// Merge statistics.
 		t.db.metalock.Lock()
 		t.db.stats.TxStats.add(&t.stats)
 		t.db.metalock.Unlock()
 
+		// Remove writer lock.
+		t.db.rwlock.Unlock()
 	} else {
 		t.db.removeTx(t)
 	}
diff --git a/tx_test.go b/tx_test.go
index 4a6c329..3372dff 100644
--- a/tx_test.go
+++ b/tx_test.go
@@ -1,6 +1,7 @@
 package bolt
 
 import (
+	"errors"
 	"fmt"
 	"math/rand"
 	"os"
@@ -484,6 +485,33 @@ func TestTxCursorIterateReverse(t *testing.T) {
 	fmt.Fprint(os.Stderr, "\n")
 }
 
+// Ensure that Tx commit handlers are called after a transaction successfully commits.
+func TestTx_OnCommit(t *testing.T) {
+	var x int
+	withOpenDB(func(db *DB, path string) {
+		db.Update(func(tx *Tx) error {
+			tx.OnCommit(func() { x += 1 })
+			tx.OnCommit(func() { x += 2 })
+			return tx.CreateBucket("widgets")
+		})
+	})
+	assert.Equal(t, 3, x)
+}
+
+// Ensure that Tx commit handlers are NOT called after a transaction rolls back.
+func TestTx_OnCommit_Rollback(t *testing.T) {
+	var x int
+	withOpenDB(func(db *DB, path string) {
+		db.Update(func(tx *Tx) error {
+			tx.OnCommit(func() { x += 1 })
+			tx.OnCommit(func() { x += 2 })
+			tx.CreateBucket("widgets")
+			return errors.New("rollback this commit")
+		})
+	})
+	assert.Equal(t, 0, x)
+}
+
 // Benchmark the performance iterating over a cursor.
 func BenchmarkTxCursor(b *testing.B) {
 	indexes := rand.Perm(b.N)