diff --git a/.gitignore b/.gitignore index 7ef12f8..4ef7d62 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea .DS_Store *.swp *.test diff --git a/.travis.yml b/.travis.yml index 4b07f8c..4d62933 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,9 @@ language: go go: - 1.12 +before_script: +- go get github.com/denisenkom/go-mssqldb + script: - mkdir -p bin - go test -v ./... diff --git a/README.md b/README.md index 6136e90..e77057e 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ Drivers: postgres mysql sqlite3 + sqlserver redshift Examples: @@ -57,6 +58,7 @@ Examples: goose mysql "user:password@/dbname?parseTime=true" status goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status goose tidb "user:password@/dbname?parseTime=true" status + goose sqlserver "sqlserver://user:password@dbname:1433?database=master"" status Options: diff --git a/cmd/goose/driver_sqlserver.go b/cmd/goose/driver_sqlserver.go new file mode 100644 index 0000000..3985725 --- /dev/null +++ b/cmd/goose/driver_sqlserver.go @@ -0,0 +1,7 @@ +// +build !no_sqlserver + +package main + +import ( + _ "github.com/denisenkom/go-mssqldb" +) diff --git a/cmd/goose/main.go b/cmd/goose/main.go index f6e6707..ff0bd51 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -6,7 +6,7 @@ import ( "log" "os" - "github.com/pressly/goose" + "github.com/bandlab/goose" ) var ( @@ -83,6 +83,7 @@ Drivers: postgres mysql sqlite3 + sqlserver redshift Examples: @@ -96,6 +97,7 @@ Examples: goose mysql "user:password@/dbname?parseTime=true" status goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status goose tidb "user:password@/dbname?parseTime=true" status + goose sqlserver "sqlserver://user:password@dbname:1433?database=master"" status Options: ` diff --git a/create.go b/create.go index 55fe540..dfb75aa 100644 --- a/create.go +++ b/create.go @@ -72,7 +72,7 @@ var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Pa import ( "database/sql" - "github.com/pressly/goose" + "github.com/bandlab/goose" ) func init() { diff --git a/db.go b/db.go index b61f886..cb66451 100644 --- a/db.go +++ b/db.go @@ -20,7 +20,7 @@ func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) { } switch driver { - case "postgres", "sqlite3", "mysql": + case "postgres", "sqlite3", "mysql", "sqlserver": return sql.Open(driver, dbstring) default: return nil, fmt.Errorf("unsupported driver %s", driver) diff --git a/dialect.go b/dialect.go index 7142a27..0d5b3dc 100644 --- a/dialect.go +++ b/dialect.go @@ -11,6 +11,7 @@ type SQLDialect interface { createVersionTableSQL() string // sql string to create the db version table insertVersionSQL() string // sql string to insert the initial version table row deleteVersionSQL() string // sql string to delete version + migrationSQL() string // sql string to retrieve migrations dbVersionQuery(db *sql.DB) (*sql.Rows, error) } @@ -30,6 +31,8 @@ func SetDialect(d string) error { dialect = &MySQLDialect{} case "sqlite3": dialect = &Sqlite3Dialect{} + case "sqlserver": + dialect = &SqlServerDialect{} case "redshift": dialect = &RedshiftDialect{} case "tidb": @@ -71,6 +74,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m PostgresDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) +} + func (pg PostgresDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) } @@ -105,10 +112,64 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m MySQLDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) +} + func (m MySQLDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) } +//////////////////////////// +// MSSQL +//////////////////////////// + +// SqlServerDialect struct. +type SqlServerDialect struct{} + +func (m SqlServerDialect) createVersionTableSQL() string { + return fmt.Sprintf(`CREATE TABLE %s ( + id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, + version_id BIGINT NOT NULL, + is_applied BIT NOT NULL, + tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP + );`, TableName()) +} + +func (m SqlServerDialect) insertVersionSQL() string { + return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName()) +} + +func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName())) + if err != nil { + return nil, err + } + + return rows, err +} + +func (m SqlServerDialect) migrationSQL() string { + const tpl = ` +WITH Migrations AS +( + SELECT tstamp, is_applied, + ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber' + FROM %s + WHERE version_id=@p1 +) +SELECT tstamp, is_applied +FROM Migrations +WHERE RowNumber BETWEEN 1 AND 2 +ORDER BY tstamp DESC +` + return fmt.Sprintf(tpl, TableName()) +} + +func (m SqlServerDialect) deleteVersionSQL() string { + return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName()) +} + //////////////////////////// // sqlite3 //////////////////////////// @@ -138,6 +199,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m Sqlite3Dialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) +} + func (m Sqlite3Dialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) } @@ -172,6 +237,10 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m RedshiftDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) +} + func (rs RedshiftDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) } @@ -206,6 +275,10 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { return rows, err } +func (m TiDBDialect) migrationSQL() string { + return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) +} + func (m TiDBDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) } diff --git a/examples/go-migrations/00002_rename_root.go b/examples/go-migrations/00002_rename_root.go index 069d40a..78d4907 100644 --- a/examples/go-migrations/00002_rename_root.go +++ b/examples/go-migrations/00002_rename_root.go @@ -3,7 +3,7 @@ package main import ( "database/sql" - "github.com/pressly/goose" + "github.com/bandlab/goose" ) func init() { diff --git a/examples/go-migrations/main.go b/examples/go-migrations/main.go index 77c9997..f91a97a 100644 --- a/examples/go-migrations/main.go +++ b/examples/go-migrations/main.go @@ -7,7 +7,7 @@ import ( "log" "os" - "github.com/pressly/goose" + "github.com/bandlab/goose" _ "github.com/mattn/go-sqlite3" ) diff --git a/status.go b/status.go index 8844ade..8a24c22 100644 --- a/status.go +++ b/status.go @@ -2,7 +2,6 @@ package goose import ( "database/sql" - "fmt" "path/filepath" "time" @@ -34,10 +33,11 @@ func Status(db *sql.DB, dir string) error { } func printMigrationStatus(db *sql.DB, version int64, script string) error { - q := fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", TableName(), version) + q := GetDialect().migrationSQL() var row MigrationRecord - err := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied) + + err := db.QueryRow(q, version).Scan(&row.TStamp, &row.IsApplied) if err != nil && err != sql.ErrNoRows { return errors.Wrap(err, "failed to query the latest migration") }