diff --git a/tx.go b/tx.go index 6ebe04e..7210e73 100644 --- a/tx.go +++ b/tx.go @@ -242,7 +242,6 @@ func (tx *Tx) Copy(w io.Writer) error { // Open reader on the database. f, err := os.OpenFile(tx.db.path, os.O_RDONLY|odirect, 0) if err != nil { - _ = tx.Rollback() return err } @@ -251,14 +250,12 @@ func (tx *Tx) Copy(w io.Writer) error { _, err = io.CopyN(w, f, int64(tx.db.pageSize*2)) tx.db.metalock.Unlock() if err != nil { - _ = tx.Rollback() _ = f.Close() return fmt.Errorf("meta copy: %s", err) } // Copy data pages. if _, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2)); err != nil { - _ = tx.Rollback() _ = f.Close() return err } diff --git a/tx_test.go b/tx_test.go index dd04ae6..5cbe20f 100644 --- a/tx_test.go +++ b/tx_test.go @@ -338,6 +338,57 @@ func TestTx_CopyFile(t *testing.T) { }) } +type failWriterError struct{} + +func (failWriterError) Error() string { + return "error injected for tests" +} + +type failWriter struct { + // fail after this many bytes + After int +} + +func (f *failWriter) Write(p []byte) (n int, err error) { + n = len(p) + if n > f.After { + n = f.After + err = failWriterError{} + } + f.After -= n + return n, err +} + +// Ensure that Copy handles write errors right. +func TestTx_CopyFile_Error_Meta(t *testing.T) { + withOpenDB(func(db *DB, path string) { + db.Update(func(tx *Tx) error { + tx.CreateBucket([]byte("widgets")) + tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) + tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("bat")) + return nil + }) + + err := db.View(func(tx *Tx) error { return tx.Copy(&failWriter{}) }) + assert.EqualError(t, err, "meta copy: error injected for tests") + }) +} + +// Ensure that Copy handles write errors right. +func TestTx_CopyFile_Error_Normal(t *testing.T) { + withOpenDB(func(db *DB, path string) { + db.Update(func(tx *Tx) error { + tx.CreateBucket([]byte("widgets")) + tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) + tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("bat")) + return nil + }) + + err := db.View(func(tx *Tx) error { return tx.Copy(&failWriter{3 * db.pageSize}) }) + assert.EqualError(t, err, "error injected for tests") + }) +} + func ExampleTx_Rollback() { // Open the database. db, _ := Open(tempfile(), 0666)