From 97d933f05502bbd648da406d3e7c8f1745c98c3a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 9 Aug 2013 16:40:04 -0500 Subject: [PATCH] Added pgx/migrate --- migrate/connection_settings_test.go.example | 7 + migrate/helper_test.go | 44 +++++ migrate/migrate.go | 130 ++++++++++++++ migrate/migrate_test.go | 183 ++++++++++++++++++++ 4 files changed, 364 insertions(+) create mode 100644 migrate/connection_settings_test.go.example create mode 100644 migrate/helper_test.go create mode 100644 migrate/migrate.go create mode 100644 migrate/migrate_test.go diff --git a/migrate/connection_settings_test.go.example b/migrate/connection_settings_test.go.example new file mode 100644 index 00000000..a9acac72 --- /dev/null +++ b/migrate/connection_settings_test.go.example @@ -0,0 +1,7 @@ +package migrate_test + +import ( + "github.com/JackC/pgx" +) + +var defaultConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/migrate/helper_test.go b/migrate/helper_test.go new file mode 100644 index 00000000..aa7095d9 --- /dev/null +++ b/migrate/helper_test.go @@ -0,0 +1,44 @@ +package migrate_test + +import ( + "github.com/JackC/pgx" + "github.com/JackC/pgx/migrate" +) + +type test interface { + Fatalf(format string, args ...interface{}) +} + +func mustConnect(t test, connectionParameters *pgx.ConnectionParameters) (conn *pgx.Connection) { + var err error + conn, err = pgx.Connect(*connectionParameters) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + return +} + +func mustCreateMigrator(t test, conn *pgx.Connection) (m *migrate.Migrator) { + var err error + m, err = migrate.NewMigrator(conn, versionTable) + if err != nil { + t.Fatalf("Unable to create migrator: %v", err) + } + return +} + +func mustExecute(t test, conn *pgx.Connection, sql string, arguments ...interface{}) (commandTag string) { + var err error + if commandTag, err = conn.Execute(sql, arguments...); err != nil { + t.Fatalf("Execute unexpectedly failed with %v: %v", sql, err) + } + return +} + +func mustSelectValue(t test, conn *pgx.Connection, sql string, arguments ...interface{}) (value interface{}) { + var err error + if value, err = conn.SelectValue(sql, arguments...); err != nil { + t.Fatalf("SelectValue unexpectedly failed with %v: %v", sql, err) + } + return +} diff --git a/migrate/migrate.go b/migrate/migrate.go new file mode 100644 index 00000000..da263154 --- /dev/null +++ b/migrate/migrate.go @@ -0,0 +1,130 @@ +package migrate + +import ( + "fmt" + "github.com/JackC/pgx" +) + +type BadVersionError string + +func (e BadVersionError) Error() string { + return string(e) +} + +type Migration struct { + Sequence int32 + Name string + SQL string +} + +type Migrator struct { + conn *pgx.Connection + versionTable string + Migrations []*Migration + OnStart func(*Migration) `called when Migrate starts a migration` +} + +func NewMigrator(conn *pgx.Connection, versionTable string) (m *Migrator, err error) { + m = &Migrator{conn: conn, versionTable: versionTable} + err = m.ensureSchemaVersionTableExists() + m.Migrations = make([]*Migration, 0) + return +} + +func (m *Migrator) AppendMigration(name, sql string) { + m.Migrations = append(m.Migrations, &Migration{Sequence: int32(len(m.Migrations)), Name: name, SQL: sql}) + return +} + +// Migrate runs pending migrations +// It calls m.OnStart when it begins a migration +func (m *Migrator) Migrate() error { + var done bool + + for !done { + var innerErr error + + var txErr error + _, txErr = m.conn.Transaction(func() bool { + // Lock version table for duration of transaction to ensure multiple migrations cannot occur simultaneously + if _, innerErr = m.conn.Execute("lock table " + m.versionTable); innerErr != nil { + return false + } + + // Get pending migrations + var pending []*Migration + if pending, innerErr = m.PendingMigrations(); innerErr != nil { + return false + } + + // If no migrations are pending set the done flag and return + if len(pending) == 0 { + done = true + return true + } + + // Fire on start callback + if m.OnStart != nil { + m.OnStart(pending[0]) + } + + // Execute the first pending migration + if _, innerErr = m.conn.Execute(pending[0].SQL); innerErr != nil { + return false + } + + // Add one to the version + if _, innerErr = m.conn.Execute("update " + m.versionTable + " set version=version+1"); innerErr != nil { + return false + } + + // A migration was completed successfully, return true to commit the transaction + return true + }) + + if txErr != nil { + return txErr + } + if innerErr != nil { + return innerErr + } + } + + return nil +} + +func (m *Migrator) PendingMigrations() ([]*Migration, error) { + if len(m.Migrations) == 0 { + return m.Migrations, nil + } + + if current, err := m.GetCurrentVersion(); err == nil { + current := int(current) + if current < 0 || len(m.Migrations) < current { + errMsg := fmt.Sprintf("%s version %d is outside the known migrations of 0 to %d", m.versionTable, current, len(m.Migrations)) + return nil, BadVersionError(errMsg) + } + return m.Migrations[current:len(m.Migrations)], nil + } else { + return nil, err + } +} + +func (m *Migrator) GetCurrentVersion() (int32, error) { + if v, err := m.conn.SelectValue("select version from " + m.versionTable); err == nil { + return v.(int32), nil + } else { + return 0, err + } +} + +func (m *Migrator) ensureSchemaVersionTableExists() (err error) { + _, err = m.conn.Execute(` + create table if not exists schema_version(version int4 not null); + + insert into schema_version(version) + select 0 + where 0=(select count(*) from schema_version); + `) + return +} diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go new file mode 100644 index 00000000..dbe689d9 --- /dev/null +++ b/migrate/migrate_test.go @@ -0,0 +1,183 @@ +package migrate_test + +import ( + "fmt" + "github.com/JackC/pgx" + "github.com/JackC/pgx/migrate" + "testing" +) + +var versionTable string = "schema_version" + +func clearMigrate(t *testing.T, conn *pgx.Connection) { + tables := []string{versionTable, "t", "t1", "t2"} + for _, table := range tables { + mustExecute(t, conn, "drop table if exists "+table) + } +} + +func TestNewMigrator(t *testing.T) { + conn := mustConnect(t, defaultConnectionParameters) + clearMigrate(t, conn) + + var m *migrate.Migrator + var err error + m, err = migrate.NewMigrator(conn, versionTable) + if err != nil { + t.Fatalf("Unable to create migrator: %v", err) + } + + schemaVersionExists := mustSelectValue(t, + conn, + "select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)", + defaultConnectionParameters.Database, + versionTable).(bool) + + if !schemaVersionExists { + t.Fatalf("NewMigrator did not create %v table", versionTable) + } + + m, err = migrate.NewMigrator(conn, versionTable) + if err != nil { + t.Fatalf("NewMigrator failed when %v table already exists: %v", versionTable, err) + } + + var initialVersion int32 + initialVersion, err = m.GetCurrentVersion() + if err != nil { + t.Fatalf("Failed to get current version: %v", err) + } + if initialVersion != 0 { + t.Fatalf("Expected initial version to be 0. but it was %v", initialVersion) + } +} + +func TestAppendMigration(t *testing.T) { + conn := mustConnect(t, defaultConnectionParameters) + clearMigrate(t, conn) + m := mustCreateMigrator(t, conn) + + name := "Update t" + sql := "update t set c=1" + m.AppendMigration(name, sql) + + if len(m.Migrations) != 1 { + t.Fatal("Expected AppendMigration to add a migration but it didn't") + } + if m.Migrations[0].Name != name { + t.Fatalf("expected first migration Name to be %v, but it was %v", name, m.Migrations[0].Name) + } + if m.Migrations[0].SQL != sql { + t.Fatalf("expected first migration SQL to be %v, but it was %v", sql, m.Migrations[0].SQL) + } +} + +func TestPendingMigrations(t *testing.T) { + conn := mustConnect(t, defaultConnectionParameters) + clearMigrate(t, conn) + m := mustCreateMigrator(t, conn) + + m.AppendMigration("update t", "update t set c=1") + m.AppendMigration("update z", "update z set c=1") + + mustExecute(t, conn, "update "+versionTable+" set version=1") + + pending, err := m.PendingMigrations() + if err != nil { + t.Fatalf("Unexpected error while getting pending migrations: %v", err) + } + if len(pending) != 1 { + t.Fatalf("Expected 1 pending migrations but there was %v", len(pending)) + } + if pending[0] != m.Migrations[1] { + t.Fatal("Did not include expected migration as pending") + } + + // Higher version than we know about + mustExecute(t, conn, "update "+versionTable+" set version=999") + _, err = m.PendingMigrations() + if _, ok := err.(migrate.BadVersionError); !ok { + t.Fatalf("Expected BadVersionError but received: %#v", err) + } + + // Lower version than is possible + mustExecute(t, conn, "update "+versionTable+" set version=-1") + _, err = m.PendingMigrations() + if _, ok := err.(migrate.BadVersionError); !ok { + t.Fatalf("Expected BadVersionError but received: %#v", err) + } +} + +func TestMigrate(t *testing.T) { + conn := mustConnect(t, defaultConnectionParameters) + clearMigrate(t, conn) + m := mustCreateMigrator(t, conn) + + m.AppendMigration("create t", "create table t(name text primary key)") + + if err := m.Migrate(); err != nil { + t.Fatalf("Unexpected error running Migrate: %v", err) + } + + if pending, err := m.PendingMigrations(); err != nil { + t.Fatalf("Unexpected error while getting pending migrations: %v", err) + } else if len(pending) != 0 { + t.Fatalf("Migrate did not do all migrations: %v pending", len(pending)) + } + + // Now test the OnStart callback and the Migrate when some are already done + var onStartCallCount int + m.OnStart = func(*migrate.Migration) { + onStartCallCount++ + } + m.AppendMigration("create t2", "create table t2(name text primary key)") + + if err := m.Migrate(); err != nil { + t.Fatalf("Unexpected error running Migrate: %v", err) + } + + if pending, err := m.PendingMigrations(); err != nil { + t.Fatalf("Unexpected error while getting pending migrations: %v", err) + } else if len(pending) != 0 { + t.Fatalf("Migrate did not do all migrations: %v pending", len(pending)) + } + + if onStartCallCount != 1 { + t.Fatalf("Expected OnStart to be called 1 time, but it was called %v times", onStartCallCount) + } + +} + +func Example_OnStartMigrationProgressLogging() { + conn, err := pgx.Connect(*defaultConnectionParameters) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + // Clear any previous runs + if _, err = conn.Execute("drop table if exists schema_version"); err != nil { + fmt.Printf("Unable to drop schema_version table: %v", err) + return + } + + var m *migrate.Migrator + m, err = migrate.NewMigrator(conn, "schema_version") + if err != nil { + fmt.Printf("Unable to create migrator: %v", err) + return + } + + m.OnStart = func(migration *migrate.Migration) { + fmt.Printf("Executing: %v", migration.Name) + } + + m.AppendMigration("create a table", "create temporary table foo(id serial primary key)") + + if err = m.Migrate(); err != nil { + fmt.Printf("Unexpected failure migrating: %v", err) + return + } + // Output: + // Executing: create a table +}