Merge pull request #44 from duffn/gofmt

Reformat code to adhere to golint and gofmt standards
This commit is contained in:
Vojtech Vitek 2017-05-08 14:39:07 -04:00 committed by GitHub
commit ba3b7a9fde
9 changed files with 73 additions and 49 deletions

View File

@ -5,26 +5,28 @@ import (
"fmt" "fmt"
) )
// SqlDialect abstracts the details of specific SQL dialects // SQLDialect abstracts the details of specific SQL dialects
// for goose's few SQL specific statements // for goose's few SQL specific statements
type SqlDialect interface { type SQLDialect interface {
createVersionTableSql() string // sql string to create the goose_db_version table createVersionTableSQL() string // sql string to create the goose_db_version table
insertVersionSql() string // sql string to insert the initial version table row insertVersionSQL() string // sql string to insert the initial version table row
dbVersionQuery(db *sql.DB) (*sql.Rows, error) dbVersionQuery(db *sql.DB) (*sql.Rows, error)
} }
var dialect SqlDialect = &PostgresDialect{} var dialect SQLDialect = &PostgresDialect{}
func GetDialect() SqlDialect { // GetDialect gets the SQLDialect
func GetDialect() SQLDialect {
return dialect return dialect
} }
// SetDialect sets the SQLDialect
func SetDialect(d string) error { func SetDialect(d string) error {
switch d { switch d {
case "postgres": case "postgres":
dialect = &PostgresDialect{} dialect = &PostgresDialect{}
case "mysql": case "mysql":
dialect = &MySqlDialect{} dialect = &MySQLDialect{}
case "sqlite3": case "sqlite3":
dialect = &Sqlite3Dialect{} dialect = &Sqlite3Dialect{}
case "redshift": case "redshift":
@ -40,9 +42,10 @@ func SetDialect(d string) error {
// Postgres // Postgres
//////////////////////////// ////////////////////////////
// PostgresDialect struct.
type PostgresDialect struct{} type PostgresDialect struct{}
func (pg PostgresDialect) createVersionTableSql() string { func (pg PostgresDialect) createVersionTableSQL() string {
return `CREATE TABLE goose_db_version ( return `CREATE TABLE goose_db_version (
id serial NOT NULL, id serial NOT NULL,
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -52,7 +55,7 @@ func (pg PostgresDialect) createVersionTableSql() string {
);` );`
} }
func (pg PostgresDialect) insertVersionSql() string { func (pg PostgresDialect) insertVersionSQL() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
} }
@ -69,9 +72,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// MySQL // MySQL
//////////////////////////// ////////////////////////////
type MySqlDialect struct{} // MySQLDialect struct.
type MySQLDialect struct{}
func (m MySqlDialect) createVersionTableSql() string { func (m MySQLDialect) createVersionTableSQL() string {
return `CREATE TABLE goose_db_version ( return `CREATE TABLE goose_db_version (
id serial NOT NULL, id serial NOT NULL,
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -81,11 +85,11 @@ func (m MySqlDialect) createVersionTableSql() string {
);` );`
} }
func (m MySqlDialect) insertVersionSql() string { func (m MySQLDialect) insertVersionSQL() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);" return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
} }
func (m MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC") rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC")
if err != nil { if err != nil {
return nil, err return nil, err
@ -98,9 +102,10 @@ func (m MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// sqlite3 // sqlite3
//////////////////////////// ////////////////////////////
// Sqlite3Dialect struct.
type Sqlite3Dialect struct{} type Sqlite3Dialect struct{}
func (m Sqlite3Dialect) createVersionTableSql() string { func (m Sqlite3Dialect) createVersionTableSQL() string {
return `CREATE TABLE goose_db_version ( return `CREATE TABLE goose_db_version (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL, version_id INTEGER NOT NULL,
@ -109,7 +114,7 @@ func (m Sqlite3Dialect) createVersionTableSql() string {
);` );`
} }
func (m Sqlite3Dialect) insertVersionSql() string { func (m Sqlite3Dialect) insertVersionSQL() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);" return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
} }
@ -126,9 +131,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
// Redshift // Redshift
//////////////////////////// ////////////////////////////
// RedshiftDialect struct.
type RedshiftDialect struct{} type RedshiftDialect struct{}
func (rs RedshiftDialect) createVersionTableSql() string { func (rs RedshiftDialect) createVersionTableSQL() string {
return `CREATE TABLE goose_db_version ( return `CREATE TABLE goose_db_version (
id integer NOT NULL identity(1, 1), id integer NOT NULL identity(1, 1),
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -138,7 +144,7 @@ func (rs RedshiftDialect) createVersionTableSql() string {
);` );`
} }
func (rs RedshiftDialect) insertVersionSql() string { func (rs RedshiftDialect) insertVersionSQL() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);" return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string) error { func Down(db *sql.DB, dir string) error {
currentVersion, err := GetDBVersion(db) currentVersion, err := GetDBVersion(db)
if err != nil { if err != nil {
@ -24,6 +25,7 @@ func Down(db *sql.DB, dir string) error {
return current.Down(db) return current.Down(db)
} }
// DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64) error { func DownTo(db *sql.DB, dir string, version int64) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion) migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil { if err != nil {
@ -59,6 +61,4 @@ func DownTo(db *sql.DB, dir string, version int64) error {
return err return err
} }
} }
return nil
} }

View File

@ -3,8 +3,8 @@ package goose
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"sync"
"strconv" "strconv"
"sync"
) )
var ( var (
@ -13,6 +13,7 @@ var (
maxVersion = int64((1 << 63) - 1) maxVersion = int64((1 << 63) - 1)
) )
// Run runs a goose command.
func Run(command string, db *sql.DB, dir string, args ...string) error { func Run(command string, db *sql.DB, dir string, args ...string) error {
switch command { switch command {
case "up": case "up":

View File

@ -11,14 +11,17 @@ import (
) )
var ( var (
// ErrNoCurrentVersion when a current migration version is not found.
ErrNoCurrentVersion = errors.New("no current version found") ErrNoCurrentVersion = errors.New("no current version found")
// ErrNoNextVersion when the next migration version is not found.
ErrNoNextVersion = errors.New("no next version found") ErrNoNextVersion = errors.New("no next version found")
// MaxVersion is the maximum allowed version.
MaxVersion int64 = 9223372036854775807 // max(int64) MaxVersion int64 = 9223372036854775807 // max(int64)
goMigrations []*Migration goMigrations []*Migration
) )
// Migrations slice.
type Migrations []*Migration type Migrations []*Migration
// helpers so we can use pkg sort // helpers so we can use pkg sort
@ -31,6 +34,7 @@ func (ms Migrations) Less(i, j int) bool {
return ms[i].Version < ms[j].Version return ms[i].Version < ms[j].Version
} }
// Current gets the current migration.
func (ms Migrations) Current(current int64) (*Migration, error) { func (ms Migrations) Current(current int64) (*Migration, error) {
for i, migration := range ms { for i, migration := range ms {
if migration.Version == current { if migration.Version == current {
@ -41,6 +45,7 @@ func (ms Migrations) Current(current int64) (*Migration, error) {
return nil, ErrNoCurrentVersion return nil, ErrNoCurrentVersion
} }
// Next gets the next migration.
func (ms Migrations) Next(current int64) (*Migration, error) { func (ms Migrations) Next(current int64) (*Migration, error) {
for i, migration := range ms { for i, migration := range ms {
if migration.Version > current { if migration.Version > current {
@ -51,8 +56,9 @@ func (ms Migrations) Next(current int64) (*Migration, error) {
return nil, ErrNoNextVersion return nil, ErrNoNextVersion
} }
// Previous : Get the previous migration.
func (ms Migrations) Previous(current int64) (*Migration, error) { func (ms Migrations) Previous(current int64) (*Migration, error) {
for i := len(ms)-1; i >= 0; i-- { for i := len(ms) - 1; i >= 0; i-- {
if ms[i].Version < current { if ms[i].Version < current {
return ms[i], nil return ms[i], nil
} }
@ -61,6 +67,7 @@ func (ms Migrations) Previous(current int64) (*Migration, error) {
return nil, ErrNoNextVersion return nil, ErrNoNextVersion
} }
// Last gets the last migration.
func (ms Migrations) Last() (*Migration, error) { func (ms Migrations) Last() (*Migration, error) {
if len(ms) == 0 { if len(ms) == 0 {
return nil, ErrNoNextVersion return nil, ErrNoNextVersion
@ -77,11 +84,13 @@ func (ms Migrations) String() string {
return str return str
} }
// AddMigration adds a migration.
func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1) _, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down) AddNamedMigration(filename, up, down)
} }
// AddNamedMigration : Add a named migration.
func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := NumericComponent(filename) v, _ := NumericComponent(filename)
migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename} migration := &Migration{Version: v, Next: -1, Previous: -1, UpFn: up, DownFn: down, Source: filename}
@ -161,7 +170,7 @@ func versionFilter(v, current, target int64) bool {
return false return false
} }
// retrieve the current version for this DB. // EnsureDBVersion retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist. // Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) { func EnsureDBVersion(db *sql.DB) (int64, error) {
rows, err := GetDialect().dbVersionQuery(db) rows, err := GetDialect().dbVersionQuery(db)
@ -178,14 +187,14 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
for rows.Next() { for rows.Next() {
var row MigrationRecord var row MigrationRecord
if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil { if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil {
log.Fatal("error scanning rows:", err) log.Fatal("error scanning rows:", err)
} }
// have we already marked this version to be skipped? // have we already marked this version to be skipped?
skip := false skip := false
for _, v := range toSkip { for _, v := range toSkip {
if v == row.VersionId { if v == row.VersionID {
skip = true skip = true
break break
} }
@ -197,11 +206,11 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
// if version has been applied we're done // if version has been applied we're done
if row.IsApplied { if row.IsApplied {
return row.VersionId, nil return row.VersionID, nil
} }
// latest version of migration has not been applied. // latest version of migration has not been applied.
toSkip = append(toSkip, row.VersionId) toSkip = append(toSkip, row.VersionID)
} }
return 0, ErrNoNextVersion return 0, ErrNoNextVersion
@ -217,14 +226,14 @@ func createVersionTable(db *sql.DB) error {
d := GetDialect() d := GetDialect()
if _, err := txn.Exec(d.createVersionTableSql()); err != nil { if _, err := txn.Exec(d.createVersionTableSQL()); err != nil {
txn.Rollback() txn.Rollback()
return err return err
} }
version := 0 version := 0
applied := true applied := true
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil { if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil {
txn.Rollback() txn.Rollback()
return err return err
} }
@ -232,8 +241,8 @@ func createVersionTable(db *sql.DB) error {
return txn.Commit() return txn.Commit()
} }
// wrapper for EnsureDBVersion for callers that don't already have // GetDBVersion is a wrapper for EnsureDBVersion for callers that don't already
// their own DB instance // have their own DB instance
func GetDBVersion(db *sql.DB) (int64, error) { func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db) version, err := EnsureDBVersion(db)
if err != nil { if err != nil {

View File

@ -12,12 +12,14 @@ import (
"time" "time"
) )
// MigrationRecord struct.
type MigrationRecord struct { type MigrationRecord struct {
VersionId int64 VersionID int64
TStamp time.Time TStamp time.Time
IsApplied bool // was this a result of up() or down() IsApplied bool // was this a result of up() or down()
} }
// Migration struct.
type Migration struct { type Migration struct {
Version int64 Version int64
Next int64 // next version, or -1 if none Next int64 // next version, or -1 if none
@ -31,10 +33,12 @@ func (m *Migration) String() string {
return fmt.Sprintf(m.Source) return fmt.Sprintf(m.Source)
} }
// Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error { func (m *Migration) Up(db *sql.DB) error {
return m.run(db, true) return m.run(db, true)
} }
// Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error { func (m *Migration) Down(db *sql.DB) error {
return m.run(db, false) return m.run(db, false)
} }
@ -43,7 +47,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) { switch filepath.Ext(m.Source) {
case ".sql": case ".sql":
if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil { if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err)) return fmt.Errorf("FAIL %v, quitting migration", err)
} }
case ".go": case ".go":
@ -74,9 +78,8 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
return nil return nil
} }
// look for migration scripts with names in the form: // NumericComponent looks for migration scripts with names in the form:
// XXX_descriptivename.ext // XXX_descriptivename.ext where XXX specifies the version number
// where XXX specifies the version number
// and ext specifies the type of migration // and ext specifies the type of migration
func NumericComponent(name string) (int64, error) { func NumericComponent(name string) (int64, error) {
@ -99,6 +102,7 @@ func NumericComponent(name string) (int64, error) {
return n, e return n, e
} }
// CreateMigration creates a migration.
func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {
if migrationType != "go" && migrationType != "sql" { if migrationType != "go" && migrationType != "sql" {
@ -111,7 +115,7 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string,
fpath := filepath.Join(dir, filename) fpath := filepath.Join(dir, filename)
tmpl := sqlMigrationTemplate tmpl := sqlMigrationTemplate
if migrationType == "go" { if migrationType == "go" {
tmpl = goSqlMigrationTemplate tmpl = goSQLMigrationTemplate
} }
path, err = writeTemplateToFile(fpath, tmpl, timestamp) path, err = writeTemplateToFile(fpath, tmpl, timestamp)
@ -119,12 +123,12 @@ func CreateMigration(name, migrationType, dir string, t time.Time) (path string,
return return
} }
// Update the version table for the given migration, // FinalizeMigration updates the version table for the given migration,
// and finalize the transaction. // and finalize the transaction.
func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error { func FinalizeMigration(tx *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number? // XXX: drop goose_db_version table on some minimum version number?
stmt := GetDialect().insertVersionSql() stmt := GetDialect().insertVersionSQL()
if _, err := tx.Exec(stmt, v, direction); err != nil { if _, err := tx.Exec(stmt, v, direction); err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -142,12 +146,12 @@ var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Par
-- SQL section 'Down' is executed when this migration is rolled back -- SQL section 'Down' is executed when this migration is rolled back
`)) `))
var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
package migration package migration
import ( import (
"database/sql" "database/sql"
"github.com/pressly/goose" "github.com/pressly/goose"
) )

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
) )
// Redo rolls back the most recently applied migration, then runs it again.
func Redo(db *sql.DB, dir string) error { func Redo(db *sql.DB, dir string) error {
currentVersion, err := GetDBVersion(db) currentVersion, err := GetDBVersion(db)
if err != nil { if err != nil {

View File

@ -8,6 +8,7 @@ import (
"time" "time"
) )
// Status prints the status of all migrations.
func Status(db *sql.DB, dir string) error { func Status(db *sql.DB, dir string) error {
// collect all migrations // collect all migrations
migrations, err := CollectMigrations(dir, minVersion, maxVersion) migrations, err := CollectMigrations(dir, minVersion, maxVersion)

5
up.go
View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// UpTo migrates up to a specific version.
func UpTo(db *sql.DB, dir string, version int64) error { func UpTo(db *sql.DB, dir string, version int64) error {
migrations, err := CollectMigrations(dir, minVersion, version) migrations, err := CollectMigrations(dir, minVersion, version)
if err != nil { if err != nil {
@ -30,14 +31,14 @@ func UpTo(db *sql.DB, dir string, version int64) error {
return err return err
} }
} }
return nil
} }
// Up applies all available migrations.
func Up(db *sql.DB, dir string) error { func Up(db *sql.DB, dir string) error {
return UpTo(db, dir, maxVersion) return UpTo(db, dir, maxVersion)
} }
// UpByOne migrates up by a single version.
func UpByOne(db *sql.DB, dir string) error { func UpByOne(db *sql.DB, dir string) error {
migrations, err := CollectMigrations(dir, minVersion, maxVersion) migrations, err := CollectMigrations(dir, minVersion, maxVersion)
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// Version prints the current version of the database.
func Version(db *sql.DB, dir string) error { func Version(db *sql.DB, dir string) error {
current, err := GetDBVersion(db) current, err := GetDBVersion(db)
if err != nil { if err != nil {