mirror of https://github.com/pressly/goose.git
254 lines
7.1 KiB
Go
254 lines
7.1 KiB
Go
package database_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/pressly/goose/v3/database"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/multierr"
|
|
"modernc.org/sqlite"
|
|
)
|
|
|
|
// The goal of this test is to verify the database store package works as expected. This test is not
|
|
// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store
|
|
// interface works against a real database.
|
|
|
|
func TestDialectStore(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("invalid", func(t *testing.T) {
|
|
// Test empty table name.
|
|
_, err := database.NewStore(database.DialectSQLite3, "")
|
|
require.Error(t, err)
|
|
// Test unknown dialect.
|
|
_, err = database.NewStore("unknown-dialect", "foo")
|
|
require.Error(t, err)
|
|
// Test empty dialect.
|
|
_, err = database.NewStore("", "foo")
|
|
require.Error(t, err)
|
|
})
|
|
// Test generic behavior.
|
|
t.Run("sqlite3", func(t *testing.T) {
|
|
db, err := sql.Open("sqlite", ":memory:")
|
|
require.NoError(t, err)
|
|
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
|
|
t.Helper()
|
|
var sqliteErr *sqlite.Error
|
|
ok := errors.As(err, &sqliteErr)
|
|
require.True(t, ok)
|
|
require.Equal(t, 1, sqliteErr.Code()) // Generic error (SQLITE_ERROR)
|
|
require.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists")
|
|
})
|
|
})
|
|
t.Run("ListMigrations", func(t *testing.T) {
|
|
dir := t.TempDir()
|
|
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
|
|
require.NoError(t, err)
|
|
store, err := database.NewStore(database.DialectSQLite3, "foo")
|
|
require.NoError(t, err)
|
|
err = store.CreateVersionTable(context.Background(), db)
|
|
require.NoError(t, err)
|
|
insert := func(db *sql.DB, version int64) error {
|
|
return store.Insert(context.Background(), db, database.InsertRequest{Version: version})
|
|
}
|
|
require.NoError(t, insert(db, 1))
|
|
require.NoError(t, insert(db, 3))
|
|
require.NoError(t, insert(db, 2))
|
|
res, err := store.ListMigrations(context.Background(), db)
|
|
require.NoError(t, err)
|
|
require.Len(t, res, 3)
|
|
// Check versions are in descending order: [2, 3, 1]
|
|
require.EqualValues(t, 2, res[0].Version)
|
|
require.EqualValues(t, 3, res[1].Version)
|
|
require.EqualValues(t, 1, res[2].Version)
|
|
})
|
|
}
|
|
|
|
// testStore tests various store operations.
|
|
//
|
|
// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable
|
|
// when the version table already exists.
|
|
func testStore(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
d database.Dialect,
|
|
db *sql.DB,
|
|
alreadyExists func(t *testing.T, err error),
|
|
) {
|
|
const (
|
|
tablename = "test_goose_db_version"
|
|
)
|
|
store, err := database.NewStore(d, tablename)
|
|
require.NoError(t, err)
|
|
// Create the version table.
|
|
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
|
return store.CreateVersionTable(ctx, tx)
|
|
})
|
|
require.NoError(t, err)
|
|
// Create the version table again. This should fail.
|
|
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
|
return store.CreateVersionTable(ctx, tx)
|
|
})
|
|
require.Error(t, err)
|
|
if alreadyExists != nil {
|
|
alreadyExists(t, err)
|
|
}
|
|
// Get the latest version. There should be none.
|
|
_, err = store.GetLatestVersion(ctx, db)
|
|
require.ErrorIs(t, err, database.ErrVersionNotFound)
|
|
|
|
// List migrations. There should be none.
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
res, err := store.ListMigrations(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.Empty(t, res, 0)
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert 5 migrations in addition to the zero migration.
|
|
for i := 0; i < 6; i++ {
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
err := store.Insert(ctx, conn, database.InsertRequest{Version: int64(i)})
|
|
require.NoError(t, err)
|
|
latest, err := store.GetLatestVersion(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.Equal(t, latest, int64(i))
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// List migrations. There should be 6.
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
res, err := store.ListMigrations(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.Len(t, res, 6)
|
|
// Check versions are in descending order.
|
|
for i := 0; i < 6; i++ {
|
|
require.EqualValues(t, res[i].Version, 5-i)
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Delete 3 migrations backwards
|
|
for i := 5; i >= 3; i-- {
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
err := store.Delete(ctx, conn, int64(i))
|
|
require.NoError(t, err)
|
|
latest, err := store.GetLatestVersion(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.Equal(t, latest, int64(i-1))
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// List migrations. There should be 3.
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
res, err := store.ListMigrations(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.Len(t, res, 3)
|
|
// Check that the remaining versions are in descending order.
|
|
for i := 0; i < 3; i++ {
|
|
require.EqualValues(t, res[i].Version, 2-i)
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Get remaining migrations one by one.
|
|
for i := 0; i < 3; i++ {
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
res, err := store.GetMigration(ctx, conn, int64(i))
|
|
require.NoError(t, err)
|
|
require.True(t, res.IsApplied)
|
|
require.False(t, res.Timestamp.IsZero())
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Delete remaining migrations one by one and use all 3 connection types:
|
|
|
|
// 1. *sql.Tx
|
|
err = runTx(ctx, db, func(tx *sql.Tx) error {
|
|
err := store.Delete(ctx, tx, 2)
|
|
require.NoError(t, err)
|
|
latest, err := store.GetLatestVersion(ctx, tx)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, latest)
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
// 2. *sql.Conn
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
err := store.Delete(ctx, conn, 1)
|
|
require.NoError(t, err)
|
|
latest, err := store.GetLatestVersion(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 0, latest)
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
// 3. *sql.DB
|
|
err = store.Delete(ctx, db, 0)
|
|
require.NoError(t, err)
|
|
_, err = store.GetLatestVersion(ctx, db)
|
|
require.ErrorIs(t, err, database.ErrVersionNotFound)
|
|
|
|
// List migrations. There should be none.
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
res, err := store.ListMigrations(ctx, conn)
|
|
require.NoError(t, err)
|
|
require.Empty(t, res)
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Try to get a migration that does not exist.
|
|
err = runConn(ctx, db, func(conn *sql.Conn) error {
|
|
_, err := store.GetMigration(ctx, conn, 0)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, database.ErrVersionNotFound)
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) {
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if retErr != nil {
|
|
retErr = multierr.Append(retErr, tx.Rollback())
|
|
}
|
|
}()
|
|
if err := fn(tx); err != nil {
|
|
return err
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) {
|
|
conn, err := db.Conn(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if retErr != nil {
|
|
retErr = multierr.Append(retErr, conn.Close())
|
|
}
|
|
}()
|
|
if err := fn(conn); err != nil {
|
|
return err
|
|
}
|
|
return conn.Close()
|
|
}
|