diff --git a/examples/go-migrations/00001_create_users_table.sql b/examples/go-migrations/00001_create_users_table.sql index efce49f..aa58084 100644 --- a/examples/go-migrations/00001_create_users_table.sql +++ b/examples/go-migrations/00001_create_users_table.sql @@ -1,9 +1,9 @@ -- +goose Up CREATE TABLE users ( - id int NOT NULL PRIMARY KEY, - username text, - name text, - surname text + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT, + name TEXT, + surname TEXT ); INSERT INTO users VALUES diff --git a/examples/go-migrations/00003_add_user_no_tx.go b/examples/go-migrations/00003_add_user_no_tx.go new file mode 100644 index 0000000..618dc6d --- /dev/null +++ b/examples/go-migrations/00003_add_user_no_tx.go @@ -0,0 +1,43 @@ +package main + +import ( + "database/sql" + "errors" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(Up00003, Down00003) +} + +func Up00003(db *sql.DB) error { + id, err := getUserID(db, "jamesbond") + if err != nil { + return err + } + if id == 0 { + query := "INSERT INTO users (username, name, surname) VALUES ($1, $2, $3)" + if _, err := db.Exec(query, "jamesbond", "James", "Bond"); err != nil { + return err + } + } + return nil +} + +func getUserID(db *sql.DB, username string) (int, error) { + var id int + err := db.QueryRow("SELECT id FROM users WHERE username = $1", username).Scan(&id) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return 0, err + } + return id, nil +} + +func Down00003(db *sql.DB) error { + query := "DELETE FROM users WHERE username = $1" + if _, err := db.Exec(query, "jamesbond"); err != nil { + return err + } + return nil +} diff --git a/examples/go-migrations/README.md b/examples/go-migrations/README.md index 5954ec6..720d11c 100644 --- a/examples/go-migrations/README.md +++ b/examples/go-migrations/README.md @@ -3,26 +3,29 @@ ## This example: Custom goose binary with built-in Go migrations ```bash -$ go build -o goose *.go +$ go build -o goose-custom *.go ``` -``` -$ ./goose sqlite3 ./foo.db status +```bash +$ ./goose-custom sqlite3 ./foo.db status Applied At Migration ======================================= Pending -- 00001_create_users_table.sql Pending -- 00002_rename_root.go + Pending -- 00003_add_user_no_tx.go -$ ./goose sqlite3 ./foo.db up -OK 00001_create_users_table.sql -OK 00002_rename_root.go -goose: no migrations to run. current version: 2 +$ ./goose-custom sqlite3 ./foo.db up + OK 00001_create_users_table.sql (711.58µs) + OK 00002_rename_root.go (302.08µs) + OK 00003_add_user_no_tx.go (648.71µs) + goose: no migrations to run. current version: 3 -$ +$ ./goose-custom sqlite3 ./foo.db status Applied At Migration ======================================= - Mon Jun 19 21:56:00 2017 -- 00001_create_users_table.sql - Mon Jun 19 21:56:00 2017 -- 00002_rename_root.go + 00001_create_users_table.sql + 00002_rename_root.go + 00003_add_user_no_tx.go ``` ## Best practice: Split migrations into a standalone package @@ -33,9 +36,9 @@ $ 3. Import this `migrations` package from your custom [cmd/main.go](main.go) file: - ```go - import ( - // Invoke init() functions within migrations pkg. - _ "github.com/pressly/goose/example/migrations-go" - ) - ``` + ```go + import ( + // Invoke init() functions within migrations pkg. + _ "github.com/pressly/goose/example/migrations-go" + ) + ``` diff --git a/examples/go-migrations/foo.db b/examples/go-migrations/foo.db deleted file mode 100644 index d857491..0000000 Binary files a/examples/go-migrations/foo.db and /dev/null differ diff --git a/examples/go-migrations/goose b/examples/go-migrations/goose deleted file mode 100755 index 0bf60fa..0000000 Binary files a/examples/go-migrations/goose and /dev/null differ diff --git a/migrate.go b/migrate.go index a4cf399..0e93f88 100644 --- a/migrate.go +++ b/migrate.go @@ -123,22 +123,70 @@ func (ms Migrations) String() string { return str } -// AddMigration adds a migration. -func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { +// GoMigration is a Go migration func that is run within a transaction. +type GoMigration func(tx *sql.Tx) error + +// GoMigrationNoTx is a Go migration func that is run outside a transaction. +type GoMigrationNoTx func(db *sql.DB) error + +// AddMigration adds Go migrations. +func AddMigration(up, down GoMigration) { _, filename, _, _ := runtime.Caller(1) AddNamedMigration(filename, up, down) } -// AddNamedMigration : Add a named migration. -func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { - v, _ := NumericComponent(filename) - migration := &Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename} - - if existing, ok := registeredGoMigrations[v]; ok { - panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) +// AddNamedMigration adds named Go migrations. +func AddNamedMigration(filename string, up, down GoMigration) { + if err := register(filename, true, up, down, nil, nil); err != nil { + panic(err) } +} - registeredGoMigrations[v] = migration +// AddMigrationNoTx adds Go migrations that will be run outside transaction. +func AddMigrationNoTx(up, down GoMigrationNoTx) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationNoTx(filename, up, down) +} + +// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. +func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { + if err := register(filename, false, nil, nil, up, down); err != nil { + panic(err) + } +} + +func register( + filename string, + useTx bool, + up, down GoMigration, + upNoTx, downNoTx GoMigrationNoTx, +) error { + // Sanity check caller did not mix tx and non-tx based functions. + if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { + return fmt.Errorf("cannot mix tx and non-tx based go migrations functions") + } + v, _ := NumericComponent(filename) + if existing, ok := registeredGoMigrations[v]; ok { + return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", + filename, + v, + existing.Source, + ) + } + // Add to global as a registered migration. + registeredGoMigrations[v] = &Migration{ + Version: v, + Next: -1, + Previous: -1, + Registered: true, + Source: filename, + UseTx: useTx, + UpFn: up, + DownFn: down, + UpFnNoTx: upNoTx, + DownFnNoTx: downNoTx, + } + return nil } func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) { diff --git a/migration.go b/migration.go index 22c945c..1910630 100644 --- a/migration.go +++ b/migration.go @@ -21,14 +21,15 @@ type MigrationRecord struct { // Migration struct. type Migration struct { - Version int64 - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - Source string // path to .sql script or go file - Registered bool - UpFn func(*sql.Tx) error // Up go migration function - DownFn func(*sql.Tx) error // Down go migration function - noVersioning bool + Version int64 + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + Source string // path to .sql script or go file + Registered bool + UseTx bool + UpFn, DownFn GoMigration + UpFnNoTx, DownFnNoTx GoMigrationNoTx + noVersioning bool } func (m *Migration) String() string { @@ -82,54 +83,121 @@ func (m *Migration) run(db *sql.DB, direction bool) error { if !m.Registered { return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source) } - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("ERROR failed to begin transaction: %w", err) - } - - fn := m.UpFn - if !direction { - fn = m.DownFn - } - - if fn != nil { - // Run Go migration function. - if err := fn(tx); err != nil { - tx.Rollback() - return fmt.Errorf("ERROR %v: failed to run Go migration function %T: %w", filepath.Base(m.Source), fn, err) - } - } - if !m.noVersioning { - if direction { - if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil { - tx.Rollback() - return fmt.Errorf("ERROR failed to execute transaction: %w", err) - } - } else { - if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil { - tx.Rollback() - return fmt.Errorf("ERROR failed to execute transaction: %w", err) - } - } - } start := time.Now() - if err := tx.Commit(); err != nil { - return fmt.Errorf("ERROR failed to commit transaction: %w", err) + var empty bool + if m.UseTx { + // Run go-based migration inside a tx. + fn := m.DownFn + if direction { + fn = m.UpFn + } + empty = (fn == nil) + if err := runGoMigration( + db, + fn, + m.Version, + direction, + !m.noVersioning, + ); err != nil { + return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err) + } + } else { + // Run go-based migration outside a tx. + fn := m.DownFnNoTx + if direction { + fn = m.UpFnNoTx + } + empty = (fn == nil) + if err := runGoMigrationNoTx( + db, + fn, + m.Version, + direction, + !m.noVersioning, + ); err != nil { + return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err) + } } finish := truncateDuration(time.Since(start)) - - if fn != nil { + if !empty { log.Printf("OK %s (%s)\n", filepath.Base(m.Source), finish) } else { log.Printf("EMPTY %s (%s)\n", filepath.Base(m.Source), finish) } + } + return nil +} +func runGoMigrationNoTx( + db *sql.DB, + fn GoMigrationNoTx, + version int64, + direction bool, + recordVersion bool, +) error { + if fn != nil { + // Run go migration function. + if err := fn(db); err != nil { + return fmt.Errorf("failed to run go migration: %w", err) + } + } + if recordVersion { + return insertOrDeleteVersionNoTx(db, version, direction) + } + return nil +} + +func runGoMigration( + db *sql.DB, + fn GoMigration, + version int64, + direction bool, + recordVersion bool, +) error { + if fn == nil && !recordVersion { return nil } - + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + if fn != nil { + // Run go migration function. + if err := fn(tx); err != nil { + tx.Rollback() + return fmt.Errorf("failed to run go migration: %w", err) + } + } + if recordVersion { + if err := insertOrDeleteVersion(tx, version, direction); err != nil { + tx.Rollback() + return fmt.Errorf("failed to update version: %w", err) + } + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } return nil } +func insertOrDeleteVersion(tx *sql.Tx, version int64, direction bool) error { + if direction { + _, err := tx.Exec(GetDialect().insertVersionSQL(), version, direction) + return err + } + _, err := tx.Exec(GetDialect().deleteVersionSQL(), version) + return err +} + +func insertOrDeleteVersionNoTx(db *sql.DB, version int64, direction bool) error { + if direction { + _, err := db.Exec(GetDialect().insertVersionSQL(), version, direction) + return err + } + _, err := db.Exec(GetDialect().deleteVersionSQL(), version) + return err +} + // NumericComponent looks for migration scripts with names in the form: // XXX_descriptivename.ext where XXX specifies the version number // and ext specifies the type of migration diff --git a/tests/gomigrations/error/gomigrations_error_test.go b/tests/gomigrations/error/gomigrations_error_test.go new file mode 100644 index 0000000..c67da9e --- /dev/null +++ b/tests/gomigrations/error/gomigrations_error_test.go @@ -0,0 +1,70 @@ +package gomigrations + +import ( + "testing" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/testdb" + + _ "github.com/pressly/goose/v3/tests/gomigrations/error/testdata" +) + +func TestGoMigrationByOne(t *testing.T) { + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + // Create goose table. + current, err := goose.EnsureDBVersion(db) + check.NoError(t, err) + check.Number(t, current, 0) + // Collect migrations. + dir := "testdata" + migrations, err := goose.CollectMigrations(dir, 0, goose.MaxVersion) + check.NoError(t, err) + check.Number(t, len(migrations), 4) + + // Setup table. + err = migrations[0].Up(db) + check.NoError(t, err) + version, err := goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, version, 1) + + // Registered Go migration run outside a goose tx using *sql.DB. + err = migrations[1].Up(db) + check.HasError(t, err) + check.Contains(t, err.Error(), "failed to run go migration") + version, err = goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, version, 1) + + // This migration was inserting 100 rows, but fails at 50, and + // because it's run outside a goose tx then we expect 50 rows. + var count int + err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count) + check.NoError(t, err) + check.Number(t, count, 50) + + // Truncate table so we have 0 rows. + err = migrations[2].Up(db) + check.NoError(t, err) + version, err = goose.GetDBVersion(db) + check.NoError(t, err) + // We're at version 3, but keep in mind 2 was never applied because it failed. + check.Number(t, version, 3) + + // Registered Go migration run within a tx. + err = migrations[3].Up(db) + check.HasError(t, err) + check.Contains(t, err.Error(), "failed to run go migration") + version, err = goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, version, 3) // This migration failed, so we're still at 3. + // This migration was inserting 100 rows, but fails at 50. However, since it's + // running within a tx we expect none of the inserts to persist. + err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count) + check.NoError(t, err) + check.Number(t, count, 0) + +} diff --git a/tests/gomigrations/error/testdata/001_up_no_tx.go b/tests/gomigrations/error/testdata/001_up_no_tx.go new file mode 100644 index 0000000..c5f22fd --- /dev/null +++ b/tests/gomigrations/error/testdata/001_up_no_tx.go @@ -0,0 +1,17 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(up001, nil) +} + +func up001(db *sql.DB) error { + q := "CREATE TABLE foo (id INT)" + _, err := db.Exec(q) + return err +} diff --git a/tests/gomigrations/error/testdata/002_ERROR_insert_no_tx.go b/tests/gomigrations/error/testdata/002_ERROR_insert_no_tx.go new file mode 100644 index 0000000..e07dc59 --- /dev/null +++ b/tests/gomigrations/error/testdata/002_ERROR_insert_no_tx.go @@ -0,0 +1,27 @@ +package gomigrations + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(up002, nil) +} + +func up002(db *sql.DB) error { + for i := 1; i <= 100; i++ { + q := "INSERT INTO foo VALUES ($1)" + if _, err := db.Exec(q, i); err != nil { + return err + } + // Simulate an error when no tx. We should have 50 rows + // inserted in the DB. + if i == 50 { + return fmt.Errorf("simulate error: too many inserts") + } + } + return nil +} diff --git a/tests/gomigrations/error/testdata/003_truncate.go b/tests/gomigrations/error/testdata/003_truncate.go new file mode 100644 index 0000000..836af1c --- /dev/null +++ b/tests/gomigrations/error/testdata/003_truncate.go @@ -0,0 +1,17 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(up003, nil) +} + +func up003(tx *sql.Tx) error { + q := "TRUNCATE TABLE foo" + _, err := tx.Exec(q) + return err +} diff --git a/tests/gomigrations/error/testdata/004_ERROR_insert.go b/tests/gomigrations/error/testdata/004_ERROR_insert.go new file mode 100644 index 0000000..7cad068 --- /dev/null +++ b/tests/gomigrations/error/testdata/004_ERROR_insert.go @@ -0,0 +1,27 @@ +package gomigrations + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(up004, nil) +} + +func up004(tx *sql.Tx) error { + for i := 1; i <= 100; i++ { + // Simulate an error when no tx. We should have 50 rows + // inserted in the DB. + if i == 50 { + return fmt.Errorf("simulate error: too many inserts") + } + q := "INSERT INTO foo VALUES ($1)" + if _, err := tx.Exec(q); err != nil { + return err + } + } + return nil +} diff --git a/tests/gomigrations/success/gomigrations_success_test.go b/tests/gomigrations/success/gomigrations_success_test.go new file mode 100644 index 0000000..87bebf4 --- /dev/null +++ b/tests/gomigrations/success/gomigrations_success_test.go @@ -0,0 +1,52 @@ +package gomigrations + +import ( + "path/filepath" + "testing" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/testdb" + + _ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata" +) + +func TestGoMigrationByOne(t *testing.T) { + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + + dir := "testdata" + files, err := filepath.Glob(dir + "/*.go") + check.NoError(t, err) + + upByOne := func(t *testing.T) int64 { + err = goose.UpByOne(db, dir) + check.NoError(t, err) + version, err := goose.GetDBVersion(db) + check.NoError(t, err) + return version + } + downByOne := func(t *testing.T) int64 { + err = goose.Down(db, dir) + check.NoError(t, err) + version, err := goose.GetDBVersion(db) + check.NoError(t, err) + return version + } + // Migrate all files up-by-one. + for i := 1; i <= len(files); i++ { + check.Number(t, upByOne(t), i) + } + version, err := goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, version, len(files)) + + // Migrate all files down-by-one. + for i := len(files) - 1; i >= 0; i-- { + check.Number(t, downByOne(t), i) + } + version, err = goose.GetDBVersion(db) + check.NoError(t, err) + check.Number(t, version, 0) +} diff --git a/tests/gomigrations/success/testdata/001_up_down.go b/tests/gomigrations/success/testdata/001_up_down.go new file mode 100644 index 0000000..9fed61c --- /dev/null +++ b/tests/gomigrations/success/testdata/001_up_down.go @@ -0,0 +1,23 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(up001, down001) +} + +func up001(tx *sql.Tx) error { + q := "CREATE TABLE foo (id INT, subid INT, name TEXT)" + _, err := tx.Exec(q) + return err +} + +func down001(tx *sql.Tx) error { + q := "DROP TABLE IF EXISTS foo" + _, err := tx.Exec(q) + return err +} diff --git a/tests/gomigrations/success/testdata/002_up_only.go b/tests/gomigrations/success/testdata/002_up_only.go new file mode 100644 index 0000000..6ece192 --- /dev/null +++ b/tests/gomigrations/success/testdata/002_up_only.go @@ -0,0 +1,17 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(up002, nil) +} + +func up002(tx *sql.Tx) error { + q := "INSERT INTO foo VALUES (1, 1, 'Alice')" + _, err := tx.Exec(q) + return err +} diff --git a/tests/gomigrations/success/testdata/003_down_only.go b/tests/gomigrations/success/testdata/003_down_only.go new file mode 100644 index 0000000..ff39f5f --- /dev/null +++ b/tests/gomigrations/success/testdata/003_down_only.go @@ -0,0 +1,17 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(nil, down003) +} + +func down003(tx *sql.Tx) error { + q := "TRUNCATE TABLE foo" + _, err := tx.Exec(q) + return err +} diff --git a/tests/gomigrations/success/testdata/004_empty.go b/tests/gomigrations/success/testdata/004_empty.go new file mode 100644 index 0000000..5efb376 --- /dev/null +++ b/tests/gomigrations/success/testdata/004_empty.go @@ -0,0 +1,9 @@ +package gomigrations + +import ( + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(nil, nil) +} diff --git a/tests/gomigrations/success/testdata/005_up_down_no_tx.go b/tests/gomigrations/success/testdata/005_up_down_no_tx.go new file mode 100644 index 0000000..7a6838d --- /dev/null +++ b/tests/gomigrations/success/testdata/005_up_down_no_tx.go @@ -0,0 +1,23 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(up005, down005) +} + +func up005(db *sql.DB) error { + q := "CREATE TABLE users (id INT, email TEXT)" + _, err := db.Exec(q) + return err +} + +func down005(db *sql.DB) error { + q := "DROP TABLE IF EXISTS users" + _, err := db.Exec(q) + return err +} diff --git a/tests/gomigrations/success/testdata/006_up_only_no_tx.go b/tests/gomigrations/success/testdata/006_up_only_no_tx.go new file mode 100644 index 0000000..26aa88c --- /dev/null +++ b/tests/gomigrations/success/testdata/006_up_only_no_tx.go @@ -0,0 +1,17 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(up006, nil) +} + +func up006(db *sql.DB) error { + q := "INSERT INTO users VALUES (1, 'admin@example.com')" + _, err := db.Exec(q) + return err +} diff --git a/tests/gomigrations/success/testdata/007_down_only_no_tx.go b/tests/gomigrations/success/testdata/007_down_only_no_tx.go new file mode 100644 index 0000000..318b02e --- /dev/null +++ b/tests/gomigrations/success/testdata/007_down_only_no_tx.go @@ -0,0 +1,17 @@ +package gomigrations + +import ( + "database/sql" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTx(nil, down007) +} + +func down007(db *sql.DB) error { + q := "TRUNCATE TABLE users" + _, err := db.Exec(q) + return err +} diff --git a/tests/gomigrations/success/testdata/008_empty_no_tx.go b/tests/gomigrations/success/testdata/008_empty_no_tx.go new file mode 100644 index 0000000..5efb376 --- /dev/null +++ b/tests/gomigrations/success/testdata/008_empty_no_tx.go @@ -0,0 +1,9 @@ +package gomigrations + +import ( + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigration(nil, nil) +}