Added pgx/migrate

This commit is contained in:
Jack Christensen 2013-08-09 16:40:04 -05:00
parent f079d84728
commit 97d933f055
4 changed files with 364 additions and 0 deletions

View File

@ -0,0 +1,7 @@
package migrate_test
import (
"github.com/JackC/pgx"
)
var defaultConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}

44
migrate/helper_test.go Normal file
View File

@ -0,0 +1,44 @@
package migrate_test
import (
"github.com/JackC/pgx"
"github.com/JackC/pgx/migrate"
)
type test interface {
Fatalf(format string, args ...interface{})
}
func mustConnect(t test, connectionParameters *pgx.ConnectionParameters) (conn *pgx.Connection) {
var err error
conn, err = pgx.Connect(*connectionParameters)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
return
}
func mustCreateMigrator(t test, conn *pgx.Connection) (m *migrate.Migrator) {
var err error
m, err = migrate.NewMigrator(conn, versionTable)
if err != nil {
t.Fatalf("Unable to create migrator: %v", err)
}
return
}
func mustExecute(t test, conn *pgx.Connection, sql string, arguments ...interface{}) (commandTag string) {
var err error
if commandTag, err = conn.Execute(sql, arguments...); err != nil {
t.Fatalf("Execute unexpectedly failed with %v: %v", sql, err)
}
return
}
func mustSelectValue(t test, conn *pgx.Connection, sql string, arguments ...interface{}) (value interface{}) {
var err error
if value, err = conn.SelectValue(sql, arguments...); err != nil {
t.Fatalf("SelectValue unexpectedly failed with %v: %v", sql, err)
}
return
}

130
migrate/migrate.go Normal file
View File

@ -0,0 +1,130 @@
package migrate
import (
"fmt"
"github.com/JackC/pgx"
)
type BadVersionError string
func (e BadVersionError) Error() string {
return string(e)
}
type Migration struct {
Sequence int32
Name string
SQL string
}
type Migrator struct {
conn *pgx.Connection
versionTable string
Migrations []*Migration
OnStart func(*Migration) `called when Migrate starts a migration`
}
func NewMigrator(conn *pgx.Connection, versionTable string) (m *Migrator, err error) {
m = &Migrator{conn: conn, versionTable: versionTable}
err = m.ensureSchemaVersionTableExists()
m.Migrations = make([]*Migration, 0)
return
}
func (m *Migrator) AppendMigration(name, sql string) {
m.Migrations = append(m.Migrations, &Migration{Sequence: int32(len(m.Migrations)), Name: name, SQL: sql})
return
}
// Migrate runs pending migrations
// It calls m.OnStart when it begins a migration
func (m *Migrator) Migrate() error {
var done bool
for !done {
var innerErr error
var txErr error
_, txErr = m.conn.Transaction(func() bool {
// Lock version table for duration of transaction to ensure multiple migrations cannot occur simultaneously
if _, innerErr = m.conn.Execute("lock table " + m.versionTable); innerErr != nil {
return false
}
// Get pending migrations
var pending []*Migration
if pending, innerErr = m.PendingMigrations(); innerErr != nil {
return false
}
// If no migrations are pending set the done flag and return
if len(pending) == 0 {
done = true
return true
}
// Fire on start callback
if m.OnStart != nil {
m.OnStart(pending[0])
}
// Execute the first pending migration
if _, innerErr = m.conn.Execute(pending[0].SQL); innerErr != nil {
return false
}
// Add one to the version
if _, innerErr = m.conn.Execute("update " + m.versionTable + " set version=version+1"); innerErr != nil {
return false
}
// A migration was completed successfully, return true to commit the transaction
return true
})
if txErr != nil {
return txErr
}
if innerErr != nil {
return innerErr
}
}
return nil
}
func (m *Migrator) PendingMigrations() ([]*Migration, error) {
if len(m.Migrations) == 0 {
return m.Migrations, nil
}
if current, err := m.GetCurrentVersion(); err == nil {
current := int(current)
if current < 0 || len(m.Migrations) < current {
errMsg := fmt.Sprintf("%s version %d is outside the known migrations of 0 to %d", m.versionTable, current, len(m.Migrations))
return nil, BadVersionError(errMsg)
}
return m.Migrations[current:len(m.Migrations)], nil
} else {
return nil, err
}
}
func (m *Migrator) GetCurrentVersion() (int32, error) {
if v, err := m.conn.SelectValue("select version from " + m.versionTable); err == nil {
return v.(int32), nil
} else {
return 0, err
}
}
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
_, err = m.conn.Execute(`
create table if not exists schema_version(version int4 not null);
insert into schema_version(version)
select 0
where 0=(select count(*) from schema_version);
`)
return
}

183
migrate/migrate_test.go Normal file
View File

@ -0,0 +1,183 @@
package migrate_test
import (
"fmt"
"github.com/JackC/pgx"
"github.com/JackC/pgx/migrate"
"testing"
)
var versionTable string = "schema_version"
func clearMigrate(t *testing.T, conn *pgx.Connection) {
tables := []string{versionTable, "t", "t1", "t2"}
for _, table := range tables {
mustExecute(t, conn, "drop table if exists "+table)
}
}
func TestNewMigrator(t *testing.T) {
conn := mustConnect(t, defaultConnectionParameters)
clearMigrate(t, conn)
var m *migrate.Migrator
var err error
m, err = migrate.NewMigrator(conn, versionTable)
if err != nil {
t.Fatalf("Unable to create migrator: %v", err)
}
schemaVersionExists := mustSelectValue(t,
conn,
"select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)",
defaultConnectionParameters.Database,
versionTable).(bool)
if !schemaVersionExists {
t.Fatalf("NewMigrator did not create %v table", versionTable)
}
m, err = migrate.NewMigrator(conn, versionTable)
if err != nil {
t.Fatalf("NewMigrator failed when %v table already exists: %v", versionTable, err)
}
var initialVersion int32
initialVersion, err = m.GetCurrentVersion()
if err != nil {
t.Fatalf("Failed to get current version: %v", err)
}
if initialVersion != 0 {
t.Fatalf("Expected initial version to be 0. but it was %v", initialVersion)
}
}
func TestAppendMigration(t *testing.T) {
conn := mustConnect(t, defaultConnectionParameters)
clearMigrate(t, conn)
m := mustCreateMigrator(t, conn)
name := "Update t"
sql := "update t set c=1"
m.AppendMigration(name, sql)
if len(m.Migrations) != 1 {
t.Fatal("Expected AppendMigration to add a migration but it didn't")
}
if m.Migrations[0].Name != name {
t.Fatalf("expected first migration Name to be %v, but it was %v", name, m.Migrations[0].Name)
}
if m.Migrations[0].SQL != sql {
t.Fatalf("expected first migration SQL to be %v, but it was %v", sql, m.Migrations[0].SQL)
}
}
func TestPendingMigrations(t *testing.T) {
conn := mustConnect(t, defaultConnectionParameters)
clearMigrate(t, conn)
m := mustCreateMigrator(t, conn)
m.AppendMigration("update t", "update t set c=1")
m.AppendMigration("update z", "update z set c=1")
mustExecute(t, conn, "update "+versionTable+" set version=1")
pending, err := m.PendingMigrations()
if err != nil {
t.Fatalf("Unexpected error while getting pending migrations: %v", err)
}
if len(pending) != 1 {
t.Fatalf("Expected 1 pending migrations but there was %v", len(pending))
}
if pending[0] != m.Migrations[1] {
t.Fatal("Did not include expected migration as pending")
}
// Higher version than we know about
mustExecute(t, conn, "update "+versionTable+" set version=999")
_, err = m.PendingMigrations()
if _, ok := err.(migrate.BadVersionError); !ok {
t.Fatalf("Expected BadVersionError but received: %#v", err)
}
// Lower version than is possible
mustExecute(t, conn, "update "+versionTable+" set version=-1")
_, err = m.PendingMigrations()
if _, ok := err.(migrate.BadVersionError); !ok {
t.Fatalf("Expected BadVersionError but received: %#v", err)
}
}
func TestMigrate(t *testing.T) {
conn := mustConnect(t, defaultConnectionParameters)
clearMigrate(t, conn)
m := mustCreateMigrator(t, conn)
m.AppendMigration("create t", "create table t(name text primary key)")
if err := m.Migrate(); err != nil {
t.Fatalf("Unexpected error running Migrate: %v", err)
}
if pending, err := m.PendingMigrations(); err != nil {
t.Fatalf("Unexpected error while getting pending migrations: %v", err)
} else if len(pending) != 0 {
t.Fatalf("Migrate did not do all migrations: %v pending", len(pending))
}
// Now test the OnStart callback and the Migrate when some are already done
var onStartCallCount int
m.OnStart = func(*migrate.Migration) {
onStartCallCount++
}
m.AppendMigration("create t2", "create table t2(name text primary key)")
if err := m.Migrate(); err != nil {
t.Fatalf("Unexpected error running Migrate: %v", err)
}
if pending, err := m.PendingMigrations(); err != nil {
t.Fatalf("Unexpected error while getting pending migrations: %v", err)
} else if len(pending) != 0 {
t.Fatalf("Migrate did not do all migrations: %v pending", len(pending))
}
if onStartCallCount != 1 {
t.Fatalf("Expected OnStart to be called 1 time, but it was called %v times", onStartCallCount)
}
}
func Example_OnStartMigrationProgressLogging() {
conn, err := pgx.Connect(*defaultConnectionParameters)
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
// Clear any previous runs
if _, err = conn.Execute("drop table if exists schema_version"); err != nil {
fmt.Printf("Unable to drop schema_version table: %v", err)
return
}
var m *migrate.Migrator
m, err = migrate.NewMigrator(conn, "schema_version")
if err != nil {
fmt.Printf("Unable to create migrator: %v", err)
return
}
m.OnStart = func(migration *migrate.Migration) {
fmt.Printf("Executing: %v", migration.Name)
}
m.AppendMigration("create a table", "create temporary table foo(id serial primary key)")
if err = m.Migrate(); err != nil {
fmt.Printf("Unexpected failure migrating: %v", err)
return
}
// Output:
// Executing: create a table
}