mirror of https://github.com/pressly/goose.git
195 lines
5.2 KiB
Go
195 lines
5.2 KiB
Go
package lock_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pressly/goose/v3/internal/check"
|
|
"github.com/pressly/goose/v3/internal/testdb"
|
|
"github.com/pressly/goose/v3/lock"
|
|
)
|
|
|
|
func TestPostgresSessionLocker(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.Skip("skip long running test")
|
|
}
|
|
db, cleanup, err := testdb.NewPostgres()
|
|
check.NoError(t, err)
|
|
t.Cleanup(cleanup)
|
|
|
|
// Do not run tests in parallel, because they are using the same database.
|
|
|
|
t.Run("lock_and_unlock", func(t *testing.T) {
|
|
const (
|
|
lockID int64 = 123456789
|
|
)
|
|
locker, err := lock.NewPostgresSessionLocker(
|
|
lock.WithLockID(lockID),
|
|
lock.WithLockTimeout(4*time.Second),
|
|
lock.WithUnlockTimeout(4*time.Second),
|
|
)
|
|
check.NoError(t, err)
|
|
ctx := context.Background()
|
|
conn, err := db.Conn(ctx)
|
|
check.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
check.NoError(t, conn.Close())
|
|
})
|
|
err = locker.SessionLock(ctx, conn)
|
|
check.NoError(t, err)
|
|
pgLocks, err := queryPgLocks(ctx, db)
|
|
check.NoError(t, err)
|
|
check.Number(t, len(pgLocks), 1)
|
|
// Check that the lock was acquired.
|
|
check.Bool(t, pgLocks[0].granted, true)
|
|
// Check that the custom lock ID is the same as the one used by the locker.
|
|
check.Equal(t, pgLocks[0].gooseLockID, lockID)
|
|
check.NumberNotZero(t, pgLocks[0].pid)
|
|
|
|
// Check that the lock is released.
|
|
err = locker.SessionUnlock(ctx, conn)
|
|
check.NoError(t, err)
|
|
pgLocks, err = queryPgLocks(ctx, db)
|
|
check.NoError(t, err)
|
|
check.Number(t, len(pgLocks), 0)
|
|
})
|
|
t.Run("lock_close_conn_unlock", func(t *testing.T) {
|
|
locker, err := lock.NewPostgresSessionLocker(
|
|
lock.WithLockTimeout(4*time.Second),
|
|
lock.WithUnlockTimeout(4*time.Second),
|
|
)
|
|
check.NoError(t, err)
|
|
ctx := context.Background()
|
|
conn, err := db.Conn(ctx)
|
|
check.NoError(t, err)
|
|
|
|
err = locker.SessionLock(ctx, conn)
|
|
check.NoError(t, err)
|
|
pgLocks, err := queryPgLocks(ctx, db)
|
|
check.NoError(t, err)
|
|
check.Number(t, len(pgLocks), 1)
|
|
check.Bool(t, pgLocks[0].granted, true)
|
|
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
|
|
// Simulate a connection close.
|
|
err = conn.Close()
|
|
check.NoError(t, err)
|
|
// Check an error is returned when unlocking, because the connection is already closed.
|
|
err = locker.SessionUnlock(ctx, conn)
|
|
check.HasError(t, err)
|
|
check.Bool(t, errors.Is(err, sql.ErrConnDone), true)
|
|
})
|
|
t.Run("multiple_connections", func(t *testing.T) {
|
|
const (
|
|
workers = 5
|
|
)
|
|
ch := make(chan error)
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < workers; i++ {
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
ctx := context.Background()
|
|
conn, err := db.Conn(ctx)
|
|
check.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
check.NoError(t, conn.Close())
|
|
})
|
|
// Exactly one connection should acquire the lock. While the other connections
|
|
// should fail to acquire the lock and timeout.
|
|
locker, err := lock.NewPostgresSessionLocker(
|
|
lock.WithLockTimeout(4*time.Second),
|
|
lock.WithUnlockTimeout(4*time.Second),
|
|
)
|
|
check.NoError(t, err)
|
|
ch <- locker.SessionLock(ctx, conn)
|
|
}()
|
|
}
|
|
go func() {
|
|
wg.Wait()
|
|
close(ch)
|
|
}()
|
|
var errors []error
|
|
for err := range ch {
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail.
|
|
for _, err := range errors {
|
|
check.HasError(t, err)
|
|
check.Equal(t, err.Error(), "failed to acquire lock")
|
|
}
|
|
pgLocks, err := queryPgLocks(context.Background(), db)
|
|
check.NoError(t, err)
|
|
check.Number(t, len(pgLocks), 1)
|
|
check.Bool(t, pgLocks[0].granted, true)
|
|
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
|
|
})
|
|
t.Run("unlock_with_different_connection", func(t *testing.T) {
|
|
ctx := context.Background()
|
|
const (
|
|
lockID int64 = 999
|
|
)
|
|
locker, err := lock.NewPostgresSessionLocker(
|
|
lock.WithLockID(lockID),
|
|
lock.WithLockTimeout(4*time.Second),
|
|
lock.WithUnlockTimeout(4*time.Second),
|
|
)
|
|
check.NoError(t, err)
|
|
|
|
conn1, err := db.Conn(ctx)
|
|
check.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
check.NoError(t, conn1.Close())
|
|
})
|
|
err = locker.SessionLock(ctx, conn1)
|
|
check.NoError(t, err)
|
|
pgLocks, err := queryPgLocks(ctx, db)
|
|
check.NoError(t, err)
|
|
check.Number(t, len(pgLocks), 1)
|
|
check.Bool(t, pgLocks[0].granted, true)
|
|
check.Equal(t, pgLocks[0].gooseLockID, lockID)
|
|
// Unlock with a different connection.
|
|
conn2, err := db.Conn(ctx)
|
|
check.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
check.NoError(t, conn2.Close())
|
|
})
|
|
// Check an error is returned when unlocking with a different connection.
|
|
err = locker.SessionUnlock(ctx, conn2)
|
|
check.HasError(t, err)
|
|
})
|
|
}
|
|
|
|
type pgLock struct {
|
|
pid int
|
|
granted bool
|
|
gooseLockID int64
|
|
}
|
|
|
|
func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
|
|
q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks WHERE locktype='advisory'`
|
|
rows, err := db.QueryContext(ctx, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var pgLocks []pgLock
|
|
for rows.Next() {
|
|
var p pgLock
|
|
if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil {
|
|
return nil, err
|
|
}
|
|
pgLocks = append(pgLocks, p)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return pgLocks, nil
|
|
}
|