feat(experimental): add internal migrate package and SessionLocker interface ()

pull/614/head
Michael Fridman 2023-10-09 15:08:51 -04:00 committed by GitHub
parent ccfb885423
commit c590380f39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 723 additions and 3 deletions

1
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/jackc/pgx/v5 v5.4.3
github.com/microsoft/go-mssqldb v1.6.0
github.com/ory/dockertest/v3 v3.10.0
github.com/sethvargo/go-retry v0.2.4
github.com/vertica/vertica-sql-go v1.3.3
github.com/ziutek/mymysql v1.5.4
go.uber.org/multierr v1.11.0

2
go.sum
View File

@ -127,6 +127,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qq
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec=
github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw=
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=

9
internal/migrate/doc.go Normal file
View File

@ -0,0 +1,9 @@
// Package migrate defines a Migration struct and implements the migration logic for executing Go
// and SQL migrations.
//
// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported.
// - For SQL migrations, all three are supported.
//
// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first
// time it is executed.
package migrate

View File

@ -0,0 +1,166 @@
package migrate
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/pressly/goose/v3/internal/sqlextended"
)
type Migration struct {
// Fullpath is the full path to the migration file.
//
// Example: /path/to/migrations/123_create_users_table.go
Fullpath string
// Version is the version of the migration.
Version int64
// Type is the type of migration.
Type MigrationType
// A migration is either a Go migration or a SQL migration, but never both.
//
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
// majority of the time migrations are incremental, so it is likely that the user will only want
// to run the last few migrations, and there is no need to parse ALL prior migrations.
//
// Exactly one of these fields will be set:
Go *Go
// -- or --
SQLParsed bool
SQL *SQL
}
type MigrationType int
const (
TypeGo MigrationType = iota + 1
TypeSQL
)
func (t MigrationType) String() string {
switch t {
case TypeGo:
return "go"
case TypeSQL:
return "sql"
default:
// This should never happen.
return "unknown"
}
}
func (m *Migration) UseTx() bool {
switch m.Type {
case TypeGo:
return m.Go.UseTx
case TypeSQL:
return m.SQL.UseTx
default:
// This should never happen.
panic("unknown migration type: use tx")
}
}
func (m *Migration) IsEmpty(direction bool) bool {
switch m.Type {
case TypeGo:
return m.Go.IsEmpty(direction)
case TypeSQL:
return m.SQL.IsEmpty(direction)
default:
// This should never happen.
panic("unknown migration type: is empty")
}
}
func (m *Migration) GetSQLStatements(direction bool) ([]string, error) {
if m.Type != TypeSQL {
return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type)
}
if m.SQL == nil {
return nil, errors.New("sql migration has not been initialized")
}
if !m.SQLParsed {
return nil, errors.New("sql migration has not been parsed")
}
if direction {
return m.SQL.UpStatements, nil
}
return m.SQL.DownStatements, nil
}
type Go struct {
// We used an explicit bool instead of relying on a pointer because registered funcs may be nil.
// These are still valid Go and versioned migrations, but they are just empty.
//
// For example: goose.AddMigration(nil, nil)
UseTx bool
// Only one of these func pairs will be set:
UpFn, DownFn func(context.Context, *sql.Tx) error
// -- or --
UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error
}
func (g *Go) IsEmpty(direction bool) bool {
if direction {
return g.UpFn == nil && g.UpFnNoTx == nil
}
return g.DownFn == nil && g.DownFnNoTx == nil
}
func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error {
var fn func(context.Context, *sql.Tx) error
if direction {
fn = g.UpFn
} else {
fn = g.DownFn
}
if fn != nil {
return fn(ctx, tx)
}
return nil
}
func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
var fn func(context.Context, *sql.DB) error
if direction {
fn = g.UpFnNoTx
} else {
fn = g.DownFnNoTx
}
if fn != nil {
return fn(ctx, db)
}
return nil
}
type SQL struct {
UseTx bool
UpStatements []string
DownStatements []string
}
func (s *SQL) IsEmpty(direction bool) bool {
if direction {
return len(s.UpStatements) == 0
}
return len(s.DownStatements) == 0
}
func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
} else {
statements = s.DownStatements
}
for _, stmt := range statements {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}
}
return nil
}

75
internal/migrate/parse.go Normal file
View File

@ -0,0 +1,75 @@
package migrate
import (
"bytes"
"io"
"io/fs"
"github.com/pressly/goose/v3/internal/sqlparser"
)
// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
// will not be parsed again.
//
// Important: This function will mutate SQL migrations.
func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error {
for _, m := range migrations {
if m.Type == TypeSQL && !m.SQLParsed {
parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug)
if err != nil {
return err
}
m.SQLParsed = true
m.SQL = parsedSQLMigration
}
}
return nil
}
// parse is used to determine which direction to parse the SQL migration.
type parse int
const (
// parseAll parses all SQL statements in BOTH directions.
parseAll parse = iota + 1
// parseUp parses all SQL statements in the UP direction.
parseUp
// parseDown parses all SQL statements in the DOWN direction.
parseDown
)
func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) {
r, err := fsys.Open(filename)
if err != nil {
return nil, err
}
by, err := io.ReadAll(r)
if err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
s := new(SQL)
if p == parseAll || p == parseUp {
s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionUp,
debug,
)
if err != nil {
return nil, err
}
}
if p == parseAll || p == parseDown {
s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionDown,
debug,
)
if err != nil {
return nil, err
}
}
return s, nil
}

53
internal/migrate/run.go Normal file
View File

@ -0,0 +1,53 @@
package migrate
import (
"context"
"database/sql"
"fmt"
"path/filepath"
)
// Run runs the migration inside of a transaction.
func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
switch m.Type {
case TypeSQL:
if m.SQL == nil || !m.SQLParsed {
return fmt.Errorf("tx: sql migration has not been parsed")
}
return m.SQL.run(ctx, tx, direction)
case TypeGo:
return m.Go.run(ctx, tx, direction)
}
// This should never happen.
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
}
// RunNoTx runs the migration without a transaction.
func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error {
switch m.Type {
case TypeSQL:
if m.SQL == nil || !m.SQLParsed {
return fmt.Errorf("db: sql migration has not been parsed")
}
return m.SQL.run(ctx, db, direction)
case TypeGo:
return m.Go.runNoTx(ctx, db, direction)
}
// This should never happen.
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
}
// RunConn runs the migration without a transaction using the provided connection.
func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error {
switch m.Type {
case TypeSQL:
if m.SQL == nil || !m.SQLParsed {
return fmt.Errorf("conn: sql migration has not been parsed")
}
return m.SQL.run(ctx, conn, direction)
case TypeGo:
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
}
// This should never happen.
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
}

View File

@ -11,7 +11,7 @@ import (
// There is a long outstanding issue to formalize a std lib interface, but alas... See:
// https://github.com/golang/go/issues/14468
type DBTxConn interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

110
lock/postgres.go Normal file
View File

@ -0,0 +1,110 @@
package lock
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/sethvargo/go-retry"
)
// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive
// session-level advisory lock mechanism.
//
// This function creates a SessionLocker that can be used to acquire and release locks for
// synchronization purposes. The lock acquisition is retried until it is successfully acquired or
// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the
// default unlock duration is set to 1 minute.
//
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
cfg := sessionLockerConfig{
lockID: DefaultLockID,
lockTimeout: DefaultLockTimeout,
unlockTimeout: DefaultUnlockTimeout,
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
return nil, err
}
}
return &postgresSessionLocker{
lockID: cfg.lockID,
retryLock: retry.WithMaxDuration(
cfg.lockTimeout,
retry.NewConstant(2*time.Second),
),
retryUnlock: retry.WithMaxDuration(
cfg.unlockTimeout,
retry.NewConstant(2*time.Second),
),
}, nil
}
type postgresSessionLocker struct {
lockID int64
retryLock retry.Backoff
retryUnlock retry.Backoff
}
var _ SessionLocker = (*postgresSessionLocker)(nil)
func (l *postgresSessionLocker) SessionLock(ctx context.Context, conn *sql.Conn) error {
return retry.Do(ctx, l.retryLock, func(ctx context.Context) error {
row := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", l.lockID)
var locked bool
if err := row.Scan(&locked); err != nil {
return fmt.Errorf("failed to execute pg_try_advisory_lock: %w", err)
}
if locked {
// A session-level advisory lock was acquired.
return nil
}
// A session-level advisory lock could not be acquired. This is likely because another
// process has already acquired the lock. We will continue retrying until the lock is
// acquired or the maximum number of retries is reached.
return retry.RetryableError(errors.New("failed to acquire lock"))
})
}
func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Conn) error {
return retry.Do(ctx, l.retryUnlock, func(ctx context.Context) error {
var unlocked bool
row := conn.QueryRowContext(ctx, "SELECT pg_advisory_unlock($1)", l.lockID)
if err := row.Scan(&unlocked); err != nil {
return fmt.Errorf("failed to execute pg_advisory_unlock: %w", err)
}
if unlocked {
// A session-level advisory lock was released.
return nil
}
/*
TODO(mf): provide users with some documentation on how they can unlock the session
manually.
This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will
release all session level advisory locks held by the current session. This function is
implicitly invoked at session end, even if the client disconnects ungracefully.
Here is output from a session that has a lock held:
SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks
WHERE locktype='advisory';
| pid | granted | goose_lock_id |
|-----|---------|---------------------|
| 191 | t | 5887940537704921958 |
A forceful way to unlock the session is to terminate the backend with SIGTERM:
SELECT pg_terminate_backend(191);
Subsequent commands on the same connection will fail with:
Query 1 ERROR: FATAL: terminating connection due to administrator command
*/
return retry.RetryableError(errors.New("failed to unlock session"))
})
}

193
lock/postgres_test.go Normal file
View File

@ -0,0 +1,193 @@
package lock_test
import (
"context"
"database/sql"
"errors"
"sync"
"testing"
"time"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
"github.com/pressly/goose/v3/lock"
)
func TestPostgresSessionLocker(t *testing.T) {
if testing.Short() {
t.Skip("skip long running test")
}
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
const (
lockID int64 = 123456789
)
// Do not run tests in parallel, because they are using the same database.
t.Run("lock_and_unlock", func(t *testing.T) {
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
)
check.NoError(t, err)
ctx := context.Background()
conn, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn.Close())
})
err = locker.SessionLock(ctx, conn)
check.NoError(t, err)
pgLocks, err := queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
// Check that the lock was acquired.
check.Bool(t, pgLocks[0].granted, true)
// Check that the custom lock ID is the same as the one used by the locker.
check.Equal(t, pgLocks[0].gooseLockID, lockID)
check.NumberNotZero(t, pgLocks[0].pid)
// Check that the lock is released.
err = locker.SessionUnlock(ctx, conn)
check.NoError(t, err)
pgLocks, err = queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 0)
})
t.Run("lock_close_conn_unlock", func(t *testing.T) {
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
)
check.NoError(t, err)
ctx := context.Background()
conn, err := db.Conn(ctx)
check.NoError(t, err)
err = locker.SessionLock(ctx, conn)
check.NoError(t, err)
pgLocks, err := queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
check.Bool(t, pgLocks[0].granted, true)
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
// Simulate a connection close.
err = conn.Close()
check.NoError(t, err)
// Check an error is returned when unlocking, because the connection is already closed.
err = locker.SessionUnlock(ctx, conn)
check.HasError(t, err)
check.Bool(t, errors.Is(err, sql.ErrConnDone), true)
})
t.Run("multiple_connections", func(t *testing.T) {
const (
workers = 5
)
ch := make(chan error)
var wg sync.WaitGroup
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()
conn, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn.Close())
})
// Exactly one connection should acquire the lock. While the other connections
// should fail to acquire the lock and timeout.
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
)
check.NoError(t, err)
ch <- locker.SessionLock(ctx, conn)
}()
}
go func() {
wg.Wait()
close(ch)
}()
var errors []error
for err := range ch {
if err != nil {
errors = append(errors, err)
}
}
check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail.
for _, err := range errors {
check.HasError(t, err)
check.Equal(t, err.Error(), "failed to acquire lock")
}
pgLocks, err := queryPgLocks(context.Background(), db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
check.Bool(t, pgLocks[0].granted, true)
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
})
t.Run("unlock_with_different_connection", func(t *testing.T) {
ctx := context.Background()
const (
lockID int64 = 999
)
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
lock.WithLockTimeout(4*time.Second),
lock.WithUnlockTimeout(4*time.Second),
)
check.NoError(t, err)
conn1, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn1.Close())
})
err = locker.SessionLock(ctx, conn1)
check.NoError(t, err)
pgLocks, err := queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
check.Bool(t, pgLocks[0].granted, true)
check.Equal(t, pgLocks[0].gooseLockID, lockID)
// Unlock with a different connection.
conn2, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn2.Close())
})
// Check an error is returned when unlocking with a different connection.
err = locker.SessionUnlock(ctx, conn2)
check.HasError(t, err)
})
}
type pgLock struct {
pid int
granted bool
gooseLockID int64
}
func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks WHERE locktype='advisory'`
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
var pgLocks []pgLock
for rows.Next() {
var p pgLock
if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil {
return nil, err
}
pgLocks = append(pgLocks, p)
}
if err := rows.Err(); err != nil {
return nil, err
}
return pgLocks, nil
}

23
lock/session_locker.go Normal file
View File

@ -0,0 +1,23 @@
// Package lock defines the Locker interface and implements the locking logic.
package lock
import (
"context"
"database/sql"
"errors"
)
var (
// ErrLockNotImplemented is returned when the database does not support locking.
ErrLockNotImplemented = errors.New("lock not implemented")
// ErrUnlockNotImplemented is returned when the database does not support unlocking.
ErrUnlockNotImplemented = errors.New("unlock not implemented")
)
// SessionLocker is the interface to lock and unlock the database for the duration of a session. The
// session is defined as the duration of a single connection and both methods must be called on the
// same connection.
type SessionLocker interface {
SessionLock(ctx context.Context, conn *sql.Conn) error
SessionUnlock(ctx context.Context, conn *sql.Conn) error
}

View File

@ -0,0 +1,63 @@
package lock
import (
"time"
)
const (
// DefaultLockID is the id used to lock the database for migrations. It is a crc64 hash of the
// string "goose". This is used to ensure that the lock is unique to goose.
//
// crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA))
DefaultLockID int64 = 5887940537704921958
// Default values for the lock (time to wait for the lock to be acquired) and unlock (time to
// wait for the lock to be released) wait durations.
DefaultLockTimeout time.Duration = 60 * time.Minute
DefaultUnlockTimeout time.Duration = 1 * time.Minute
)
// SessionLockerOption is used to configure a SessionLocker.
type SessionLockerOption interface {
apply(*sessionLockerConfig) error
}
// WithLockID sets the lock ID to use when locking the database.
//
// If WithLockID is not called, the DefaultLockID is used.
func WithLockID(lockID int64) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.lockID = lockID
return nil
})
}
// WithLockTimeout sets the max duration to wait for the lock to be acquired.
func WithLockTimeout(duration time.Duration) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.lockTimeout = duration
return nil
})
}
// WithUnlockTimeout sets the max duration to wait for the lock to be released.
func WithUnlockTimeout(duration time.Duration) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.unlockTimeout = duration
return nil
})
}
type sessionLockerConfig struct {
lockID int64
lockTimeout time.Duration
unlockTimeout time.Duration
}
var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)
type sessionLockerConfigFunc func(*sessionLockerConfig) error
func (f sessionLockerConfigFunc) apply(cfg *sessionLockerConfig) error {
return f(cfg)
}

View File

@ -3,6 +3,8 @@ package goose
import (
"errors"
"fmt"
"github.com/pressly/goose/v3/lock"
)
const (
@ -38,13 +40,36 @@ func WithVerbose() ProviderOption {
})
}
// WithSessionLocker enables locking using the provided SessionLocker.
//
// If WithSessionLocker is not called, locking is disabled.
func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
return configFunc(func(c *config) error {
if c.lockEnabled {
return errors.New("lock already enabled")
}
if c.sessionLocker != nil {
return errors.New("session locker already set")
}
if locker == nil {
return errors.New("session locker must not be nil")
}
c.lockEnabled = true
c.sessionLocker = locker
return nil
})
}
type config struct {
tableName string
verbose bool
lockEnabled bool
sessionLocker lock.SessionLocker
}
type configFunc func(*config) error
func (o configFunc) apply(cfg *config) error {
return o(cfg)
func (f configFunc) apply(cfg *config) error {
return f(cfg)
}