From a1f43f4d60bf434baebf13b4810e5db4916baec5 Mon Sep 17 00:00:00 2001
From: Ben Johnson <benbjohnson@yahoo.com>
Date: Sat, 1 Mar 2014 09:13:59 -0700
Subject: [PATCH] Allow reads of unflushed nodes.

This commit allows cursors to read updated values from within the
RWTransaction.
---
 bucket.go        |   2 +-
 bucket_test.go   |  14 ++++
 cursor.go        | 185 ++++++++++++++++++++++++++++++++++-------------
 db.go            |   5 +-
 page.go          |   6 --
 rwtransaction.go |   4 +-
 transaction.go   |  18 +++++
 7 files changed, 172 insertions(+), 62 deletions(-)

diff --git a/bucket.go b/bucket.go
index d62d798..6c1eb5c 100644
--- a/bucket.go
+++ b/bucket.go
@@ -35,7 +35,7 @@ func (b *Bucket) Cursor() *Cursor {
 	return &Cursor{
 		transaction: b.transaction,
 		root:        b.root,
-		stack:       make([]pageElementRef, 0),
+		stack:       make([]elemRef, 0),
 	}
 }
 
diff --git a/bucket_test.go b/bucket_test.go
index 5a2c0ae..e4ccabd 100644
--- a/bucket_test.go
+++ b/bucket_test.go
@@ -23,6 +23,20 @@ func TestBucketGetNonExistent(t *testing.T) {
 	})
 }
 
+// Ensure that a bucket can read a value that is not flushed yet.
+func TestBucketGetFromNode(t *testing.T) {
+	withOpenDB(func(db *DB, path string) {
+		db.CreateBucket("widgets")
+		db.Do(func(txn *RWTransaction) error {
+			b := txn.Bucket("widgets")
+			b.Put([]byte("foo"), []byte("bar"))
+			value := b.Get([]byte("foo"))
+			assert.Equal(t, value, []byte("bar"))
+			return nil
+		})
+	})
+}
+
 // Ensure that a bucket can write a key/value.
 func TestBucketPut(t *testing.T) {
 	withOpenDB(func(db *DB, path string) {
diff --git a/cursor.go b/cursor.go
index cb0d9a5..f8e28ea 100644
--- a/cursor.go
+++ b/cursor.go
@@ -10,15 +10,15 @@ import (
 type Cursor struct {
 	transaction *Transaction
 	root        pgid
-	stack       []pageElementRef
+	stack       []elemRef
 }
 
 // First moves the cursor to the first item in the bucket and returns its key and value.
 // If the bucket is empty then a nil key and value are returned.
 func (c *Cursor) First() (key []byte, value []byte) {
 	c.stack = c.stack[:0]
-	p := c.transaction.page(c.root)
-	c.stack = append(c.stack, pageElementRef{page: p, index: 0})
+	p, n := c.transaction.pageNode(c.root)
+	c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
 	c.first()
 	return c.keyValue()
 }
@@ -27,8 +27,10 @@ func (c *Cursor) First() (key []byte, value []byte) {
 // If the bucket is empty then a nil key and value are returned.
 func (c *Cursor) Last() (key []byte, value []byte) {
 	c.stack = c.stack[:0]
-	p := c.transaction.page(c.root)
-	c.stack = append(c.stack, pageElementRef{page: p, index: p.count - 1})
+	p, n := c.transaction.pageNode(c.root)
+	ref := elemRef{page: p, node: n}
+	ref.index = ref.count() - 1
+	c.stack = append(c.stack, ref)
 	c.last()
 	return c.keyValue()
 }
@@ -40,7 +42,7 @@ func (c *Cursor) Next() (key []byte, value []byte) {
 	// Move up the stack as we hit the end of each page in our stack.
 	for i := len(c.stack) - 1; i >= 0; i-- {
 		elem := &c.stack[i]
-		if elem.index < elem.page.count-1 {
+		if elem.index < elem.count()-1 {
 			elem.index++
 			break
 		}
@@ -85,61 +87,107 @@ func (c *Cursor) Prev() (key []byte, value []byte) {
 // If the key does not exist then the next key is used. If no keys
 // follow, a nil value is returned.
 func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) {
-	// Start from root page and traverse to correct page.
+	// Start from root page/node and traverse to correct page.
 	c.stack = c.stack[:0]
-	c.search(seek, c.transaction.page(c.root))
-	p, index := c.top()
+	c.search(seek, c.root)
+	ref := &c.stack[len(c.stack)-1]
 
-	// If the cursor is pointing to the end of page then return nil.
-	if index == p.count {
+	// If the cursor is pointing to the end of page/node then return nil.
+	if ref.index >= ref.count() {
 		return nil, nil
 	}
 
-	return c.element().key(), c.element().value()
+	return c.keyValue()
 }
 
 // first moves the cursor to the first leaf element under the last page in the stack.
 func (c *Cursor) first() {
-	p := c.stack[len(c.stack)-1].page
 	for {
 		// Exit when we hit a leaf page.
-		if (p.flags & leafPageFlag) != 0 {
+		ref := &c.stack[len(c.stack)-1]
+		if ref.isLeaf() {
 			break
 		}
 
 		// Keep adding pages pointing to the first element to the stack.
-		p = c.transaction.page(p.branchPageElement(c.stack[len(c.stack)-1].index).pgid)
-		c.stack = append(c.stack, pageElementRef{page: p, index: 0})
+		var pgid pgid
+		if ref.node != nil {
+			pgid = ref.node.inodes[ref.index].pgid
+		} else {
+			pgid = ref.page.branchPageElement(uint16(ref.index)).pgid
+		}
+		p, n := c.transaction.pageNode(pgid)
+		c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
 	}
 }
 
 // last moves the cursor to the last leaf element under the last page in the stack.
 func (c *Cursor) last() {
-	p := c.stack[len(c.stack)-1].page
 	for {
 		// Exit when we hit a leaf page.
-		if (p.flags & leafPageFlag) != 0 {
+		ref := &c.stack[len(c.stack)-1]
+		if ref.isLeaf() {
 			break
 		}
 
 		// Keep adding pages pointing to the last element in the stack.
-		p = c.transaction.page(p.branchPageElement(c.stack[len(c.stack)-1].index).pgid)
-		c.stack = append(c.stack, pageElementRef{page: p, index: p.count - 1})
+		var pgid pgid
+		if ref.node != nil {
+			pgid = ref.node.inodes[ref.index].pgid
+		} else {
+			pgid = ref.page.branchPageElement(uint16(ref.index)).pgid
+		}
+		p, n := c.transaction.pageNode(pgid)
+
+		var nextRef = elemRef{page: p, node: n}
+		nextRef.index = nextRef.count() - 1
+		c.stack = append(c.stack, nextRef)
 	}
 }
 
-// search recursively performs a binary search against a given page until it finds a given key.
-func (c *Cursor) search(key []byte, p *page) {
-	_assert((p.flags&(branchPageFlag|leafPageFlag)) != 0, "invalid page type: "+p.typ())
-	e := pageElementRef{page: p}
+// search recursively performs a binary search against a given page/node until it finds a given key.
+func (c *Cursor) search(key []byte, pgid pgid) {
+	p, n := c.transaction.pageNode(pgid)
+	if p != nil {
+		_assert((p.flags&(branchPageFlag|leafPageFlag)) != 0, "invalid page type: "+p.typ())
+	}
+	e := elemRef{page: p, node: n}
 	c.stack = append(c.stack, e)
 
-	// If we're on a leaf page then find the specific node.
-	if (p.flags & leafPageFlag) != 0 {
-		c.nsearch(key, p)
+	// If we're on a leaf page/node then find the specific node.
+	if e.isLeaf() {
+		c.nsearch(key)
 		return
 	}
 
+	if n != nil {
+		c.searchNode(key, n)
+		return
+	}
+	c.searchPage(key, p)
+}
+
+func (c *Cursor) searchNode(key []byte, n *node) {
+	var exact bool
+	index := sort.Search(len(n.inodes), func(i int) bool {
+		// TODO(benbjohnson): Optimize this range search. It's a bit hacky right now.
+		// sort.Search() finds the lowest index where f() != -1 but we need the highest index.
+		ret := bytes.Compare(n.inodes[i].key, key)
+		if ret == 0 {
+			exact = true
+		}
+		return ret != -1
+	})
+	if !exact && index > 0 {
+		index--
+	}
+	c.stack[len(c.stack)-1].index = index
+
+	// Recursively search to the next page.
+	c.search(key, n.inodes[index].pgid)
+}
+
+func (c *Cursor) searchPage(key []byte, p *page) {
 	// Binary search for the correct range.
 	inodes := p.branchPageElements()
 
@@ -156,58 +204,93 @@ func (c *Cursor) search(key []byte, p *page) {
 	if !exact && index > 0 {
 		index--
 	}
-	c.stack[len(c.stack)-1].index = uint16(index)
+	c.stack[len(c.stack)-1].index = index
 
 	// Recursively search to the next page.
-	c.search(key, c.transaction.page(inodes[index].pgid))
+	c.search(key, inodes[index].pgid)
 }
 
-// nsearch searches a leaf node for the index of the node that matches key.
-func (c *Cursor) nsearch(key []byte, p *page) {
+// nsearch searches the leaf node on the top of the stack for a key.
+func (c *Cursor) nsearch(key []byte) {
 	e := &c.stack[len(c.stack)-1]
+	p, n := e.page, e.node
 
-	// Binary search for the correct leaf node index.
+	// If we have a node then search its inodes.
+	if n != nil {
+		index := sort.Search(len(n.inodes), func(i int) bool {
+			return bytes.Compare(n.inodes[i].key, key) != -1
+		})
+		e.index = index
+		return
+	}
+
+	// If we have a page then search its leaf elements.
 	inodes := p.leafPageElements()
 	index := sort.Search(int(p.count), func(i int) bool {
 		return bytes.Compare(inodes[i].key(), key) != -1
 	})
-	e.index = uint16(index)
-}
-
-// top returns the page and leaf node that the cursor is currently pointing at.
-func (c *Cursor) top() (*page, uint16) {
-	ptr := c.stack[len(c.stack)-1]
-	return ptr.page, ptr.index
-}
-
-// element returns the leaf element that the cursor is currently positioned on.
-func (c *Cursor) element() *leafPageElement {
-	ref := c.stack[len(c.stack)-1]
-	return ref.page.leafPageElement(ref.index)
+	e.index = index
 }
 
 // keyValue returns the key and value of the current leaf element.
 func (c *Cursor) keyValue() ([]byte, []byte) {
 	ref := &c.stack[len(c.stack)-1]
-	if ref.index >= ref.page.count {
+	if ref.index >= ref.count() {
 		return nil, nil
 	}
-	e := ref.page.leafPageElement(ref.index)
-	return e.key(), e.value()
+
+	// Retrieve value from node.
+	if ref.node != nil {
+		inode := &ref.node.inodes[ref.index]
+		return inode.key, inode.value
+	}
+
+	// Or retrieve value from page.
+	elem := ref.page.leafPageElement(uint16(ref.index))
+	return elem.key(), elem.value()
 }
 
 // node returns the node that the cursor is currently positioned on.
 func (c *Cursor) node(t *RWTransaction) *node {
 	_assert(len(c.stack) > 0, "accessing a node with a zero-length cursor stack")
 
+	// If the top of the stack is a leaf node then just return it.
+	if ref := &c.stack[len(c.stack)-1]; ref.node != nil && ref.isLeaf() {
+		return ref.node
+	}
+
 	// Start from root and traverse down the hierarchy.
-	n := t.node(c.stack[0].page.id, nil)
+	var n = c.stack[0].node
+	if n == nil {
+		n = t.node(c.stack[0].page.id, nil)
+	}
 	for _, ref := range c.stack[:len(c.stack)-1] {
 		_assert(!n.isLeaf, "expected branch node")
-		_assert(ref.page.id == n.pgid, "node/page mismatch a: %d != %d", ref.page.id, n.childAt(int(ref.index)).pgid)
 		n = n.childAt(int(ref.index))
 	}
 	_assert(n.isLeaf, "expected leaf node")
-	_assert(n.pgid == c.stack[len(c.stack)-1].page.id, "node/page mismatch b: %d != %d", n.pgid, c.stack[len(c.stack)-1].page.id)
 	return n
 }
+
+// elemRef represents a reference to an element on a given page/node.
+type elemRef struct {
+	page  *page
+	node  *node
+	index int
+}
+
+// isLeaf returns whether the ref is pointing at a leaf page/node.
+func (r *elemRef) isLeaf() bool {
+	if r.node != nil {
+		return r.node.isLeaf
+	}
+	return (r.page.flags & leafPageFlag) != 0
+}
+
+// count returns the number of inodes or page elements.
+func (r *elemRef) count() int {
+	if r.node != nil {
+		return len(r.node.inodes)
+	}
+	return int(r.page.count)
+}
diff --git a/db.go b/db.go
index 22da2e6..8403afc 100644
--- a/db.go
+++ b/db.go
@@ -304,7 +304,7 @@ func (db *DB) RWTransaction() (*RWTransaction, error) {
 	}
 
 	// Create a transaction associated with the database.
-	t := &RWTransaction{nodes: make(map[pgid]*node)}
+	t := &RWTransaction{}
 	t.init(db)
 	db.rwtransaction = t
 
@@ -571,7 +571,8 @@ func (db *DB) Stat() (*Stat, error) {
 
 // page retrieves a page reference from the mmap based on the current page size.
 func (db *DB) page(id pgid) *page {
-	return (*page)(unsafe.Pointer(&db.data[id*pgid(db.pageSize)]))
+	pos := id*pgid(db.pageSize)
+	return (*page)(unsafe.Pointer(&db.data[pos]))
 }
 
 // pageInBuffer retrieves a page reference from a given byte array based on the current page size.
diff --git a/page.go b/page.go
index 77a31e4..5b60c4d 100644
--- a/page.go
+++ b/page.go
@@ -33,12 +33,6 @@ type page struct {
 	ptr      uintptr
 }
 
-// pageElementRef represents a reference to an element on a given page.
-type pageElementRef struct {
-	page  *page
-	index uint16
-}
-
 // typ returns a human readable page type string used for debugging.
 func (p *page) typ() string {
 	if (p.flags & branchPageFlag) != 0 {
diff --git a/rwtransaction.go b/rwtransaction.go
index ddd4b96..f47597f 100644
--- a/rwtransaction.go
+++ b/rwtransaction.go
@@ -11,7 +11,6 @@ import (
 // functions provided by Transaction.
 type RWTransaction struct {
 	Transaction
-	nodes   map[pgid]*node
 	pending []*node
 }
 
@@ -20,6 +19,7 @@ func (t *RWTransaction) init(db *DB) {
 	t.Transaction.init(db)
 	t.Transaction.rwtransaction = t
 	t.pages = make(map[pgid]*page)
+	t.nodes = make(map[pgid]*node)
 
 	// Increment the transaction id.
 	t.meta.txnid += txnid(1)
@@ -266,7 +266,7 @@ func (t *RWTransaction) writeMeta() error {
 // node creates a node from a page and associates it with a given parent.
 func (t *RWTransaction) node(pgid pgid, parent *node) *node {
 	// Retrieve node if it has already been fetched.
-	if n := t.nodes[pgid]; n != nil {
+	if n := t.Transaction.node(pgid); n != nil {
 		return n
 	}
 
diff --git a/transaction.go b/transaction.go
index 90f388b..d680e24 100644
--- a/transaction.go
+++ b/transaction.go
@@ -12,6 +12,7 @@ type Transaction struct {
 	rwtransaction *RWTransaction
 	meta          *meta
 	buckets       *buckets
+	nodes         map[pgid]*node
 	pages         map[pgid]*page
 }
 
@@ -95,6 +96,23 @@ func (t *Transaction) page(id pgid) *page {
 	return t.db.page(id)
 }
 
+// node returns a reference to the in-memory node for a given page, if it exists.
+func (t *Transaction) node(id pgid) *node {
+	if t.nodes == nil {
+		return nil
+	}
+	return t.nodes[id]
+}
+
+// pageNode returns the in-memory node, if it exists.
+// Otherwise returns the underlying page.
+func (t *Transaction) pageNode(id pgid) (*page, *node) {
+	if n := t.node(id); n != nil {
+		return nil, n
+	}
+	return t.page(id), nil
+}
+
 // forEachPage iterates over every page within a given page and executes a function.
 func (t *Transaction) forEachPage(pgid pgid, depth int, fn func(*page, int)) {
 	p := t.page(pgid)