diff --git a/db.go b/db.go index 962248c..870c8b1 100644 --- a/db.go +++ b/db.go @@ -121,6 +121,7 @@ type DB struct { AllocSize int path string + openFile func(string, int, os.FileMode) (*os.File, error) file *os.File dataref []byte // mmap'ed readonly, write throws SEGV data *[maxMapSize]byte @@ -199,10 +200,15 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) { db.readOnly = true } + db.openFile = options.OpenFile + if db.openFile == nil { + db.openFile = os.OpenFile + } + // Open data file and separate sync handler for metadata writes. db.path = path var err error - if db.file, err = os.OpenFile(db.path, flag|os.O_CREATE, mode); err != nil { + if db.file, err = db.openFile(db.path, flag|os.O_CREATE, mode); err != nil { _ = db.close() return nil, err } @@ -1054,6 +1060,10 @@ type Options struct { // set directly on the DB itself when returned from Open(), but this option // is useful in APIs which expose Options but not the underlying DB. NoSync bool + + // OpenFile is used to open files. It defaults to os.OpenFile. This option + // is useful for writing hermetic tests. + OpenFile func(string, int, os.FileMode) (*os.File, error) } // DefaultOptions represent the options used if nil options are passed into Open(). diff --git a/tx.go b/tx.go index f508641..52ab139 100644 --- a/tx.go +++ b/tx.go @@ -315,7 +315,7 @@ func (tx *Tx) Copy(w io.Writer) error { // If err == nil then exactly tx.Size() bytes will be written into the writer. func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) { // Attempt to open reader with WriteFlag - f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0) + f, err := tx.db.openFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0) if err != nil { return 0, err } @@ -369,7 +369,7 @@ func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) { // A reader transaction is maintained during the copy so it is safe to continue // using the database while a copy is in progress. func (tx *Tx) CopyFile(path string, mode os.FileMode) error { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) + f, err := tx.db.openFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) if err != nil { return err }