mirror of https://github.com/jackc/pgx.git
parent
84cc10595c
commit
97f9fe2209
|
@ -1,7 +0,0 @@
|
|||
package migrate_test
|
||||
|
||||
import (
|
||||
"github.com/JackC/pgx"
|
||||
)
|
||||
|
||||
var defaultConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
|
@ -1,194 +0,0 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/JackC/pgx"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type BadVersionError string
|
||||
|
||||
func (e BadVersionError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type IrreversibleMigrationError struct {
|
||||
m *Migration
|
||||
}
|
||||
|
||||
func (e IrreversibleMigrationError) Error() string {
|
||||
return fmt.Sprintf("Irreversible migration: %d - %s", e.m.Sequence, e.m.Name)
|
||||
}
|
||||
|
||||
type NoMigrationsFoundError struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func (e NoMigrationsFoundError) Error() string {
|
||||
return fmt.Sprintf("No migrations found at %s", e.Path)
|
||||
}
|
||||
|
||||
type Migration struct {
|
||||
Sequence int32
|
||||
Name string
|
||||
UpSQL string
|
||||
DownSQL string
|
||||
}
|
||||
|
||||
type Migrator struct {
|
||||
conn *pgx.Connection
|
||||
versionTable string
|
||||
Migrations []*Migration
|
||||
OnStart func(*Migration, string) // OnStart is called when a migration is run with the migration and direction
|
||||
}
|
||||
|
||||
func NewMigrator(conn *pgx.Connection, versionTable string) (m *Migrator, err error) {
|
||||
m = &Migrator{conn: conn, versionTable: versionTable}
|
||||
err = m.ensureSchemaVersionTableExists()
|
||||
m.Migrations = make([]*Migration, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Migrator) LoadMigrations(path string) error {
|
||||
paths, err := filepath.Glob(filepath.Join(path, "*.sql"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(paths) == 0 {
|
||||
return NoMigrationsFoundError{Path: path}
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
body, err := ioutil.ReadFile(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pieces := strings.SplitN(string(body), "---- create above / drop below ----", 2)
|
||||
var upSQL, downSQL string
|
||||
upSQL = strings.TrimSpace(pieces[0])
|
||||
if len(pieces) == 2 {
|
||||
downSQL = strings.TrimSpace(pieces[1])
|
||||
}
|
||||
m.AppendMigration(filepath.Base(p), upSQL, downSQL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) AppendMigration(name, upSQL, downSQL string) {
|
||||
m.Migrations = append(m.Migrations, &Migration{Sequence: int32(len(m.Migrations)) + 1, Name: name, UpSQL: upSQL, DownSQL: downSQL})
|
||||
return
|
||||
}
|
||||
|
||||
// Migrate runs pending migrations
|
||||
// It calls m.OnStart when it begins a migration
|
||||
func (m *Migrator) Migrate() error {
|
||||
return m.MigrateTo(int32(len(m.Migrations)))
|
||||
}
|
||||
|
||||
// MigrateTo migrates to targetVersion
|
||||
func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
|
||||
// Lock to ensure multiple migrations cannot occur simultaneously
|
||||
lockNum := int64(9628173550095224) // arbitrary random number
|
||||
if _, lockErr := m.conn.Execute("select pg_advisory_lock($1)", lockNum); lockErr != nil {
|
||||
return lockErr
|
||||
}
|
||||
defer func() {
|
||||
_, unlockErr := m.conn.Execute("select pg_advisory_unlock($1)", lockNum)
|
||||
if err == nil && unlockErr != nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
currentVersion, err := m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if targetVersion < 0 || int32(len(m.Migrations)) < targetVersion {
|
||||
errMsg := fmt.Sprintf("%s version %d is outside the valid versions of 0 to %d", m.versionTable, targetVersion, len(m.Migrations))
|
||||
return BadVersionError(errMsg)
|
||||
}
|
||||
|
||||
var direction int32
|
||||
if currentVersion < targetVersion {
|
||||
direction = 1
|
||||
} else {
|
||||
direction = -1
|
||||
}
|
||||
|
||||
for currentVersion != targetVersion {
|
||||
var current *Migration
|
||||
var sql, directionName string
|
||||
var sequence int32
|
||||
if direction == 1 {
|
||||
current = m.Migrations[currentVersion]
|
||||
sequence = current.Sequence
|
||||
sql = current.UpSQL
|
||||
directionName = "up"
|
||||
} else {
|
||||
current = m.Migrations[currentVersion-1]
|
||||
sequence = current.Sequence - 1
|
||||
sql = current.DownSQL
|
||||
directionName = "down"
|
||||
if current.DownSQL == "" {
|
||||
return IrreversibleMigrationError{m: current}
|
||||
}
|
||||
}
|
||||
|
||||
var innerErr error
|
||||
_, txErr := m.conn.Transaction(func() bool {
|
||||
|
||||
// Fire on start callback
|
||||
if m.OnStart != nil {
|
||||
m.OnStart(current, directionName)
|
||||
}
|
||||
|
||||
// Execute the migration
|
||||
if _, innerErr = m.conn.Execute(sql); innerErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add one to the version
|
||||
if _, innerErr = m.conn.Execute("update "+m.versionTable+" set version=$1", sequence); innerErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// A migration was completed successfully, return true to commit the transaction
|
||||
return true
|
||||
})
|
||||
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
if innerErr != nil {
|
||||
return innerErr
|
||||
}
|
||||
|
||||
currentVersion = currentVersion + direction
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) GetCurrentVersion() (int32, error) {
|
||||
if v, err := m.conn.SelectValue("select version from " + m.versionTable); err == nil {
|
||||
return v.(int32), nil
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
|
||||
_, err = m.conn.Execute(`
|
||||
create table if not exists schema_version(version int4 not null);
|
||||
|
||||
insert into schema_version(version)
|
||||
select 0
|
||||
where 0=(select count(*) from schema_version);
|
||||
`)
|
||||
return
|
||||
}
|
|
@ -1,287 +0,0 @@
|
|||
package migrate_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/JackC/pgx"
|
||||
"github.com/JackC/pgx/migrate"
|
||||
. "gopkg.in/check.v1"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MigrateSuite struct {
|
||||
conn *pgx.Connection
|
||||
}
|
||||
|
||||
func Test(t *testing.T) { TestingT(t) }
|
||||
|
||||
var _ = Suite(&MigrateSuite{})
|
||||
|
||||
var versionTable string = "schema_version"
|
||||
|
||||
func (s *MigrateSuite) SetUpTest(c *C) {
|
||||
var err error
|
||||
s.conn, err = pgx.Connect(*defaultConnectionParameters)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
s.cleanupSampleMigrator(c)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) SelectValue(c *C, sql string, arguments ...interface{}) interface{} {
|
||||
value, err := s.conn.SelectValue(sql, arguments...)
|
||||
c.Assert(err, IsNil)
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) Execute(c *C, sql string, arguments ...interface{}) string {
|
||||
commandTag, err := s.conn.Execute(sql, arguments...)
|
||||
c.Assert(err, IsNil)
|
||||
return commandTag
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) tableExists(c *C, tableName string) bool {
|
||||
return s.SelectValue(c,
|
||||
"select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)",
|
||||
defaultConnectionParameters.Database,
|
||||
tableName).(bool)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) createEmptyMigrator(c *C) *migrate.Migrator {
|
||||
var err error
|
||||
m, err := migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) createSampleMigrator(c *C) *migrate.Migrator {
|
||||
m := s.createEmptyMigrator(c)
|
||||
m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;")
|
||||
m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;")
|
||||
m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;")
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) cleanupSampleMigrator(c *C) {
|
||||
tables := []string{versionTable, "t1", "t2", "t3"}
|
||||
for _, table := range tables {
|
||||
s.Execute(c, "drop table if exists "+table)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestNewMigrator(c *C) {
|
||||
var m *migrate.Migrator
|
||||
var err error
|
||||
|
||||
// Initial run
|
||||
m, err = migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// Creates version table
|
||||
schemaVersionExists := s.tableExists(c, versionTable)
|
||||
c.Assert(schemaVersionExists, Equals, true)
|
||||
|
||||
// Succeeds when version table is already created
|
||||
m, err = migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
initialVersion, err := m.GetCurrentVersion()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(initialVersion, Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestAppendMigration(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
|
||||
name := "Create t"
|
||||
upSQL := "create t..."
|
||||
downSQL := "drop t..."
|
||||
m.AppendMigration(name, upSQL, downSQL)
|
||||
|
||||
c.Assert(len(m.Migrations), Equals, 1)
|
||||
c.Assert(m.Migrations[0].Name, Equals, name)
|
||||
c.Assert(m.Migrations[0].UpSQL, Equals, upSQL)
|
||||
c.Assert(m.Migrations[0].DownSQL, Equals, downSQL)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrationsMissingDirectory(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
err := m.LoadMigrations("testdata/missing")
|
||||
c.Assert(err, ErrorMatches, "No migrations found at testdata/missing")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrationsEmptyDirectory(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
err := m.LoadMigrations("testdata/empty")
|
||||
c.Assert(err, ErrorMatches, "No migrations found at testdata/empty")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrations(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
err := m.LoadMigrations("testdata/sample")
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(m.Migrations, HasLen, 3)
|
||||
|
||||
c.Check(m.Migrations[0].Name, Equals, "001_create_t1.sql")
|
||||
c.Check(m.Migrations[0].UpSQL, Equals, `create table t1(
|
||||
id serial primary key
|
||||
);`)
|
||||
c.Check(m.Migrations[0].DownSQL, Equals, "drop table t1;")
|
||||
|
||||
c.Check(m.Migrations[1].Name, Equals, "002_create_t2.sql")
|
||||
c.Check(m.Migrations[1].UpSQL, Equals, `create table t2(
|
||||
id serial primary key
|
||||
);`)
|
||||
c.Check(m.Migrations[1].DownSQL, Equals, "drop table t2;")
|
||||
|
||||
c.Check(m.Migrations[2].Name, Equals, "003_irreversible.sql")
|
||||
c.Check(m.Migrations[2].UpSQL, Equals, "drop table t2;")
|
||||
c.Check(m.Migrations[2].DownSQL, Equals, "")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrate(c *C) {
|
||||
m := s.createSampleMigrator(c)
|
||||
|
||||
err := m.Migrate()
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion := s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToLifeCycle(c *C) {
|
||||
m := s.createSampleMigrator(c)
|
||||
|
||||
var onStartCallUpCount int
|
||||
var onStartCallDownCount int
|
||||
m.OnStart = func(_ *migrate.Migration, direction string) {
|
||||
switch direction {
|
||||
case "up":
|
||||
onStartCallUpCount++
|
||||
case "down":
|
||||
onStartCallDownCount++
|
||||
default:
|
||||
c.Fatalf("Unexpected direction: %s", direction)
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate from 0 up to 1
|
||||
err := m.MigrateTo(1)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion := s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(1))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
c.Assert(onStartCallUpCount, Equals, 1)
|
||||
c.Assert(onStartCallDownCount, Equals, 0)
|
||||
|
||||
// Migrate from 1 up to 3
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 0)
|
||||
|
||||
// Migrate from 3 to 3 is no-op
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 0)
|
||||
|
||||
// Migrate from 3 down to 1
|
||||
err = m.MigrateTo(1)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(1))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 2)
|
||||
|
||||
// Migrate from 1 down to 0
|
||||
err = m.MigrateTo(0)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(0))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 3)
|
||||
|
||||
// Migrate back up to 3
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.SelectValue(c, "select version from schema_version")
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
c.Assert(onStartCallUpCount, Equals, 6)
|
||||
c.Assert(onStartCallDownCount, Equals, 3)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToBoundaries(c *C) {
|
||||
m := s.createSampleMigrator(c)
|
||||
|
||||
// Migrate to -1 is error
|
||||
err := m.MigrateTo(-1)
|
||||
c.Assert(err, ErrorMatches, "schema_version version -1 is outside the valid versions of 0 to 3")
|
||||
|
||||
// Migrate past end is error
|
||||
err = m.MigrateTo(int32(len(m.Migrations)) + 1)
|
||||
c.Assert(err, ErrorMatches, "schema_version version 4 is outside the valid versions of 0 to 3")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToIrreversible(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
m.AppendMigration("Foo", "drop table if exists t3", "")
|
||||
|
||||
err := m.MigrateTo(1)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
err = m.MigrateTo(0)
|
||||
c.Assert(err, ErrorMatches, "Irreversible migration: 1 - Foo")
|
||||
}
|
||||
|
||||
func Example_OnStartMigrationProgressLogging() {
|
||||
conn, err := pgx.Connect(*defaultConnectionParameters)
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear any previous runs
|
||||
if _, err = conn.Execute("drop table if exists schema_version"); err != nil {
|
||||
fmt.Printf("Unable to drop schema_version table: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var m *migrate.Migrator
|
||||
m, err = migrate.NewMigrator(conn, "schema_version")
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to create migrator: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.OnStart = func(migration *migrate.Migration, direction string) {
|
||||
fmt.Printf("Migrating %s: %s", direction, migration.Name)
|
||||
}
|
||||
|
||||
m.AppendMigration("create a table", "create temporary table foo(id serial primary key)", "")
|
||||
|
||||
if err = m.Migrate(); err != nil {
|
||||
fmt.Printf("Unexpected failure migrating: %v", err)
|
||||
return
|
||||
}
|
||||
// Output:
|
||||
// Migrating up: create a table
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
create table t1(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t1;
|
|
@ -1,7 +0,0 @@
|
|||
create table t2(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t2;
|
|
@ -1 +0,0 @@
|
|||
drop table t2;
|
Loading…
Reference in New Issue