bbolt/internal/btesting/btesting.go

214 lines
4.9 KiB
Go

package btesting
import (
"flag"
"fmt"
"os"
"path/filepath"
"regexp"
"testing"
"time"
"github.com/stretchr/testify/require"
bolt "go.etcd.io/bbolt"
)
var statsFlag = flag.Bool("stats", false, "show performance stats")
// TestFreelistType is used as a env variable for test to indicate the backend type
const TestFreelistType = "TEST_FREELIST_TYPE"
// DB is a test wrapper for bolt.DB.
type DB struct {
*bolt.DB
f string
o *bolt.Options
t testing.TB
}
// MustCreateDB returns a new, open DB at a temporary location.
func MustCreateDB(t testing.TB) *DB {
return MustCreateDBWithOption(t, nil)
}
// MustCreateDBWithOption returns a new, open DB at a temporary location with given options.
func MustCreateDBWithOption(t testing.TB, o *bolt.Options) *DB {
f := filepath.Join(t.TempDir(), "db")
return MustOpenDBWithOption(t, f, o)
}
func MustOpenDBWithOption(t testing.TB, f string, o *bolt.Options) *DB {
t.Logf("Opening bbolt DB at: %s", f)
if o == nil {
o = bolt.DefaultOptions
}
freelistType := bolt.FreelistArrayType
if env := os.Getenv(TestFreelistType); env == string(bolt.FreelistMapType) {
freelistType = bolt.FreelistMapType
}
o.FreelistType = freelistType
db, err := bolt.Open(f, 0666, o)
if err != nil {
panic(err)
}
resDB := &DB{
DB: db,
f: f,
o: o,
t: t,
}
t.Cleanup(resDB.PostTestCleanup)
return resDB
}
func (db *DB) PostTestCleanup() {
// Check database consistency after every test.
if db.DB != nil {
db.MustCheck()
db.MustClose()
}
}
// Close closes the database but does NOT delete the underlying file.
func (db *DB) Close() error {
if db.DB != nil {
// Log statistics.
if *statsFlag {
db.PrintStats()
}
db.t.Logf("Closing bbolt DB at: %s", db.f)
err := db.DB.Close()
if err != nil {
return err
}
db.DB = nil
}
return nil
}
// MustClose closes the database but does NOT delete the underlying file.
func (db *DB) MustClose() {
if err := db.Close(); err != nil {
panic(err)
}
}
func (db *DB) MustDeleteFile() {
err := os.Remove(db.Path())
require.NoError(db.t, err)
}
func (db *DB) SetOptions(o *bolt.Options) {
db.o = o
}
// MustReopen reopen the database. Panic on error.
func (db *DB) MustReopen() {
if db.DB != nil {
panic("Please call Close() before MustReopen()")
}
db.t.Logf("Reopening bbolt DB at: %s", db.f)
indb, err := bolt.Open(db.Path(), 0666, db.o)
if err != nil {
panic(err)
}
db.DB = indb
}
// MustCheck runs a consistency check on the database and panics if any errors are found.
func (db *DB) MustCheck() {
if err := db.Update(func(tx *bolt.Tx) error {
// Collect all the errors.
var errors []error
for err := range tx.Check() {
errors = append(errors, err)
if len(errors) > 10 {
break
}
}
// If errors occurred, copy the DB and print the errors.
if len(errors) > 0 {
var path = filepath.Join(db.t.TempDir(), "db.backup")
if err := tx.CopyFile(path, 0600); err != nil {
panic(err)
}
// Print errors.
fmt.Print("\n\n")
fmt.Printf("consistency check failed (%d errors)\n", len(errors))
for _, err := range errors {
fmt.Println(err)
}
fmt.Println("")
fmt.Println("db saved to:")
fmt.Println(path)
fmt.Print("\n\n")
os.Exit(-1)
}
return nil
}); err != nil && err != bolt.ErrDatabaseNotOpen {
panic(err)
}
}
// Fill - fills the DB using numTx transactions and numKeysPerTx.
func (db *DB) Fill(bucket []byte, numTx int, numKeysPerTx int,
keyGen func(tx int, key int) []byte,
valueGen func(tx int, key int) []byte) error {
for tr := 0; tr < numTx; tr++ {
err := db.Update(func(tx *bolt.Tx) error {
b, _ := tx.CreateBucketIfNotExists(bucket)
for i := 0; i < numKeysPerTx; i++ {
if err := b.Put(keyGen(tr, i), valueGen(tr, i)); err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
}
return nil
}
func (db *DB) Path() string {
return db.f
}
// CopyTempFile copies a database to a temporary file.
func (db *DB) CopyTempFile() {
path := filepath.Join(db.t.TempDir(), "db.copy")
if err := db.View(func(tx *bolt.Tx) error {
return tx.CopyFile(path, 0600)
}); err != nil {
panic(err)
}
fmt.Println("db copied to: ", path)
}
// PrintStats prints the database stats
func (db *DB) PrintStats() {
var stats = db.Stats()
fmt.Printf("[db] %-20s %-20s %-20s\n",
fmt.Sprintf("pg(%d/%d)", stats.TxStats.PageCount, stats.TxStats.PageAlloc),
fmt.Sprintf("cur(%d)", stats.TxStats.CursorCount),
fmt.Sprintf("node(%d/%d)", stats.TxStats.NodeCount, stats.TxStats.NodeDeref),
)
fmt.Printf(" %-20s %-20s %-20s\n",
fmt.Sprintf("rebal(%d/%v)", stats.TxStats.Rebalance, truncDuration(stats.TxStats.RebalanceTime)),
fmt.Sprintf("spill(%d/%v)", stats.TxStats.Spill, truncDuration(stats.TxStats.SpillTime)),
fmt.Sprintf("w(%d/%v)", stats.TxStats.Write, truncDuration(stats.TxStats.WriteTime)),
)
}
func truncDuration(d time.Duration) string {
return regexp.MustCompile(`^(\d+)(\.\d+)`).ReplaceAllString(d.String(), "$1")
}