feat: add support for `*sql.DB`-registered Go migration (#450)

pull/446/head^2
Michael Fridman 2023-01-25 08:15:50 -05:00 committed by GitHub
parent 635add3280
commit 203277344b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 577 additions and 73 deletions

View File

@ -1,9 +1,9 @@
-- +goose Up
CREATE TABLE users (
id int NOT NULL PRIMARY KEY,
username text,
name text,
surname text
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT,
name TEXT,
surname TEXT
);
INSERT INTO users VALUES

View File

@ -0,0 +1,43 @@
package main
import (
"database/sql"
"errors"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(Up00003, Down00003)
}
func Up00003(db *sql.DB) error {
id, err := getUserID(db, "jamesbond")
if err != nil {
return err
}
if id == 0 {
query := "INSERT INTO users (username, name, surname) VALUES ($1, $2, $3)"
if _, err := db.Exec(query, "jamesbond", "James", "Bond"); err != nil {
return err
}
}
return nil
}
func getUserID(db *sql.DB, username string) (int, error) {
var id int
err := db.QueryRow("SELECT id FROM users WHERE username = $1", username).Scan(&id)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, err
}
return id, nil
}
func Down00003(db *sql.DB) error {
query := "DELETE FROM users WHERE username = $1"
if _, err := db.Exec(query, "jamesbond"); err != nil {
return err
}
return nil
}

View File

@ -3,26 +3,29 @@
## This example: Custom goose binary with built-in Go migrations
```bash
$ go build -o goose *.go
$ go build -o goose-custom *.go
```
```
$ ./goose sqlite3 ./foo.db status
```bash
$ ./goose-custom sqlite3 ./foo.db status
Applied At Migration
=======================================
Pending -- 00001_create_users_table.sql
Pending -- 00002_rename_root.go
Pending -- 00003_add_user_no_tx.go
$ ./goose sqlite3 ./foo.db up
OK 00001_create_users_table.sql
OK 00002_rename_root.go
goose: no migrations to run. current version: 2
$ ./goose-custom sqlite3 ./foo.db up
OK 00001_create_users_table.sql (711.58µs)
OK 00002_rename_root.go (302.08µs)
OK 00003_add_user_no_tx.go (648.71µs)
goose: no migrations to run. current version: 3
$
$ ./goose-custom sqlite3 ./foo.db status
Applied At Migration
=======================================
Mon Jun 19 21:56:00 2017 -- 00001_create_users_table.sql
Mon Jun 19 21:56:00 2017 -- 00002_rename_root.go
00001_create_users_table.sql
00002_rename_root.go
00003_add_user_no_tx.go
```
## Best practice: Split migrations into a standalone package
@ -33,9 +36,9 @@ $
3. Import this `migrations` package from your custom [cmd/main.go](main.go) file:
```go
import (
// Invoke init() functions within migrations pkg.
_ "github.com/pressly/goose/example/migrations-go"
)
```
```go
import (
// Invoke init() functions within migrations pkg.
_ "github.com/pressly/goose/example/migrations-go"
)
```

Binary file not shown.

Binary file not shown.

View File

@ -123,22 +123,70 @@ func (ms Migrations) String() string {
return str
}
// AddMigration adds a migration.
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
// GoMigration is a Go migration func that is run within a transaction.
type GoMigration func(tx *sql.Tx) error
// GoMigrationNoTx is a Go migration func that is run outside a transaction.
type GoMigrationNoTx func(db *sql.DB) error
// AddMigration adds Go migrations.
func AddMigration(up, down GoMigration) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down)
}
// AddNamedMigration : Add a named migration.
func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := NumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
if existing, ok := registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
// AddNamedMigration adds named Go migrations.
func AddNamedMigration(filename string, up, down GoMigration) {
if err := register(filename, true, up, down, nil, nil); err != nil {
panic(err)
}
}
registeredGoMigrations[v] = migration
// AddMigrationNoTx adds Go migrations that will be run outside transaction.
func AddMigrationNoTx(up, down GoMigrationNoTx) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTx(filename, up, down)
}
// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
if err := register(filename, false, nil, nil, up, down); err != nil {
panic(err)
}
}
func register(
filename string,
useTx bool,
up, down GoMigration,
upNoTx, downNoTx GoMigrationNoTx,
) error {
// Sanity check caller did not mix tx and non-tx based functions.
if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) {
return fmt.Errorf("cannot mix tx and non-tx based go migrations functions")
}
v, _ := NumericComponent(filename)
if existing, ok := registeredGoMigrations[v]; ok {
return fmt.Errorf("failed to add migration %q: version %d conflicts with %q",
filename,
v,
existing.Source,
)
}
// Add to global as a registered migration.
registeredGoMigrations[v] = &Migration{
Version: v,
Next: -1,
Previous: -1,
Registered: true,
Source: filename,
UseTx: useTx,
UpFn: up,
DownFn: down,
UpFnNoTx: upNoTx,
DownFnNoTx: downNoTx,
}
return nil
}
func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) {

View File

@ -21,14 +21,15 @@ type MigrationRecord struct {
// Migration struct.
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UpFn func(*sql.Tx) error // Up go migration function
DownFn func(*sql.Tx) error // Down go migration function
noVersioning bool
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UseTx bool
UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx
noVersioning bool
}
func (m *Migration) String() string {
@ -82,54 +83,121 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
if !m.Registered {
return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("ERROR failed to begin transaction: %w", err)
}
fn := m.UpFn
if !direction {
fn = m.DownFn
}
if fn != nil {
// Run Go migration function.
if err := fn(tx); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR %v: failed to run Go migration function %T: %w", filepath.Base(m.Source), fn, err)
}
}
if !m.noVersioning {
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR failed to execute transaction: %w", err)
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR failed to execute transaction: %w", err)
}
}
}
start := time.Now()
if err := tx.Commit(); err != nil {
return fmt.Errorf("ERROR failed to commit transaction: %w", err)
var empty bool
if m.UseTx {
// Run go-based migration inside a tx.
fn := m.DownFn
if direction {
fn = m.UpFn
}
empty = (fn == nil)
if err := runGoMigration(
db,
fn,
m.Version,
direction,
!m.noVersioning,
); err != nil {
return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err)
}
} else {
// Run go-based migration outside a tx.
fn := m.DownFnNoTx
if direction {
fn = m.UpFnNoTx
}
empty = (fn == nil)
if err := runGoMigrationNoTx(
db,
fn,
m.Version,
direction,
!m.noVersioning,
); err != nil {
return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err)
}
}
finish := truncateDuration(time.Since(start))
if fn != nil {
if !empty {
log.Printf("OK %s (%s)\n", filepath.Base(m.Source), finish)
} else {
log.Printf("EMPTY %s (%s)\n", filepath.Base(m.Source), finish)
}
}
return nil
}
func runGoMigrationNoTx(
db *sql.DB,
fn GoMigrationNoTx,
version int64,
direction bool,
recordVersion bool,
) error {
if fn != nil {
// Run go migration function.
if err := fn(db); err != nil {
return fmt.Errorf("failed to run go migration: %w", err)
}
}
if recordVersion {
return insertOrDeleteVersionNoTx(db, version, direction)
}
return nil
}
func runGoMigration(
db *sql.DB,
fn GoMigration,
version int64,
direction bool,
recordVersion bool,
) error {
if fn == nil && !recordVersion {
return nil
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
if fn != nil {
// Run go migration function.
if err := fn(tx); err != nil {
tx.Rollback()
return fmt.Errorf("failed to run go migration: %w", err)
}
}
if recordVersion {
if err := insertOrDeleteVersion(tx, version, direction); err != nil {
tx.Rollback()
return fmt.Errorf("failed to update version: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func insertOrDeleteVersion(tx *sql.Tx, version int64, direction bool) error {
if direction {
_, err := tx.Exec(GetDialect().insertVersionSQL(), version, direction)
return err
}
_, err := tx.Exec(GetDialect().deleteVersionSQL(), version)
return err
}
func insertOrDeleteVersionNoTx(db *sql.DB, version int64, direction bool) error {
if direction {
_, err := db.Exec(GetDialect().insertVersionSQL(), version, direction)
return err
}
_, err := db.Exec(GetDialect().deleteVersionSQL(), version)
return err
}
// NumericComponent looks for migration scripts with names in the form:
// XXX_descriptivename.ext where XXX specifies the version number
// and ext specifies the type of migration

View File

@ -0,0 +1,70 @@
package gomigrations
import (
"testing"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
_ "github.com/pressly/goose/v3/tests/gomigrations/error/testdata"
)
func TestGoMigrationByOne(t *testing.T) {
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
// Create goose table.
current, err := goose.EnsureDBVersion(db)
check.NoError(t, err)
check.Number(t, current, 0)
// Collect migrations.
dir := "testdata"
migrations, err := goose.CollectMigrations(dir, 0, goose.MaxVersion)
check.NoError(t, err)
check.Number(t, len(migrations), 4)
// Setup table.
err = migrations[0].Up(db)
check.NoError(t, err)
version, err := goose.GetDBVersion(db)
check.NoError(t, err)
check.Number(t, version, 1)
// Registered Go migration run outside a goose tx using *sql.DB.
err = migrations[1].Up(db)
check.HasError(t, err)
check.Contains(t, err.Error(), "failed to run go migration")
version, err = goose.GetDBVersion(db)
check.NoError(t, err)
check.Number(t, version, 1)
// This migration was inserting 100 rows, but fails at 50, and
// because it's run outside a goose tx then we expect 50 rows.
var count int
err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count)
check.NoError(t, err)
check.Number(t, count, 50)
// Truncate table so we have 0 rows.
err = migrations[2].Up(db)
check.NoError(t, err)
version, err = goose.GetDBVersion(db)
check.NoError(t, err)
// We're at version 3, but keep in mind 2 was never applied because it failed.
check.Number(t, version, 3)
// Registered Go migration run within a tx.
err = migrations[3].Up(db)
check.HasError(t, err)
check.Contains(t, err.Error(), "failed to run go migration")
version, err = goose.GetDBVersion(db)
check.NoError(t, err)
check.Number(t, version, 3) // This migration failed, so we're still at 3.
// This migration was inserting 100 rows, but fails at 50. However, since it's
// running within a tx we expect none of the inserts to persist.
err = db.QueryRow("SELECT COUNT(*) FROM foo").Scan(&count)
check.NoError(t, err)
check.Number(t, count, 0)
}

View File

@ -0,0 +1,17 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(up001, nil)
}
func up001(db *sql.DB) error {
q := "CREATE TABLE foo (id INT)"
_, err := db.Exec(q)
return err
}

View File

@ -0,0 +1,27 @@
package gomigrations
import (
"database/sql"
"fmt"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(up002, nil)
}
func up002(db *sql.DB) error {
for i := 1; i <= 100; i++ {
q := "INSERT INTO foo VALUES ($1)"
if _, err := db.Exec(q, i); err != nil {
return err
}
// Simulate an error when no tx. We should have 50 rows
// inserted in the DB.
if i == 50 {
return fmt.Errorf("simulate error: too many inserts")
}
}
return nil
}

View File

@ -0,0 +1,17 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up003, nil)
}
func up003(tx *sql.Tx) error {
q := "TRUNCATE TABLE foo"
_, err := tx.Exec(q)
return err
}

View File

@ -0,0 +1,27 @@
package gomigrations
import (
"database/sql"
"fmt"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up004, nil)
}
func up004(tx *sql.Tx) error {
for i := 1; i <= 100; i++ {
// Simulate an error when no tx. We should have 50 rows
// inserted in the DB.
if i == 50 {
return fmt.Errorf("simulate error: too many inserts")
}
q := "INSERT INTO foo VALUES ($1)"
if _, err := tx.Exec(q); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,52 @@
package gomigrations
import (
"path/filepath"
"testing"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
_ "github.com/pressly/goose/v3/tests/gomigrations/success/testdata"
)
func TestGoMigrationByOne(t *testing.T) {
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
dir := "testdata"
files, err := filepath.Glob(dir + "/*.go")
check.NoError(t, err)
upByOne := func(t *testing.T) int64 {
err = goose.UpByOne(db, dir)
check.NoError(t, err)
version, err := goose.GetDBVersion(db)
check.NoError(t, err)
return version
}
downByOne := func(t *testing.T) int64 {
err = goose.Down(db, dir)
check.NoError(t, err)
version, err := goose.GetDBVersion(db)
check.NoError(t, err)
return version
}
// Migrate all files up-by-one.
for i := 1; i <= len(files); i++ {
check.Number(t, upByOne(t), i)
}
version, err := goose.GetDBVersion(db)
check.NoError(t, err)
check.Number(t, version, len(files))
// Migrate all files down-by-one.
for i := len(files) - 1; i >= 0; i-- {
check.Number(t, downByOne(t), i)
}
version, err = goose.GetDBVersion(db)
check.NoError(t, err)
check.Number(t, version, 0)
}

View File

@ -0,0 +1,23 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up001, down001)
}
func up001(tx *sql.Tx) error {
q := "CREATE TABLE foo (id INT, subid INT, name TEXT)"
_, err := tx.Exec(q)
return err
}
func down001(tx *sql.Tx) error {
q := "DROP TABLE IF EXISTS foo"
_, err := tx.Exec(q)
return err
}

View File

@ -0,0 +1,17 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up002, nil)
}
func up002(tx *sql.Tx) error {
q := "INSERT INTO foo VALUES (1, 1, 'Alice')"
_, err := tx.Exec(q)
return err
}

View File

@ -0,0 +1,17 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(nil, down003)
}
func down003(tx *sql.Tx) error {
q := "TRUNCATE TABLE foo"
_, err := tx.Exec(q)
return err
}

View File

@ -0,0 +1,9 @@
package gomigrations
import (
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(nil, nil)
}

View File

@ -0,0 +1,23 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(up005, down005)
}
func up005(db *sql.DB) error {
q := "CREATE TABLE users (id INT, email TEXT)"
_, err := db.Exec(q)
return err
}
func down005(db *sql.DB) error {
q := "DROP TABLE IF EXISTS users"
_, err := db.Exec(q)
return err
}

View File

@ -0,0 +1,17 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(up006, nil)
}
func up006(db *sql.DB) error {
q := "INSERT INTO users VALUES (1, 'admin@example.com')"
_, err := db.Exec(q)
return err
}

View File

@ -0,0 +1,17 @@
package gomigrations
import (
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigrationNoTx(nil, down007)
}
func down007(db *sql.DB) error {
q := "TRUNCATE TABLE users"
_, err := db.Exec(q)
return err
}

View File

@ -0,0 +1,9 @@
package gomigrations
import (
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(nil, nil)
}