added mssql support.

pull/169/head
Aleksei Maslov 2019-04-12 10:47:35 +08:00
parent dcdfaa3d34
commit ea2101beb3
11 changed files with 96 additions and 8 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
.DS_Store .DS_Store
*.swp *.swp
*.test *.test

View File

@ -3,6 +3,9 @@ language: go
go: go:
- 1.12 - 1.12
before_script:
- go get github.com/denisenkom/go-mssqldb
script: script:
- mkdir -p bin - mkdir -p bin
- go test -v ./... - go test -v ./...

View File

@ -44,6 +44,7 @@ Drivers:
postgres postgres
mysql mysql
sqlite3 sqlite3
sqlserver
redshift redshift
Examples: Examples:
@ -57,6 +58,7 @@ Examples:
goose mysql "user:password@/dbname?parseTime=true" status goose mysql "user:password@/dbname?parseTime=true" status
goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status
goose tidb "user:password@/dbname?parseTime=true" status goose tidb "user:password@/dbname?parseTime=true" status
goose sqlserver "sqlserver://user:password@dbname:1433?database=master"" status
Options: Options:

View File

@ -0,0 +1,7 @@
// +build !no_sqlserver
package main
import (
_ "github.com/denisenkom/go-mssqldb"
)

View File

@ -6,7 +6,7 @@ import (
"log" "log"
"os" "os"
"github.com/pressly/goose" "github.com/bandlab/goose"
) )
var ( var (
@ -83,6 +83,7 @@ Drivers:
postgres postgres
mysql mysql
sqlite3 sqlite3
sqlserver
redshift redshift
Examples: Examples:
@ -96,6 +97,7 @@ Examples:
goose mysql "user:password@/dbname?parseTime=true" status goose mysql "user:password@/dbname?parseTime=true" status
goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status
goose tidb "user:password@/dbname?parseTime=true" status goose tidb "user:password@/dbname?parseTime=true" status
goose sqlserver "sqlserver://user:password@dbname:1433?database=master"" status
Options: Options:
` `

View File

@ -72,7 +72,7 @@ var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Pa
import ( import (
"database/sql" "database/sql"
"github.com/pressly/goose" "github.com/bandlab/goose"
) )
func init() { func init() {

2
db.go
View File

@ -20,7 +20,7 @@ func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
} }
switch driver { switch driver {
case "postgres", "sqlite3", "mysql": case "postgres", "sqlite3", "mysql", "sqlserver":
return sql.Open(driver, dbstring) return sql.Open(driver, dbstring)
default: default:
return nil, fmt.Errorf("unsupported driver %s", driver) return nil, fmt.Errorf("unsupported driver %s", driver)

View File

@ -11,6 +11,7 @@ type SQLDialect interface {
createVersionTableSQL() string // sql string to create the db version table createVersionTableSQL() string // sql string to create the db version table
insertVersionSQL() string // sql string to insert the initial version table row insertVersionSQL() string // sql string to insert the initial version table row
deleteVersionSQL() string // sql string to delete version deleteVersionSQL() string // sql string to delete version
migrationSQL() string // sql string to retrieve migrations
dbVersionQuery(db *sql.DB) (*sql.Rows, error) dbVersionQuery(db *sql.DB) (*sql.Rows, error)
} }
@ -30,6 +31,8 @@ func SetDialect(d string) error {
dialect = &MySQLDialect{} dialect = &MySQLDialect{}
case "sqlite3": case "sqlite3":
dialect = &Sqlite3Dialect{} dialect = &Sqlite3Dialect{}
case "sqlserver":
dialect = &SqlServerDialect{}
case "redshift": case "redshift":
dialect = &RedshiftDialect{} dialect = &RedshiftDialect{}
case "tidb": case "tidb":
@ -71,6 +74,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err return rows, err
} }
func (m PostgresDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (pg PostgresDialect) deleteVersionSQL() string { func (pg PostgresDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
} }
@ -105,10 +112,64 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err return rows, err
} }
func (m MySQLDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m MySQLDialect) deleteVersionSQL() string { func (m MySQLDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
} }
////////////////////////////
// MSSQL
////////////////////////////
// SqlServerDialect struct.
type SqlServerDialect struct{}
func (m SqlServerDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
version_id BIGINT NOT NULL,
is_applied BIT NOT NULL,
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
);`, TableName())
}
func (m SqlServerDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName())
}
func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m SqlServerDialect) migrationSQL() string {
const tpl = `
WITH Migrations AS
(
SELECT tstamp, is_applied,
ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber'
FROM %s
WHERE version_id=@p1
)
SELECT tstamp, is_applied
FROM Migrations
WHERE RowNumber BETWEEN 1 AND 2
ORDER BY tstamp DESC
`
return fmt.Sprintf(tpl, TableName())
}
func (m SqlServerDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName())
}
//////////////////////////// ////////////////////////////
// sqlite3 // sqlite3
//////////////////////////// ////////////////////////////
@ -138,6 +199,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err return rows, err
} }
func (m Sqlite3Dialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m Sqlite3Dialect) deleteVersionSQL() string { func (m Sqlite3Dialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
} }
@ -172,6 +237,10 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err return rows, err
} }
func (m RedshiftDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (rs RedshiftDialect) deleteVersionSQL() string { func (rs RedshiftDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
} }
@ -206,6 +275,10 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err return rows, err
} }
func (m TiDBDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m TiDBDialect) deleteVersionSQL() string { func (m TiDBDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
} }

View File

@ -3,7 +3,7 @@ package main
import ( import (
"database/sql" "database/sql"
"github.com/pressly/goose" "github.com/bandlab/goose"
) )
func init() { func init() {

View File

@ -7,7 +7,7 @@ import (
"log" "log"
"os" "os"
"github.com/pressly/goose" "github.com/bandlab/goose"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )

View File

@ -2,7 +2,6 @@ package goose
import ( import (
"database/sql" "database/sql"
"fmt"
"path/filepath" "path/filepath"
"time" "time"
@ -34,10 +33,11 @@ func Status(db *sql.DB, dir string) error {
} }
func printMigrationStatus(db *sql.DB, version int64, script string) error { func printMigrationStatus(db *sql.DB, version int64, script string) error {
q := fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", TableName(), version) q := GetDialect().migrationSQL()
var row MigrationRecord var row MigrationRecord
err := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied)
err := db.QueryRow(q, version).Scan(&row.TStamp, &row.IsApplied)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return errors.Wrap(err, "failed to query the latest migration") return errors.Wrap(err, "failed to query the latest migration")
} }