diff --git a/dialect.go b/dialect.go index 8237849..60e4603 100644 --- a/dialect.go +++ b/dialect.go @@ -5,26 +5,28 @@ import ( "fmt" ) -// SqlDialect abstracts the details of specific SQL dialects +// SQLDialect abstracts the details of specific SQL dialects // for goose's few SQL specific statements -type SqlDialect interface { - createVersionTableSql() string // sql string to create the goose_db_version table - insertVersionSql() string // sql string to insert the initial version table row +type SQLDialect interface { + createVersionTableSQL() string // sql string to create the goose_db_version table + insertVersionSQL() string // sql string to insert the initial version table row dbVersionQuery(db *sql.DB) (*sql.Rows, error) } -var dialect SqlDialect = &PostgresDialect{} +var dialect SQLDialect = &PostgresDialect{} -func GetDialect() SqlDialect { +// GetDialect gets the SQLDialect +func GetDialect() SQLDialect { return dialect } +// SetDialect sets the SQLDialect func SetDialect(d string) error { switch d { case "postgres": dialect = &PostgresDialect{} case "mysql": - dialect = &MySqlDialect{} + dialect = &MySQLDialect{} case "sqlite3": dialect = &Sqlite3Dialect{} case "redshift": @@ -40,9 +42,10 @@ func SetDialect(d string) error { // Postgres //////////////////////////// +// PostgresDialect struct. type PostgresDialect struct{} -func (pg PostgresDialect) createVersionTableSql() string { +func (pg PostgresDialect) createVersionTableSQL() string { return `CREATE TABLE goose_db_version ( id serial NOT NULL, version_id bigint NOT NULL, @@ -52,7 +55,7 @@ func (pg PostgresDialect) createVersionTableSql() string { );` } -func (pg PostgresDialect) insertVersionSql() string { +func (pg PostgresDialect) insertVersionSQL() string { return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" } @@ -69,9 +72,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // MySQL //////////////////////////// -type MySqlDialect struct{} +// MySQLDialect struct. +type MySQLDialect struct{} -func (m MySqlDialect) createVersionTableSql() string { +func (m MySQLDialect) createVersionTableSQL() string { return `CREATE TABLE goose_db_version ( id serial NOT NULL, version_id bigint NOT NULL, @@ -81,11 +85,11 @@ func (m MySqlDialect) createVersionTableSql() string { );` } -func (m MySqlDialect) insertVersionSql() string { +func (m MySQLDialect) insertVersionSQL() string { return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);" } -func (m MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { +func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC") if err != nil { return nil, err @@ -98,9 +102,10 @@ func (m MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // sqlite3 //////////////////////////// +// Sqlite3Dialect struct. type Sqlite3Dialect struct{} -func (m Sqlite3Dialect) createVersionTableSql() string { +func (m Sqlite3Dialect) createVersionTableSQL() string { return `CREATE TABLE goose_db_version ( id INTEGER PRIMARY KEY AUTOINCREMENT, version_id INTEGER NOT NULL, @@ -109,7 +114,7 @@ func (m Sqlite3Dialect) createVersionTableSql() string { );` } -func (m Sqlite3Dialect) insertVersionSql() string { +func (m Sqlite3Dialect) insertVersionSQL() string { return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);" } @@ -126,9 +131,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { // Redshift //////////////////////////// +// RedshiftDialect struct. type RedshiftDialect struct{} -func (rs RedshiftDialect) createVersionTableSql() string { +func (rs RedshiftDialect) createVersionTableSQL() string { return `CREATE TABLE goose_db_version ( id integer NOT NULL identity(1, 1), version_id bigint NOT NULL, @@ -138,7 +144,7 @@ func (rs RedshiftDialect) createVersionTableSql() string { );` } -func (rs RedshiftDialect) insertVersionSql() string { +func (rs RedshiftDialect) insertVersionSQL() string { return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" } diff --git a/down.go b/down.go index 4246813..551c302 100644 --- a/down.go +++ b/down.go @@ -5,6 +5,7 @@ import ( "fmt" ) +// Down rolls back a single migration from the current version. func Down(db *sql.DB, dir string) error { currentVersion, err := GetDBVersion(db) if err != nil { @@ -24,6 +25,7 @@ func Down(db *sql.DB, dir string) error { return current.Down(db) } +// DownTo rolls back migrations to a specific version. func DownTo(db *sql.DB, dir string, version int64) error { migrations, err := CollectMigrations(dir, minVersion, maxVersion) if err != nil { @@ -59,6 +61,4 @@ func DownTo(db *sql.DB, dir string, version int64) error { return err } } - - return nil } diff --git a/goose.go b/goose.go index 3573a31..3861e41 100644 --- a/goose.go +++ b/goose.go @@ -3,8 +3,8 @@ package goose import ( "database/sql" "fmt" - "sync" "strconv" + "sync" ) var ( @@ -13,6 +13,7 @@ var ( maxVersion = int64((1 << 63) - 1) ) +// Run runs a goose command. func Run(command string, db *sql.DB, dir string, args ...string) error { switch command { case "up": diff --git a/migrate.go b/migrate.go index 2778faf..c82bdaa 100644 --- a/migrate.go +++ b/migrate.go @@ -11,14 +11,17 @@ import ( ) var ( + // ErrNoCurrentVersion when a current migration version is not found. ErrNoCurrentVersion = errors.New("no current version found") - ErrNoNextVersion = errors.New("no next version found") - + // ErrNoNextVersion when the next migration version is not found. + ErrNoNextVersion = errors.New("no next version found") + // MaxVersion is the maximum allowed version. MaxVersion int64 = 9223372036854775807 // max(int64) goMigrations []*Migration ) +// Migrations slice. type Migrations []*Migration // helpers so we can use pkg sort @@ -31,6 +34,7 @@ func (ms Migrations) Less(i, j int) bool { return ms[i].Version < ms[j].Version } +// Current gets the current migration. func (ms Migrations) Current(current int64) (*Migration, error) { for i, migration := range ms { if migration.Version == current { @@ -41,6 +45,7 @@ func (ms Migrations) Current(current int64) (*Migration, error) { return nil, ErrNoCurrentVersion } +// Next gets the next migration. func (ms Migrations) Next(current int64) (*Migration, error) { for i, migration := range ms { if migration.Version > current { @@ -51,8 +56,9 @@ func (ms Migrations) Next(current int64) (*Migration, error) { return nil, ErrNoNextVersion } +// Previous : Get the previous migration. func (ms Migrations) Previous(current int64) (*Migration, error) { - for i := len(ms)-1; i >= 0; i-- { + for i := len(ms) - 1; i >= 0; i-- { if ms[i].Version < current { return ms[i], nil } @@ -61,6 +67,7 @@ func (ms Migrations) Previous(current int64) (*Migration, error) { return nil, ErrNoNextVersion } +// Last gets the last migration. func (ms Migrations) Last() (*Migration, error) { if len(ms) == 0 { return nil, ErrNoNextVersion @@ -77,11 +84,13 @@ func (ms Migrations) String() string { return str } +// AddMigration adds a migration. func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { _, filename, _, _ := runtime.Caller(1) AddNamedMigration(filename, up, down) } +// AddNamedMigration : Add a named migration. func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { v, _ := NumericComponent(filename) migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename} @@ -161,7 +170,7 @@ func versionFilter(v, current, target int64) bool { return false } -// retrieve the current version for this DB. +// EnsureDBVersion retrieves the current version for this DB. // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { rows, err := GetDialect().dbVersionQuery(db) @@ -178,14 +187,14 @@ func EnsureDBVersion(db *sql.DB) (int64, error) { for rows.Next() { var row MigrationRecord - if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil { + if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil { log.Fatal("error scanning rows:", err) } // have we already marked this version to be skipped? skip := false for _, v := range toSkip { - if v == row.VersionId { + if v == row.VersionID { skip = true break } @@ -197,11 +206,11 @@ func EnsureDBVersion(db *sql.DB) (int64, error) { // if version has been applied we're done if row.IsApplied { - return row.VersionId, nil + return row.VersionID, nil } // latest version of migration has not been applied. - toSkip = append(toSkip, row.VersionId) + toSkip = append(toSkip, row.VersionID) } return 0, ErrNoNextVersion @@ -217,14 +226,14 @@ func createVersionTable(db *sql.DB) error { d := GetDialect() - if _, err := txn.Exec(d.createVersionTableSql()); err != nil { + if _, err := txn.Exec(d.createVersionTableSQL()); err != nil { txn.Rollback() return err } version := 0 applied := true - if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil { + if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil { txn.Rollback() return err } @@ -232,8 +241,8 @@ func createVersionTable(db *sql.DB) error { return txn.Commit() } -// wrapper for EnsureDBVersion for callers that don't already have -// their own DB instance +// GetDBVersion is a wrapper for EnsureDBVersion for callers that don't already +// have their own DB instance func GetDBVersion(db *sql.DB) (int64, error) { version, err := EnsureDBVersion(db) if err != nil { diff --git a/migration.go b/migration.go index e37313c..b39946e 100644 --- a/migration.go +++ b/migration.go @@ -12,12 +12,14 @@ import ( "time" ) +// MigrationRecord struct. type MigrationRecord struct { - VersionId int64 + VersionID int64 TStamp time.Time IsApplied bool // was this a result of up() or down() } +// Migration struct. type Migration struct { Version int64 Next int64 // next version, or -1 if none @@ -31,10 +33,12 @@ func (m *Migration) String() string { return fmt.Sprintf(m.Source) } +// Up runs an up migration. func (m *Migration) Up(db *sql.DB) error { return m.run(db, true) } +// Down runs a down migration. func (m *Migration) Down(db *sql.DB) error { return m.run(db, false) } @@ -43,7 +47,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { switch filepath.Ext(m.Source) { case ".sql": if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil { - return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err)) + return fmt.Errorf("FAIL %v, quitting migration", err) } case ".go": @@ -74,9 +78,8 @@ func (m *Migration) run(db *sql.DB, direction bool) error { return nil } -// look for migration scripts with names in the form: -// XXX_descriptivename.ext -// where XXX specifies the version number +// NumericComponent looks for migration scripts with names in the form: +// XXX_descriptivename.ext where XXX specifies the version number // and ext specifies the type of migration func NumericComponent(name string) (int64, error) { @@ -99,6 +102,7 @@ func NumericComponent(name string) (int64, error) { return n, e } +// CreateMigration creates a migration. func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { if migrationType != "go" && migrationType != "sql" { @@ -111,7 +115,7 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string, fpath := filepath.Join(dir, filename) tmpl := sqlMigrationTemplate if migrationType == "go" { - tmpl = goSqlMigrationTemplate + tmpl = goSQLMigrationTemplate } path, err = writeTemplateToFile(fpath, tmpl, timestamp) @@ -119,12 +123,12 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string, return } -// Update the version table for the given migration, +// FinalizeMigration updates the version table for the given migration, // and finalize the transaction. func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error { // XXX: drop goose_db_version table on some minimum version number? - stmt := GetDialect().insertVersionSql() + stmt := GetDialect().insertVersionSQL() if _, err := tx.Exec(stmt, v, direction); err != nil { tx.Rollback() return err @@ -142,26 +146,26 @@ var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Par -- SQL section 'Down' is executed when this migration is rolled back `)) -var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(` + +var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(` package migration import ( - "database/sql" - - "github.com/pressly/goose" + "database/sql" + "github.com/pressly/goose" ) func init() { - goose.AddMigration(Up{{.}}, Down{{.}}) + goose.AddMigration(Up{{.}}, Down{{.}}) } // Up{{.}} updates the database to the new requirements func Up{{.}}(tx *sql.Tx) error { - return nil + return nil } // Down{{.}} should send the database back to the state it was from before Up was ran func Down{{.}}(tx *sql.Tx) error { - return nil + return nil } `)) diff --git a/redo.go b/redo.go index 2ec7a32..6f9049f 100644 --- a/redo.go +++ b/redo.go @@ -4,6 +4,7 @@ import ( "database/sql" ) +// Redo rolls back the most recently applied migration, then runs it again. func Redo(db *sql.DB, dir string) error { currentVersion, err := GetDBVersion(db) if err != nil { diff --git a/status.go b/status.go index e1fd05d..5f57421 100644 --- a/status.go +++ b/status.go @@ -8,6 +8,7 @@ import ( "time" ) +// Status prints the status of all migrations. func Status(db *sql.DB, dir string) error { // collect all migrations migrations, err := CollectMigrations(dir, minVersion, maxVersion) diff --git a/up.go b/up.go index 8f9d2e7..448a159 100644 --- a/up.go +++ b/up.go @@ -5,6 +5,7 @@ import ( "fmt" ) +// UpTo migrates up to a specific version. func UpTo(db *sql.DB, dir string, version int64) error { migrations, err := CollectMigrations(dir, minVersion, version) if err != nil { @@ -30,14 +31,14 @@ func UpTo(db *sql.DB, dir string, version int64) error { return err } } - - return nil } +// Up applies all available migrations. func Up(db *sql.DB, dir string) error { return UpTo(db, dir, maxVersion) } +// UpByOne migrates up by a single version. func UpByOne(db *sql.DB, dir string) error { migrations, err := CollectMigrations(dir, minVersion, maxVersion) if err != nil { diff --git a/version.go b/version.go index d10d2e1..133df8e 100644 --- a/version.go +++ b/version.go @@ -5,6 +5,7 @@ import ( "fmt" ) +// Version prints the current version of the database. func Version(db *sql.DB, dir string) error { current, err := GetDBVersion(db) if err != nil {