mirror of https://github.com/pressly/goose.git
fix: use global table name in queries (#515)
parent
7ce30b743d
commit
87592390b9
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
func init() {
|
||||
store, _ = dialect.NewStore(dialect.Postgres, TableName())
|
||||
store, _ = dialect.NewStore(dialect.Postgres)
|
||||
}
|
||||
|
||||
var store dialect.Store
|
||||
|
@ -36,6 +36,6 @@ func SetDialect(s string) error {
|
|||
return fmt.Errorf("%q: unknown dialect", s)
|
||||
}
|
||||
var err error
|
||||
store, err = dialect.NewStore(d, TableName())
|
||||
store, err = dialect.NewStore(d)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,13 +2,11 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Clickhouse struct {
|
||||
Table string
|
||||
}
|
||||
type Clickhouse struct{}
|
||||
|
||||
var _ Querier = (*Clickhouse)(nil)
|
||||
|
||||
func (c *Clickhouse) CreateTable() string {
|
||||
func (c *Clickhouse) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE IF NOT EXISTS %s (
|
||||
version_id Int64,
|
||||
is_applied UInt8,
|
||||
|
@ -17,25 +15,25 @@ func (c *Clickhouse) CreateTable() string {
|
|||
)
|
||||
ENGINE = MergeTree()
|
||||
ORDER BY (date)`
|
||||
return fmt.Sprintf(q, c.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (c *Clickhouse) InsertVersion() string {
|
||||
func (c *Clickhouse) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
|
||||
return fmt.Sprintf(q, c.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (c *Clickhouse) DeleteVersion() string {
|
||||
func (c *Clickhouse) DeleteVersion(tableName string) string {
|
||||
q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2`
|
||||
return fmt.Sprintf(q, c.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (c *Clickhouse) GetMigrationByVersion() string {
|
||||
func (c *Clickhouse) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, c.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (c *Clickhouse) ListMigrations() string {
|
||||
func (c *Clickhouse) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC`
|
||||
return fmt.Sprintf(q, c.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -4,25 +4,25 @@ package dialectquery
|
|||
// specific query.
|
||||
type Querier interface {
|
||||
// CreateTable returns the SQL query string to create the db version table.
|
||||
CreateTable() string
|
||||
CreateTable(tableName string) string
|
||||
|
||||
// InsertVersion returns the SQL query string to insert a new version into
|
||||
// the db version table.
|
||||
InsertVersion() string
|
||||
InsertVersion(tableName string) string
|
||||
|
||||
// DeleteVersion returns the SQL query string to delete a version from
|
||||
// the db version table.
|
||||
DeleteVersion() string
|
||||
DeleteVersion(tableName string) string
|
||||
|
||||
// GetMigrationByVersion returns the SQL query string to get a single
|
||||
// migration by version.
|
||||
//
|
||||
// The query should return the timestamp and is_applied columns.
|
||||
GetMigrationByVersion() string
|
||||
GetMigrationByVersion(tableName string) string
|
||||
|
||||
// ListMigrations returns the SQL query string to list all migrations in
|
||||
// descending order by id.
|
||||
//
|
||||
// The query should return the version_id and is_applied columns.
|
||||
ListMigrations() string
|
||||
ListMigrations(tableName string) string
|
||||
}
|
||||
|
|
|
@ -2,13 +2,11 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Mysql struct {
|
||||
Table string
|
||||
}
|
||||
type Mysql struct{}
|
||||
|
||||
var _ Querier = (*Mysql)(nil)
|
||||
|
||||
func (m *Mysql) CreateTable() string {
|
||||
func (m *Mysql) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id serial NOT NULL,
|
||||
version_id bigint NOT NULL,
|
||||
|
@ -16,25 +14,25 @@ func (m *Mysql) CreateTable() string {
|
|||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
)`
|
||||
return fmt.Sprintf(q, m.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (m *Mysql) InsertVersion() string {
|
||||
func (m *Mysql) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
|
||||
return fmt.Sprintf(q, m.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (m *Mysql) DeleteVersion() string {
|
||||
func (m *Mysql) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=?`
|
||||
return fmt.Sprintf(q, m.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (m *Mysql) GetMigrationByVersion() string {
|
||||
func (m *Mysql) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, m.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (m *Mysql) ListMigrations() string {
|
||||
func (m *Mysql) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, m.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -2,13 +2,11 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Postgres struct {
|
||||
Table string
|
||||
}
|
||||
type Postgres struct{}
|
||||
|
||||
var _ Querier = (*Postgres)(nil)
|
||||
|
||||
func (p *Postgres) CreateTable() string {
|
||||
func (p *Postgres) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id serial NOT NULL,
|
||||
version_id bigint NOT NULL,
|
||||
|
@ -16,25 +14,25 @@ func (p *Postgres) CreateTable() string {
|
|||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
)`
|
||||
return fmt.Sprintf(q, p.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (p *Postgres) InsertVersion() string {
|
||||
func (p *Postgres) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
|
||||
return fmt.Sprintf(q, p.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (p *Postgres) DeleteVersion() string {
|
||||
func (p *Postgres) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=$1`
|
||||
return fmt.Sprintf(q, p.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (p *Postgres) GetMigrationByVersion() string {
|
||||
func (p *Postgres) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, p.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (p *Postgres) ListMigrations() string {
|
||||
func (p *Postgres) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, p.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -2,13 +2,11 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Redshift struct {
|
||||
Table string
|
||||
}
|
||||
type Redshift struct{}
|
||||
|
||||
var _ Querier = (*Redshift)(nil)
|
||||
|
||||
func (r *Redshift) CreateTable() string {
|
||||
func (r *Redshift) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id integer NOT NULL identity(1, 1),
|
||||
version_id bigint NOT NULL,
|
||||
|
@ -16,25 +14,25 @@ func (r *Redshift) CreateTable() string {
|
|||
tstamp timestamp NULL default sysdate,
|
||||
PRIMARY KEY(id)
|
||||
)`
|
||||
return fmt.Sprintf(q, r.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (r *Redshift) InsertVersion() string {
|
||||
func (r *Redshift) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
|
||||
return fmt.Sprintf(q, r.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (r *Redshift) DeleteVersion() string {
|
||||
func (r *Redshift) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=$1`
|
||||
return fmt.Sprintf(q, r.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (r *Redshift) GetMigrationByVersion() string {
|
||||
func (r *Redshift) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, r.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (r *Redshift) ListMigrations() string {
|
||||
func (r *Redshift) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, r.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -2,38 +2,36 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Sqlite3 struct {
|
||||
Table string
|
||||
}
|
||||
type Sqlite3 struct{}
|
||||
|
||||
var _ Querier = (*Sqlite3)(nil)
|
||||
|
||||
func (s *Sqlite3) CreateTable() string {
|
||||
func (s *Sqlite3) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
version_id INTEGER NOT NULL,
|
||||
is_applied INTEGER NOT NULL,
|
||||
tstamp TIMESTAMP DEFAULT (datetime('now'))
|
||||
)`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlite3) InsertVersion() string {
|
||||
func (s *Sqlite3) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlite3) DeleteVersion() string {
|
||||
func (s *Sqlite3) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=?`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlite3) GetMigrationByVersion() string {
|
||||
func (s *Sqlite3) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlite3) ListMigrations() string {
|
||||
func (s *Sqlite3) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -2,33 +2,31 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Sqlserver struct {
|
||||
Table string
|
||||
}
|
||||
type Sqlserver struct{}
|
||||
|
||||
var _ Querier = (*Sqlserver)(nil)
|
||||
|
||||
func (s *Sqlserver) CreateTable() string {
|
||||
func (s *Sqlserver) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
|
||||
version_id BIGINT NOT NULL,
|
||||
is_applied BIT NOT NULL,
|
||||
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlserver) InsertVersion() string {
|
||||
func (s *Sqlserver) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2)`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlserver) DeleteVersion() string {
|
||||
func (s *Sqlserver) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=@p1`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlserver) GetMigrationByVersion() string {
|
||||
func (s *Sqlserver) GetMigrationByVersion(tableName string) string {
|
||||
q := `
|
||||
WITH Migrations AS
|
||||
(
|
||||
|
@ -42,10 +40,10 @@ FROM Migrations
|
|||
WHERE RowNumber BETWEEN 1 AND 2
|
||||
ORDER BY tstamp DESC
|
||||
`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (s *Sqlserver) ListMigrations() string {
|
||||
func (s *Sqlserver) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, s.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -2,13 +2,11 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Tidb struct {
|
||||
Table string
|
||||
}
|
||||
type Tidb struct{}
|
||||
|
||||
var _ Querier = (*Tidb)(nil)
|
||||
|
||||
func (t *Tidb) CreateTable() string {
|
||||
func (t *Tidb) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
|
||||
version_id bigint NOT NULL,
|
||||
|
@ -16,25 +14,25 @@ func (t *Tidb) CreateTable() string {
|
|||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
)`
|
||||
return fmt.Sprintf(q, t.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (t *Tidb) InsertVersion() string {
|
||||
func (t *Tidb) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
|
||||
return fmt.Sprintf(q, t.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (t *Tidb) DeleteVersion() string {
|
||||
func (t *Tidb) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=?`
|
||||
return fmt.Sprintf(q, t.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (t *Tidb) GetMigrationByVersion() string {
|
||||
func (t *Tidb) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, t.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (t *Tidb) ListMigrations() string {
|
||||
func (t *Tidb) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, t.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -2,13 +2,11 @@ package dialectquery
|
|||
|
||||
import "fmt"
|
||||
|
||||
type Vertica struct {
|
||||
Table string
|
||||
}
|
||||
type Vertica struct{}
|
||||
|
||||
var _ Querier = (*Vertica)(nil)
|
||||
|
||||
func (v *Vertica) CreateTable() string {
|
||||
func (v *Vertica) CreateTable(tableName string) string {
|
||||
q := `CREATE TABLE %s (
|
||||
id identity(1,1) NOT NULL,
|
||||
version_id bigint NOT NULL,
|
||||
|
@ -16,25 +14,25 @@ func (v *Vertica) CreateTable() string {
|
|||
tstamp timestamp NULL default now(),
|
||||
PRIMARY KEY(id)
|
||||
)`
|
||||
return fmt.Sprintf(q, v.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (v *Vertica) InsertVersion() string {
|
||||
func (v *Vertica) InsertVersion(tableName string) string {
|
||||
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
|
||||
return fmt.Sprintf(q, v.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (v *Vertica) DeleteVersion() string {
|
||||
func (v *Vertica) DeleteVersion(tableName string) string {
|
||||
q := `DELETE FROM %s WHERE version_id=?`
|
||||
return fmt.Sprintf(q, v.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (v *Vertica) GetMigrationByVersion() string {
|
||||
func (v *Vertica) GetMigrationByVersion(tableName string) string {
|
||||
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
|
||||
return fmt.Sprintf(q, v.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
||||
func (v *Vertica) ListMigrations() string {
|
||||
func (v *Vertica) ListMigrations(tableName string) string {
|
||||
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
|
||||
return fmt.Sprintf(q, v.Table)
|
||||
return fmt.Sprintf(q, tableName)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package dialect
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -22,55 +21,50 @@ import (
|
|||
type Store interface {
|
||||
// CreateVersionTable creates the version table within a transaction.
|
||||
// This table is used to store goose migrations.
|
||||
CreateVersionTable(ctx context.Context, tx *sql.Tx) error
|
||||
CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error
|
||||
|
||||
// InsertVersion inserts a version id into the version table within a transaction.
|
||||
InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error
|
||||
InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
|
||||
// InsertVersionNoTx inserts a version id into the version table without a transaction.
|
||||
InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error
|
||||
InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
|
||||
|
||||
// DeleteVersion deletes a version id from the version table within a transaction.
|
||||
DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error
|
||||
DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
|
||||
// DeleteVersionNoTx deletes a version id from the version table without a transaction.
|
||||
DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error
|
||||
DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
|
||||
|
||||
// GetMigrationRow retrieves a single migration by version id.
|
||||
//
|
||||
// Returns the raw sql error if the query fails. It is the callers responsibility
|
||||
// to assert for the correct error, such as sql.ErrNoRows.
|
||||
GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error)
|
||||
GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error)
|
||||
|
||||
// ListMigrations retrieves all migrations sorted in descending order by id.
|
||||
//
|
||||
// If there are no migrations, an empty slice is returned with no error.
|
||||
ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error)
|
||||
ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error)
|
||||
}
|
||||
|
||||
// NewStore returns a new Store for the given dialect.
|
||||
//
|
||||
// The table name is used to store the goose migrations.
|
||||
func NewStore(d Dialect, table string) (Store, error) {
|
||||
if table == "" {
|
||||
return nil, errors.New("table name cannot be empty")
|
||||
}
|
||||
func NewStore(d Dialect) (Store, error) {
|
||||
var querier dialectquery.Querier
|
||||
switch d {
|
||||
case Postgres:
|
||||
querier = &dialectquery.Postgres{Table: table}
|
||||
querier = &dialectquery.Postgres{}
|
||||
case Mysql:
|
||||
querier = &dialectquery.Mysql{Table: table}
|
||||
querier = &dialectquery.Mysql{}
|
||||
case Sqlite3:
|
||||
querier = &dialectquery.Sqlite3{Table: table}
|
||||
querier = &dialectquery.Sqlite3{}
|
||||
case Sqlserver:
|
||||
querier = &dialectquery.Sqlserver{Table: table}
|
||||
querier = &dialectquery.Sqlserver{}
|
||||
case Redshift:
|
||||
querier = &dialectquery.Redshift{Table: table}
|
||||
querier = &dialectquery.Redshift{}
|
||||
case Tidb:
|
||||
querier = &dialectquery.Tidb{Table: table}
|
||||
querier = &dialectquery.Tidb{}
|
||||
case Clickhouse:
|
||||
querier = &dialectquery.Clickhouse{Table: table}
|
||||
querier = &dialectquery.Clickhouse{}
|
||||
case Vertica:
|
||||
querier = &dialectquery.Vertica{Table: table}
|
||||
querier = &dialectquery.Vertica{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown querier dialect: %v", d)
|
||||
}
|
||||
|
@ -93,38 +87,38 @@ type store struct {
|
|||
|
||||
var _ Store = (*store)(nil)
|
||||
|
||||
func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx) error {
|
||||
q := s.querier.CreateTable()
|
||||
func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error {
|
||||
q := s.querier.CreateTable(tableName)
|
||||
_, err := tx.ExecContext(ctx, q)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error {
|
||||
q := s.querier.InsertVersion()
|
||||
func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
|
||||
q := s.querier.InsertVersion(tableName)
|
||||
_, err := tx.ExecContext(ctx, q, version, true)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error {
|
||||
q := s.querier.InsertVersion()
|
||||
func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
|
||||
q := s.querier.InsertVersion(tableName)
|
||||
_, err := db.ExecContext(ctx, q, version, true)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error {
|
||||
q := s.querier.DeleteVersion()
|
||||
func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
|
||||
q := s.querier.DeleteVersion(tableName)
|
||||
_, err := tx.ExecContext(ctx, q, version)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error {
|
||||
q := s.querier.DeleteVersion()
|
||||
func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
|
||||
q := s.querier.DeleteVersion(tableName)
|
||||
_, err := db.ExecContext(ctx, q, version)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) {
|
||||
q := s.querier.GetMigrationByVersion()
|
||||
func (s *store) GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error) {
|
||||
q := s.querier.GetMigrationByVersion(tableName)
|
||||
var timestamp time.Time
|
||||
var isApplied bool
|
||||
err := db.QueryRowContext(ctx, q, version).Scan(×tamp, &isApplied)
|
||||
|
@ -137,8 +131,8 @@ func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*G
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *store) ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) {
|
||||
q := s.querier.ListMigrations()
|
||||
func (s *store) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) {
|
||||
q := s.querier.ListMigrations(tableName)
|
||||
rows, err := db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -296,7 +296,7 @@ func versionFilter(v, current, target int64) bool {
|
|||
// Create and initialize the DB version table if it doesn't exist.
|
||||
func EnsureDBVersion(db *sql.DB) (int64, error) {
|
||||
ctx := context.Background()
|
||||
dbMigrations, err := store.ListMigrations(ctx, db)
|
||||
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
|
||||
if err != nil {
|
||||
return 0, createVersionTable(ctx, db)
|
||||
}
|
||||
|
@ -336,11 +336,11 @@ func createVersionTable(ctx context.Context, db *sql.DB) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := store.CreateVersionTable(ctx, txn); err != nil {
|
||||
if err := store.CreateVersionTable(ctx, txn, TableName()); err != nil {
|
||||
_ = txn.Rollback()
|
||||
return err
|
||||
}
|
||||
if err := store.InsertVersion(ctx, txn, 0); err != nil {
|
||||
if err := store.InsertVersion(ctx, txn, TableName(), 0); err != nil {
|
||||
_ = txn.Rollback()
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -188,16 +188,16 @@ func runGoMigration(
|
|||
|
||||
func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
|
||||
if direction {
|
||||
return store.InsertVersion(ctx, tx, version)
|
||||
return store.InsertVersion(ctx, tx, TableName(), version)
|
||||
}
|
||||
return store.DeleteVersion(ctx, tx, version)
|
||||
return store.DeleteVersion(ctx, tx, TableName(), version)
|
||||
}
|
||||
|
||||
func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
|
||||
if direction {
|
||||
return store.InsertVersionNoTx(ctx, db, version)
|
||||
return store.InsertVersionNoTx(ctx, db, TableName(), version)
|
||||
}
|
||||
return store.DeleteVersionNoTx(ctx, db, version)
|
||||
return store.DeleteVersionNoTx(ctx, db, TableName(), version)
|
||||
}
|
||||
|
||||
// NumericComponent looks for migration scripts with names in the form:
|
||||
|
|
|
@ -45,13 +45,13 @@ func runSQLMigration(
|
|||
|
||||
if !noVersioning {
|
||||
if direction {
|
||||
if err := store.InsertVersion(ctx, tx, v); err != nil {
|
||||
if err := store.InsertVersion(ctx, tx, TableName(), v); err != nil {
|
||||
verboseInfo("Rollback transaction")
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("failed to insert new goose version: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := store.DeleteVersion(ctx, tx, v); err != nil {
|
||||
if err := store.DeleteVersion(ctx, tx, TableName(), v); err != nil {
|
||||
verboseInfo("Rollback transaction")
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("failed to delete goose version: %w", err)
|
||||
|
@ -76,11 +76,11 @@ func runSQLMigration(
|
|||
}
|
||||
if !noVersioning {
|
||||
if direction {
|
||||
if err := store.InsertVersionNoTx(ctx, db, v); err != nil {
|
||||
if err := store.InsertVersionNoTx(ctx, db, TableName(), v); err != nil {
|
||||
return fmt.Errorf("failed to insert new goose version: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := store.DeleteVersionNoTx(ctx, db, v); err != nil {
|
||||
if err := store.DeleteVersionNoTx(ctx, db, TableName(), v); err != nil {
|
||||
return fmt.Errorf("failed to delete goose version: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
2
reset.go
2
reset.go
|
@ -41,7 +41,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
|
|||
}
|
||||
|
||||
func dbMigrationsStatus(ctx context.Context, db *sql.DB) (map[int64]bool, error) {
|
||||
dbMigrations, err := store.ListMigrations(ctx, db)
|
||||
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
|
|||
}
|
||||
|
||||
func printMigrationStatus(ctx context.Context, db *sql.DB, version int64, script string) error {
|
||||
m, err := store.GetMigration(ctx, db, version)
|
||||
m, err := store.GetMigration(ctx, db, TableName(), version)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return fmt.Errorf("failed to query the latest migration: %w", err)
|
||||
}
|
||||
|
|
2
up.go
2
up.go
|
@ -225,7 +225,7 @@ func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error {
|
|||
// listAllDBVersions returns a list of all migrations, ordered ascending.
|
||||
// TODO(mf): fairly cheap, but a nice-to-have is pagination support.
|
||||
func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) {
|
||||
dbMigrations, err := store.ListMigrations(ctx, db)
|
||||
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue