fix: use global table name in queries (#515)

pull/516/head
Michael Fridman 2023-05-08 08:17:14 -04:00 committed by GitHub
parent 7ce30b743d
commit 87592390b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 139 additions and 161 deletions

View File

@ -7,7 +7,7 @@ import (
) )
func init() { func init() {
store, _ = dialect.NewStore(dialect.Postgres, TableName()) store, _ = dialect.NewStore(dialect.Postgres)
} }
var store dialect.Store var store dialect.Store
@ -36,6 +36,6 @@ func SetDialect(s string) error {
return fmt.Errorf("%q: unknown dialect", s) return fmt.Errorf("%q: unknown dialect", s)
} }
var err error var err error
store, err = dialect.NewStore(d, TableName()) store, err = dialect.NewStore(d)
return err return err
} }

View File

@ -2,13 +2,11 @@ package dialectquery
import "fmt" import "fmt"
type Clickhouse struct { type Clickhouse struct{}
Table string
}
var _ Querier = (*Clickhouse)(nil) var _ Querier = (*Clickhouse)(nil)
func (c *Clickhouse) CreateTable() string { func (c *Clickhouse) CreateTable(tableName string) string {
q := `CREATE TABLE IF NOT EXISTS %s ( q := `CREATE TABLE IF NOT EXISTS %s (
version_id Int64, version_id Int64,
is_applied UInt8, is_applied UInt8,
@ -17,25 +15,25 @@ func (c *Clickhouse) CreateTable() string {
) )
ENGINE = MergeTree() ENGINE = MergeTree()
ORDER BY (date)` ORDER BY (date)`
return fmt.Sprintf(q, c.Table) return fmt.Sprintf(q, tableName)
} }
func (c *Clickhouse) InsertVersion() string { func (c *Clickhouse) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, c.Table) return fmt.Sprintf(q, tableName)
} }
func (c *Clickhouse) DeleteVersion() string { func (c *Clickhouse) DeleteVersion(tableName string) string {
q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2` q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2`
return fmt.Sprintf(q, c.Table) return fmt.Sprintf(q, tableName)
} }
func (c *Clickhouse) GetMigrationByVersion() string { func (c *Clickhouse) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, c.Table) return fmt.Sprintf(q, tableName)
} }
func (c *Clickhouse) ListMigrations() string { func (c *Clickhouse) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC` q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC`
return fmt.Sprintf(q, c.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -4,25 +4,25 @@ package dialectquery
// specific query. // specific query.
type Querier interface { type Querier interface {
// CreateTable returns the SQL query string to create the db version table. // CreateTable returns the SQL query string to create the db version table.
CreateTable() string CreateTable(tableName string) string
// InsertVersion returns the SQL query string to insert a new version into // InsertVersion returns the SQL query string to insert a new version into
// the db version table. // the db version table.
InsertVersion() string InsertVersion(tableName string) string
// DeleteVersion returns the SQL query string to delete a version from // DeleteVersion returns the SQL query string to delete a version from
// the db version table. // the db version table.
DeleteVersion() string DeleteVersion(tableName string) string
// GetMigrationByVersion returns the SQL query string to get a single // GetMigrationByVersion returns the SQL query string to get a single
// migration by version. // migration by version.
// //
// The query should return the timestamp and is_applied columns. // The query should return the timestamp and is_applied columns.
GetMigrationByVersion() string GetMigrationByVersion(tableName string) string
// ListMigrations returns the SQL query string to list all migrations in // ListMigrations returns the SQL query string to list all migrations in
// descending order by id. // descending order by id.
// //
// The query should return the version_id and is_applied columns. // The query should return the version_id and is_applied columns.
ListMigrations() string ListMigrations(tableName string) string
} }

View File

@ -2,13 +2,11 @@ package dialectquery
import "fmt" import "fmt"
type Mysql struct { type Mysql struct{}
Table string
}
var _ Querier = (*Mysql)(nil) var _ Querier = (*Mysql)(nil)
func (m *Mysql) CreateTable() string { func (m *Mysql) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id serial NOT NULL, id serial NOT NULL,
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -16,25 +14,25 @@ func (m *Mysql) CreateTable() string {
tstamp timestamp NULL default now(), tstamp timestamp NULL default now(),
PRIMARY KEY(id) PRIMARY KEY(id)
)` )`
return fmt.Sprintf(q, m.Table) return fmt.Sprintf(q, tableName)
} }
func (m *Mysql) InsertVersion() string { func (m *Mysql) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, m.Table) return fmt.Sprintf(q, tableName)
} }
func (m *Mysql) DeleteVersion() string { func (m *Mysql) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?` q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, m.Table) return fmt.Sprintf(q, tableName)
} }
func (m *Mysql) GetMigrationByVersion() string { func (m *Mysql) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, m.Table) return fmt.Sprintf(q, tableName)
} }
func (m *Mysql) ListMigrations() string { func (m *Mysql) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, m.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -2,13 +2,11 @@ package dialectquery
import "fmt" import "fmt"
type Postgres struct { type Postgres struct{}
Table string
}
var _ Querier = (*Postgres)(nil) var _ Querier = (*Postgres)(nil)
func (p *Postgres) CreateTable() string { func (p *Postgres) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id serial NOT NULL, id serial NOT NULL,
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -16,25 +14,25 @@ func (p *Postgres) CreateTable() string {
tstamp timestamp NULL default now(), tstamp timestamp NULL default now(),
PRIMARY KEY(id) PRIMARY KEY(id)
)` )`
return fmt.Sprintf(q, p.Table) return fmt.Sprintf(q, tableName)
} }
func (p *Postgres) InsertVersion() string { func (p *Postgres) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, p.Table) return fmt.Sprintf(q, tableName)
} }
func (p *Postgres) DeleteVersion() string { func (p *Postgres) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=$1` q := `DELETE FROM %s WHERE version_id=$1`
return fmt.Sprintf(q, p.Table) return fmt.Sprintf(q, tableName)
} }
func (p *Postgres) GetMigrationByVersion() string { func (p *Postgres) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, p.Table) return fmt.Sprintf(q, tableName)
} }
func (p *Postgres) ListMigrations() string { func (p *Postgres) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, p.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -2,13 +2,11 @@ package dialectquery
import "fmt" import "fmt"
type Redshift struct { type Redshift struct{}
Table string
}
var _ Querier = (*Redshift)(nil) var _ Querier = (*Redshift)(nil)
func (r *Redshift) CreateTable() string { func (r *Redshift) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id integer NOT NULL identity(1, 1), id integer NOT NULL identity(1, 1),
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -16,25 +14,25 @@ func (r *Redshift) CreateTable() string {
tstamp timestamp NULL default sysdate, tstamp timestamp NULL default sysdate,
PRIMARY KEY(id) PRIMARY KEY(id)
)` )`
return fmt.Sprintf(q, r.Table) return fmt.Sprintf(q, tableName)
} }
func (r *Redshift) InsertVersion() string { func (r *Redshift) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, r.Table) return fmt.Sprintf(q, tableName)
} }
func (r *Redshift) DeleteVersion() string { func (r *Redshift) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=$1` q := `DELETE FROM %s WHERE version_id=$1`
return fmt.Sprintf(q, r.Table) return fmt.Sprintf(q, tableName)
} }
func (r *Redshift) GetMigrationByVersion() string { func (r *Redshift) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, r.Table) return fmt.Sprintf(q, tableName)
} }
func (r *Redshift) ListMigrations() string { func (r *Redshift) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, r.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -2,38 +2,36 @@ package dialectquery
import "fmt" import "fmt"
type Sqlite3 struct { type Sqlite3 struct{}
Table string
}
var _ Querier = (*Sqlite3)(nil) var _ Querier = (*Sqlite3)(nil)
func (s *Sqlite3) CreateTable() string { func (s *Sqlite3) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL, version_id INTEGER NOT NULL,
is_applied INTEGER NOT NULL, is_applied INTEGER NOT NULL,
tstamp TIMESTAMP DEFAULT (datetime('now')) tstamp TIMESTAMP DEFAULT (datetime('now'))
)` )`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlite3) InsertVersion() string { func (s *Sqlite3) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlite3) DeleteVersion() string { func (s *Sqlite3) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?` q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlite3) GetMigrationByVersion() string { func (s *Sqlite3) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlite3) ListMigrations() string { func (s *Sqlite3) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -2,33 +2,31 @@ package dialectquery
import "fmt" import "fmt"
type Sqlserver struct { type Sqlserver struct{}
Table string
}
var _ Querier = (*Sqlserver)(nil) var _ Querier = (*Sqlserver)(nil)
func (s *Sqlserver) CreateTable() string { func (s *Sqlserver) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
version_id BIGINT NOT NULL, version_id BIGINT NOT NULL,
is_applied BIT NOT NULL, is_applied BIT NOT NULL,
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
)` )`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlserver) InsertVersion() string { func (s *Sqlserver) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2)` q := `INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2)`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlserver) DeleteVersion() string { func (s *Sqlserver) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=@p1` q := `DELETE FROM %s WHERE version_id=@p1`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlserver) GetMigrationByVersion() string { func (s *Sqlserver) GetMigrationByVersion(tableName string) string {
q := ` q := `
WITH Migrations AS WITH Migrations AS
( (
@ -42,10 +40,10 @@ FROM Migrations
WHERE RowNumber BETWEEN 1 AND 2 WHERE RowNumber BETWEEN 1 AND 2
ORDER BY tstamp DESC ORDER BY tstamp DESC
` `
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }
func (s *Sqlserver) ListMigrations() string { func (s *Sqlserver) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC` q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC`
return fmt.Sprintf(q, s.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -2,13 +2,11 @@ package dialectquery
import "fmt" import "fmt"
type Tidb struct { type Tidb struct{}
Table string
}
var _ Querier = (*Tidb)(nil) var _ Querier = (*Tidb)(nil)
func (t *Tidb) CreateTable() string { func (t *Tidb) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -16,25 +14,25 @@ func (t *Tidb) CreateTable() string {
tstamp timestamp NULL default now(), tstamp timestamp NULL default now(),
PRIMARY KEY(id) PRIMARY KEY(id)
)` )`
return fmt.Sprintf(q, t.Table) return fmt.Sprintf(q, tableName)
} }
func (t *Tidb) InsertVersion() string { func (t *Tidb) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, t.Table) return fmt.Sprintf(q, tableName)
} }
func (t *Tidb) DeleteVersion() string { func (t *Tidb) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?` q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, t.Table) return fmt.Sprintf(q, tableName)
} }
func (t *Tidb) GetMigrationByVersion() string { func (t *Tidb) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, t.Table) return fmt.Sprintf(q, tableName)
} }
func (t *Tidb) ListMigrations() string { func (t *Tidb) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, t.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -2,13 +2,11 @@ package dialectquery
import "fmt" import "fmt"
type Vertica struct { type Vertica struct{}
Table string
}
var _ Querier = (*Vertica)(nil) var _ Querier = (*Vertica)(nil)
func (v *Vertica) CreateTable() string { func (v *Vertica) CreateTable(tableName string) string {
q := `CREATE TABLE %s ( q := `CREATE TABLE %s (
id identity(1,1) NOT NULL, id identity(1,1) NOT NULL,
version_id bigint NOT NULL, version_id bigint NOT NULL,
@ -16,25 +14,25 @@ func (v *Vertica) CreateTable() string {
tstamp timestamp NULL default now(), tstamp timestamp NULL default now(),
PRIMARY KEY(id) PRIMARY KEY(id)
)` )`
return fmt.Sprintf(q, v.Table) return fmt.Sprintf(q, tableName)
} }
func (v *Vertica) InsertVersion() string { func (v *Vertica) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, v.Table) return fmt.Sprintf(q, tableName)
} }
func (v *Vertica) DeleteVersion() string { func (v *Vertica) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?` q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, v.Table) return fmt.Sprintf(q, tableName)
} }
func (v *Vertica) GetMigrationByVersion() string { func (v *Vertica) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, v.Table) return fmt.Sprintf(q, tableName)
} }
func (v *Vertica) ListMigrations() string { func (v *Vertica) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC` q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, v.Table) return fmt.Sprintf(q, tableName)
} }

View File

@ -3,7 +3,6 @@ package dialect
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"time" "time"
@ -22,55 +21,50 @@ import (
type Store interface { type Store interface {
// CreateVersionTable creates the version table within a transaction. // CreateVersionTable creates the version table within a transaction.
// This table is used to store goose migrations. // This table is used to store goose migrations.
CreateVersionTable(ctx context.Context, tx *sql.Tx) error CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error
// InsertVersion inserts a version id into the version table within a transaction. // InsertVersion inserts a version id into the version table within a transaction.
InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
// InsertVersionNoTx inserts a version id into the version table without a transaction. // InsertVersionNoTx inserts a version id into the version table without a transaction.
InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
// DeleteVersion deletes a version id from the version table within a transaction. // DeleteVersion deletes a version id from the version table within a transaction.
DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
// DeleteVersionNoTx deletes a version id from the version table without a transaction. // DeleteVersionNoTx deletes a version id from the version table without a transaction.
DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error
// GetMigrationRow retrieves a single migration by version id. // GetMigrationRow retrieves a single migration by version id.
// //
// Returns the raw sql error if the query fails. It is the callers responsibility // Returns the raw sql error if the query fails. It is the callers responsibility
// to assert for the correct error, such as sql.ErrNoRows. // to assert for the correct error, such as sql.ErrNoRows.
GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error)
// ListMigrations retrieves all migrations sorted in descending order by id. // ListMigrations retrieves all migrations sorted in descending order by id.
// //
// If there are no migrations, an empty slice is returned with no error. // If there are no migrations, an empty slice is returned with no error.
ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error)
} }
// NewStore returns a new Store for the given dialect. // NewStore returns a new Store for the given dialect.
// func NewStore(d Dialect) (Store, error) {
// The table name is used to store the goose migrations.
func NewStore(d Dialect, table string) (Store, error) {
if table == "" {
return nil, errors.New("table name cannot be empty")
}
var querier dialectquery.Querier var querier dialectquery.Querier
switch d { switch d {
case Postgres: case Postgres:
querier = &dialectquery.Postgres{Table: table} querier = &dialectquery.Postgres{}
case Mysql: case Mysql:
querier = &dialectquery.Mysql{Table: table} querier = &dialectquery.Mysql{}
case Sqlite3: case Sqlite3:
querier = &dialectquery.Sqlite3{Table: table} querier = &dialectquery.Sqlite3{}
case Sqlserver: case Sqlserver:
querier = &dialectquery.Sqlserver{Table: table} querier = &dialectquery.Sqlserver{}
case Redshift: case Redshift:
querier = &dialectquery.Redshift{Table: table} querier = &dialectquery.Redshift{}
case Tidb: case Tidb:
querier = &dialectquery.Tidb{Table: table} querier = &dialectquery.Tidb{}
case Clickhouse: case Clickhouse:
querier = &dialectquery.Clickhouse{Table: table} querier = &dialectquery.Clickhouse{}
case Vertica: case Vertica:
querier = &dialectquery.Vertica{Table: table} querier = &dialectquery.Vertica{}
default: default:
return nil, fmt.Errorf("unknown querier dialect: %v", d) return nil, fmt.Errorf("unknown querier dialect: %v", d)
} }
@ -93,38 +87,38 @@ type store struct {
var _ Store = (*store)(nil) var _ Store = (*store)(nil)
func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx) error { func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error {
q := s.querier.CreateTable() q := s.querier.CreateTable(tableName)
_, err := tx.ExecContext(ctx, q) _, err := tx.ExecContext(ctx, q)
return err return err
} }
func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error { func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
q := s.querier.InsertVersion() q := s.querier.InsertVersion(tableName)
_, err := tx.ExecContext(ctx, q, version, true) _, err := tx.ExecContext(ctx, q, version, true)
return err return err
} }
func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error { func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
q := s.querier.InsertVersion() q := s.querier.InsertVersion(tableName)
_, err := db.ExecContext(ctx, q, version, true) _, err := db.ExecContext(ctx, q, version, true)
return err return err
} }
func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error { func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
q := s.querier.DeleteVersion() q := s.querier.DeleteVersion(tableName)
_, err := tx.ExecContext(ctx, q, version) _, err := tx.ExecContext(ctx, q, version)
return err return err
} }
func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error { func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
q := s.querier.DeleteVersion() q := s.querier.DeleteVersion(tableName)
_, err := db.ExecContext(ctx, q, version) _, err := db.ExecContext(ctx, q, version)
return err return err
} }
func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) { func (s *store) GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion() q := s.querier.GetMigrationByVersion(tableName)
var timestamp time.Time var timestamp time.Time
var isApplied bool var isApplied bool
err := db.QueryRowContext(ctx, q, version).Scan(&timestamp, &isApplied) err := db.QueryRowContext(ctx, q, version).Scan(&timestamp, &isApplied)
@ -137,8 +131,8 @@ func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*G
}, nil }, nil
} }
func (s *store) ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) { func (s *store) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations() q := s.querier.ListMigrations(tableName)
rows, err := db.QueryContext(ctx, q) rows, err := db.QueryContext(ctx, q)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -296,7 +296,7 @@ func versionFilter(v, current, target int64) bool {
// Create and initialize the DB version table if it doesn't exist. // Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) { func EnsureDBVersion(db *sql.DB) (int64, error) {
ctx := context.Background() ctx := context.Background()
dbMigrations, err := store.ListMigrations(ctx, db) dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil { if err != nil {
return 0, createVersionTable(ctx, db) return 0, createVersionTable(ctx, db)
} }
@ -336,11 +336,11 @@ func createVersionTable(ctx context.Context, db *sql.DB) error {
if err != nil { if err != nil {
return err return err
} }
if err := store.CreateVersionTable(ctx, txn); err != nil { if err := store.CreateVersionTable(ctx, txn, TableName()); err != nil {
_ = txn.Rollback() _ = txn.Rollback()
return err return err
} }
if err := store.InsertVersion(ctx, txn, 0); err != nil { if err := store.InsertVersion(ctx, txn, TableName(), 0); err != nil {
_ = txn.Rollback() _ = txn.Rollback()
return err return err
} }

View File

@ -188,16 +188,16 @@ func runGoMigration(
func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error { func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
if direction { if direction {
return store.InsertVersion(ctx, tx, version) return store.InsertVersion(ctx, tx, TableName(), version)
} }
return store.DeleteVersion(ctx, tx, version) return store.DeleteVersion(ctx, tx, TableName(), version)
} }
func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error { func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
if direction { if direction {
return store.InsertVersionNoTx(ctx, db, version) return store.InsertVersionNoTx(ctx, db, TableName(), version)
} }
return store.DeleteVersionNoTx(ctx, db, version) return store.DeleteVersionNoTx(ctx, db, TableName(), version)
} }
// NumericComponent looks for migration scripts with names in the form: // NumericComponent looks for migration scripts with names in the form:

View File

@ -45,13 +45,13 @@ func runSQLMigration(
if !noVersioning { if !noVersioning {
if direction { if direction {
if err := store.InsertVersion(ctx, tx, v); err != nil { if err := store.InsertVersion(ctx, tx, TableName(), v); err != nil {
verboseInfo("Rollback transaction") verboseInfo("Rollback transaction")
_ = tx.Rollback() _ = tx.Rollback()
return fmt.Errorf("failed to insert new goose version: %w", err) return fmt.Errorf("failed to insert new goose version: %w", err)
} }
} else { } else {
if err := store.DeleteVersion(ctx, tx, v); err != nil { if err := store.DeleteVersion(ctx, tx, TableName(), v); err != nil {
verboseInfo("Rollback transaction") verboseInfo("Rollback transaction")
_ = tx.Rollback() _ = tx.Rollback()
return fmt.Errorf("failed to delete goose version: %w", err) return fmt.Errorf("failed to delete goose version: %w", err)
@ -76,11 +76,11 @@ func runSQLMigration(
} }
if !noVersioning { if !noVersioning {
if direction { if direction {
if err := store.InsertVersionNoTx(ctx, db, v); err != nil { if err := store.InsertVersionNoTx(ctx, db, TableName(), v); err != nil {
return fmt.Errorf("failed to insert new goose version: %w", err) return fmt.Errorf("failed to insert new goose version: %w", err)
} }
} else { } else {
if err := store.DeleteVersionNoTx(ctx, db, v); err != nil { if err := store.DeleteVersionNoTx(ctx, db, TableName(), v); err != nil {
return fmt.Errorf("failed to delete goose version: %w", err) return fmt.Errorf("failed to delete goose version: %w", err)
} }
} }

View File

@ -41,7 +41,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
} }
func dbMigrationsStatus(ctx context.Context, db *sql.DB) (map[int64]bool, error) { func dbMigrationsStatus(ctx context.Context, db *sql.DB) (map[int64]bool, error) {
dbMigrations, err := store.ListMigrations(ctx, db) dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -46,7 +46,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
} }
func printMigrationStatus(ctx context.Context, db *sql.DB, version int64, script string) error { func printMigrationStatus(ctx context.Context, db *sql.DB, version int64, script string) error {
m, err := store.GetMigration(ctx, db, version) m, err := store.GetMigration(ctx, db, TableName(), version)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("failed to query the latest migration: %w", err) return fmt.Errorf("failed to query the latest migration: %w", err)
} }

2
up.go
View File

@ -225,7 +225,7 @@ func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error {
// listAllDBVersions returns a list of all migrations, ordered ascending. // listAllDBVersions returns a list of all migrations, ordered ascending.
// TODO(mf): fairly cheap, but a nice-to-have is pagination support. // TODO(mf): fairly cheap, but a nice-to-have is pagination support.
func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) { func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) {
dbMigrations, err := store.ListMigrations(ctx, db) dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil { if err != nil {
return nil, err return nil, err
} }