added mssql support.

pull/169/head
Aleksei Maslov 2019-04-12 10:47:35 +08:00
parent dcdfaa3d34
commit ea2101beb3
11 changed files with 96 additions and 8 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
.DS_Store
*.swp
*.test

View File

@ -3,6 +3,9 @@ language: go
go:
- 1.12
before_script:
- go get github.com/denisenkom/go-mssqldb
script:
- mkdir -p bin
- go test -v ./...

View File

@ -44,6 +44,7 @@ Drivers:
postgres
mysql
sqlite3
sqlserver
redshift
Examples:
@ -57,6 +58,7 @@ Examples:
goose mysql "user:password@/dbname?parseTime=true" status
goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status
goose tidb "user:password@/dbname?parseTime=true" status
goose sqlserver "sqlserver://user:password@dbname:1433?database=master"" status
Options:

View File

@ -0,0 +1,7 @@
// +build !no_sqlserver
package main
import (
_ "github.com/denisenkom/go-mssqldb"
)

View File

@ -6,7 +6,7 @@ import (
"log"
"os"
"github.com/pressly/goose"
"github.com/bandlab/goose"
)
var (
@ -83,6 +83,7 @@ Drivers:
postgres
mysql
sqlite3
sqlserver
redshift
Examples:
@ -96,6 +97,7 @@ Examples:
goose mysql "user:password@/dbname?parseTime=true" status
goose redshift "postgres://user:password@qwerty.us-east-1.redshift.amazonaws.com:5439/db" status
goose tidb "user:password@/dbname?parseTime=true" status
goose sqlserver "sqlserver://user:password@dbname:1433?database=master"" status
Options:
`

View File

@ -72,7 +72,7 @@ var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Pa
import (
"database/sql"
"github.com/pressly/goose"
"github.com/bandlab/goose"
)
func init() {

2
db.go
View File

@ -20,7 +20,7 @@ func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
}
switch driver {
case "postgres", "sqlite3", "mysql":
case "postgres", "sqlite3", "mysql", "sqlserver":
return sql.Open(driver, dbstring)
default:
return nil, fmt.Errorf("unsupported driver %s", driver)

View File

@ -11,6 +11,7 @@ type SQLDialect interface {
createVersionTableSQL() string // sql string to create the db version table
insertVersionSQL() string // sql string to insert the initial version table row
deleteVersionSQL() string // sql string to delete version
migrationSQL() string // sql string to retrieve migrations
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
}
@ -30,6 +31,8 @@ func SetDialect(d string) error {
dialect = &MySQLDialect{}
case "sqlite3":
dialect = &Sqlite3Dialect{}
case "sqlserver":
dialect = &SqlServerDialect{}
case "redshift":
dialect = &RedshiftDialect{}
case "tidb":
@ -71,6 +74,10 @@ func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}
func (m PostgresDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (pg PostgresDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
}
@ -105,10 +112,64 @@ func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}
func (m MySQLDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m MySQLDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}
////////////////////////////
// MSSQL
////////////////////////////
// SqlServerDialect struct.
type SqlServerDialect struct{}
func (m SqlServerDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
version_id BIGINT NOT NULL,
is_applied BIT NOT NULL,
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
);`, TableName())
}
func (m SqlServerDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName())
}
func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}
return rows, err
}
func (m SqlServerDialect) migrationSQL() string {
const tpl = `
WITH Migrations AS
(
SELECT tstamp, is_applied,
ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber'
FROM %s
WHERE version_id=@p1
)
SELECT tstamp, is_applied
FROM Migrations
WHERE RowNumber BETWEEN 1 AND 2
ORDER BY tstamp DESC
`
return fmt.Sprintf(tpl, TableName())
}
func (m SqlServerDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName())
}
////////////////////////////
// sqlite3
////////////////////////////
@ -138,6 +199,10 @@ func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}
func (m Sqlite3Dialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m Sqlite3Dialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}
@ -172,6 +237,10 @@ func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}
func (m RedshiftDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (rs RedshiftDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
}
@ -206,6 +275,10 @@ func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
return rows, err
}
func (m TiDBDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}
func (m TiDBDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}

View File

@ -3,7 +3,7 @@ package main
import (
"database/sql"
"github.com/pressly/goose"
"github.com/bandlab/goose"
)
func init() {

View File

@ -7,7 +7,7 @@ import (
"log"
"os"
"github.com/pressly/goose"
"github.com/bandlab/goose"
_ "github.com/mattn/go-sqlite3"
)

View File

@ -2,7 +2,6 @@ package goose
import (
"database/sql"
"fmt"
"path/filepath"
"time"
@ -34,10 +33,11 @@ func Status(db *sql.DB, dir string) error {
}
func printMigrationStatus(db *sql.DB, version int64, script string) error {
q := fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", TableName(), version)
q := GetDialect().migrationSQL()
var row MigrationRecord
err := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied)
err := db.QueryRow(q, version).Scan(&row.TStamp, &row.IsApplied)
if err != nil && err != sql.ErrNoRows {
return errors.Wrap(err, "failed to query the latest migration")
}