mirror of https://github.com/pressly/goose.git
testing: replace check with stretchr/testify (#842)
parent
cf53a224e1
commit
053b1fd49e
|
@ -8,7 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/database"
|
"github.com/pressly/goose/v3/database"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/multierr"
|
"go.uber.org/multierr"
|
||||||
"modernc.org/sqlite"
|
"modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
@ -22,47 +22,47 @@ func TestDialectStore(t *testing.T) {
|
||||||
t.Run("invalid", func(t *testing.T) {
|
t.Run("invalid", func(t *testing.T) {
|
||||||
// Test empty table name.
|
// Test empty table name.
|
||||||
_, err := database.NewStore(database.DialectSQLite3, "")
|
_, err := database.NewStore(database.DialectSQLite3, "")
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Test unknown dialect.
|
// Test unknown dialect.
|
||||||
_, err = database.NewStore("unknown-dialect", "foo")
|
_, err = database.NewStore("unknown-dialect", "foo")
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Test empty dialect.
|
// Test empty dialect.
|
||||||
_, err = database.NewStore("", "foo")
|
_, err = database.NewStore("", "foo")
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
// Test generic behavior.
|
// Test generic behavior.
|
||||||
t.Run("sqlite3", func(t *testing.T) {
|
t.Run("sqlite3", func(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite", ":memory:")
|
db, err := sql.Open("sqlite", ":memory:")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
|
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
|
||||||
var sqliteErr *sqlite.Error
|
var sqliteErr *sqlite.Error
|
||||||
ok := errors.As(err, &sqliteErr)
|
ok := errors.As(err, &sqliteErr)
|
||||||
check.Bool(t, ok, true)
|
require.True(t, ok)
|
||||||
check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR)
|
require.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR)
|
||||||
check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
|
require.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("ListMigrations", func(t *testing.T) {
|
t.Run("ListMigrations", func(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
store, err := database.NewStore(database.DialectSQLite3, "foo")
|
store, err := database.NewStore(database.DialectSQLite3, "foo")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = store.CreateVersionTable(context.Background(), db)
|
err = store.CreateVersionTable(context.Background(), db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
insert := func(db *sql.DB, version int64) error {
|
insert := func(db *sql.DB, version int64) error {
|
||||||
return store.Insert(context.Background(), db, database.InsertRequest{Version: version})
|
return store.Insert(context.Background(), db, database.InsertRequest{Version: version})
|
||||||
}
|
}
|
||||||
check.NoError(t, insert(db, 1))
|
require.NoError(t, insert(db, 1))
|
||||||
check.NoError(t, insert(db, 3))
|
require.NoError(t, insert(db, 3))
|
||||||
check.NoError(t, insert(db, 2))
|
require.NoError(t, insert(db, 2))
|
||||||
res, err := store.ListMigrations(context.Background(), db)
|
res, err := store.ListMigrations(context.Background(), db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), 3)
|
require.Equal(t, len(res), 3)
|
||||||
// Check versions are in descending order: [2, 3, 1]
|
// Check versions are in descending order: [2, 3, 1]
|
||||||
check.Number(t, res[0].Version, 2)
|
require.EqualValues(t, res[0].Version, 2)
|
||||||
check.Number(t, res[1].Version, 3)
|
require.EqualValues(t, res[1].Version, 3)
|
||||||
check.Number(t, res[2].Version, 1)
|
require.EqualValues(t, res[2].Version, 1)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,95 +81,95 @@ func testStore(
|
||||||
tablename = "test_goose_db_version"
|
tablename = "test_goose_db_version"
|
||||||
)
|
)
|
||||||
store, err := database.NewStore(d, tablename)
|
store, err := database.NewStore(d, tablename)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Create the version table.
|
// Create the version table.
|
||||||
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
||||||
return store.CreateVersionTable(ctx, tx)
|
return store.CreateVersionTable(ctx, tx)
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Create the version table again. This should fail.
|
// Create the version table again. This should fail.
|
||||||
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
||||||
return store.CreateVersionTable(ctx, tx)
|
return store.CreateVersionTable(ctx, tx)
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
if alreadyExists != nil {
|
if alreadyExists != nil {
|
||||||
alreadyExists(t, err)
|
alreadyExists(t, err)
|
||||||
}
|
}
|
||||||
// Get the latest version. There should be none.
|
// Get the latest version. There should be none.
|
||||||
_, err = store.GetLatestVersion(ctx, db)
|
_, err = store.GetLatestVersion(ctx, db)
|
||||||
check.IsError(t, err, database.ErrVersionNotFound)
|
require.ErrorIs(t, err, database.ErrVersionNotFound)
|
||||||
|
|
||||||
// List migrations. There should be none.
|
// List migrations. There should be none.
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
res, err := store.ListMigrations(ctx, conn)
|
res, err := store.ListMigrations(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), 0)
|
require.Equal(t, len(res), 0)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Insert 5 migrations in addition to the zero migration.
|
// Insert 5 migrations in addition to the zero migration.
|
||||||
for i := 0; i < 6; i++ {
|
for i := 0; i < 6; i++ {
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
|
err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
latest, err := store.GetLatestVersion(ctx, conn)
|
latest, err := store.GetLatestVersion(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, latest, int64(i))
|
require.Equal(t, latest, int64(i))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List migrations. There should be 6.
|
// List migrations. There should be 6.
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
res, err := store.ListMigrations(ctx, conn)
|
res, err := store.ListMigrations(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), 6)
|
require.Equal(t, len(res), 6)
|
||||||
// Check versions are in descending order.
|
// Check versions are in descending order.
|
||||||
for i := 0; i < 6; i++ {
|
for i := 0; i < 6; i++ {
|
||||||
check.Number(t, res[i].Version, 5-i)
|
require.EqualValues(t, res[i].Version, 5-i)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Delete 3 migrations backwards
|
// Delete 3 migrations backwards
|
||||||
for i := 5; i >= 3; i-- {
|
for i := 5; i >= 3; i-- {
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
err := store.Delete(ctx, conn, int64(i))
|
err := store.Delete(ctx, conn, int64(i))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
latest, err := store.GetLatestVersion(ctx, conn)
|
latest, err := store.GetLatestVersion(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, latest, int64(i-1))
|
require.Equal(t, latest, int64(i-1))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List migrations. There should be 3.
|
// List migrations. There should be 3.
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
res, err := store.ListMigrations(ctx, conn)
|
res, err := store.ListMigrations(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), 3)
|
require.Equal(t, len(res), 3)
|
||||||
// Check that the remaining versions are in descending order.
|
// Check that the remaining versions are in descending order.
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
check.Number(t, res[i].Version, 2-i)
|
require.EqualValues(t, res[i].Version, 2-i)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get remaining migrations one by one.
|
// Get remaining migrations one by one.
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
res, err := store.GetMigration(ctx, conn, int64(i))
|
res, err := store.GetMigration(ctx, conn, int64(i))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, res.IsApplied, true)
|
require.Equal(t, res.IsApplied, true)
|
||||||
check.Equal(t, res.Timestamp.IsZero(), false)
|
require.Equal(t, res.Timestamp.IsZero(), false)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete remaining migrations one by one and use all 3 connection types:
|
// Delete remaining migrations one by one and use all 3 connection types:
|
||||||
|
@ -177,46 +177,46 @@ func testStore(
|
||||||
// 1. *sql.Tx
|
// 1. *sql.Tx
|
||||||
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
||||||
err := store.Delete(ctx, tx, 2)
|
err := store.Delete(ctx, tx, 2)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
latest, err := store.GetLatestVersion(ctx, tx)
|
latest, err := store.GetLatestVersion(ctx, tx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, latest, 1)
|
require.EqualValues(t, latest, 1)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// 2. *sql.Conn
|
// 2. *sql.Conn
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
err := store.Delete(ctx, conn, 1)
|
err := store.Delete(ctx, conn, 1)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
latest, err := store.GetLatestVersion(ctx, conn)
|
latest, err := store.GetLatestVersion(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, latest, 0)
|
require.EqualValues(t, latest, 0)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// 3. *sql.DB
|
// 3. *sql.DB
|
||||||
err = store.Delete(ctx, db, 0)
|
err = store.Delete(ctx, db, 0)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = store.GetLatestVersion(ctx, db)
|
_, err = store.GetLatestVersion(ctx, db)
|
||||||
check.IsError(t, err, database.ErrVersionNotFound)
|
require.ErrorIs(t, err, database.ErrVersionNotFound)
|
||||||
|
|
||||||
// List migrations. There should be none.
|
// List migrations. There should be none.
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
res, err := store.ListMigrations(ctx, conn)
|
res, err := store.ListMigrations(ctx, conn)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), 0)
|
require.Equal(t, len(res), 0)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Try to get a migration that does not exist.
|
// Try to get a migration that does not exist.
|
||||||
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
||||||
_, err := store.GetMigration(ctx, conn, 0)
|
_, err := store.GetMigration(ctx, conn, 0)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
|
require.True(t, errors.Is(err, database.ErrVersionNotFound))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) {
|
func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) {
|
||||||
|
|
234
globals_test.go
234
globals_test.go
|
@ -5,31 +5,31 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewGoMigration(t *testing.T) {
|
func TestNewGoMigration(t *testing.T) {
|
||||||
t.Run("valid_both_nil", func(t *testing.T) {
|
t.Run("valid_both_nil", func(t *testing.T) {
|
||||||
m := NewGoMigration(1, nil, nil)
|
m := NewGoMigration(1, nil, nil)
|
||||||
// roundtrip
|
// roundtrip
|
||||||
check.Equal(t, m.Version, int64(1))
|
require.Equal(t, m.Version, int64(1))
|
||||||
check.Equal(t, m.Type, TypeGo)
|
require.Equal(t, m.Type, TypeGo)
|
||||||
check.Equal(t, m.Registered, true)
|
require.Equal(t, m.Registered, true)
|
||||||
check.Equal(t, m.Next, int64(-1))
|
require.Equal(t, m.Next, int64(-1))
|
||||||
check.Equal(t, m.Previous, int64(-1))
|
require.Equal(t, m.Previous, int64(-1))
|
||||||
check.Equal(t, m.Source, "")
|
require.Equal(t, m.Source, "")
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, true)
|
require.Nil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, true)
|
require.Nil(t, m.DownFnNoTxContext)
|
||||||
check.Bool(t, m.UpFnContext == nil, true)
|
require.Nil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.DownFnContext == nil, true)
|
require.Nil(t, m.DownFnContext)
|
||||||
check.Bool(t, m.UpFn == nil, true)
|
require.Nil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, true)
|
require.Nil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnNoTx == nil, true)
|
require.Nil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, true)
|
require.Nil(t, m.DownFnNoTx)
|
||||||
check.Bool(t, m.goUp != nil, true)
|
require.True(t, m.goUp != nil)
|
||||||
check.Bool(t, m.goDown != nil, true)
|
require.True(t, m.goDown != nil)
|
||||||
check.Equal(t, m.goUp.Mode, TransactionEnabled)
|
require.Equal(t, m.goUp.Mode, TransactionEnabled)
|
||||||
check.Equal(t, m.goDown.Mode, TransactionEnabled)
|
require.Equal(t, m.goDown.Mode, TransactionEnabled)
|
||||||
})
|
})
|
||||||
t.Run("all_set", func(t *testing.T) {
|
t.Run("all_set", func(t *testing.T) {
|
||||||
// This will eventually be an error when registering migrations.
|
// This will eventually be an error when registering migrations.
|
||||||
|
@ -39,14 +39,14 @@ func TestNewGoMigration(t *testing.T) {
|
||||||
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
|
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
|
||||||
)
|
)
|
||||||
// check only functions
|
// check only functions
|
||||||
check.Bool(t, m.UpFn != nil, true)
|
require.True(t, m.UpFn != nil)
|
||||||
check.Bool(t, m.UpFnContext != nil, true)
|
require.True(t, m.UpFnContext != nil)
|
||||||
check.Bool(t, m.UpFnNoTx != nil, true)
|
require.True(t, m.UpFnNoTx != nil)
|
||||||
check.Bool(t, m.UpFnNoTxContext != nil, true)
|
require.True(t, m.UpFnNoTxContext != nil)
|
||||||
check.Bool(t, m.DownFn != nil, true)
|
require.True(t, m.DownFn != nil)
|
||||||
check.Bool(t, m.DownFnContext != nil, true)
|
require.True(t, m.DownFnContext != nil)
|
||||||
check.Bool(t, m.DownFnNoTx != nil, true)
|
require.True(t, m.DownFnNoTx != nil)
|
||||||
check.Bool(t, m.DownFnNoTxContext != nil, true)
|
require.True(t, m.DownFnNoTxContext != nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,67 +59,67 @@ func TestTransactionMode(t *testing.T) {
|
||||||
err := SetGlobalMigrations(
|
err := SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunTx: runTx, RunDB: runDB}, nil), // cannot specify both
|
NewGoMigration(1, &GoFunc{RunTx: runTx, RunDB: runDB}, nil), // cannot specify both
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB")
|
require.Contains(t, err.Error(), "up function: must specify exactly one of RunTx or RunDB")
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, RunDB: runDB}), // cannot specify both
|
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, RunDB: runDB}), // cannot specify both
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB")
|
require.Contains(t, err.Error(), "down function: must specify exactly one of RunTx or RunDB")
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}, nil), // invalid explicit mode tx
|
NewGoMigration(1, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}, nil), // invalid explicit mode tx
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set")
|
require.Contains(t, err.Error(), "up function: transaction mode must be enabled or unspecified when RunTx is set")
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}), // invalid explicit mode tx
|
NewGoMigration(1, nil, &GoFunc{RunTx: runTx, Mode: TransactionDisabled}), // invalid explicit mode tx
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set")
|
require.Contains(t, err.Error(), "down function: transaction mode must be enabled or unspecified when RunTx is set")
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}, nil), // invalid explicit mode no-tx
|
NewGoMigration(1, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}, nil), // invalid explicit mode no-tx
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set")
|
require.Contains(t, err.Error(), "up function: transaction mode must be disabled or unspecified when RunDB is set")
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, nil, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}), // invalid explicit mode no-tx
|
NewGoMigration(1, nil, &GoFunc{RunDB: runDB, Mode: TransactionEnabled}), // invalid explicit mode no-tx
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set")
|
require.Contains(t, err.Error(), "down function: transaction mode must be disabled or unspecified when RunDB is set")
|
||||||
|
|
||||||
t.Run("default_mode", func(t *testing.T) {
|
t.Run("default_mode", func(t *testing.T) {
|
||||||
t.Cleanup(ResetGlobalMigrations)
|
t.Cleanup(ResetGlobalMigrations)
|
||||||
|
|
||||||
m := NewGoMigration(1, nil, nil)
|
m := NewGoMigration(1, nil, nil)
|
||||||
err = SetGlobalMigrations(m)
|
err = SetGlobalMigrations(m)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(registeredGoMigrations), 1)
|
require.Equal(t, len(registeredGoMigrations), 1)
|
||||||
registered := registeredGoMigrations[1]
|
registered := registeredGoMigrations[1]
|
||||||
check.Bool(t, registered.goUp != nil, true)
|
require.True(t, registered.goUp != nil)
|
||||||
check.Bool(t, registered.goDown != nil, true)
|
require.True(t, registered.goDown != nil)
|
||||||
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
|
require.Equal(t, registered.goUp.Mode, TransactionEnabled)
|
||||||
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
|
require.Equal(t, registered.goDown.Mode, TransactionEnabled)
|
||||||
|
|
||||||
migration2 := NewGoMigration(2, nil, nil)
|
migration2 := NewGoMigration(2, nil, nil)
|
||||||
// reset so we can check the default is set
|
// reset so we can check the default is set
|
||||||
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
|
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
|
||||||
err = SetGlobalMigrations(migration2)
|
err = SetGlobalMigrations(migration2)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")
|
require.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")
|
||||||
|
|
||||||
migration3 := NewGoMigration(3, nil, nil)
|
migration3 := NewGoMigration(3, nil, nil)
|
||||||
// reset so we can check the default is set
|
// reset so we can check the default is set
|
||||||
migration3.goDown.Mode = 0
|
migration3.goDown.Mode = 0
|
||||||
err = SetGlobalMigrations(migration3)
|
err = SetGlobalMigrations(migration3)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
|
require.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
|
||||||
})
|
})
|
||||||
t.Run("unknown_mode", func(t *testing.T) {
|
t.Run("unknown_mode", func(t *testing.T) {
|
||||||
m := NewGoMigration(1, nil, nil)
|
m := NewGoMigration(1, nil, nil)
|
||||||
m.goUp.Mode, m.goDown.Mode = 3, 3 // reset to default
|
m.goUp.Mode, m.goDown.Mode = 3, 3 // reset to default
|
||||||
err := SetGlobalMigrations(m)
|
err := SetGlobalMigrations(m)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "invalid mode: 3")
|
require.Contains(t, err.Error(), "invalid mode: 3")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,12 +131,12 @@ func TestLegacyFunctions(t *testing.T) {
|
||||||
|
|
||||||
assertMigration := func(t *testing.T, m *Migration, version int64) {
|
assertMigration := func(t *testing.T, m *Migration, version int64) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, m.Version, version)
|
require.Equal(t, m.Version, version)
|
||||||
check.Equal(t, m.Type, TypeGo)
|
require.Equal(t, m.Type, TypeGo)
|
||||||
check.Equal(t, m.Registered, true)
|
require.Equal(t, m.Registered, true)
|
||||||
check.Equal(t, m.Next, int64(-1))
|
require.Equal(t, m.Next, int64(-1))
|
||||||
check.Equal(t, m.Previous, int64(-1))
|
require.Equal(t, m.Previous, int64(-1))
|
||||||
check.Equal(t, m.Source, "")
|
require.Equal(t, m.Source, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("all_tx", func(t *testing.T) {
|
t.Run("all_tx", func(t *testing.T) {
|
||||||
|
@ -144,46 +144,46 @@ func TestLegacyFunctions(t *testing.T) {
|
||||||
err := SetGlobalMigrations(
|
err := SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}),
|
NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(registeredGoMigrations), 1)
|
require.Equal(t, len(registeredGoMigrations), 1)
|
||||||
m := registeredGoMigrations[1]
|
m := registeredGoMigrations[1]
|
||||||
assertMigration(t, m, 1)
|
assertMigration(t, m, 1)
|
||||||
// Legacy functions.
|
// Legacy functions.
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, true)
|
require.Nil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, true)
|
require.Nil(t, m.DownFnNoTxContext)
|
||||||
// Context-aware functions.
|
// Context-aware functions.
|
||||||
check.Bool(t, m.goUp == nil, false)
|
require.NotNil(t, m.goUp)
|
||||||
check.Bool(t, m.UpFnContext == nil, false)
|
require.NotNil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.goDown == nil, false)
|
require.NotNil(t, m.goDown)
|
||||||
check.Bool(t, m.DownFnContext == nil, false)
|
require.NotNil(t, m.DownFnContext)
|
||||||
// Always nil
|
// Always nil
|
||||||
check.Bool(t, m.UpFn == nil, false)
|
require.NotNil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, false)
|
require.NotNil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnNoTx == nil, true)
|
require.Nil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, true)
|
require.Nil(t, m.DownFnNoTx)
|
||||||
})
|
})
|
||||||
t.Run("all_db", func(t *testing.T) {
|
t.Run("all_db", func(t *testing.T) {
|
||||||
t.Cleanup(ResetGlobalMigrations)
|
t.Cleanup(ResetGlobalMigrations)
|
||||||
err := SetGlobalMigrations(
|
err := SetGlobalMigrations(
|
||||||
NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}),
|
NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(registeredGoMigrations), 1)
|
require.Equal(t, len(registeredGoMigrations), 1)
|
||||||
m := registeredGoMigrations[2]
|
m := registeredGoMigrations[2]
|
||||||
assertMigration(t, m, 2)
|
assertMigration(t, m, 2)
|
||||||
// Legacy functions.
|
// Legacy functions.
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, false)
|
require.NotNil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.goUp == nil, false)
|
require.NotNil(t, m.goUp)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, false)
|
require.NotNil(t, m.DownFnNoTxContext)
|
||||||
check.Bool(t, m.goDown == nil, false)
|
require.NotNil(t, m.goDown)
|
||||||
// Context-aware functions.
|
// Context-aware functions.
|
||||||
check.Bool(t, m.UpFnContext == nil, true)
|
require.Nil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.DownFnContext == nil, true)
|
require.Nil(t, m.DownFnContext)
|
||||||
// Always nil
|
// Always nil
|
||||||
check.Bool(t, m.UpFn == nil, true)
|
require.Nil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, true)
|
require.Nil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnNoTx == nil, false)
|
require.NotNil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, false)
|
require.NotNil(t, m.DownFnNoTx)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,91 +195,91 @@ func TestGlobalRegister(t *testing.T) {
|
||||||
|
|
||||||
// Success.
|
// Success.
|
||||||
err := SetGlobalMigrations([]*Migration{}...)
|
err := SetGlobalMigrations([]*Migration{}...)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Try to register the same migration again.
|
// Try to register the same migration again.
|
||||||
err = SetGlobalMigrations(
|
err = SetGlobalMigrations(
|
||||||
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "go migration with version 1 already registered")
|
require.Contains(t, err.Error(), "go migration with version 1 already registered")
|
||||||
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
|
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
require.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckMigration(t *testing.T) {
|
func TestCheckMigration(t *testing.T) {
|
||||||
// Success.
|
// Success.
|
||||||
err := checkGoMigration(NewGoMigration(1, nil, nil))
|
err := checkGoMigration(NewGoMigration(1, nil, nil))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Failures.
|
// Failures.
|
||||||
err = checkGoMigration(&Migration{})
|
err = checkGoMigration(&Migration{})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
require.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
|
||||||
err = checkGoMigration(&Migration{construct: true})
|
err = checkGoMigration(&Migration{construct: true})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must be registered")
|
require.Contains(t, err.Error(), "must be registered")
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true})
|
err = checkGoMigration(&Migration{construct: true, Registered: true})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), `type must be "go"`)
|
require.Contains(t, err.Error(), `type must be "go"`)
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "version must be greater than zero")
|
require.Contains(t, err.Error(), "version must be greater than zero")
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "up function: invalid mode: 0")
|
require.Contains(t, err.Error(), "up function: invalid mode: 0")
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "down function: invalid mode: 0")
|
require.Contains(t, err.Error(), "down function: invalid mode: 0")
|
||||||
// Success.
|
// Success.
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Failures.
|
// Failures.
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
|
require.Contains(t, err.Error(), `source must have .go extension: "foo"`)
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), `no filename separator '_' found`)
|
require.Contains(t, err.Error(), `no filename separator '_' found`)
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
|
require.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
|
require.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||||
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
require.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
|
||||||
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
|
||||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
require.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
UpFn: func(*sql.Tx) error { return nil },
|
UpFn: func(*sql.Tx) error { return nil },
|
||||||
UpFnNoTx: func(*sql.DB) error { return nil },
|
UpFnNoTx: func(*sql.DB) error { return nil },
|
||||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
|
require.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
|
||||||
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
|
||||||
DownFn: func(*sql.Tx) error { return nil },
|
DownFn: func(*sql.Tx) error { return nil },
|
||||||
DownFnNoTx: func(*sql.DB) error { return nil },
|
DownFnNoTx: func(*sql.DB) error { return nil },
|
||||||
goUp: &GoFunc{Mode: TransactionEnabled},
|
goUp: &GoFunc{Mode: TransactionEnabled},
|
||||||
goDown: &GoFunc{Mode: TransactionEnabled},
|
goDown: &GoFunc{Mode: TransactionEnabled},
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
|
require.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,8 +23,8 @@ func TestFullBinary(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cli := buildGooseCLI(t, false)
|
cli := buildGooseCLI(t, false)
|
||||||
out, err := cli.run("--version")
|
out, err := cli.run("--version")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n")
|
require.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLiteBinary(t *testing.T) {
|
func TestLiteBinary(t *testing.T) {
|
||||||
|
@ -34,8 +34,8 @@ func TestLiteBinary(t *testing.T) {
|
||||||
t.Run("binary_version", func(t *testing.T) {
|
t.Run("binary_version", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
out, err := cli.run("--version")
|
out, err := cli.run("--version")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n")
|
require.Equal(t, out, "goose version: "+gooseTestBinaryVersion+"\n")
|
||||||
})
|
})
|
||||||
t.Run("default_binary", func(t *testing.T) {
|
t.Run("default_binary", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
@ -55,8 +55,8 @@ func TestLiteBinary(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, c := range commands {
|
for _, c := range commands {
|
||||||
out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd)
|
out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Contains(t, out, c.out)
|
require.Contains(t, out, c.out)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("gh_issue_532", func(t *testing.T) {
|
t.Run("gh_issue_532", func(t *testing.T) {
|
||||||
|
@ -65,13 +65,13 @@ func TestLiteBinary(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
total := countSQLFiles(t, "testdata/migrations")
|
total := countSQLFiles(t, "testdata/migrations")
|
||||||
_, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up")
|
_, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up")
|
out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "up")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Contains(t, out, "goose: no migrations to run. current version: "+strconv.Itoa(total))
|
require.Contains(t, out, "goose: no migrations to run. current version: "+strconv.Itoa(total))
|
||||||
out, err = cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "version")
|
out, err = cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), "version")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Contains(t, out, "goose: version "+strconv.Itoa(total))
|
require.Contains(t, out, "goose: version "+strconv.Itoa(total))
|
||||||
})
|
})
|
||||||
t.Run("gh_issue_293", func(t *testing.T) {
|
t.Run("gh_issue_293", func(t *testing.T) {
|
||||||
// https://github.com/pressly/goose/issues/293
|
// https://github.com/pressly/goose/issues/293
|
||||||
|
@ -92,8 +92,8 @@ func TestLiteBinary(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, c := range commands {
|
for _, c := range commands {
|
||||||
out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd)
|
out, err := cli.run("-dir=testdata/migrations", "sqlite3", filepath.Join(dir, "sql.db"), c.cmd)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Contains(t, out, c.out)
|
require.Contains(t, out, c.out)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("gh_issue_336", func(t *testing.T) {
|
t.Run("gh_issue_336", func(t *testing.T) {
|
||||||
|
@ -101,8 +101,8 @@ func TestLiteBinary(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
_, err := cli.run("-dir="+dir, "sqlite3", filepath.Join(dir, "sql.db"), "up")
|
_, err := cli.run("-dir="+dir, "sqlite3", filepath.Join(dir, "sql.db"), "up")
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "goose run: no migration files found")
|
require.Contains(t, err.Error(), "goose run: no migration files found")
|
||||||
})
|
})
|
||||||
t.Run("create_and_fix", func(t *testing.T) {
|
t.Run("create_and_fix", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
@ -112,8 +112,8 @@ func TestLiteBinary(t *testing.T) {
|
||||||
createEmptyFile(t, dir, "20230826163141_charlie.sql")
|
createEmptyFile(t, dir, "20230826163141_charlie.sql")
|
||||||
createEmptyFile(t, dir, "20230826163151_delta.go")
|
createEmptyFile(t, dir, "20230826163151_delta.go")
|
||||||
total, err := os.ReadDir(dir)
|
total, err := os.ReadDir(dir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(total), 4)
|
require.Equal(t, len(total), 4)
|
||||||
migrationFiles := []struct {
|
migrationFiles := []struct {
|
||||||
name string
|
name string
|
||||||
fileType string
|
fileType string
|
||||||
|
@ -128,22 +128,22 @@ func TestLiteBinary(t *testing.T) {
|
||||||
args = append(args, f.fileType)
|
args = append(args, f.fileType)
|
||||||
}
|
}
|
||||||
out, err := cli.run(args...)
|
out, err := cli.run(args...)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Contains(t, out, "Created new file")
|
require.Contains(t, out, "Created new file")
|
||||||
// ensure different timestamps, granularity is 1 second
|
// ensure different timestamps, granularity is 1 second
|
||||||
if i < len(migrationFiles)-1 {
|
if i < len(migrationFiles)-1 {
|
||||||
time.Sleep(1100 * time.Millisecond)
|
time.Sleep(1100 * time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
total, err = os.ReadDir(dir)
|
total, err = os.ReadDir(dir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(total), 7)
|
require.Equal(t, len(total), 7)
|
||||||
out, err := cli.run("-dir="+dir, "fix")
|
out, err := cli.run("-dir="+dir, "fix")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Contains(t, out, "RENAMED")
|
require.Contains(t, out, "RENAMED")
|
||||||
files, err := os.ReadDir(dir)
|
files, err := os.ReadDir(dir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(files), 7)
|
require.Equal(t, len(files), 7)
|
||||||
expected := []string{
|
expected := []string{
|
||||||
"00001_alpha.sql",
|
"00001_alpha.sql",
|
||||||
"00003_bravo.sql",
|
"00003_bravo.sql",
|
||||||
|
@ -154,7 +154,7 @@ func TestLiteBinary(t *testing.T) {
|
||||||
"00008_golf.go",
|
"00008_golf.go",
|
||||||
}
|
}
|
||||||
for i, f := range files {
|
for i, f := range files {
|
||||||
check.Equal(t, f.Name(), expected[i])
|
require.Equal(t, f.Name(), expected[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -201,7 +201,7 @@ func buildGooseCLI(t *testing.T, lite bool) gooseBinary {
|
||||||
func countSQLFiles(t *testing.T, dir string) int {
|
func countSQLFiles(t *testing.T, dir string) int {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
files, err := filepath.Glob(filepath.Join(dir, "*.sql"))
|
files, err := filepath.Glob(filepath.Join(dir, "*.sql"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return len(files)
|
return len(files)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,6 +209,6 @@ func createEmptyFile(t *testing.T, dir, name string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
path := filepath.Join(dir, name)
|
path := filepath.Join(dir, name)
|
||||||
f, err := os.Create(path)
|
f, err := os.Create(path)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,43 +20,43 @@ func TestEmbeddedMigrations(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
// not using t.Parallel here to avoid races
|
// not using t.Parallel here to avoid races
|
||||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
db.SetMaxOpenConns(1)
|
db.SetMaxOpenConns(1)
|
||||||
|
|
||||||
migrationFiles, err := fs.ReadDir(embedMigrations, "testdata/migrations")
|
migrationFiles, err := fs.ReadDir(embedMigrations, "testdata/migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
total := len(migrationFiles)
|
total := len(migrationFiles)
|
||||||
|
|
||||||
// decouple from existing structure
|
// decouple from existing structure
|
||||||
fsys, err := fs.Sub(embedMigrations, "testdata/migrations")
|
fsys, err := fs.Sub(embedMigrations, "testdata/migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
goose.SetBaseFS(fsys)
|
goose.SetBaseFS(fsys)
|
||||||
t.Cleanup(func() { goose.SetBaseFS(nil) })
|
t.Cleanup(func() { goose.SetBaseFS(nil) })
|
||||||
check.NoError(t, goose.SetDialect("sqlite3"))
|
require.NoError(t, goose.SetDialect("sqlite3"))
|
||||||
|
|
||||||
t.Run("migration_cycle", func(t *testing.T) {
|
t.Run("migration_cycle", func(t *testing.T) {
|
||||||
err := goose.Up(db, ".")
|
err := goose.Up(db, ".")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ver, err := goose.GetDBVersion(db)
|
ver, err := goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, ver, total)
|
require.EqualValues(t, ver, total)
|
||||||
err = goose.Reset(db, ".")
|
err = goose.Reset(db, ".")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ver, err = goose.GetDBVersion(db)
|
ver, err = goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, ver, 0)
|
require.EqualValues(t, ver, 0)
|
||||||
})
|
})
|
||||||
t.Run("create_uses_os_fs", func(t *testing.T) {
|
t.Run("create_uses_os_fs", func(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
err := goose.Create(db, dir, "test", "sql")
|
err := goose.Create(db, dir, "test", "sql")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
paths, _ := filepath.Glob(filepath.Join(dir, "*test.sql"))
|
paths, _ := filepath.Glob(filepath.Join(dir, "*test.sql"))
|
||||||
check.NumberNotZero(t, len(paths))
|
require.NotZero(t, len(paths))
|
||||||
err = goose.Fix(dir)
|
err = goose.Fix(dir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = os.Stat(filepath.Join(dir, "00001_test.sql"))
|
_, err = os.Stat(filepath.Join(dir, "00001_test.sql"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,86 +0,0 @@
|
||||||
package check
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NoError(t *testing.T, err error) {
|
|
||||||
t.Helper()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func HasError(t *testing.T, err error) {
|
|
||||||
t.Helper()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expecting an error: got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsError(t *testing.T, err, target error) {
|
|
||||||
t.Helper()
|
|
||||||
if !errors.Is(err, target) {
|
|
||||||
t.Fatalf("expecting specific error:\ngot: %v\nwant: %v", err, target)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Number(t *testing.T, got, want interface{}) {
|
|
||||||
t.Helper()
|
|
||||||
gotNumber, err := reflectToInt64(got)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
wantNumber, err := reflectToInt64(want)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if gotNumber != wantNumber {
|
|
||||||
t.Fatalf("unexpected number value: got:%d want:%d ", gotNumber, wantNumber)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Equal(t *testing.T, got, want interface{}) {
|
|
||||||
t.Helper()
|
|
||||||
if !reflect.DeepEqual(got, want) {
|
|
||||||
t.Fatalf("failed deep equal:\ngot:\t%v\nwant:\t%v\v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NumberNotZero(t *testing.T, got interface{}) {
|
|
||||||
t.Helper()
|
|
||||||
gotNumber, err := reflectToInt64(got)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if gotNumber == 0 {
|
|
||||||
t.Fatalf("unexpected number value: got:%d want non-zero ", gotNumber)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Bool(t *testing.T, got, want bool) {
|
|
||||||
t.Helper()
|
|
||||||
if got != want {
|
|
||||||
t.Fatalf("unexpected boolean value: got:%t want:%t", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Contains(t *testing.T, got, want string) {
|
|
||||||
t.Helper()
|
|
||||||
if !strings.Contains(got, want) {
|
|
||||||
t.Errorf("failed to find substring:\n%s\n\nin string value:\n%s", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func reflectToInt64(v interface{}) (int64, error) {
|
|
||||||
switch typ := v.(type) {
|
|
||||||
case int, int8, int16, int32, int64:
|
|
||||||
return reflect.ValueOf(typ).Int(), nil
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("invalid number: must be int64 type: got:%T", v)
|
|
||||||
}
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParsingGoMigrations(t *testing.T) {
|
func TestParsingGoMigrations(t *testing.T) {
|
||||||
|
@ -31,11 +31,11 @@ func TestParsingGoMigrations(t *testing.T) {
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
g, err := parseGoFile(strings.NewReader(tc.input))
|
g, err := parseGoFile(strings.NewReader(tc.input))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, g.useTx != nil, true)
|
require.Equal(t, g.useTx != nil, true)
|
||||||
check.Bool(t, *g.useTx, tc.wantTx)
|
require.Equal(t, *g.useTx, tc.wantTx)
|
||||||
check.Equal(t, g.downFuncName, tc.wantDownName)
|
require.Equal(t, g.downFuncName, tc.wantDownName)
|
||||||
check.Equal(t, g.upFuncName, tc.wantUpName)
|
require.Equal(t, g.upFuncName, tc.wantUpName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -45,15 +45,15 @@ func TestGoMigrationStats(t *testing.T) {
|
||||||
|
|
||||||
base := "../../tests/gomigrations/success/testdata"
|
base := "../../tests/gomigrations/success/testdata"
|
||||||
all, err := os.ReadDir(base)
|
all, err := os.ReadDir(base)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, len(all), 16)
|
require.Equal(t, len(all), 16)
|
||||||
files := make([]string, 0, len(all))
|
files := make([]string, 0, len(all))
|
||||||
for _, f := range all {
|
for _, f := range all {
|
||||||
files = append(files, filepath.Join(base, f.Name()))
|
files = append(files, filepath.Join(base, f.Name()))
|
||||||
}
|
}
|
||||||
stats, err := GatherStats(NewFileWalker(files...), false)
|
stats, err := GatherStats(NewFileWalker(files...), false)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, len(stats), 16)
|
require.Equal(t, len(stats), 16)
|
||||||
checkGoStats(t, stats[0], "001_up_down.go", 1, 1, 1, true)
|
checkGoStats(t, stats[0], "001_up_down.go", 1, 1, 1, true)
|
||||||
checkGoStats(t, stats[1], "002_up_only.go", 2, 1, 0, true)
|
checkGoStats(t, stats[1], "002_up_only.go", 2, 1, 0, true)
|
||||||
checkGoStats(t, stats[2], "003_down_only.go", 3, 0, 1, true)
|
checkGoStats(t, stats[2], "003_down_only.go", 3, 0, 1, true)
|
||||||
|
@ -74,22 +74,22 @@ func TestGoMigrationStats(t *testing.T) {
|
||||||
|
|
||||||
func checkGoStats(t *testing.T, stats *Stats, filename string, version int64, upCount, downCount int, tx bool) {
|
func checkGoStats(t *testing.T, stats *Stats, filename string, version int64, upCount, downCount int, tx bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, filepath.Base(stats.FileName), filename)
|
require.Equal(t, filepath.Base(stats.FileName), filename)
|
||||||
check.Equal(t, stats.Version, version)
|
require.Equal(t, stats.Version, version)
|
||||||
check.Equal(t, stats.UpCount, upCount)
|
require.Equal(t, stats.UpCount, upCount)
|
||||||
check.Equal(t, stats.DownCount, downCount)
|
require.Equal(t, stats.DownCount, downCount)
|
||||||
check.Equal(t, stats.Tx, tx)
|
require.Equal(t, stats.Tx, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParsingGoMigrationsError(t *testing.T) {
|
func TestParsingGoMigrationsError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
_, err := parseGoFile(strings.NewReader(emptyInit))
|
_, err := parseGoFile(strings.NewReader(emptyInit))
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "no registered goose functions")
|
require.Contains(t, err.Error(), "no registered goose functions")
|
||||||
|
|
||||||
_, err = parseGoFile(strings.NewReader(wrongName))
|
_, err = parseGoFile(strings.NewReader(wrongName))
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext")
|
require.Contains(t, err.Error(), "AddMigration, AddMigrationNoTx, AddMigrationContext, AddMigrationNoTxContext")
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
package sqlparser_test
|
package sqlparser_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/fstest"
|
"testing/fstest"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
|
||||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseAllFromFS(t *testing.T) {
|
func TestParseAllFromFS(t *testing.T) {
|
||||||
|
@ -15,17 +14,17 @@ func TestParseAllFromFS(t *testing.T) {
|
||||||
t.Run("file_not_exist", func(t *testing.T) {
|
t.Run("file_not_exist", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{}
|
mapFS := fstest.MapFS{}
|
||||||
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, os.ErrNotExist), true)
|
require.ErrorIs(t, err, os.ErrNotExist)
|
||||||
})
|
})
|
||||||
t.Run("empty_file", func(t *testing.T) {
|
t.Run("empty_file", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
"001_foo.sql": &fstest.MapFile{},
|
"001_foo.sql": &fstest.MapFile{},
|
||||||
}
|
}
|
||||||
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
_, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "failed to parse migration")
|
require.Contains(t, err.Error(), "failed to parse migration")
|
||||||
check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation")
|
require.Contains(t, err.Error(), "must start with '-- +goose Up' annotation")
|
||||||
})
|
})
|
||||||
t.Run("all_statements", func(t *testing.T) {
|
t.Run("all_statements", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -53,26 +52,26 @@ DROP TABLE foo;
|
||||||
`),
|
`),
|
||||||
}
|
}
|
||||||
parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertParsedSQL(t, parsedSQL, true, 0, 0)
|
assertParsedSQL(t, parsedSQL, true, 0, 0)
|
||||||
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false)
|
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertParsedSQL(t, parsedSQL, true, 0, 0)
|
assertParsedSQL(t, parsedSQL, true, 0, 0)
|
||||||
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false)
|
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertParsedSQL(t, parsedSQL, true, 2, 1)
|
assertParsedSQL(t, parsedSQL, true, 2, 1)
|
||||||
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false)
|
parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertParsedSQL(t, parsedSQL, false, 1, 1)
|
assertParsedSQL(t, parsedSQL, false, 1, 1)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) {
|
func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Bool(t, got != nil, true)
|
require.NotNil(t, got)
|
||||||
check.Equal(t, len(got.Up), up)
|
require.Equal(t, len(got.Up), up)
|
||||||
check.Equal(t, len(got.Down), down)
|
require.Equal(t, len(got.Down), down)
|
||||||
check.Equal(t, got.UseTx, useTx)
|
require.Equal(t, got.UseTx, useTx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFile(data string) *fstest.MapFile {
|
func newFile(data string) *fstest.MapFile {
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -91,14 +91,14 @@ func TestInvalidUp(t *testing.T) {
|
||||||
|
|
||||||
testdataDir := filepath.Join("testdata", "invalid", "up")
|
testdataDir := filepath.Join("testdata", "invalid", "up")
|
||||||
entries, err := os.ReadDir(testdataDir)
|
entries, err := os.ReadDir(testdataDir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.NumberNotZero(t, len(entries))
|
require.NotZero(t, len(entries))
|
||||||
|
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
by, err := os.ReadFile(filepath.Join(testdataDir, entry.Name()))
|
by, err := os.ReadFile(filepath.Join(testdataDir, entry.Name()))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, _, err = ParseSQLMigration(strings.NewReader(string(by)), DirectionUp, false)
|
_, _, err = ParseSQLMigration(strings.NewReader(string(by)), DirectionUp, false)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -410,11 +410,11 @@ func testValid(t *testing.T, dir string, count int, direction Direction) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.Open(filepath.Join(dir, "input.sql"))
|
f, err := os.Open(filepath.Join(dir, "input.sql"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { f.Close() })
|
t.Cleanup(func() { f.Close() })
|
||||||
statements, _, err := ParseSQLMigration(f, direction, debug)
|
statements, _, err := ParseSQLMigration(f, direction, debug)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(statements), count)
|
require.Equal(t, len(statements), count)
|
||||||
compareStatements(t, dir, statements, direction)
|
compareStatements(t, dir, statements, direction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -422,7 +422,7 @@ func compareStatements(t *testing.T, dir string, statements []string, direction
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
files, err := filepath.Glob(filepath.Join(dir, fmt.Sprintf("*.%s.golden.sql", direction)))
|
files, err := filepath.Glob(filepath.Join(dir, fmt.Sprintf("*.%s.golden.sql", direction)))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if len(statements) != len(files) {
|
if len(statements) != len(files) {
|
||||||
t.Fatalf("mismatch between parsed statements (%d) and golden files (%d), did you check in NN.{up|down}.golden.sql file in %q?", len(statements), len(files), dir)
|
t.Fatalf("mismatch between parsed statements (%d) and golden files (%d), did you check in NN.{up|down}.golden.sql file in %q?", len(statements), len(files), dir)
|
||||||
}
|
}
|
||||||
|
@ -433,12 +433,12 @@ func compareStatements(t *testing.T, dir string, statements []string, direction
|
||||||
t.Fatal(`failed to cut on file delimiter ".", must be of the format NN.{up|down}.golden.sql`)
|
t.Fatal(`failed to cut on file delimiter ".", must be of the format NN.{up|down}.golden.sql`)
|
||||||
}
|
}
|
||||||
index, err := strconv.Atoi(before)
|
index, err := strconv.Atoi(before)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
index--
|
index--
|
||||||
|
|
||||||
goldenFilePath := filepath.Join(dir, goldenFile)
|
goldenFilePath := filepath.Join(dir, goldenFile)
|
||||||
by, err := os.ReadFile(goldenFilePath)
|
by, err := os.ReadFile(goldenFilePath)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, want := statements[index], string(by)
|
got, want := statements[index], string(by)
|
||||||
|
|
||||||
|
@ -452,7 +452,7 @@ func compareStatements(t *testing.T, dir string, statements []string, direction
|
||||||
filepath.Join("internal", "sqlparser", goldenFilePath),
|
filepath.Join("internal", "sqlparser", goldenFilePath),
|
||||||
)
|
)
|
||||||
err := os.WriteFile(goldenFilePath+".FAIL", []byte(got), 0644)
|
err := os.WriteFile(goldenFilePath+".FAIL", []byte(got), 0644)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -504,8 +504,8 @@ CREATE TABLE post (
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
_, _, err := ParseSQLMigration(strings.NewReader(s), DirectionUp, debug)
|
_, _, err := ParseSQLMigration(strings.NewReader(s), DirectionUp, debug)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:")
|
require.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_extractAnnotation(t *testing.T) {
|
func Test_extractAnnotation(t *testing.T) {
|
||||||
|
@ -513,76 +513,79 @@ func Test_extractAnnotation(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
input string
|
input string
|
||||||
want annotation
|
want annotation
|
||||||
wantErr func(t *testing.T, err error)
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Up",
|
name: "Up",
|
||||||
input: "-- +goose Up",
|
input: "-- +goose Up",
|
||||||
want: annotationUp,
|
want: annotationUp,
|
||||||
wantErr: check.NoError,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Down",
|
name: "Down",
|
||||||
input: "-- +goose Down",
|
input: "-- +goose Down",
|
||||||
want: annotationDown,
|
want: annotationDown,
|
||||||
wantErr: check.NoError,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "StmtBegin",
|
name: "StmtBegin",
|
||||||
input: "-- +goose StatementBegin",
|
input: "-- +goose StatementBegin",
|
||||||
want: annotationStatementBegin,
|
want: annotationStatementBegin,
|
||||||
wantErr: check.NoError,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "NoTransact",
|
name: "NoTransact",
|
||||||
input: "-- +goose NO TRANSACTION",
|
input: "-- +goose NO TRANSACTION",
|
||||||
want: annotationNoTransaction,
|
want: annotationNoTransaction,
|
||||||
wantErr: check.NoError,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Unsupported",
|
name: "Unsupported",
|
||||||
input: "-- +goose unsupported",
|
input: "-- +goose unsupported",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: check.HasError,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty",
|
name: "Empty",
|
||||||
input: "-- +goose",
|
input: "-- +goose",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: check.HasError,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "statement with spaces and Uppercase",
|
name: "statement with spaces and Uppercase",
|
||||||
input: "-- +goose UP ",
|
input: "-- +goose UP ",
|
||||||
want: annotationUp,
|
want: annotationUp,
|
||||||
wantErr: check.NoError,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "statement with leading whitespace - error",
|
name: "statement with leading whitespace - error",
|
||||||
input: " -- +goose UP ",
|
input: " -- +goose UP ",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: check.HasError,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "statement with leading \t - error",
|
name: "statement with leading \t - error",
|
||||||
input: "\t-- +goose UP ",
|
input: "\t-- +goose UP ",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: check.HasError,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple +goose annotations - error",
|
name: "multiple +goose annotations - error",
|
||||||
input: "-- +goose +goose Up",
|
input: "-- +goose +goose Up",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: check.HasError,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := extractAnnotation(tt.input)
|
got, err := extractAnnotation(tt.input)
|
||||||
tt.wantErr(t, err)
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
check.Equal(t, got, tt.want)
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
require.Equal(t, got, tt.want)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
|
||||||
"github.com/pressly/goose/v3/internal/testing/testdb"
|
"github.com/pressly/goose/v3/internal/testing/testdb"
|
||||||
"github.com/pressly/goose/v3/lock"
|
"github.com/pressly/goose/v3/lock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -433,23 +432,23 @@ func TestPostgresPending(t *testing.T) {
|
||||||
for i := 0; i < workers; i++ {
|
for i := 0; i < workers; i++ {
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS(testDir))
|
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS(testDir))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
hasPending, err := p.HasPending(context.Background())
|
hasPending, err := p.HasPending(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
boolCh <- hasPending
|
boolCh <- hasPending
|
||||||
current, target, err := p.GetVersions(context.Background())
|
current, target, err := p.GetVersions(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, int64(wantCurrent))
|
require.Equal(t, current, int64(wantCurrent))
|
||||||
check.Number(t, target, int64(wantTarget))
|
require.Equal(t, target, int64(wantTarget))
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
check.NoError(t, g.Wait())
|
require.NoError(t, g.Wait())
|
||||||
close(boolCh)
|
close(boolCh)
|
||||||
// expect all values to be true
|
// expect all values to be true
|
||||||
for hasPending := range boolCh {
|
for hasPending := range boolCh {
|
||||||
check.Bool(t, hasPending, want)
|
require.Equal(t, hasPending, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.Run("concurrent_has_pending", func(t *testing.T) {
|
t.Run("concurrent_has_pending", func(t *testing.T) {
|
||||||
|
@ -458,9 +457,9 @@ func TestPostgresPending(t *testing.T) {
|
||||||
|
|
||||||
// apply all migrations
|
// apply all migrations
|
||||||
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
|
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(context.Background())
|
_, err = p.Up(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("concurrent_no_pending", func(t *testing.T) {
|
t.Run("concurrent_no_pending", func(t *testing.T) {
|
||||||
run(t, false, len(files), len(files))
|
run(t, false, len(files), len(files))
|
||||||
|
@ -480,10 +479,10 @@ SELECT pg_sleep_for('4 seconds');
|
||||||
sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times.
|
sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times.
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker))
|
newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(newProvider.ListSources()), 1)
|
require.Equal(t, len(newProvider.ListSources()), 1)
|
||||||
oldProvider := p
|
oldProvider := p
|
||||||
check.Number(t, len(oldProvider.ListSources()), len(files))
|
require.Equal(t, len(oldProvider.ListSources()), len(files))
|
||||||
|
|
||||||
var g errgroup.Group
|
var g errgroup.Group
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
|
@ -491,13 +490,13 @@ SELECT pg_sleep_for('4 seconds');
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
check.Bool(t, hasPending, true)
|
require.True(t, hasPending)
|
||||||
current, target, err := newProvider.GetVersions(context.Background())
|
current, target, err := newProvider.GetVersions(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
check.Number(t, current, lastVersion)
|
require.EqualValues(t, current, lastVersion)
|
||||||
check.Number(t, target, lastVersion+1)
|
require.EqualValues(t, target, lastVersion+1)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
|
@ -505,16 +504,16 @@ SELECT pg_sleep_for('4 seconds');
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
check.Bool(t, hasPending, false)
|
require.False(t, hasPending)
|
||||||
current, target, err := oldProvider.GetVersions(context.Background())
|
current, target, err := oldProvider.GetVersions(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
check.Number(t, current, lastVersion)
|
require.EqualValues(t, current, lastVersion)
|
||||||
check.Number(t, target, lastVersion)
|
require.EqualValues(t, target, lastVersion)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
check.NoError(t, g.Wait())
|
require.NoError(t, g.Wait())
|
||||||
|
|
||||||
// A new provider is running in the background with a session lock to simulate a long running
|
// A new provider is running in the background with a session lock to simulate a long running
|
||||||
// migration. If older instances come up, they should not have any pending migrations and not be
|
// migration. If older instances come up, they should not have any pending migrations and not be
|
||||||
|
@ -526,29 +525,29 @@ SELECT pg_sleep_for('4 seconds');
|
||||||
})
|
})
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
isLocked, err := existsPgLock(context.Background(), db, lockID)
|
isLocked, err := existsPgLock(context.Background(), db, lockID)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, isLocked, true)
|
require.True(t, isLocked)
|
||||||
hasPending, err := oldProvider.HasPending(context.Background())
|
hasPending, err := oldProvider.HasPending(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, hasPending, false)
|
require.False(t, hasPending)
|
||||||
current, target, err := oldProvider.GetVersions(context.Background())
|
current, target, err := oldProvider.GetVersions(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, lastVersion)
|
require.EqualValues(t, current, lastVersion)
|
||||||
check.Number(t, target, lastVersion)
|
require.EqualValues(t, target, lastVersion)
|
||||||
// Wait for the long running migration to finish
|
// Wait for the long running migration to finish
|
||||||
check.NoError(t, g.Wait())
|
require.NoError(t, g.Wait())
|
||||||
// Check that the new migration was applied
|
// Check that the new migration was applied
|
||||||
hasPending, err = newProvider.HasPending(context.Background())
|
hasPending, err = newProvider.HasPending(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, hasPending, false)
|
require.False(t, hasPending)
|
||||||
current, target, err = newProvider.GetVersions(context.Background())
|
current, target, err = newProvider.GetVersions(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, lastVersion+1)
|
require.EqualValues(t, current, lastVersion+1)
|
||||||
check.Number(t, target, lastVersion+1)
|
require.EqualValues(t, target, lastVersion+1)
|
||||||
// The max version should be the new migration
|
// The max version should be the new migration
|
||||||
currentVersion, err := newProvider.GetDBVersion(context.Background())
|
currentVersion, err := newProvider.GetDBVersion(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, lastVersion+1)
|
require.EqualValues(t, currentVersion, lastVersion+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) {
|
func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) {
|
||||||
|
|
110
migrate_test.go
110
migrate_test.go
|
@ -7,7 +7,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMigrationSort(t *testing.T) {
|
func TestMigrationSort(t *testing.T) {
|
||||||
|
@ -68,10 +68,10 @@ func TestCollectMigrations(t *testing.T) {
|
||||||
t.Run("no_migration_files_found", func(t *testing.T) {
|
t.Run("no_migration_files_found", func(t *testing.T) {
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
err := os.MkdirAll(filepath.Join(tmp, "migrations-test"), 0755)
|
err := os.MkdirAll(filepath.Join(tmp, "migrations-test"), 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = collectMigrationsFS(os.DirFS(tmp), "migrations-test", 0, math.MaxInt64, nil)
|
_, err = collectMigrationsFS(os.DirFS(tmp), "migrations-test", 0, math.MaxInt64, nil)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "no migration files found")
|
require.Contains(t, err.Error(), "no migration files found")
|
||||||
})
|
})
|
||||||
t.Run("filesystem_registered_with_single_dirpath", func(t *testing.T) {
|
t.Run("filesystem_registered_with_single_dirpath", func(t *testing.T) {
|
||||||
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
||||||
|
@ -79,26 +79,26 @@ func TestCollectMigrations(t *testing.T) {
|
||||||
file3, file4 := "19081_a.go", "19082_b.go"
|
file3, file4 := "19081_a.go", "19082_b.go"
|
||||||
AddNamedMigrationContext(file1, nil, nil)
|
AddNamedMigrationContext(file1, nil, nil)
|
||||||
AddNamedMigrationContext(file2, nil, nil)
|
AddNamedMigrationContext(file2, nil, nil)
|
||||||
check.Number(t, len(registeredGoMigrations), 2)
|
require.Equal(t, len(registeredGoMigrations), 2)
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
dir := filepath.Join(tmp, "migrations", "dir1")
|
dir := filepath.Join(tmp, "migrations", "dir1")
|
||||||
err := os.MkdirAll(dir, 0755)
|
err := os.MkdirAll(dir, 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
createEmptyFile(t, dir, file1)
|
createEmptyFile(t, dir, file1)
|
||||||
createEmptyFile(t, dir, file2)
|
createEmptyFile(t, dir, file2)
|
||||||
createEmptyFile(t, dir, file3)
|
createEmptyFile(t, dir, file3)
|
||||||
createEmptyFile(t, dir, file4)
|
createEmptyFile(t, dir, file4)
|
||||||
fsys := os.DirFS(tmp)
|
fsys := os.DirFS(tmp)
|
||||||
files, err := fs.ReadDir(fsys, "migrations/dir1")
|
files, err := fs.ReadDir(fsys, "migrations/dir1")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(files), 4)
|
require.Equal(t, len(files), 4)
|
||||||
all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 4)
|
require.Equal(t, len(all), 4)
|
||||||
check.Number(t, all[0].Version, 9081)
|
require.EqualValues(t, all[0].Version, 9081)
|
||||||
check.Number(t, all[1].Version, 9082)
|
require.EqualValues(t, all[1].Version, 9082)
|
||||||
check.Number(t, all[2].Version, 19081)
|
require.EqualValues(t, all[2].Version, 19081)
|
||||||
check.Number(t, all[3].Version, 19082)
|
require.EqualValues(t, all[3].Version, 19082)
|
||||||
})
|
})
|
||||||
t.Run("filesystem_registered_with_multiple_dirpath", func(t *testing.T) {
|
t.Run("filesystem_registered_with_multiple_dirpath", func(t *testing.T) {
|
||||||
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
||||||
|
@ -106,14 +106,14 @@ func TestCollectMigrations(t *testing.T) {
|
||||||
AddNamedMigrationContext(file1, nil, nil)
|
AddNamedMigrationContext(file1, nil, nil)
|
||||||
AddNamedMigrationContext(file2, nil, nil)
|
AddNamedMigrationContext(file2, nil, nil)
|
||||||
AddNamedMigrationContext(file3, nil, nil)
|
AddNamedMigrationContext(file3, nil, nil)
|
||||||
check.Number(t, len(registeredGoMigrations), 3)
|
require.Equal(t, len(registeredGoMigrations), 3)
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
||||||
dir2 := filepath.Join(tmp, "migrations", "dir2")
|
dir2 := filepath.Join(tmp, "migrations", "dir2")
|
||||||
err := os.MkdirAll(dir1, 0755)
|
err := os.MkdirAll(dir1, 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = os.MkdirAll(dir2, 0755)
|
err = os.MkdirAll(dir2, 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
createEmptyFile(t, dir1, file1)
|
createEmptyFile(t, dir1, file1)
|
||||||
createEmptyFile(t, dir1, file2)
|
createEmptyFile(t, dir1, file2)
|
||||||
createEmptyFile(t, dir2, file3)
|
createEmptyFile(t, dir2, file3)
|
||||||
|
@ -122,33 +122,33 @@ func TestCollectMigrations(t *testing.T) {
|
||||||
// even though 3 Go migrations have been registered.
|
// even though 3 Go migrations have been registered.
|
||||||
{
|
{
|
||||||
all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
all, err := collectMigrationsFS(fsys, "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 2)
|
require.Equal(t, len(all), 2)
|
||||||
check.Number(t, all[0].Version, 1)
|
require.EqualValues(t, all[0].Version, 1)
|
||||||
check.Number(t, all[1].Version, 2)
|
require.EqualValues(t, all[1].Version, 2)
|
||||||
}
|
}
|
||||||
// Validate if dirpath 2 is specified we only get the one Go migration in migrations/dir2 folder
|
// Validate if dirpath 2 is specified we only get the one Go migration in migrations/dir2 folder
|
||||||
// even though 3 Go migrations have been registered.
|
// even though 3 Go migrations have been registered.
|
||||||
{
|
{
|
||||||
all, err := collectMigrationsFS(fsys, "migrations/dir2", 0, math.MaxInt64, registeredGoMigrations)
|
all, err := collectMigrationsFS(fsys, "migrations/dir2", 0, math.MaxInt64, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 1)
|
require.Equal(t, len(all), 1)
|
||||||
check.Number(t, all[0].Version, 1111)
|
require.EqualValues(t, all[0].Version, 1111)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("empty_filesystem_registered_manually", func(t *testing.T) {
|
t.Run("empty_filesystem_registered_manually", func(t *testing.T) {
|
||||||
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
||||||
AddNamedMigrationContext("00101_a.go", nil, nil)
|
AddNamedMigrationContext("00101_a.go", nil, nil)
|
||||||
AddNamedMigrationContext("00102_b.go", nil, nil)
|
AddNamedMigrationContext("00102_b.go", nil, nil)
|
||||||
check.Number(t, len(registeredGoMigrations), 2)
|
require.Equal(t, len(registeredGoMigrations), 2)
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
err := os.MkdirAll(filepath.Join(tmp, "migrations"), 0755)
|
err := os.MkdirAll(filepath.Join(tmp, "migrations"), 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations", 0, math.MaxInt64, registeredGoMigrations)
|
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations", 0, math.MaxInt64, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 2)
|
require.Equal(t, len(all), 2)
|
||||||
check.Number(t, all[0].Version, 101)
|
require.EqualValues(t, all[0].Version, 101)
|
||||||
check.Number(t, all[1].Version, 102)
|
require.EqualValues(t, all[1].Version, 102)
|
||||||
})
|
})
|
||||||
t.Run("unregistered_go_migrations", func(t *testing.T) {
|
t.Run("unregistered_go_migrations", func(t *testing.T) {
|
||||||
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
||||||
|
@ -157,67 +157,67 @@ func TestCollectMigrations(t *testing.T) {
|
||||||
// valid looking file2 Go migration
|
// valid looking file2 Go migration
|
||||||
AddNamedMigrationContext(file1, nil, nil)
|
AddNamedMigrationContext(file1, nil, nil)
|
||||||
AddNamedMigrationContext(file3, nil, nil)
|
AddNamedMigrationContext(file3, nil, nil)
|
||||||
check.Number(t, len(registeredGoMigrations), 2)
|
require.Equal(t, len(registeredGoMigrations), 2)
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
||||||
err := os.MkdirAll(dir1, 0755)
|
err := os.MkdirAll(dir1, 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Include the valid file2 with file1, file3. But remember, it has NOT been
|
// Include the valid file2 with file1, file3. But remember, it has NOT been
|
||||||
// registered.
|
// registered.
|
||||||
createEmptyFile(t, dir1, file1)
|
createEmptyFile(t, dir1, file1)
|
||||||
createEmptyFile(t, dir1, file2)
|
createEmptyFile(t, dir1, file2)
|
||||||
createEmptyFile(t, dir1, file3)
|
createEmptyFile(t, dir1, file3)
|
||||||
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 3)
|
require.Equal(t, len(all), 3)
|
||||||
check.Number(t, all[0].Version, 1)
|
require.EqualValues(t, all[0].Version, 1)
|
||||||
check.Bool(t, all[0].Registered, true)
|
require.True(t, all[0].Registered)
|
||||||
check.Number(t, all[1].Version, 998)
|
require.EqualValues(t, all[1].Version, 998)
|
||||||
// This migrations is marked unregistered and will lazily raise an error if/when this
|
// This migrations is marked unregistered and will lazily raise an error if/when this
|
||||||
// migration is run
|
// migration is run
|
||||||
check.Bool(t, all[1].Registered, false)
|
require.False(t, all[1].Registered)
|
||||||
check.Number(t, all[2].Version, 999)
|
require.EqualValues(t, all[2].Version, 999)
|
||||||
check.Bool(t, all[2].Registered, true)
|
require.True(t, all[2].Registered)
|
||||||
})
|
})
|
||||||
t.Run("with_skipped_go_files", func(t *testing.T) {
|
t.Run("with_skipped_go_files", func(t *testing.T) {
|
||||||
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
||||||
file1, file2, file3, file4 := "00001_a.go", "00002_b.sql", "00999_c_test.go", "embed.go"
|
file1, file2, file3, file4 := "00001_a.go", "00002_b.sql", "00999_c_test.go", "embed.go"
|
||||||
AddNamedMigrationContext(file1, nil, nil)
|
AddNamedMigrationContext(file1, nil, nil)
|
||||||
check.Number(t, len(registeredGoMigrations), 1)
|
require.Equal(t, len(registeredGoMigrations), 1)
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
||||||
err := os.MkdirAll(dir1, 0755)
|
err := os.MkdirAll(dir1, 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
createEmptyFile(t, dir1, file1)
|
createEmptyFile(t, dir1, file1)
|
||||||
createEmptyFile(t, dir1, file2)
|
createEmptyFile(t, dir1, file2)
|
||||||
createEmptyFile(t, dir1, file3)
|
createEmptyFile(t, dir1, file3)
|
||||||
createEmptyFile(t, dir1, file4)
|
createEmptyFile(t, dir1, file4)
|
||||||
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 0, math.MaxInt64, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 2)
|
require.Equal(t, len(all), 2)
|
||||||
check.Number(t, all[0].Version, 1)
|
require.EqualValues(t, all[0].Version, 1)
|
||||||
check.Bool(t, all[0].Registered, true)
|
require.True(t, all[0].Registered)
|
||||||
check.Number(t, all[1].Version, 2)
|
require.EqualValues(t, all[1].Version, 2)
|
||||||
check.Bool(t, all[1].Registered, false)
|
require.False(t, all[1].Registered)
|
||||||
})
|
})
|
||||||
t.Run("current_and_target", func(t *testing.T) {
|
t.Run("current_and_target", func(t *testing.T) {
|
||||||
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
t.Cleanup(func() { clearMap(registeredGoMigrations) })
|
||||||
file1, file2, file3 := "01001_a.go", "01002_b.sql", "01003_c.go"
|
file1, file2, file3 := "01001_a.go", "01002_b.sql", "01003_c.go"
|
||||||
AddNamedMigrationContext(file1, nil, nil)
|
AddNamedMigrationContext(file1, nil, nil)
|
||||||
AddNamedMigrationContext(file3, nil, nil)
|
AddNamedMigrationContext(file3, nil, nil)
|
||||||
check.Number(t, len(registeredGoMigrations), 2)
|
require.Equal(t, len(registeredGoMigrations), 2)
|
||||||
tmp := t.TempDir()
|
tmp := t.TempDir()
|
||||||
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
dir1 := filepath.Join(tmp, "migrations", "dir1")
|
||||||
err := os.MkdirAll(dir1, 0755)
|
err := os.MkdirAll(dir1, 0755)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
createEmptyFile(t, dir1, file1)
|
createEmptyFile(t, dir1, file1)
|
||||||
createEmptyFile(t, dir1, file2)
|
createEmptyFile(t, dir1, file2)
|
||||||
createEmptyFile(t, dir1, file3)
|
createEmptyFile(t, dir1, file3)
|
||||||
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 1001, 1003, registeredGoMigrations)
|
all, err := collectMigrationsFS(os.DirFS(tmp), "migrations/dir1", 1001, 1003, registeredGoMigrations)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(all), 2)
|
require.Equal(t, len(all), 2)
|
||||||
check.Number(t, all[0].Version, 1002)
|
require.EqualValues(t, all[0].Version, 1002)
|
||||||
check.Number(t, all[1].Version, 1003)
|
require.EqualValues(t, all[1].Version, 1003)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ func TestVersionFilter(t *testing.T) {
|
||||||
func createEmptyFile(t *testing.T, dir, name string) {
|
func createEmptyFile(t *testing.T, dir, name string) {
|
||||||
path := filepath.Join(dir, name)
|
path := filepath.Join(dir, name)
|
||||||
f, err := os.Create(path)
|
f, err := os.Create(path)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,31 +5,31 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"testing/fstest"
|
"testing/fstest"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCollectFileSources(t *testing.T) {
|
func TestCollectFileSources(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("nil_fsys", func(t *testing.T) {
|
t.Run("nil_fsys", func(t *testing.T) {
|
||||||
sources, err := collectFilesystemSources(nil, false, nil, nil)
|
sources, err := collectFilesystemSources(nil, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, sources != nil, true)
|
require.True(t, sources != nil)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
require.Equal(t, len(sources.sqlSources), 0)
|
||||||
})
|
})
|
||||||
t.Run("noop_fsys", func(t *testing.T) {
|
t.Run("noop_fsys", func(t *testing.T) {
|
||||||
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
|
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, sources != nil, true)
|
require.True(t, sources != nil)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
require.Equal(t, len(sources.sqlSources), 0)
|
||||||
})
|
})
|
||||||
t.Run("empty_fsys", func(t *testing.T) {
|
t.Run("empty_fsys", func(t *testing.T) {
|
||||||
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
|
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
require.Equal(t, len(sources.sqlSources), 0)
|
||||||
check.Bool(t, sources != nil, true)
|
require.True(t, sources != nil)
|
||||||
})
|
})
|
||||||
t.Run("incorrect_fsys", func(t *testing.T) {
|
t.Run("incorrect_fsys", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -37,21 +37,21 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
}
|
}
|
||||||
// strict disable - should not error
|
// strict disable - should not error
|
||||||
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
check.Number(t, len(sources.sqlSources), 0)
|
require.Equal(t, len(sources.sqlSources), 0)
|
||||||
// strict enabled - should error
|
// strict enabled - should error
|
||||||
_, err = collectFilesystemSources(mapFS, true, nil, nil)
|
_, err = collectFilesystemSources(mapFS, true, nil, nil)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "migration version must be greater than zero")
|
require.Contains(t, err.Error(), "migration version must be greater than zero")
|
||||||
})
|
})
|
||||||
t.Run("collect", func(t *testing.T) {
|
t.Run("collect", func(t *testing.T) {
|
||||||
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 4)
|
require.Equal(t, len(sources.sqlSources), 4)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
expected := fileSources{
|
expected := fileSources{
|
||||||
sqlSources: []Source{
|
sqlSources: []Source{
|
||||||
newSource(TypeSQL, "00001_foo.sql", 1),
|
newSource(TypeSQL, "00001_foo.sql", 1),
|
||||||
|
@ -61,12 +61,12 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for i := 0; i < len(sources.sqlSources); i++ {
|
for i := 0; i < len(sources.sqlSources); i++ {
|
||||||
check.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
require.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("excludes", func(t *testing.T) {
|
t.Run("excludes", func(t *testing.T) {
|
||||||
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(
|
sources, err := collectFilesystemSources(
|
||||||
fsys,
|
fsys,
|
||||||
false,
|
false,
|
||||||
|
@ -77,9 +77,9 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 2)
|
require.Equal(t, len(sources.sqlSources), 2)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
expected := fileSources{
|
expected := fileSources{
|
||||||
sqlSources: []Source{
|
sqlSources: []Source{
|
||||||
newSource(TypeSQL, "00001_foo.sql", 1),
|
newSource(TypeSQL, "00001_foo.sql", 1),
|
||||||
|
@ -87,7 +87,7 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for i := 0; i < len(sources.sqlSources); i++ {
|
for i := 0; i < len(sources.sqlSources); i++ {
|
||||||
check.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
require.Equal(t, sources.sqlSources[i], expected.sqlSources[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("strict", func(t *testing.T) {
|
t.Run("strict", func(t *testing.T) {
|
||||||
|
@ -95,10 +95,10 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
// Add a file with no version number
|
// Add a file with no version number
|
||||||
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = collectFilesystemSources(fsys, true, nil, nil)
|
_, err = collectFilesystemSources(fsys, true, nil, nil)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
require.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
|
||||||
})
|
})
|
||||||
t.Run("skip_go_test_files", func(t *testing.T) {
|
t.Run("skip_go_test_files", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -109,9 +109,9 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
"5_foo_test.go": {Data: []byte(`package goose_test`)},
|
||||||
}
|
}
|
||||||
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 4)
|
require.Equal(t, len(sources.sqlSources), 4)
|
||||||
check.Number(t, len(sources.goSources), 0)
|
require.Equal(t, len(sources.goSources), 0)
|
||||||
})
|
})
|
||||||
t.Run("skip_random_files", func(t *testing.T) {
|
t.Run("skip_random_files", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -124,18 +124,18 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
|
||||||
}
|
}
|
||||||
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(sources.sqlSources), 2)
|
require.Equal(t, len(sources.sqlSources), 2)
|
||||||
check.Number(t, len(sources.goSources), 1)
|
require.Equal(t, len(sources.goSources), 1)
|
||||||
// 1
|
// 1
|
||||||
check.Equal(t, sources.sqlSources[0].Path, "1_foo.sql")
|
require.Equal(t, sources.sqlSources[0].Path, "1_foo.sql")
|
||||||
check.Equal(t, sources.sqlSources[0].Version, int64(1))
|
require.Equal(t, sources.sqlSources[0].Version, int64(1))
|
||||||
// 2
|
// 2
|
||||||
check.Equal(t, sources.sqlSources[1].Path, "5_qux.sql")
|
require.Equal(t, sources.sqlSources[1].Path, "5_qux.sql")
|
||||||
check.Equal(t, sources.sqlSources[1].Version, int64(5))
|
require.Equal(t, sources.sqlSources[1].Version, int64(5))
|
||||||
// 3
|
// 3
|
||||||
check.Equal(t, sources.goSources[0].Path, "4_something.go")
|
require.Equal(t, sources.goSources[0].Path, "4_something.go")
|
||||||
check.Equal(t, sources.goSources[0].Version, int64(4))
|
require.Equal(t, sources.goSources[0].Version, int64(4))
|
||||||
})
|
})
|
||||||
t.Run("duplicate_versions", func(t *testing.T) {
|
t.Run("duplicate_versions", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -143,8 +143,8 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
"01_bar.sql": sqlMapFile,
|
"01_bar.sql": sqlMapFile,
|
||||||
}
|
}
|
||||||
_, err := collectFilesystemSources(mapFS, false, nil, nil)
|
_, err := collectFilesystemSources(mapFS, false, nil, nil)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
require.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||||
})
|
})
|
||||||
t.Run("dirpath", func(t *testing.T) {
|
t.Run("dirpath", func(t *testing.T) {
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -157,13 +157,13 @@ func TestCollectFileSources(t *testing.T) {
|
||||||
assertDirpath := func(dirpath string, sqlSources []Source) {
|
assertDirpath := func(dirpath string, sqlSources []Source) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
f, err := fs.Sub(mapFS, dirpath)
|
f, err := fs.Sub(mapFS, dirpath)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
got, err := collectFilesystemSources(f, false, nil, nil)
|
got, err := collectFilesystemSources(f, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(got.sqlSources), len(sqlSources))
|
require.Equal(t, len(got.sqlSources), len(sqlSources))
|
||||||
check.Number(t, len(got.goSources), 0)
|
require.Equal(t, len(got.goSources), 0)
|
||||||
for i := 0; i < len(got.sqlSources); i++ {
|
for i := 0; i < len(got.sqlSources); i++ {
|
||||||
check.Equal(t, got.sqlSources[i], sqlSources[i])
|
require.Equal(t, got.sqlSources[i], sqlSources[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assertDirpath(".", []Source{
|
assertDirpath(".", []Source{
|
||||||
|
@ -193,35 +193,35 @@ func TestMerge(t *testing.T) {
|
||||||
"migrations/00003_baz.go": {Data: []byte(`package migrations`)},
|
"migrations/00003_baz.go": {Data: []byte(`package migrations`)},
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, len(sources.sqlSources), 1)
|
require.Equal(t, len(sources.sqlSources), 1)
|
||||||
check.Equal(t, len(sources.goSources), 2)
|
require.Equal(t, len(sources.goSources), 2)
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
registered := map[int64]*Migration{
|
registered := map[int64]*Migration{
|
||||||
2: NewGoMigration(2, nil, nil),
|
2: NewGoMigration(2, nil, nil),
|
||||||
3: NewGoMigration(3, nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
}
|
}
|
||||||
migrations, err := merge(sources, registered)
|
migrations, err := merge(sources, registered)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 3)
|
require.Equal(t, len(migrations), 3)
|
||||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
||||||
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
|
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
|
||||||
})
|
})
|
||||||
t.Run("unregistered_all", func(t *testing.T) {
|
t.Run("unregistered_all", func(t *testing.T) {
|
||||||
_, err := merge(sources, nil)
|
_, err := merge(sources, nil)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:")
|
require.Contains(t, err.Error(), "error: detected 2 unregistered Go files:")
|
||||||
check.Contains(t, err.Error(), "00002_bar.go")
|
require.Contains(t, err.Error(), "00002_bar.go")
|
||||||
check.Contains(t, err.Error(), "00003_baz.go")
|
require.Contains(t, err.Error(), "00003_baz.go")
|
||||||
})
|
})
|
||||||
t.Run("unregistered_some", func(t *testing.T) {
|
t.Run("unregistered_some", func(t *testing.T) {
|
||||||
_, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)})
|
_, err := merge(sources, map[int64]*Migration{2: NewGoMigration(2, nil, nil)})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
require.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
|
||||||
check.Contains(t, err.Error(), "00003_baz.go")
|
require.Contains(t, err.Error(), "00003_baz.go")
|
||||||
})
|
})
|
||||||
t.Run("duplicate_sql", func(t *testing.T) {
|
t.Run("duplicate_sql", func(t *testing.T) {
|
||||||
_, err := merge(sources, map[int64]*Migration{
|
_, err := merge(sources, map[int64]*Migration{
|
||||||
|
@ -229,8 +229,8 @@ func TestMerge(t *testing.T) {
|
||||||
2: NewGoMigration(2, nil, nil),
|
2: NewGoMigration(2, nil, nil),
|
||||||
3: NewGoMigration(3, nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
})
|
})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "found duplicate migration version 1")
|
require.Contains(t, err.Error(), "found duplicate migration version 1")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("no_go_files_on_disk", func(t *testing.T) {
|
t.Run("no_go_files_on_disk", func(t *testing.T) {
|
||||||
|
@ -241,17 +241,17 @@ func TestMerge(t *testing.T) {
|
||||||
"migrations/00005_baz.sql": sqlMapFile,
|
"migrations/00005_baz.sql": sqlMapFile,
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Run("unregistered_all", func(t *testing.T) {
|
t.Run("unregistered_all", func(t *testing.T) {
|
||||||
migrations, err := merge(sources, map[int64]*Migration{
|
migrations, err := merge(sources, map[int64]*Migration{
|
||||||
3: NewGoMigration(3, nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
// 4 is missing
|
// 4 is missing
|
||||||
6: NewGoMigration(6, nil, nil),
|
6: NewGoMigration(6, nil, nil),
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 5)
|
require.Equal(t, len(migrations), 5)
|
||||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||||
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
|
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
|
||||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||||
|
@ -265,9 +265,9 @@ func TestMerge(t *testing.T) {
|
||||||
"migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)},
|
"migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)},
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
sources, err := collectFilesystemSources(fsys, false, nil, nil)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Run("unregistered_all", func(t *testing.T) {
|
t.Run("unregistered_all", func(t *testing.T) {
|
||||||
migrations, err := merge(sources, map[int64]*Migration{
|
migrations, err := merge(sources, map[int64]*Migration{
|
||||||
// This is the only Go file on disk.
|
// This is the only Go file on disk.
|
||||||
|
@ -276,8 +276,8 @@ func TestMerge(t *testing.T) {
|
||||||
3: NewGoMigration(3, nil, nil),
|
3: NewGoMigration(3, nil, nil),
|
||||||
6: NewGoMigration(6, nil, nil),
|
6: NewGoMigration(6, nil, nil),
|
||||||
})
|
})
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 4)
|
require.Equal(t, len(migrations), 4)
|
||||||
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
|
||||||
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
|
||||||
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
|
||||||
|
@ -288,15 +288,15 @@ func TestMerge(t *testing.T) {
|
||||||
|
|
||||||
func assertMigration(t *testing.T, got *Migration, want Source) {
|
func assertMigration(t *testing.T, got *Migration, want Source) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.Type, want.Type)
|
require.Equal(t, got.Type, want.Type)
|
||||||
check.Equal(t, got.Version, want.Version)
|
require.Equal(t, got.Version, want.Version)
|
||||||
check.Equal(t, got.Source, want.Path)
|
require.Equal(t, got.Source, want.Path)
|
||||||
switch got.Type {
|
switch got.Type {
|
||||||
case TypeGo:
|
case TypeGo:
|
||||||
check.Bool(t, got.goUp != nil, true)
|
require.True(t, got.goUp != nil)
|
||||||
check.Bool(t, got.goDown != nil, true)
|
require.True(t, got.goDown != nil)
|
||||||
case TypeSQL:
|
case TypeSQL:
|
||||||
check.Bool(t, got.sql.Parsed, false)
|
require.False(t, got.sql.Parsed)
|
||||||
default:
|
default:
|
||||||
t.Fatalf("unknown migration type: %s", got.Type)
|
t.Fatalf("unknown migration type: %s", got.Type)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,14 +8,14 @@ import (
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/database"
|
"github.com/pressly/goose/v3/database"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewProvider(t *testing.T) {
|
func TestNewProvider(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
fsys := fstest.MapFS{
|
fsys := fstest.MapFS{
|
||||||
"1_foo.sql": {Data: []byte(migration1)},
|
"1_foo.sql": {Data: []byte(migration1)},
|
||||||
"2_bar.sql": {Data: []byte(migration2)},
|
"2_bar.sql": {Data: []byte(migration2)},
|
||||||
|
@ -25,41 +25,41 @@ func TestNewProvider(t *testing.T) {
|
||||||
t.Run("invalid", func(t *testing.T) {
|
t.Run("invalid", func(t *testing.T) {
|
||||||
// Empty dialect not allowed
|
// Empty dialect not allowed
|
||||||
_, err = goose.NewProvider("", db, fsys)
|
_, err = goose.NewProvider("", db, fsys)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Invalid dialect not allowed
|
// Invalid dialect not allowed
|
||||||
_, err = goose.NewProvider("unknown-dialect", db, fsys)
|
_, err = goose.NewProvider("unknown-dialect", db, fsys)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Nil db not allowed
|
// Nil db not allowed
|
||||||
_, err = goose.NewProvider(goose.DialectSQLite3, nil, fsys)
|
_, err = goose.NewProvider(goose.DialectSQLite3, nil, fsys)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Nil store not allowed
|
// Nil store not allowed
|
||||||
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(nil))
|
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(nil))
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Cannot set both dialect and store
|
// Cannot set both dialect and store
|
||||||
store, err := database.NewStore(goose.DialectSQLite3, "custom_table")
|
store, err := database.NewStore(goose.DialectSQLite3, "custom_table")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(store))
|
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(store))
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// Multiple stores not allowed
|
// Multiple stores not allowed
|
||||||
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil,
|
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil,
|
||||||
goose.WithStore(store),
|
goose.WithStore(store),
|
||||||
goose.WithStore(store),
|
goose.WithStore(store),
|
||||||
)
|
)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
// Valid dialect, db, and fsys allowed
|
// Valid dialect, db, and fsys allowed
|
||||||
_, err = goose.NewProvider(goose.DialectSQLite3, db, fsys)
|
_, err = goose.NewProvider(goose.DialectSQLite3, db, fsys)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Valid dialect, db, fsys, and verbose allowed
|
// Valid dialect, db, fsys, and verbose allowed
|
||||||
_, err = goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
_, err = goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
||||||
goose.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Custom store allowed
|
// Custom store allowed
|
||||||
store, err := database.NewStore(goose.DialectSQLite3, "custom_table")
|
store, err := database.NewStore(goose.DialectSQLite3, "custom_table")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = goose.NewProvider("", db, nil, goose.WithStore(store))
|
_, err = goose.NewProvider("", db, nil, goose.WithStore(store))
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/database"
|
"github.com/pressly/goose/v3/database"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProviderRun(t *testing.T) {
|
func TestProviderRun(t *testing.T) {
|
||||||
|
@ -24,38 +24,38 @@ func TestProviderRun(t *testing.T) {
|
||||||
|
|
||||||
t.Run("closed_db", func(t *testing.T) {
|
t.Run("closed_db", func(t *testing.T) {
|
||||||
p, db := newProviderWithDB(t)
|
p, db := newProviderWithDB(t)
|
||||||
check.NoError(t, db.Close())
|
require.NoError(t, db.Close())
|
||||||
_, err := p.Up(context.Background())
|
_, err := p.Up(context.Background())
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Equal(t, err.Error(), "failed to initialize: sql: database is closed")
|
require.Equal(t, err.Error(), "failed to initialize: sql: database is closed")
|
||||||
})
|
})
|
||||||
t.Run("ping_and_close", func(t *testing.T) {
|
t.Run("ping_and_close", func(t *testing.T) {
|
||||||
p, _ := newProviderWithDB(t)
|
p, _ := newProviderWithDB(t)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
check.NoError(t, p.Close())
|
require.NoError(t, p.Close())
|
||||||
})
|
})
|
||||||
check.NoError(t, p.Ping(context.Background()))
|
require.NoError(t, p.Ping(context.Background()))
|
||||||
})
|
})
|
||||||
t.Run("apply_unknown_version", func(t *testing.T) {
|
t.Run("apply_unknown_version", func(t *testing.T) {
|
||||||
p, _ := newProviderWithDB(t)
|
p, _ := newProviderWithDB(t)
|
||||||
_, err := p.ApplyVersion(context.Background(), 999, true)
|
_, err := p.ApplyVersion(context.Background(), 999, true)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
|
require.True(t, errors.Is(err, goose.ErrVersionNotFound))
|
||||||
_, err = p.ApplyVersion(context.Background(), 999, false)
|
_, err = p.ApplyVersion(context.Background(), 999, false)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, goose.ErrVersionNotFound), true)
|
require.True(t, errors.Is(err, goose.ErrVersionNotFound))
|
||||||
})
|
})
|
||||||
t.Run("run_zero", func(t *testing.T) {
|
t.Run("run_zero", func(t *testing.T) {
|
||||||
p, _ := newProviderWithDB(t)
|
p, _ := newProviderWithDB(t)
|
||||||
_, err := p.UpTo(context.Background(), 0)
|
_, err := p.UpTo(context.Background(), 0)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Equal(t, err.Error(), "version must be greater than 0")
|
require.Equal(t, err.Error(), "version must be greater than 0")
|
||||||
_, err = p.DownTo(context.Background(), -1)
|
_, err = p.DownTo(context.Background(), -1)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
|
require.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
|
||||||
_, err = p.ApplyVersion(context.Background(), 0, true)
|
_, err = p.ApplyVersion(context.Background(), 0, true)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Equal(t, err.Error(), "version must be greater than 0")
|
require.Equal(t, err.Error(), "version must be greater than 0")
|
||||||
})
|
})
|
||||||
t.Run("up_and_down_all", func(t *testing.T) {
|
t.Run("up_and_down_all", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -64,15 +64,15 @@ func TestProviderRun(t *testing.T) {
|
||||||
numCount = 7
|
numCount = 7
|
||||||
)
|
)
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.Number(t, len(sources), numCount)
|
require.Equal(t, len(sources), numCount)
|
||||||
// Ensure only SQL migrations are returned
|
// Ensure only SQL migrations are returned
|
||||||
for _, s := range sources {
|
for _, s := range sources {
|
||||||
check.Equal(t, s.Type, goose.TypeSQL)
|
require.Equal(t, s.Type, goose.TypeSQL)
|
||||||
}
|
}
|
||||||
// Test Up
|
// Test Up
|
||||||
res, err := p.Up(ctx)
|
res, err := p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), numCount)
|
require.Equal(t, len(res), numCount)
|
||||||
assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
assertResult(t, res[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||||
assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
assertResult(t, res[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||||
assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
assertResult(t, res[2], newSource(goose.TypeSQL, "00003_comments_table.sql", 3), "up", false)
|
||||||
|
@ -82,8 +82,8 @@ func TestProviderRun(t *testing.T) {
|
||||||
assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
assertResult(t, res[6], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "up", true)
|
||||||
// Test Down
|
// Test Down
|
||||||
res, err = p.DownTo(ctx, 0)
|
res, err = p.DownTo(ctx, 0)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(res), numCount)
|
require.Equal(t, len(res), numCount)
|
||||||
assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
assertResult(t, res[0], newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), "down", true)
|
||||||
assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
assertResult(t, res[1], newSource(goose.TypeSQL, "00006_empty_up.sql", 6), "down", true)
|
||||||
assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
assertResult(t, res[2], newSource(goose.TypeSQL, "00005_posts_view.sql", 5), "down", false)
|
||||||
|
@ -107,13 +107,13 @@ func TestProviderRun(t *testing.T) {
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, res != nil, true)
|
require.True(t, res != nil)
|
||||||
check.Number(t, res.Source.Version, int64(counter))
|
require.Equal(t, res.Source.Version, int64(counter))
|
||||||
}
|
}
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, int64(maxVersion))
|
require.Equal(t, currentVersion, int64(maxVersion))
|
||||||
// Reset counter
|
// Reset counter
|
||||||
counter = 0
|
counter = 0
|
||||||
// Rollback all migrations one-by-one.
|
// Rollback all migrations one-by-one.
|
||||||
|
@ -126,14 +126,14 @@ func TestProviderRun(t *testing.T) {
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, res != nil, true)
|
require.True(t, res != nil)
|
||||||
check.Number(t, res.Source.Version, int64(maxVersion-counter+1))
|
require.Equal(t, res.Source.Version, int64(maxVersion-counter+1))
|
||||||
}
|
}
|
||||||
// Once everything is tested the version should match the highest testdata version
|
// Once everything is tested the version should match the highest testdata version
|
||||||
currentVersion, err = p.GetDBVersion(ctx)
|
currentVersion, err = p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, 0)
|
require.EqualValues(t, currentVersion, 0)
|
||||||
})
|
})
|
||||||
t.Run("up_to", func(t *testing.T) {
|
t.Run("up_to", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -142,18 +142,18 @@ func TestProviderRun(t *testing.T) {
|
||||||
upToVersion int64 = 2
|
upToVersion int64 = 2
|
||||||
)
|
)
|
||||||
results, err := p.UpTo(ctx, upToVersion)
|
results, err := p.UpTo(ctx, upToVersion)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(results), upToVersion)
|
require.EqualValues(t, len(results), upToVersion)
|
||||||
assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
assertResult(t, results[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||||
assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
assertResult(t, results[1], newSource(goose.TypeSQL, "00002_posts_table.sql", 2), "up", false)
|
||||||
// Fetch the goose version from DB
|
// Fetch the goose version from DB
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, upToVersion)
|
require.Equal(t, currentVersion, upToVersion)
|
||||||
// Validate the version actually matches what goose claims it is
|
// Validate the version actually matches what goose claims it is
|
||||||
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, gotVersion, upToVersion)
|
require.Equal(t, gotVersion, upToVersion)
|
||||||
})
|
})
|
||||||
t.Run("sql_connections", func(t *testing.T) {
|
t.Run("sql_connections", func(t *testing.T) {
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
|
@ -177,26 +177,26 @@ func TestProviderRun(t *testing.T) {
|
||||||
db.SetMaxIdleConns(tc.maxIdleConns)
|
db.SetMaxIdleConns(tc.maxIdleConns)
|
||||||
}
|
}
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.NumberNotZero(t, len(sources))
|
require.NotZero(t, len(sources))
|
||||||
|
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, 0)
|
require.EqualValues(t, currentVersion, 0)
|
||||||
|
|
||||||
{
|
{
|
||||||
// Apply all up migrations
|
// Apply all up migrations
|
||||||
upResult, err := p.Up(ctx)
|
upResult, err := p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(upResult), len(sources))
|
require.Equal(t, len(upResult), len(sources))
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version)
|
require.Equal(t, currentVersion, p.ListSources()[len(sources)-1].Version)
|
||||||
// Validate the db migration version actually matches what goose claims it is
|
// Validate the db migration version actually matches what goose claims it is
|
||||||
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, gotVersion, currentVersion)
|
require.Equal(t, gotVersion, currentVersion)
|
||||||
tables, err := getTableNames(db)
|
tables, err := getTableNames(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if !reflect.DeepEqual(tables, knownTables) {
|
if !reflect.DeepEqual(tables, knownTables) {
|
||||||
t.Logf("got tables: %v", tables)
|
t.Logf("got tables: %v", tables)
|
||||||
t.Logf("known tables: %v", knownTables)
|
t.Logf("known tables: %v", knownTables)
|
||||||
|
@ -206,14 +206,14 @@ func TestProviderRun(t *testing.T) {
|
||||||
{
|
{
|
||||||
// Apply all down migrations
|
// Apply all down migrations
|
||||||
downResult, err := p.DownTo(ctx, 0)
|
downResult, err := p.DownTo(ctx, 0)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(downResult), len(sources))
|
require.Equal(t, len(downResult), len(sources))
|
||||||
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
gotVersion, err := getMaxVersionID(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, gotVersion, 0)
|
require.EqualValues(t, gotVersion, 0)
|
||||||
// Should only be left with a single table, the default goose table
|
// Should only be left with a single table, the default goose table
|
||||||
tables, err := getTableNames(db)
|
tables, err := getTableNames(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
knownTables := []string{goose.DefaultTablename, "sqlite_sequence"}
|
knownTables := []string{goose.DefaultTablename, "sqlite_sequence"}
|
||||||
if !reflect.DeepEqual(tables, knownTables) {
|
if !reflect.DeepEqual(tables, knownTables) {
|
||||||
t.Logf("got tables: %v", tables)
|
t.Logf("got tables: %v", tables)
|
||||||
|
@ -231,7 +231,7 @@ func TestProviderRun(t *testing.T) {
|
||||||
// Apply all migrations in the up direction.
|
// Apply all migrations in the up direction.
|
||||||
for _, s := range sources {
|
for _, s := range sources {
|
||||||
res, err := p.ApplyVersion(ctx, s.Version, true)
|
res, err := p.ApplyVersion(ctx, s.Version, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Round-trip the migration result through the database to ensure it's valid.
|
// Round-trip the migration result through the database to ensure it's valid.
|
||||||
var empty bool
|
var empty bool
|
||||||
if s.Version == 6 || s.Version == 7 {
|
if s.Version == 6 || s.Version == 7 {
|
||||||
|
@ -243,7 +243,7 @@ func TestProviderRun(t *testing.T) {
|
||||||
for i := len(sources) - 1; i >= 0; i-- {
|
for i := len(sources) - 1; i >= 0; i-- {
|
||||||
s := sources[i]
|
s := sources[i]
|
||||||
res, err := p.ApplyVersion(ctx, s.Version, false)
|
res, err := p.ApplyVersion(ctx, s.Version, false)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Round-trip the migration result through the database to ensure it's valid.
|
// Round-trip the migration result through the database to ensure it's valid.
|
||||||
var empty bool
|
var empty bool
|
||||||
if s.Version == 6 || s.Version == 7 {
|
if s.Version == 6 || s.Version == 7 {
|
||||||
|
@ -253,11 +253,11 @@ func TestProviderRun(t *testing.T) {
|
||||||
}
|
}
|
||||||
// Try apply version 1 multiple times
|
// Try apply version 1 multiple times
|
||||||
_, err := p.ApplyVersion(ctx, 1, true)
|
_, err := p.ApplyVersion(ctx, 1, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.ApplyVersion(ctx, 1, true)
|
_, err = p.ApplyVersion(ctx, 1, true)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, goose.ErrAlreadyApplied), true)
|
require.True(t, errors.Is(err, goose.ErrAlreadyApplied))
|
||||||
check.Contains(t, err.Error(), "version 1: migration already applied")
|
require.Contains(t, err.Error(), "version 1: migration already applied")
|
||||||
})
|
})
|
||||||
t.Run("status", func(t *testing.T) {
|
t.Run("status", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -265,8 +265,8 @@ func TestProviderRun(t *testing.T) {
|
||||||
numCount := len(p.ListSources())
|
numCount := len(p.ListSources())
|
||||||
// Before any migrations are applied, the status should be empty.
|
// Before any migrations are applied, the status should be empty.
|
||||||
status, err := p.Status(ctx)
|
status, err := p.Status(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(status), numCount)
|
require.Equal(t, len(status), numCount)
|
||||||
assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true)
|
assertStatus(t, status[0], goose.StatePending, newSource(goose.TypeSQL, "00001_users_table.sql", 1), true)
|
||||||
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true)
|
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), true)
|
||||||
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true)
|
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), true)
|
||||||
|
@ -276,10 +276,10 @@ func TestProviderRun(t *testing.T) {
|
||||||
assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true)
|
assertStatus(t, status[6], goose.StatePending, newSource(goose.TypeSQL, "00007_empty_up_down.sql", 7), true)
|
||||||
// Apply all migrations
|
// Apply all migrations
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
status, err = p.Status(ctx)
|
status, err = p.Status(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(status), numCount)
|
require.Equal(t, len(status), numCount)
|
||||||
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
||||||
assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false)
|
assertStatus(t, status[1], goose.StateApplied, newSource(goose.TypeSQL, "00002_posts_table.sql", 2), false)
|
||||||
assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false)
|
assertStatus(t, status[2], goose.StateApplied, newSource(goose.TypeSQL, "00003_comments_table.sql", 3), false)
|
||||||
|
@ -317,35 +317,35 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
|
||||||
`),
|
`),
|
||||||
}
|
}
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, mapFS)
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, mapFS)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "partial migration error (type:sql,version:2)")
|
require.Contains(t, err.Error(), "partial migration error (type:sql,version:2)")
|
||||||
var expected *goose.PartialError
|
var expected *goose.PartialError
|
||||||
check.Bool(t, errors.As(err, &expected), true)
|
require.True(t, errors.As(err, &expected))
|
||||||
// Check Err field
|
// Check Err field
|
||||||
check.Bool(t, expected.Err != nil, true)
|
require.True(t, expected.Err != nil)
|
||||||
check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
|
require.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||||
// Check Results field
|
// Check Results field
|
||||||
check.Number(t, len(expected.Applied), 1)
|
require.Equal(t, len(expected.Applied), 1)
|
||||||
assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
assertResult(t, expected.Applied[0], newSource(goose.TypeSQL, "00001_users_table.sql", 1), "up", false)
|
||||||
// Check Failed field
|
// Check Failed field
|
||||||
check.Bool(t, expected.Failed != nil, true)
|
require.True(t, expected.Failed != nil)
|
||||||
assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2)
|
assertSource(t, expected.Failed.Source, goose.TypeSQL, "00002_partial_error.sql", 2)
|
||||||
check.Bool(t, expected.Failed.Empty, false)
|
require.False(t, expected.Failed.Empty)
|
||||||
check.Bool(t, expected.Failed.Error != nil, true)
|
require.True(t, expected.Failed.Error != nil)
|
||||||
check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)")
|
require.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)")
|
||||||
check.Equal(t, expected.Failed.Direction, "up")
|
require.Equal(t, expected.Failed.Direction, "up")
|
||||||
check.Bool(t, expected.Failed.Duration > 0, true)
|
require.True(t, expected.Failed.Duration > 0)
|
||||||
|
|
||||||
// Ensure the partial error did not affect the database.
|
// Ensure the partial error did not affect the database.
|
||||||
count, err := countOwners(db)
|
count, err := countOwners(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, count, 0)
|
require.Equal(t, count, 0)
|
||||||
|
|
||||||
status, err := p.Status(ctx)
|
status, err := p.Status(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(status), 3)
|
require.Equal(t, len(status), 3)
|
||||||
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
assertStatus(t, status[0], goose.StateApplied, newSource(goose.TypeSQL, "00001_users_table.sql", 1), false)
|
||||||
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true)
|
assertStatus(t, status[1], goose.StatePending, newSource(goose.TypeSQL, "00002_partial_error.sql", 2), true)
|
||||||
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true)
|
assertStatus(t, status[2], goose.StatePending, newSource(goose.TypeSQL, "00003_insert_data.sql", 3), true)
|
||||||
|
@ -391,13 +391,13 @@ func TestConcurrentProvider(t *testing.T) {
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
check.Number(t, len(versions), maxVersion)
|
require.Equal(t, len(versions), maxVersion)
|
||||||
for i := 0; i < maxVersion; i++ {
|
for i := 0; i < maxVersion; i++ {
|
||||||
check.Number(t, versions[i], int64(i+1))
|
require.Equal(t, versions[i], int64(i+1))
|
||||||
}
|
}
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, maxVersion)
|
require.EqualValues(t, currentVersion, maxVersion)
|
||||||
})
|
})
|
||||||
t.Run("down", func(t *testing.T) {
|
t.Run("down", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -405,10 +405,10 @@ func TestConcurrentProvider(t *testing.T) {
|
||||||
maxVersion := len(p.ListSources())
|
maxVersion := len(p.ListSources())
|
||||||
// Apply all migrations
|
// Apply all migrations
|
||||||
_, err := p.Up(ctx)
|
_, err := p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, maxVersion)
|
require.EqualValues(t, currentVersion, maxVersion)
|
||||||
|
|
||||||
ch := make(chan []*goose.MigrationResult)
|
ch := make(chan []*goose.MigrationResult)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -444,10 +444,10 @@ func TestConcurrentProvider(t *testing.T) {
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
check.Equal(t, len(valid), 1)
|
require.Equal(t, len(valid), 1)
|
||||||
check.Equal(t, len(empty), maxVersion-1)
|
require.Equal(t, len(empty), maxVersion-1)
|
||||||
// Ensure the valid result is correct.
|
// Ensure the valid result is correct.
|
||||||
check.Number(t, len(valid[0]), maxVersion)
|
require.Equal(t, len(valid[0]), maxVersion)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -473,7 +473,7 @@ func TestNoVersioning(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8))
|
dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8))
|
||||||
db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName))
|
db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "migrations"))
|
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "migrations"))
|
||||||
const (
|
const (
|
||||||
// Total owners created by the seed files.
|
// Total owners created by the seed files.
|
||||||
|
@ -485,50 +485,50 @@ func TestNoVersioning(t *testing.T) {
|
||||||
goose.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
goose.WithDisableVersioning(false), // This is the default.
|
goose.WithDisableVersioning(false), // This is the default.
|
||||||
)
|
)
|
||||||
check.Number(t, len(p.ListSources()), 3)
|
require.Equal(t, len(p.ListSources()), 3)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
baseVersion, err := p.GetDBVersion(ctx)
|
baseVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, baseVersion, 3)
|
require.EqualValues(t, baseVersion, 3)
|
||||||
t.Run("seed-up-down-to-zero", func(t *testing.T) {
|
t.Run("seed-up-down-to-zero", func(t *testing.T) {
|
||||||
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
|
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
||||||
goose.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
goose.WithDisableVersioning(true), // Provider with no versioning.
|
goose.WithDisableVersioning(true), // Provider with no versioning.
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(p.ListSources()), 2)
|
require.Equal(t, len(p.ListSources()), 2)
|
||||||
|
|
||||||
// Run (all) up migrations from the seed dir
|
// Run (all) up migrations from the seed dir
|
||||||
{
|
{
|
||||||
upResult, err := p.Up(ctx)
|
upResult, err := p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(upResult), 2)
|
require.Equal(t, len(upResult), 2)
|
||||||
// When versioning is disabled, we cannot track the version of the seed files.
|
// When versioning is disabled, we cannot track the version of the seed files.
|
||||||
_, err = p.GetDBVersion(ctx)
|
_, err = p.GetDBVersion(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
seedOwnerCount, err := countSeedOwners(db)
|
seedOwnerCount, err := countSeedOwners(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, seedOwnerCount, wantSeedOwnerCount)
|
require.Equal(t, seedOwnerCount, wantSeedOwnerCount)
|
||||||
}
|
}
|
||||||
// Run (all) down migrations from the seed dir
|
// Run (all) down migrations from the seed dir
|
||||||
{
|
{
|
||||||
downResult, err := p.DownTo(ctx, 0)
|
downResult, err := p.DownTo(ctx, 0)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(downResult), 2)
|
require.Equal(t, len(downResult), 2)
|
||||||
// When versioning is disabled, we cannot track the version of the seed files.
|
// When versioning is disabled, we cannot track the version of the seed files.
|
||||||
_, err = p.GetDBVersion(ctx)
|
_, err = p.GetDBVersion(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
seedOwnerCount, err := countSeedOwners(db)
|
seedOwnerCount, err := countSeedOwners(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, seedOwnerCount, 0)
|
require.Equal(t, seedOwnerCount, 0)
|
||||||
}
|
}
|
||||||
// The migrations added 4 non-seed owners, they must remain in the database afterwards
|
// The migrations added 4 non-seed owners, they must remain in the database afterwards
|
||||||
ownerCount, err := countOwners(db)
|
ownerCount, err := countOwners(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, ownerCount, wantOwnerCount)
|
require.Equal(t, ownerCount, wantOwnerCount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -548,22 +548,22 @@ func TestAllowMissing(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(),
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(),
|
||||||
goose.WithAllowOutofOrder(false),
|
goose.WithAllowOutofOrder(false),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create and apply first 3 migrations.
|
// Create and apply first 3 migrations.
|
||||||
_, err = p.UpTo(ctx, 3)
|
_, err = p.UpTo(ctx, 3)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, 3)
|
require.EqualValues(t, currentVersion, 3)
|
||||||
|
|
||||||
// Developer A - migration 5 (mistakenly applied)
|
// Developer A - migration 5 (mistakenly applied)
|
||||||
result, err := p.ApplyVersion(ctx, 5, true)
|
result, err := p.ApplyVersion(ctx, 5, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, result.Source.Version, 5)
|
require.EqualValues(t, result.Source.Version, 5)
|
||||||
current, err := p.GetDBVersion(ctx)
|
current, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 5)
|
require.EqualValues(t, current, 5)
|
||||||
|
|
||||||
// The database has migrations 1,2,3,5 applied.
|
// The database has migrations 1,2,3,5 applied.
|
||||||
|
|
||||||
|
@ -571,31 +571,31 @@ func TestAllowMissing(t *testing.T) {
|
||||||
// default goose does not allow missing (out-of-order) migrations, which means halt if a
|
// default goose does not allow missing (out-of-order) migrations, which means halt if a
|
||||||
// missing migration is detected.
|
// missing migration is detected.
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// found 1 missing (out-of-order) migration: [00004_insert_data.sql]
|
// found 1 missing (out-of-order) migration: [00004_insert_data.sql]
|
||||||
check.Contains(t, err.Error(), "missing (out-of-order) migration")
|
require.Contains(t, err.Error(), "missing (out-of-order) migration")
|
||||||
// Confirm db version is unchanged.
|
// Confirm db version is unchanged.
|
||||||
current, err = p.GetDBVersion(ctx)
|
current, err = p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 5)
|
require.EqualValues(t, current, 5)
|
||||||
|
|
||||||
_, err = p.UpByOne(ctx)
|
_, err = p.UpByOne(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// found 1 missing (out-of-order) migration: [00004_insert_data.sql]
|
// found 1 missing (out-of-order) migration: [00004_insert_data.sql]
|
||||||
check.Contains(t, err.Error(), "missing (out-of-order) migration")
|
require.Contains(t, err.Error(), "missing (out-of-order) migration")
|
||||||
// Confirm db version is unchanged.
|
// Confirm db version is unchanged.
|
||||||
current, err = p.GetDBVersion(ctx)
|
current, err = p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 5)
|
require.EqualValues(t, current, 5)
|
||||||
|
|
||||||
_, err = p.UpTo(ctx, math.MaxInt64)
|
_, err = p.UpTo(ctx, math.MaxInt64)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
// found 1 missing (out-of-order) migration: [00004_insert_data.sql]
|
// found 1 missing (out-of-order) migration: [00004_insert_data.sql]
|
||||||
check.Contains(t, err.Error(), "missing (out-of-order) migration")
|
require.Contains(t, err.Error(), "missing (out-of-order) migration")
|
||||||
// Confirm db version is unchanged.
|
// Confirm db version is unchanged.
|
||||||
current, err = p.GetDBVersion(ctx)
|
current, err = p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 5)
|
require.EqualValues(t, current, 5)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("missing_allowed", func(t *testing.T) {
|
t.Run("missing_allowed", func(t *testing.T) {
|
||||||
|
@ -603,43 +603,43 @@ func TestAllowMissing(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(),
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(),
|
||||||
goose.WithAllowOutofOrder(true),
|
goose.WithAllowOutofOrder(true),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create and apply first 3 migrations.
|
// Create and apply first 3 migrations.
|
||||||
_, err = p.UpTo(ctx, 3)
|
_, err = p.UpTo(ctx, 3)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, 3)
|
require.EqualValues(t, currentVersion, 3)
|
||||||
|
|
||||||
// Developer A - migration 5 (mistakenly applied)
|
// Developer A - migration 5 (mistakenly applied)
|
||||||
{
|
{
|
||||||
_, err = p.ApplyVersion(ctx, 5, true)
|
_, err = p.ApplyVersion(ctx, 5, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
current, err := p.GetDBVersion(ctx)
|
current, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 5)
|
require.EqualValues(t, current, 5)
|
||||||
}
|
}
|
||||||
// Developer B - migration 4 (missing) and 6 (new)
|
// Developer B - migration 4 (missing) and 6 (new)
|
||||||
{
|
{
|
||||||
// 4
|
// 4
|
||||||
upResult, err := p.UpByOne(ctx)
|
upResult, err := p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, upResult != nil, true)
|
require.True(t, upResult != nil)
|
||||||
check.Number(t, upResult.Source.Version, 4)
|
require.EqualValues(t, upResult.Source.Version, 4)
|
||||||
// 6
|
// 6
|
||||||
upResult, err = p.UpByOne(ctx)
|
upResult, err = p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, upResult != nil, true)
|
require.True(t, upResult != nil)
|
||||||
check.Number(t, upResult.Source.Version, 6)
|
require.EqualValues(t, upResult.Source.Version, 6)
|
||||||
|
|
||||||
count, err := getGooseVersionCount(db, goose.DefaultTablename)
|
count, err := getGooseVersionCount(db, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, count, 6)
|
require.EqualValues(t, count, 6)
|
||||||
current, err := p.GetDBVersion(ctx)
|
current, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Expecting max(version_id) to be 8
|
// Expecting max(version_id) to be 8
|
||||||
check.Number(t, current, 6)
|
require.EqualValues(t, current, 6)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The applied order in the database is expected to be:
|
// The applied order in the database is expected to be:
|
||||||
|
@ -649,12 +649,12 @@ func TestAllowMissing(t *testing.T) {
|
||||||
|
|
||||||
testDownAndVersion := func(wantDBVersion, wantResultVersion int64) {
|
testDownAndVersion := func(wantDBVersion, wantResultVersion int64) {
|
||||||
currentVersion, err := p.GetDBVersion(ctx)
|
currentVersion, err := p.GetDBVersion(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, currentVersion, wantDBVersion)
|
require.Equal(t, currentVersion, wantDBVersion)
|
||||||
downRes, err := p.Down(ctx)
|
downRes, err := p.Down(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, downRes != nil, true)
|
require.True(t, downRes != nil)
|
||||||
check.Number(t, downRes.Source.Version, wantResultVersion)
|
require.Equal(t, downRes.Source.Version, wantResultVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This behaviour may need to change, see the following issues for more details:
|
// This behaviour may need to change, see the following issues for more details:
|
||||||
|
@ -668,8 +668,8 @@ func TestAllowMissing(t *testing.T) {
|
||||||
testDownAndVersion(2, 2)
|
testDownAndVersion(2, 2)
|
||||||
testDownAndVersion(1, 1)
|
testDownAndVersion(1, 1)
|
||||||
_, err = p.Down(ctx)
|
_, err = p.Down(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, goose.ErrNoNextVersion), true)
|
require.True(t, errors.Is(err, goose.ErrNoNextVersion))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -690,30 +690,30 @@ func TestSQLiteSharedCache(t *testing.T) {
|
||||||
// database connections as follows: file::memory:?cache=shared"
|
// database connections as follows: file::memory:?cache=shared"
|
||||||
t.Run("shared_cache", func(t *testing.T) {
|
t.Run("shared_cache", func(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
db, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
fsys := fstest.MapFS{"00001_a.sql": newMapFile(`-- +goose Up`)}
|
fsys := fstest.MapFS{"00001_a.sql": newMapFile(`-- +goose Up`)}
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
||||||
goose.WithGoMigrations(
|
goose.WithGoMigrations(
|
||||||
goose.NewGoMigration(2, &goose.GoFunc{Mode: goose.TransactionDisabled}, nil),
|
goose.NewGoMigration(2, &goose.GoFunc{Mode: goose.TransactionDisabled}, nil),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(context.Background())
|
_, err = p.Up(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
t.Run("no_shared_cache", func(t *testing.T) {
|
t.Run("no_shared_cache", func(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite", "file::memory:")
|
db, err := sql.Open("sqlite", "file::memory:")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
fsys := fstest.MapFS{"00001_a.sql": newMapFile(`-- +goose Up`)}
|
fsys := fstest.MapFS{"00001_a.sql": newMapFile(`-- +goose Up`)}
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
|
||||||
goose.WithGoMigrations(
|
goose.WithGoMigrations(
|
||||||
goose.NewGoMigration(2, &goose.GoFunc{Mode: goose.TransactionDisabled}, nil),
|
goose.NewGoMigration(2, &goose.GoFunc{Mode: goose.TransactionDisabled}, nil),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(context.Background())
|
_, err = p.Up(context.Background())
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "SQL logic error: no such table: goose_db_version")
|
require.Contains(t, err.Error(), "SQL logic error: no such table: goose_db_version")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -736,26 +736,26 @@ func TestGoMigrationPanic(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), nil,
|
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), nil,
|
||||||
goose.WithGoMigrations(migration), // Add a Go migration that panics.
|
goose.WithGoMigrations(migration), // Add a Go migration that panics.
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), wantErrString)
|
require.Contains(t, err.Error(), wantErrString)
|
||||||
var expected *goose.PartialError
|
var expected *goose.PartialError
|
||||||
check.Bool(t, errors.As(err, &expected), true)
|
require.True(t, errors.As(err, &expected))
|
||||||
check.Contains(t, expected.Err.Error(), wantErrString)
|
require.Contains(t, expected.Err.Error(), wantErrString)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomStoreTableExists(t *testing.T) {
|
func TestCustomStoreTableExists(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
|
store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p, err := goose.NewProvider("", newDB(t), newFsys(),
|
p, err := goose.NewProvider("", newDB(t), newFsys(),
|
||||||
goose.WithStore(&customStoreSQLite3{store}),
|
goose.WithStore(&customStoreSQLite3{store}),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.Up(context.Background())
|
_, err = p.Up(context.Background())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProviderApply(t *testing.T) {
|
func TestProviderApply(t *testing.T) {
|
||||||
|
@ -763,13 +763,13 @@ func TestProviderApply(t *testing.T) {
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys())
|
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys())
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.ApplyVersion(ctx, 1, true)
|
_, err = p.ApplyVersion(ctx, 1, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// This version has a corresponding down migration, but has never been applied.
|
// This version has a corresponding down migration, but has never been applied.
|
||||||
_, err = p.ApplyVersion(ctx, 2, false)
|
_, err = p.ApplyVersion(ctx, 2, false)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, goose.ErrNotApplied), true)
|
require.True(t, errors.Is(err, goose.ErrNotApplied))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPending(t *testing.T) {
|
func TestPending(t *testing.T) {
|
||||||
|
@ -780,31 +780,31 @@ func TestPending(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys,
|
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys,
|
||||||
goose.WithAllowOutofOrder(true),
|
goose.WithAllowOutofOrder(true),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Some migrations have been applied out of order.
|
// Some migrations have been applied out of order.
|
||||||
_, err = p.ApplyVersion(ctx, 1, true)
|
_, err = p.ApplyVersion(ctx, 1, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.ApplyVersion(ctx, 3, true)
|
_, err = p.ApplyVersion(ctx, 3, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Even though the latest migration HAS been applied, there are still pending out-of-order
|
// Even though the latest migration HAS been applied, there are still pending out-of-order
|
||||||
// migrations.
|
// migrations.
|
||||||
current, target, err := p.GetVersions(ctx)
|
current, target, err := p.GetVersions(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 3)
|
require.EqualValues(t, current, 3)
|
||||||
check.Number(t, target, len(fsys))
|
require.EqualValues(t, target, len(fsys))
|
||||||
hasPending, err := p.HasPending(ctx)
|
hasPending, err := p.HasPending(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, hasPending, true)
|
require.True(t, hasPending)
|
||||||
// Apply the missing migrations.
|
// Apply the missing migrations.
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// All migrations have been applied.
|
// All migrations have been applied.
|
||||||
hasPending, err = p.HasPending(ctx)
|
hasPending, err = p.HasPending(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Bool(t, hasPending, false)
|
require.False(t, hasPending)
|
||||||
current, target, err = p.GetVersions(ctx)
|
current, target, err = p.GetVersions(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, target)
|
require.Equal(t, current, target)
|
||||||
})
|
})
|
||||||
t.Run("disallow_out_of_order", func(t *testing.T) {
|
t.Run("disallow_out_of_order", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -814,24 +814,24 @@ func TestPending(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys,
|
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys,
|
||||||
goose.WithAllowOutofOrder(false),
|
goose.WithAllowOutofOrder(false),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Some migrations have been applied.
|
// Some migrations have been applied.
|
||||||
_, err = p.ApplyVersion(ctx, 1, true)
|
_, err = p.ApplyVersion(ctx, 1, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = p.ApplyVersion(ctx, versionToApply, true)
|
_, err = p.ApplyVersion(ctx, versionToApply, true)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// TODO(mf): revisit the pending check behavior in addition to the HasPending
|
// TODO(mf): revisit the pending check behavior in addition to the HasPending
|
||||||
// method.
|
// method.
|
||||||
current, target, err := p.GetVersions(ctx)
|
current, target, err := p.GetVersions(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, versionToApply)
|
require.Equal(t, current, versionToApply)
|
||||||
check.Number(t, target, len(fsys))
|
require.EqualValues(t, target, len(fsys))
|
||||||
_, err = p.HasPending(ctx)
|
_, err = p.HasPending(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "missing (out-of-order) migration")
|
require.Contains(t, err.Error(), "missing (out-of-order) migration")
|
||||||
_, err = p.Up(ctx)
|
_, err = p.Up(ctx)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "missing (out-of-order) migration")
|
require.Contains(t, err.Error(), "missing (out-of-order) migration")
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("latest_version", func(t *testing.T) {
|
t.Run("latest_version", func(t *testing.T) {
|
||||||
|
@ -874,7 +874,7 @@ func TestGoOnly(t *testing.T) {
|
||||||
q := `SELECT count(*)FROM users`
|
q := `SELECT count(*)FROM users`
|
||||||
var count int
|
var count int
|
||||||
err := db.QueryRow(q).Scan(&count)
|
err := db.QueryRow(q).Scan(&count)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -888,7 +888,7 @@ func TestGoOnly(t *testing.T) {
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
err := goose.SetGlobalMigrations(register...)
|
err := goose.SetGlobalMigrations(register...)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(goose.ResetGlobalMigrations)
|
t.Cleanup(goose.ResetGlobalMigrations)
|
||||||
|
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
|
@ -902,33 +902,33 @@ func TestGoOnly(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, nil,
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, nil,
|
||||||
goose.WithGoMigrations(register...),
|
goose.WithGoMigrations(register...),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.Number(t, len(p.ListSources()), 2)
|
require.Equal(t, len(p.ListSources()), 2)
|
||||||
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
||||||
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
||||||
// Apply migration 1
|
// Apply migration 1
|
||||||
res, err := p.UpByOne(ctx)
|
res, err := p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
||||||
check.Number(t, countUser(db), 0)
|
require.Equal(t, countUser(db), 0)
|
||||||
check.Bool(t, tableExists(t, db, "users"), true)
|
require.True(t, tableExists(t, db, "users"))
|
||||||
// Apply migration 2
|
// Apply migration 2
|
||||||
res, err = p.UpByOne(ctx)
|
res, err = p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
||||||
check.Number(t, countUser(db), 3)
|
require.Equal(t, countUser(db), 3)
|
||||||
// Rollback migration 2
|
// Rollback migration 2
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
||||||
check.Number(t, countUser(db), 0)
|
require.Equal(t, countUser(db), 0)
|
||||||
// Rollback migration 1
|
// Rollback migration 1
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
||||||
// Check table does not exist
|
// Check table does not exist
|
||||||
check.Bool(t, tableExists(t, db, "users"), false)
|
require.False(t, tableExists(t, db, "users"))
|
||||||
})
|
})
|
||||||
t.Run("with_db", func(t *testing.T) {
|
t.Run("with_db", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -944,7 +944,7 @@ func TestGoOnly(t *testing.T) {
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
err := goose.SetGlobalMigrations(register...)
|
err := goose.SetGlobalMigrations(register...)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(goose.ResetGlobalMigrations)
|
t.Cleanup(goose.ResetGlobalMigrations)
|
||||||
|
|
||||||
db := newDB(t)
|
db := newDB(t)
|
||||||
|
@ -958,33 +958,33 @@ func TestGoOnly(t *testing.T) {
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, nil,
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, nil,
|
||||||
goose.WithGoMigrations(register...),
|
goose.WithGoMigrations(register...),
|
||||||
)
|
)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.Number(t, len(p.ListSources()), 2)
|
require.Equal(t, len(p.ListSources()), 2)
|
||||||
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
assertSource(t, sources[0], goose.TypeGo, "", 1)
|
||||||
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
assertSource(t, sources[1], goose.TypeGo, "", 2)
|
||||||
// Apply migration 1
|
// Apply migration 1
|
||||||
res, err := p.UpByOne(ctx)
|
res, err := p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "up", false)
|
||||||
check.Number(t, countUser(db), 0)
|
require.Equal(t, countUser(db), 0)
|
||||||
check.Bool(t, tableExists(t, db, "users"), true)
|
require.True(t, tableExists(t, db, "users"))
|
||||||
// Apply migration 2
|
// Apply migration 2
|
||||||
res, err = p.UpByOne(ctx)
|
res, err = p.UpByOne(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "up", false)
|
||||||
check.Number(t, countUser(db), 3)
|
require.Equal(t, countUser(db), 3)
|
||||||
// Rollback migration 2
|
// Rollback migration 2
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 2), "down", false)
|
||||||
check.Number(t, countUser(db), 0)
|
require.Equal(t, countUser(db), 0)
|
||||||
// Rollback migration 1
|
// Rollback migration 1
|
||||||
res, err = p.Down(ctx)
|
res, err = p.Down(ctx)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
assertResult(t, res, newSource(goose.TypeGo, "", 1), "down", false)
|
||||||
// Check table does not exist
|
// Check table does not exist
|
||||||
check.Bool(t, tableExists(t, db, "users"), false)
|
require.False(t, tableExists(t, db, "users"))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1006,7 +1006,7 @@ func tableExists(t *testing.T, db *sql.DB, table string) bool {
|
||||||
q := fmt.Sprintf(`SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS table_exists FROM sqlite_master WHERE type = 'table' AND name = '%s'`, table)
|
q := fmt.Sprintf(`SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS table_exists FROM sqlite_master WHERE type = 'table' AND name = '%s'`, table)
|
||||||
var b string
|
var b string
|
||||||
err := db.QueryRow(q).Scan(&b)
|
err := db.QueryRow(q).Scan(&b)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return b == "1"
|
return b == "1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1030,7 +1030,7 @@ func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provi
|
||||||
goose.WithVerbose(testing.Verbose()),
|
goose.WithVerbose(testing.Verbose()),
|
||||||
)
|
)
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(), opts...)
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(), opts...)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return p, db
|
return p, db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1038,7 +1038,7 @@ func newDB(t *testing.T) *sql.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8))
|
dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8))
|
||||||
db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName))
|
db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1074,26 +1074,26 @@ func getTableNames(db *sql.DB) ([]string, error) {
|
||||||
|
|
||||||
func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) {
|
func assertStatus(t *testing.T, got *goose.MigrationStatus, state goose.State, source *goose.Source, appliedIsZero bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.State, state)
|
require.Equal(t, got.State, state)
|
||||||
check.Equal(t, got.Source, source)
|
require.Equal(t, got.Source, source)
|
||||||
check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero)
|
require.Equal(t, got.AppliedAt.IsZero(), appliedIsZero)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) {
|
func assertResult(t *testing.T, got *goose.MigrationResult, source *goose.Source, direction string, isEmpty bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Bool(t, got != nil, true)
|
require.True(t, got != nil)
|
||||||
check.Equal(t, got.Source, source)
|
require.Equal(t, got.Source, source)
|
||||||
check.Equal(t, got.Direction, direction)
|
require.Equal(t, got.Direction, direction)
|
||||||
check.Equal(t, got.Empty, isEmpty)
|
require.Equal(t, got.Empty, isEmpty)
|
||||||
check.Bool(t, got.Error == nil, true)
|
require.Nil(t, got.Error)
|
||||||
check.Bool(t, got.Duration > 0, true)
|
require.True(t, got.Duration > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) {
|
func assertSource(t *testing.T, got *goose.Source, typ goose.MigrationType, name string, version int64) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.Type, typ)
|
require.Equal(t, got.Type, typ)
|
||||||
check.Equal(t, got.Path, name)
|
require.Equal(t, got.Path, name)
|
||||||
check.Equal(t, got.Version, version)
|
require.Equal(t, got.Version, version)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source {
|
func newSource(t goose.MigrationType, fullpath string, version int64) *goose.Source {
|
||||||
|
|
|
@ -9,18 +9,18 @@ import (
|
||||||
"testing/fstest"
|
"testing/fstest"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProvider(t *testing.T) {
|
func TestProvider(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Run("empty", func(t *testing.T) {
|
t.Run("empty", func(t *testing.T) {
|
||||||
_, err := goose.NewProvider(goose.DialectSQLite3, db, fstest.MapFS{})
|
_, err := goose.NewProvider(goose.DialectSQLite3, db, fstest.MapFS{})
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true)
|
require.True(t, errors.Is(err, goose.ErrNoMigrations))
|
||||||
})
|
})
|
||||||
|
|
||||||
mapFS := fstest.MapFS{
|
mapFS := fstest.MapFS{
|
||||||
|
@ -28,13 +28,13 @@ func TestProvider(t *testing.T) {
|
||||||
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
"migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)},
|
||||||
}
|
}
|
||||||
fsys, err := fs.Sub(mapFS, "migrations")
|
fsys, err := fs.Sub(mapFS, "migrations")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys)
|
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sources := p.ListSources()
|
sources := p.ListSources()
|
||||||
check.Equal(t, len(sources), 2)
|
require.Equal(t, len(sources), 2)
|
||||||
check.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1))
|
require.Equal(t, sources[0], newSource(goose.TypeSQL, "001_foo.sql", 1))
|
||||||
check.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2))
|
require.Equal(t, sources[1], newSource(goose.TypeSQL, "002_bar.sql", 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -78,7 +78,5 @@ ALTER TABLE my_foo RENAME TO foo;
|
||||||
|
|
||||||
func TestPartialErrorUnwrap(t *testing.T) {
|
func TestPartialErrorUnwrap(t *testing.T) {
|
||||||
err := &goose.PartialError{Err: goose.ErrNoCurrentVersion}
|
err := &goose.PartialError{Err: goose.ErrNoCurrentVersion}
|
||||||
|
require.ErrorIs(t, err, goose.ErrNoCurrentVersion)
|
||||||
got := errors.Is(err, goose.ErrNoCurrentVersion)
|
|
||||||
check.Bool(t, got, true)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,68 +6,68 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
|
||||||
_ "github.com/pressly/goose/v3/tests/gomigrations/error/testdata"
|
_ "github.com/pressly/goose/v3/tests/gomigrations/error/testdata"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGoMigrationByOne(t *testing.T) {
|
func TestGoMigrationByOne(t *testing.T) {
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
db, err := sql.Open("sqlite", filepath.Join(tempDir, "test.db"))
|
db, err := sql.Open("sqlite", filepath.Join(tempDir, "test.db"))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = goose.SetDialect(string(goose.DialectSQLite3))
|
err = goose.SetDialect(string(goose.DialectSQLite3))
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Create goose table.
|
// Create goose table.
|
||||||
current, err := goose.EnsureDBVersion(db)
|
current, err := goose.EnsureDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, current, 0)
|
require.Equal(t, current, 0)
|
||||||
// Collect migrations.
|
// Collect migrations.
|
||||||
dir := "testdata"
|
dir := "testdata"
|
||||||
migrations, err := goose.CollectMigrations(dir, 0, goose.MaxVersion)
|
migrations, err := goose.CollectMigrations(dir, 0, goose.MaxVersion)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(migrations), 4)
|
require.Equal(t, len(migrations), 4)
|
||||||
|
|
||||||
// Setup table.
|
// Setup table.
|
||||||
err = migrations[0].Up(db)
|
err = migrations[0].Up(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
version, err := goose.GetDBVersion(db)
|
version, err := goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, version, 1)
|
require.Equal(t, version, 1)
|
||||||
|
|
||||||
// Registered Go migration run outside a goose tx using *sql.DB.
|
// Registered Go migration run outside a goose tx using *sql.DB.
|
||||||
err = migrations[1].Up(db)
|
err = migrations[1].Up(db)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "failed to run go migration")
|
require.Contains(t, err.Error(), "failed to run go migration")
|
||||||
version, err = goose.GetDBVersion(db)
|
version, err = goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, version, 1)
|
require.Equal(t, version, 1)
|
||||||
|
|
||||||
// This migration was inserting 100 rows, but fails at 50, and
|
// This migration was inserting 100 rows, but fails at 50, and
|
||||||
// because it's run outside a goose tx then we expect 50 rows.
|
// because it's run outside a goose tx then we expect 50 rows.
|
||||||
var count int
|
var count int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count)
|
err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, count, 50)
|
require.Equal(t, count, 50)
|
||||||
|
|
||||||
// Truncate table so we have 0 rows.
|
// Truncate table so we have 0 rows.
|
||||||
err = migrations[2].Up(db)
|
err = migrations[2].Up(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
version, err = goose.GetDBVersion(db)
|
version, err = goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// We're at version 3, but keep in mind 2 was never applied because it failed.
|
// We're at version 3, but keep in mind 2 was never applied because it failed.
|
||||||
check.Number(t, version, 3)
|
require.Equal(t, version, 3)
|
||||||
|
|
||||||
// Registered Go migration run within a tx.
|
// Registered Go migration run within a tx.
|
||||||
err = migrations[3].Up(db)
|
err = migrations[3].Up(db)
|
||||||
check.HasError(t, err)
|
require.Error(t, err)
|
||||||
check.Contains(t, err.Error(), "failed to run go migration")
|
require.Contains(t, err.Error(), "failed to run go migration")
|
||||||
version, err = goose.GetDBVersion(db)
|
version, err = goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, version, 3) // This migration failed, so we're still at 3.
|
require.Equal(t, version, 3) // This migration failed, so we're still at 3.
|
||||||
// This migration was inserting 100 rows, but fails at 50. However, since it's
|
// This migration was inserting 100 rows, but fails at 50. However, since it's
|
||||||
// running within a tx we expect none of the inserts to persist.
|
// running within a tx we expect none of the inserts to persist.
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count)
|
err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, count, 0)
|
require.Equal(t, count, 0)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,14 +6,14 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
|
||||||
_ "github.com/pressly/goose/v3/tests/gomigrations/register/testdata"
|
_ "github.com/pressly/goose/v3/tests/gomigrations/register/testdata"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAddFunctions(t *testing.T) {
|
func TestAddFunctions(t *testing.T) {
|
||||||
goMigrations, err := goose.CollectMigrations("testdata", 0, math.MaxInt64)
|
goMigrations, err := goose.CollectMigrations("testdata", 0, math.MaxInt64)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, len(goMigrations), 4)
|
require.Equal(t, len(goMigrations), 4)
|
||||||
|
|
||||||
checkMigration(t, goMigrations[0], &goose.Migration{
|
checkMigration(t, goMigrations[0], &goose.Migration{
|
||||||
Version: 1,
|
Version: 1,
|
||||||
|
@ -51,12 +51,12 @@ func TestAddFunctions(t *testing.T) {
|
||||||
|
|
||||||
func checkMigration(t *testing.T, got *goose.Migration, want *goose.Migration) {
|
func checkMigration(t *testing.T, got *goose.Migration, want *goose.Migration) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
check.Equal(t, got.Version, want.Version)
|
require.Equal(t, got.Version, want.Version)
|
||||||
check.Equal(t, got.Next, want.Next)
|
require.Equal(t, got.Next, want.Next)
|
||||||
check.Equal(t, got.Previous, want.Previous)
|
require.Equal(t, got.Previous, want.Previous)
|
||||||
check.Equal(t, filepath.Base(got.Source), want.Source)
|
require.Equal(t, filepath.Base(got.Source), want.Source)
|
||||||
check.Equal(t, got.Registered, want.Registered)
|
require.Equal(t, got.Registered, want.Registered)
|
||||||
check.Bool(t, got.UseTx, want.UseTx)
|
require.Equal(t, got.UseTx, want.UseTx)
|
||||||
checkFunctions(t, got)
|
checkFunctions(t, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,48 +65,48 @@ func checkFunctions(t *testing.T, m *goose.Migration) {
|
||||||
switch filepath.Base(m.Source) {
|
switch filepath.Base(m.Source) {
|
||||||
case "001_addmigration.go":
|
case "001_addmigration.go":
|
||||||
// With transaction
|
// With transaction
|
||||||
check.Bool(t, m.UpFn == nil, false)
|
require.NotNil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, false)
|
require.NotNil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnContext == nil, false)
|
require.NotNil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.DownFnContext == nil, false)
|
require.NotNil(t, m.DownFnContext)
|
||||||
// No transaction
|
// No transaction
|
||||||
check.Bool(t, m.UpFnNoTx == nil, true)
|
require.Nil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, true)
|
require.Nil(t, m.DownFnNoTx)
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, true)
|
require.Nil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, true)
|
require.Nil(t, m.DownFnNoTxContext)
|
||||||
case "002_addmigrationnotx.go":
|
case "002_addmigrationnotx.go":
|
||||||
// With transaction
|
// With transaction
|
||||||
check.Bool(t, m.UpFn == nil, true)
|
require.Nil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, true)
|
require.Nil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnContext == nil, true)
|
require.Nil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.DownFnContext == nil, true)
|
require.Nil(t, m.DownFnContext)
|
||||||
// No transaction
|
// No transaction
|
||||||
check.Bool(t, m.UpFnNoTx == nil, false)
|
require.NotNil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, false)
|
require.NotNil(t, m.DownFnNoTx)
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, false)
|
require.NotNil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, false)
|
require.NotNil(t, m.DownFnNoTxContext)
|
||||||
case "003_addmigrationcontext.go":
|
case "003_addmigrationcontext.go":
|
||||||
// With transaction
|
// With transaction
|
||||||
check.Bool(t, m.UpFn == nil, false)
|
require.NotNil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, false)
|
require.NotNil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnContext == nil, false)
|
require.NotNil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.DownFnContext == nil, false)
|
require.NotNil(t, m.DownFnContext)
|
||||||
// No transaction
|
// No transaction
|
||||||
check.Bool(t, m.UpFnNoTx == nil, true)
|
require.Nil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, true)
|
require.Nil(t, m.DownFnNoTx)
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, true)
|
require.Nil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, true)
|
require.Nil(t, m.DownFnNoTxContext)
|
||||||
case "004_addmigrationnotxcontext.go":
|
case "004_addmigrationnotxcontext.go":
|
||||||
// With transaction
|
// With transaction
|
||||||
check.Bool(t, m.UpFn == nil, true)
|
require.Nil(t, m.UpFn)
|
||||||
check.Bool(t, m.DownFn == nil, true)
|
require.Nil(t, m.DownFn)
|
||||||
check.Bool(t, m.UpFnContext == nil, true)
|
require.Nil(t, m.UpFnContext)
|
||||||
check.Bool(t, m.DownFnContext == nil, true)
|
require.Nil(t, m.DownFnContext)
|
||||||
// No transaction
|
// No transaction
|
||||||
check.Bool(t, m.UpFnNoTx == nil, false)
|
require.NotNil(t, m.UpFnNoTx)
|
||||||
check.Bool(t, m.DownFnNoTx == nil, false)
|
require.NotNil(t, m.DownFnNoTx)
|
||||||
check.Bool(t, m.UpFnNoTxContext == nil, false)
|
require.NotNil(t, m.UpFnNoTxContext)
|
||||||
check.Bool(t, m.DownFnNoTxContext == nil, false)
|
require.NotNil(t, m.DownFnNoTxContext)
|
||||||
default:
|
default:
|
||||||
t.Fatalf("unexpected migration: %s", filepath.Base(m.Source))
|
t.Fatalf("unexpected migration: %s", filepath.Base(m.Source))
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pressly/goose/v3"
|
"github.com/pressly/goose/v3"
|
||||||
"github.com/pressly/goose/v3/internal/check"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
_ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata"
|
_ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
|
@ -15,39 +15,39 @@ import (
|
||||||
func TestGoMigrationByOne(t *testing.T) {
|
func TestGoMigrationByOne(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
check.NoError(t, goose.SetDialect("sqlite3"))
|
require.NoError(t, goose.SetDialect("sqlite3"))
|
||||||
db, err := sql.Open("sqlite", ":memory:")
|
db, err := sql.Open("sqlite", ":memory:")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dir := "testdata"
|
dir := "testdata"
|
||||||
files, err := filepath.Glob(dir + "/*.go")
|
files, err := filepath.Glob(dir + "/*.go")
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
upByOne := func(t *testing.T) int64 {
|
upByOne := func(t *testing.T) int64 {
|
||||||
err = goose.UpByOne(db, dir)
|
err = goose.UpByOne(db, dir)
|
||||||
t.Logf("err: %v %s", err, dir)
|
t.Logf("err: %v %s", err, dir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
version, err := goose.GetDBVersion(db)
|
version, err := goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return version
|
return version
|
||||||
}
|
}
|
||||||
downByOne := func(t *testing.T) int64 {
|
downByOne := func(t *testing.T) int64 {
|
||||||
err = goose.Down(db, dir)
|
err = goose.Down(db, dir)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
version, err := goose.GetDBVersion(db)
|
version, err := goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return version
|
return version
|
||||||
}
|
}
|
||||||
// Migrate all files up-by-one.
|
// Migrate all files up-by-one.
|
||||||
for i := 1; i <= len(files); i++ {
|
for i := 1; i <= len(files); i++ {
|
||||||
check.Number(t, upByOne(t), i)
|
require.Equal(t, upByOne(t), i)
|
||||||
}
|
}
|
||||||
version, err := goose.GetDBVersion(db)
|
version, err := goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, version, len(files))
|
require.Equal(t, version, len(files))
|
||||||
|
|
||||||
tables, err := ListTables(db)
|
tables, err := ListTables(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, tables, []string{
|
require.Equal(t, tables, []string{
|
||||||
"alpha",
|
"alpha",
|
||||||
"bravo",
|
"bravo",
|
||||||
"charlie",
|
"charlie",
|
||||||
|
@ -62,15 +62,15 @@ func TestGoMigrationByOne(t *testing.T) {
|
||||||
|
|
||||||
// Migrate all files down-by-one.
|
// Migrate all files down-by-one.
|
||||||
for i := len(files) - 1; i >= 0; i-- {
|
for i := len(files) - 1; i >= 0; i-- {
|
||||||
check.Number(t, downByOne(t), i)
|
require.Equal(t, downByOne(t), i)
|
||||||
}
|
}
|
||||||
version, err = goose.GetDBVersion(db)
|
version, err = goose.GetDBVersion(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Number(t, version, 0)
|
require.Equal(t, version, 0)
|
||||||
|
|
||||||
tables, err = ListTables(db)
|
tables, err = ListTables(db)
|
||||||
check.NoError(t, err)
|
require.NoError(t, err)
|
||||||
check.Equal(t, tables, []string{
|
require.Equal(t, tables, []string{
|
||||||
"goose_db_version",
|
"goose_db_version",
|
||||||
"sqlite_sequence",
|
"sqlite_sequence",
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue