diff --git a/CHANGELOG.md b/CHANGELOG.md index e05ea71..6e2c27a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Add `TableExists` table existence check for the mysql dialect (#895) + ## [v3.24.1] - Fix regression (`v3.23.1` and `v3.24.0`) in postgres migration table existence check for diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index 49ce96f..389da24 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -1,5 +1,7 @@ package dialectquery +import "strings" + // Querier is the interface that wraps the basic methods to create a dialect specific query. type Querier interface { // CreateTable returns the SQL query string to create the db version table. @@ -47,3 +49,11 @@ func (c *QueryController) TableExists(tableName string) string { } return "" } + +func parseTableIdentifier(name string) (schema, table string) { + schema, table, found := strings.Cut(name, ".") + if !found { + return "", name + } + return schema, table +} diff --git a/internal/dialect/dialectquery/mysql.go b/internal/dialect/dialectquery/mysql.go index 1ce165c..ea2dcae 100644 --- a/internal/dialect/dialectquery/mysql.go +++ b/internal/dialect/dialectquery/mysql.go @@ -41,3 +41,13 @@ func (m *Mysql) GetLatestVersion(tableName string) string { q := `SELECT MAX(version_id) FROM %s` return fmt.Sprintf(q, tableName) } + +func (m *Mysql) TableExists(tableName string) string { + schemaName, tableName := parseTableIdentifier(tableName) + if schemaName != "" { + q := `SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = '%s' AND table_name = '%s' )` + return fmt.Sprintf(q, schemaName, tableName) + } + q := `SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE (database() IS NULL OR table_schema = database()) AND table_name = '%s' )` + return fmt.Sprintf(q, tableName) +} diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go index 580c1a9..9c44999 100644 --- a/internal/dialect/dialectquery/postgres.go +++ b/internal/dialect/dialectquery/postgres.go @@ -2,7 +2,6 @@ package dialectquery import ( "fmt" - "strings" ) type Postgres struct{} @@ -53,11 +52,3 @@ func (p *Postgres) TableExists(tableName string) string { q := `SELECT EXISTS ( SELECT 1 FROM pg_tables WHERE (current_schema() IS NULL OR schemaname = current_schema()) AND tablename = '%s' )` return fmt.Sprintf(q, tableName) } - -func parseTableIdentifier(name string) (schema, table string) { - schema, table, found := strings.Cut(name, ".") - if !found { - return "", name - } - return schema, table -}