goose/database/store_test.go

254 lines
7.1 KiB
Go

package database_test
import (
"context"
"database/sql"
"errors"
"path/filepath"
"testing"
"github.com/pressly/goose/v3/database"
"github.com/stretchr/testify/require"
"go.uber.org/multierr"
"modernc.org/sqlite"
)
// The goal of this test is to verify the database store package works as expected. This test is not
// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store
// interface works against a real database.
func TestDialectStore(t *testing.T) {
t.Parallel()
t.Run("invalid", func(t *testing.T) {
// Test empty table name.
_, err := database.NewStore(database.DialectSQLite3, "")
require.Error(t, err)
// Test unknown dialect.
_, err = database.NewStore("unknown-dialect", "foo")
require.Error(t, err)
// Test empty dialect.
_, err = database.NewStore("", "foo")
require.Error(t, err)
})
// Test generic behavior.
t.Run("sqlite3", func(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
require.NoError(t, err)
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
t.Helper()
var sqliteErr *sqlite.Error
ok := errors.As(err, &sqliteErr)
require.True(t, ok)
require.Equal(t, 1, sqliteErr.Code()) // Generic error (SQLITE_ERROR)
require.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
})
})
t.Run("ListMigrations", func(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
require.NoError(t, err)
store, err := database.NewStore(database.DialectSQLite3, "foo")
require.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
require.NoError(t, err)
insert := func(db *sql.DB, version int64) error {
return store.Insert(context.Background(), db, database.InsertRequest{Version: version})
}
require.NoError(t, insert(db, 1))
require.NoError(t, insert(db, 3))
require.NoError(t, insert(db, 2))
res, err := store.ListMigrations(context.Background(), db)
require.NoError(t, err)
require.Len(t, res, 3)
// Check versions are in descending order: [2, 3, 1]
require.EqualValues(t, 2, res[0].Version)
require.EqualValues(t, 3, res[1].Version)
require.EqualValues(t, 1, res[2].Version)
})
}
// testStore tests various store operations.
//
// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable
// when the version table already exists.
func testStore(
ctx context.Context,
t *testing.T,
d database.Dialect,
db *sql.DB,
alreadyExists func(t *testing.T, err error),
) {
const (
tablename = "test_goose_db_version"
)
store, err := database.NewStore(d, tablename)
require.NoError(t, err)
// Create the version table.
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.CreateVersionTable(ctx, tx)
})
require.NoError(t, err)
// Create the version table again. This should fail.
err = runTx(ctx, db, func(tx *sql.Tx) error {
return store.CreateVersionTable(ctx, tx)
})
require.Error(t, err)
if alreadyExists != nil {
alreadyExists(t, err)
}
// Get the latest version. There should be none.
_, err = store.GetLatestVersion(ctx, db)
require.ErrorIs(t, err, database.ErrVersionNotFound)
// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
require.NoError(t, err)
require.Empty(t, res, 0)
return nil
})
require.NoError(t, err)
// Insert 5 migrations in addition to the zero migration.
for i := 0; i < 6; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
require.NoError(t, err)
require.Equal(t, latest, int64(i))
return nil
})
require.NoError(t, err)
}
// List migrations. There should be 6.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
require.NoError(t, err)
require.Len(t, res, 6)
// Check versions are in descending order.
for i := 0; i < 6; i++ {
require.EqualValues(t, res[i].Version, 5-i)
}
return nil
})
require.NoError(t, err)
// Delete 3 migrations backwards
for i := 5; i >= 3; i-- {
err = runConn(ctx, db, func(conn *sql.Conn) error {
err := store.Delete(ctx, conn, int64(i))
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
require.NoError(t, err)
require.Equal(t, latest, int64(i-1))
return nil
})
require.NoError(t, err)
}
// List migrations. There should be 3.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
require.NoError(t, err)
require.Len(t, res, 3)
// Check that the remaining versions are in descending order.
for i := 0; i < 3; i++ {
require.EqualValues(t, res[i].Version, 2-i)
}
return nil
})
require.NoError(t, err)
// Get remaining migrations one by one.
for i := 0; i < 3; i++ {
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.GetMigration(ctx, conn, int64(i))
require.NoError(t, err)
require.True(t, res.IsApplied)
require.False(t, res.Timestamp.IsZero())
return nil
})
require.NoError(t, err)
}
// Delete remaining migrations one by one and use all 3 connection types:
// 1. *sql.Tx
err = runTx(ctx, db, func(tx *sql.Tx) error {
err := store.Delete(ctx, tx, 2)
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, tx)
require.NoError(t, err)
require.EqualValues(t, 1, latest)
return nil
})
require.NoError(t, err)
// 2. *sql.Conn
err = runConn(ctx, db, func(conn *sql.Conn) error {
err := store.Delete(ctx, conn, 1)
require.NoError(t, err)
latest, err := store.GetLatestVersion(ctx, conn)
require.NoError(t, err)
require.EqualValues(t, 0, latest)
return nil
})
require.NoError(t, err)
// 3. *sql.DB
err = store.Delete(ctx, db, 0)
require.NoError(t, err)
_, err = store.GetLatestVersion(ctx, db)
require.ErrorIs(t, err, database.ErrVersionNotFound)
// List migrations. There should be none.
err = runConn(ctx, db, func(conn *sql.Conn) error {
res, err := store.ListMigrations(ctx, conn)
require.NoError(t, err)
require.Empty(t, res)
return nil
})
require.NoError(t, err)
// Try to get a migration that does not exist.
err = runConn(ctx, db, func(conn *sql.Conn) error {
_, err := store.GetMigration(ctx, conn, 0)
require.Error(t, err)
require.ErrorIs(t, err, database.ErrVersionNotFound)
return nil
})
require.NoError(t, err)
}
func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
if retErr != nil {
retErr = multierr.Append(retErr, tx.Rollback())
}
}()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) {
conn, err := db.Conn(ctx)
if err != nil {
return err
}
defer func() {
if retErr != nil {
retErr = multierr.Append(retErr, conn.Close())
}
}()
if err := fn(conn); err != nil {
return err
}
return conn.Close()
}