mirror of https://github.com/pressly/goose.git
feat(experimental): add internal migrate package and SessionLocker interface (#606)
parent
ccfb885423
commit
c590380f39
1
go.mod
1
go.mod
|
@ -8,6 +8,7 @@ require (
|
|||
github.com/jackc/pgx/v5 v5.4.3
|
||||
github.com/microsoft/go-mssqldb v1.6.0
|
||||
github.com/ory/dockertest/v3 v3.10.0
|
||||
github.com/sethvargo/go-retry v0.2.4
|
||||
github.com/vertica/vertica-sql-go v1.3.3
|
||||
github.com/ziutek/mymysql v1.5.4
|
||||
go.uber.org/multierr v1.11.0
|
||||
|
|
2
go.sum
2
go.sum
|
@ -127,6 +127,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qq
|
|||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec=
|
||||
github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw=
|
||||
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
|
||||
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// Package migrate defines a Migration struct and implements the migration logic for executing Go
|
||||
// and SQL migrations.
|
||||
//
|
||||
// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported.
|
||||
// - For SQL migrations, all three are supported.
|
||||
//
|
||||
// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first
|
||||
// time it is executed.
|
||||
package migrate
|
|
@ -0,0 +1,166 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlextended"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
// Fullpath is the full path to the migration file.
|
||||
//
|
||||
// Example: /path/to/migrations/123_create_users_table.go
|
||||
Fullpath string
|
||||
// Version is the version of the migration.
|
||||
Version int64
|
||||
// Type is the type of migration.
|
||||
Type MigrationType
|
||||
// A migration is either a Go migration or a SQL migration, but never both.
|
||||
//
|
||||
// Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is
|
||||
// an optimization to avoid parsing the SQL migration if it is never required. Also, the
|
||||
// majority of the time migrations are incremental, so it is likely that the user will only want
|
||||
// to run the last few migrations, and there is no need to parse ALL prior migrations.
|
||||
//
|
||||
// Exactly one of these fields will be set:
|
||||
Go *Go
|
||||
// -- or --
|
||||
SQLParsed bool
|
||||
SQL *SQL
|
||||
}
|
||||
|
||||
type MigrationType int
|
||||
|
||||
const (
|
||||
TypeGo MigrationType = iota + 1
|
||||
TypeSQL
|
||||
)
|
||||
|
||||
func (t MigrationType) String() string {
|
||||
switch t {
|
||||
case TypeGo:
|
||||
return "go"
|
||||
case TypeSQL:
|
||||
return "sql"
|
||||
default:
|
||||
// This should never happen.
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migration) UseTx() bool {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
return m.Go.UseTx
|
||||
case TypeSQL:
|
||||
return m.SQL.UseTx
|
||||
default:
|
||||
// This should never happen.
|
||||
panic("unknown migration type: use tx")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migration) IsEmpty(direction bool) bool {
|
||||
switch m.Type {
|
||||
case TypeGo:
|
||||
return m.Go.IsEmpty(direction)
|
||||
case TypeSQL:
|
||||
return m.SQL.IsEmpty(direction)
|
||||
default:
|
||||
// This should never happen.
|
||||
panic("unknown migration type: is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migration) GetSQLStatements(direction bool) ([]string, error) {
|
||||
if m.Type != TypeSQL {
|
||||
return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type)
|
||||
}
|
||||
if m.SQL == nil {
|
||||
return nil, errors.New("sql migration has not been initialized")
|
||||
}
|
||||
if !m.SQLParsed {
|
||||
return nil, errors.New("sql migration has not been parsed")
|
||||
}
|
||||
if direction {
|
||||
return m.SQL.UpStatements, nil
|
||||
}
|
||||
return m.SQL.DownStatements, nil
|
||||
}
|
||||
|
||||
type Go struct {
|
||||
// We used an explicit bool instead of relying on a pointer because registered funcs may be nil.
|
||||
// These are still valid Go and versioned migrations, but they are just empty.
|
||||
//
|
||||
// For example: goose.AddMigration(nil, nil)
|
||||
UseTx bool
|
||||
|
||||
// Only one of these func pairs will be set:
|
||||
UpFn, DownFn func(context.Context, *sql.Tx) error
|
||||
// -- or --
|
||||
UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error
|
||||
}
|
||||
|
||||
func (g *Go) IsEmpty(direction bool) bool {
|
||||
if direction {
|
||||
return g.UpFn == nil && g.UpFnNoTx == nil
|
||||
}
|
||||
return g.DownFn == nil && g.DownFnNoTx == nil
|
||||
}
|
||||
|
||||
func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
var fn func(context.Context, *sql.Tx) error
|
||||
if direction {
|
||||
fn = g.UpFn
|
||||
} else {
|
||||
fn = g.DownFn
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, tx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
var fn func(context.Context, *sql.DB) error
|
||||
if direction {
|
||||
fn = g.UpFnNoTx
|
||||
} else {
|
||||
fn = g.DownFnNoTx
|
||||
}
|
||||
if fn != nil {
|
||||
return fn(ctx, db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SQL struct {
|
||||
UseTx bool
|
||||
UpStatements []string
|
||||
DownStatements []string
|
||||
}
|
||||
|
||||
func (s *SQL) IsEmpty(direction bool) bool {
|
||||
if direction {
|
||||
return len(s.UpStatements) == 0
|
||||
}
|
||||
return len(s.DownStatements) == 0
|
||||
}
|
||||
|
||||
func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
|
||||
var statements []string
|
||||
if direction {
|
||||
statements = s.UpStatements
|
||||
} else {
|
||||
statements = s.DownStatements
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
if _, err := db.ExecContext(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/sqlparser"
|
||||
)
|
||||
|
||||
// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it
|
||||
// will not be parsed again.
|
||||
//
|
||||
// Important: This function will mutate SQL migrations.
|
||||
func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error {
|
||||
for _, m := range migrations {
|
||||
if m.Type == TypeSQL && !m.SQLParsed {
|
||||
parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.SQLParsed = true
|
||||
m.SQL = parsedSQLMigration
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse is used to determine which direction to parse the SQL migration.
|
||||
type parse int
|
||||
|
||||
const (
|
||||
// parseAll parses all SQL statements in BOTH directions.
|
||||
parseAll parse = iota + 1
|
||||
// parseUp parses all SQL statements in the UP direction.
|
||||
parseUp
|
||||
// parseDown parses all SQL statements in the DOWN direction.
|
||||
parseDown
|
||||
)
|
||||
|
||||
func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) {
|
||||
r, err := fsys.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
by, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := new(SQL)
|
||||
if p == parseAll || p == parseUp {
|
||||
s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
|
||||
bytes.NewReader(by),
|
||||
sqlparser.DirectionUp,
|
||||
debug,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if p == parseAll || p == parseDown {
|
||||
s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration(
|
||||
bytes.NewReader(by),
|
||||
sqlparser.DirectionDown,
|
||||
debug,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Run runs the migration inside of a transaction.
|
||||
func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error {
|
||||
switch m.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil || !m.SQLParsed {
|
||||
return fmt.Errorf("tx: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, tx, direction)
|
||||
case TypeGo:
|
||||
return m.Go.run(ctx, tx, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
|
||||
}
|
||||
|
||||
// RunNoTx runs the migration without a transaction.
|
||||
func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error {
|
||||
switch m.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil || !m.SQLParsed {
|
||||
return fmt.Errorf("db: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, db, direction)
|
||||
case TypeGo:
|
||||
return m.Go.runNoTx(ctx, db, direction)
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
|
||||
}
|
||||
|
||||
// RunConn runs the migration without a transaction using the provided connection.
|
||||
func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error {
|
||||
switch m.Type {
|
||||
case TypeSQL:
|
||||
if m.SQL == nil || !m.SQLParsed {
|
||||
return fmt.Errorf("conn: sql migration has not been parsed")
|
||||
}
|
||||
return m.SQL.run(ctx, conn, direction)
|
||||
case TypeGo:
|
||||
return fmt.Errorf("conn: go migrations are not supported with *sql.Conn")
|
||||
}
|
||||
// This should never happen.
|
||||
return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath))
|
||||
}
|
|
@ -11,7 +11,7 @@ import (
|
|||
// There is a long outstanding issue to formalize a std lib interface, but alas... See:
|
||||
// https://github.com/golang/go/issues/14468
|
||||
type DBTxConn interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
)
|
||||
|
||||
// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive
|
||||
// session-level advisory lock mechanism.
|
||||
//
|
||||
// This function creates a SessionLocker that can be used to acquire and release locks for
|
||||
// synchronization purposes. The lock acquisition is retried until it is successfully acquired or
|
||||
// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the
|
||||
// default unlock duration is set to 1 minute.
|
||||
//
|
||||
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
|
||||
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
|
||||
cfg := sessionLockerConfig{
|
||||
lockID: DefaultLockID,
|
||||
lockTimeout: DefaultLockTimeout,
|
||||
unlockTimeout: DefaultUnlockTimeout,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &postgresSessionLocker{
|
||||
lockID: cfg.lockID,
|
||||
retryLock: retry.WithMaxDuration(
|
||||
cfg.lockTimeout,
|
||||
retry.NewConstant(2*time.Second),
|
||||
),
|
||||
retryUnlock: retry.WithMaxDuration(
|
||||
cfg.unlockTimeout,
|
||||
retry.NewConstant(2*time.Second),
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type postgresSessionLocker struct {
|
||||
lockID int64
|
||||
retryLock retry.Backoff
|
||||
retryUnlock retry.Backoff
|
||||
}
|
||||
|
||||
var _ SessionLocker = (*postgresSessionLocker)(nil)
|
||||
|
||||
func (l *postgresSessionLocker) SessionLock(ctx context.Context, conn *sql.Conn) error {
|
||||
return retry.Do(ctx, l.retryLock, func(ctx context.Context) error {
|
||||
row := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", l.lockID)
|
||||
var locked bool
|
||||
if err := row.Scan(&locked); err != nil {
|
||||
return fmt.Errorf("failed to execute pg_try_advisory_lock: %w", err)
|
||||
}
|
||||
if locked {
|
||||
// A session-level advisory lock was acquired.
|
||||
return nil
|
||||
}
|
||||
// A session-level advisory lock could not be acquired. This is likely because another
|
||||
// process has already acquired the lock. We will continue retrying until the lock is
|
||||
// acquired or the maximum number of retries is reached.
|
||||
return retry.RetryableError(errors.New("failed to acquire lock"))
|
||||
})
|
||||
}
|
||||
|
||||
func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Conn) error {
|
||||
return retry.Do(ctx, l.retryUnlock, func(ctx context.Context) error {
|
||||
var unlocked bool
|
||||
row := conn.QueryRowContext(ctx, "SELECT pg_advisory_unlock($1)", l.lockID)
|
||||
if err := row.Scan(&unlocked); err != nil {
|
||||
return fmt.Errorf("failed to execute pg_advisory_unlock: %w", err)
|
||||
}
|
||||
if unlocked {
|
||||
// A session-level advisory lock was released.
|
||||
return nil
|
||||
}
|
||||
/*
|
||||
TODO(mf): provide users with some documentation on how they can unlock the session
|
||||
manually.
|
||||
|
||||
This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will
|
||||
release all session level advisory locks held by the current session. This function is
|
||||
implicitly invoked at session end, even if the client disconnects ungracefully.
|
||||
|
||||
Here is output from a session that has a lock held:
|
||||
|
||||
SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks
|
||||
WHERE locktype='advisory';
|
||||
|
||||
| pid | granted | goose_lock_id |
|
||||
|-----|---------|---------------------|
|
||||
| 191 | t | 5887940537704921958 |
|
||||
|
||||
A forceful way to unlock the session is to terminate the backend with SIGTERM:
|
||||
|
||||
SELECT pg_terminate_backend(191);
|
||||
|
||||
Subsequent commands on the same connection will fail with:
|
||||
|
||||
Query 1 ERROR: FATAL: terminating connection due to administrator command
|
||||
*/
|
||||
return retry.RetryableError(errors.New("failed to unlock session"))
|
||||
})
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package lock_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3/internal/check"
|
||||
"github.com/pressly/goose/v3/internal/testdb"
|
||||
"github.com/pressly/goose/v3/lock"
|
||||
)
|
||||
|
||||
func TestPostgresSessionLocker(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip long running test")
|
||||
}
|
||||
db, cleanup, err := testdb.NewPostgres()
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
const (
|
||||
lockID int64 = 123456789
|
||||
)
|
||||
|
||||
// Do not run tests in parallel, because they are using the same database.
|
||||
|
||||
t.Run("lock_and_unlock", func(t *testing.T) {
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockID(lockID),
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
ctx := context.Background()
|
||||
conn, err := db.Conn(ctx)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
check.NoError(t, conn.Close())
|
||||
})
|
||||
err = locker.SessionLock(ctx, conn)
|
||||
check.NoError(t, err)
|
||||
pgLocks, err := queryPgLocks(ctx, db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(pgLocks), 1)
|
||||
// Check that the lock was acquired.
|
||||
check.Bool(t, pgLocks[0].granted, true)
|
||||
// Check that the custom lock ID is the same as the one used by the locker.
|
||||
check.Equal(t, pgLocks[0].gooseLockID, lockID)
|
||||
check.NumberNotZero(t, pgLocks[0].pid)
|
||||
|
||||
// Check that the lock is released.
|
||||
err = locker.SessionUnlock(ctx, conn)
|
||||
check.NoError(t, err)
|
||||
pgLocks, err = queryPgLocks(ctx, db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(pgLocks), 0)
|
||||
})
|
||||
t.Run("lock_close_conn_unlock", func(t *testing.T) {
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
ctx := context.Background()
|
||||
conn, err := db.Conn(ctx)
|
||||
check.NoError(t, err)
|
||||
|
||||
err = locker.SessionLock(ctx, conn)
|
||||
check.NoError(t, err)
|
||||
pgLocks, err := queryPgLocks(ctx, db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(pgLocks), 1)
|
||||
check.Bool(t, pgLocks[0].granted, true)
|
||||
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
|
||||
// Simulate a connection close.
|
||||
err = conn.Close()
|
||||
check.NoError(t, err)
|
||||
// Check an error is returned when unlocking, because the connection is already closed.
|
||||
err = locker.SessionUnlock(ctx, conn)
|
||||
check.HasError(t, err)
|
||||
check.Bool(t, errors.Is(err, sql.ErrConnDone), true)
|
||||
})
|
||||
t.Run("multiple_connections", func(t *testing.T) {
|
||||
const (
|
||||
workers = 5
|
||||
)
|
||||
ch := make(chan error)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
conn, err := db.Conn(ctx)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
check.NoError(t, conn.Close())
|
||||
})
|
||||
// Exactly one connection should acquire the lock. While the other connections
|
||||
// should fail to acquire the lock and timeout.
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
ch <- locker.SessionLock(ctx, conn)
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
var errors []error
|
||||
for err := range ch {
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail.
|
||||
for _, err := range errors {
|
||||
check.HasError(t, err)
|
||||
check.Equal(t, err.Error(), "failed to acquire lock")
|
||||
}
|
||||
pgLocks, err := queryPgLocks(context.Background(), db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(pgLocks), 1)
|
||||
check.Bool(t, pgLocks[0].granted, true)
|
||||
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
|
||||
})
|
||||
t.Run("unlock_with_different_connection", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const (
|
||||
lockID int64 = 999
|
||||
)
|
||||
locker, err := lock.NewPostgresSessionLocker(
|
||||
lock.WithLockID(lockID),
|
||||
lock.WithLockTimeout(4*time.Second),
|
||||
lock.WithUnlockTimeout(4*time.Second),
|
||||
)
|
||||
check.NoError(t, err)
|
||||
|
||||
conn1, err := db.Conn(ctx)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
check.NoError(t, conn1.Close())
|
||||
})
|
||||
err = locker.SessionLock(ctx, conn1)
|
||||
check.NoError(t, err)
|
||||
pgLocks, err := queryPgLocks(ctx, db)
|
||||
check.NoError(t, err)
|
||||
check.Number(t, len(pgLocks), 1)
|
||||
check.Bool(t, pgLocks[0].granted, true)
|
||||
check.Equal(t, pgLocks[0].gooseLockID, lockID)
|
||||
// Unlock with a different connection.
|
||||
conn2, err := db.Conn(ctx)
|
||||
check.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
check.NoError(t, conn2.Close())
|
||||
})
|
||||
// Check an error is returned when unlocking with a different connection.
|
||||
err = locker.SessionUnlock(ctx, conn2)
|
||||
check.HasError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
type pgLock struct {
|
||||
pid int
|
||||
granted bool
|
||||
gooseLockID int64
|
||||
}
|
||||
|
||||
func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
|
||||
q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks WHERE locktype='advisory'`
|
||||
rows, err := db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pgLocks []pgLock
|
||||
for rows.Next() {
|
||||
var p pgLock
|
||||
if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pgLocks = append(pgLocks, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pgLocks, nil
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Package lock defines the Locker interface and implements the locking logic.
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrLockNotImplemented is returned when the database does not support locking.
|
||||
ErrLockNotImplemented = errors.New("lock not implemented")
|
||||
// ErrUnlockNotImplemented is returned when the database does not support unlocking.
|
||||
ErrUnlockNotImplemented = errors.New("unlock not implemented")
|
||||
)
|
||||
|
||||
// SessionLocker is the interface to lock and unlock the database for the duration of a session. The
|
||||
// session is defined as the duration of a single connection and both methods must be called on the
|
||||
// same connection.
|
||||
type SessionLocker interface {
|
||||
SessionLock(ctx context.Context, conn *sql.Conn) error
|
||||
SessionUnlock(ctx context.Context, conn *sql.Conn) error
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package lock
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLockID is the id used to lock the database for migrations. It is a crc64 hash of the
|
||||
// string "goose". This is used to ensure that the lock is unique to goose.
|
||||
//
|
||||
// crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA))
|
||||
DefaultLockID int64 = 5887940537704921958
|
||||
|
||||
// Default values for the lock (time to wait for the lock to be acquired) and unlock (time to
|
||||
// wait for the lock to be released) wait durations.
|
||||
DefaultLockTimeout time.Duration = 60 * time.Minute
|
||||
DefaultUnlockTimeout time.Duration = 1 * time.Minute
|
||||
)
|
||||
|
||||
// SessionLockerOption is used to configure a SessionLocker.
|
||||
type SessionLockerOption interface {
|
||||
apply(*sessionLockerConfig) error
|
||||
}
|
||||
|
||||
// WithLockID sets the lock ID to use when locking the database.
|
||||
//
|
||||
// If WithLockID is not called, the DefaultLockID is used.
|
||||
func WithLockID(lockID int64) SessionLockerOption {
|
||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||
c.lockID = lockID
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithLockTimeout sets the max duration to wait for the lock to be acquired.
|
||||
func WithLockTimeout(duration time.Duration) SessionLockerOption {
|
||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||
c.lockTimeout = duration
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithUnlockTimeout sets the max duration to wait for the lock to be released.
|
||||
func WithUnlockTimeout(duration time.Duration) SessionLockerOption {
|
||||
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
|
||||
c.unlockTimeout = duration
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type sessionLockerConfig struct {
|
||||
lockID int64
|
||||
lockTimeout time.Duration
|
||||
unlockTimeout time.Duration
|
||||
}
|
||||
|
||||
var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)
|
||||
|
||||
type sessionLockerConfigFunc func(*sessionLockerConfig) error
|
||||
|
||||
func (f sessionLockerConfigFunc) apply(cfg *sessionLockerConfig) error {
|
||||
return f(cfg)
|
||||
}
|
|
@ -3,6 +3,8 @@ package goose
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pressly/goose/v3/lock"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -38,13 +40,36 @@ func WithVerbose() ProviderOption {
|
|||
})
|
||||
}
|
||||
|
||||
// WithSessionLocker enables locking using the provided SessionLocker.
|
||||
//
|
||||
// If WithSessionLocker is not called, locking is disabled.
|
||||
func WithSessionLocker(locker lock.SessionLocker) ProviderOption {
|
||||
return configFunc(func(c *config) error {
|
||||
if c.lockEnabled {
|
||||
return errors.New("lock already enabled")
|
||||
}
|
||||
if c.sessionLocker != nil {
|
||||
return errors.New("session locker already set")
|
||||
}
|
||||
if locker == nil {
|
||||
return errors.New("session locker must not be nil")
|
||||
}
|
||||
c.lockEnabled = true
|
||||
c.sessionLocker = locker
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type config struct {
|
||||
tableName string
|
||||
verbose bool
|
||||
|
||||
lockEnabled bool
|
||||
sessionLocker lock.SessionLocker
|
||||
}
|
||||
|
||||
type configFunc func(*config) error
|
||||
|
||||
func (o configFunc) apply(cfg *config) error {
|
||||
return o(cfg)
|
||||
func (f configFunc) apply(cfg *config) error {
|
||||
return f(cfg)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue