diff --git a/db.go b/db.go index 035a310..0bce184 100644 --- a/db.go +++ b/db.go @@ -349,19 +349,26 @@ func (db *DB) removeTx(t *Tx) { } } -// Do executes a function within the context of a read-write transaction. +// Do executes a function within the context of a read-write managed transaction. // If no error is returned from the function then the transaction is committed. // If an error is returned then the entire transaction is rolled back. // Any error that is returned from the function or returned from the commit is // returned from the Do() method. +// +// Attempting to manually commit or rollback within the function will cause a panic. func (db *DB) Do(fn func(*Tx) error) error { t, err := db.RWTx() if err != nil { return err } + // Mark as a managed tx so that the inner function cannot manually commit. + t.managed = true + // If an error is returned from the function then rollback and return error. - if err := fn(t); err != nil { + err = fn(t) + t.managed = false + if err != nil { t.Rollback() return err } @@ -369,8 +376,10 @@ func (db *DB) Do(fn func(*Tx) error) error { return t.Commit() } -// With executes a function within the context of a transaction. +// With executes a function within the context of a managed transaction. // Any error that is returned from the function is returned from the With() method. +// +// Attempting to manually rollback within the function will cause a panic. func (db *DB) With(fn func(*Tx) error) error { t, err := db.Tx() if err != nil { @@ -378,8 +387,14 @@ func (db *DB) With(fn func(*Tx) error) error { } defer t.Rollback() + // Mark as a managed tx so that the inner function cannot manually rollback. + t.managed = true + // If an error is returned from the function then pass it through. - return fn(t) + err = fn(t) + t.managed = false + + return err } // Copy writes the entire database to a writer. diff --git a/db_test.go b/db_test.go index 04abd75..2882ba8 100644 --- a/db_test.go +++ b/db_test.go @@ -111,6 +111,23 @@ func TestDBTxBlockWhileClosed(t *testing.T) { }) } +// Ensure a panic occurs while trying to commit a managed transaction. +func TestDBTxBlockWithManualCommitAndRollback(t *testing.T) { + withOpenDB(func(db *DB, path string) { + db.Do(func(tx *Tx) error { + tx.CreateBucket("widgets") + assert.Panics(t, func() { tx.Commit() }) + assert.Panics(t, func() { tx.Rollback() }) + return nil + }) + db.With(func(tx *Tx) error { + assert.Panics(t, func() { tx.Commit() }) + assert.Panics(t, func() { tx.Rollback() }) + return nil + }) + }) +} + // Ensure that the database can be copied to a file path. func TestDBCopyFile(t *testing.T) { withOpenDB(func(db *DB, path string) { diff --git a/tx.go b/tx.go index 181444e..c18fbf7 100644 --- a/tx.go +++ b/tx.go @@ -18,6 +18,7 @@ type txid uint64 // quickly grow. type Tx struct { writable bool + managed bool db *DB meta *meta buckets *buckets @@ -155,7 +156,9 @@ func (t *Tx) DeleteBucket(name string) error { // Commit writes all changes to disk and updates the meta page. // Returns an error if a disk write error occurs. func (t *Tx) Commit() error { - if t.db == nil { + if t.managed { + panic("managed tx commit not allowed") + } else if t.db == nil { return nil } else if !t.writable { t.Rollback() @@ -194,6 +197,9 @@ func (t *Tx) Commit() error { // Rollback closes the transaction and ignores all previous updates. func (t *Tx) Rollback() { + if t.managed { + panic("managed tx rollback not allowed") + } t.close() }