package sqlparser import ( "fmt" "os" "path/filepath" "strconv" "strings" "testing" "github.com/pressly/goose/v3/internal/check" ) var ( debug = false ) func TestMain(m *testing.M) { debug, _ = strconv.ParseBool(os.Getenv("DEBUG_TEST")) os.Exit(m.Run()) } func TestSemicolons(t *testing.T) { t.Parallel() type testData struct { line string result bool } tests := []testData{ {line: "END;", result: true}, {line: "END; -- comment", result: true}, {line: "END ; -- comment", result: true}, {line: "END -- comment", result: false}, {line: "END -- comment ;", result: false}, {line: "END \" ; \" -- comment", result: false}, } for _, test := range tests { r := endsWithSemicolon(test.line) if r != test.result { t.Errorf("incorrect semicolon. got %v, want %v", r, test.result) } } } func TestSplitStatements(t *testing.T) { t.Parallel() type testData struct { sql string up int down int } tt := []testData{ {sql: multilineSQL, up: 4, down: 1}, {sql: emptySQL, up: 0, down: 0}, {sql: emptySQL2, up: 0, down: 0}, {sql: functxt, up: 2, down: 2}, {sql: mysqlChangeDelimiter, up: 4, down: 0}, {sql: copyFromStdin, up: 1, down: 0}, {sql: plpgsqlSyntax, up: 2, down: 2}, {sql: plpgsqlSyntaxMixedStatements, up: 2, down: 2}, } for i, test := range tt { // up stmts, _, err := ParseSQLMigration(strings.NewReader(test.sql), DirectionUp, debug) if err != nil { t.Error(fmt.Errorf("tt[%v] unexpected error: %w", i, err)) } if len(stmts) != test.up { t.Errorf("tt[%v] incorrect number of up stmts. got %v (%+v), want %v", i, len(stmts), stmts, test.up) } // down stmts, _, err = ParseSQLMigration(strings.NewReader(test.sql), DirectionDown, debug) if err != nil { t.Error(fmt.Errorf("tt[%v] unexpected error: %w", i, err)) } if len(stmts) != test.down { t.Errorf("tt[%v] incorrect number of down stmts. got %v (%+v), want %v", i, len(stmts), stmts, test.down) } } } func TestInvalidUp(t *testing.T) { t.Parallel() testdataDir := filepath.Join("testdata", "invalid", "up") entries, err := os.ReadDir(testdataDir) check.NoError(t, err) check.NumberNotZero(t, len(entries)) for _, entry := range entries { by, err := os.ReadFile(filepath.Join(testdataDir, entry.Name())) check.NoError(t, err) _, _, err = ParseSQLMigration(strings.NewReader(string(by)), DirectionUp, false) check.HasError(t, err) } } func TestUseTransactions(t *testing.T) { t.Parallel() type testData struct { fileName string useTransactions bool } tests := []testData{ {fileName: "testdata/valid-txn/00001_create_users_table.sql", useTransactions: true}, {fileName: "testdata/valid-txn/00002_rename_root.sql", useTransactions: true}, {fileName: "testdata/valid-txn/00003_no_transaction.sql", useTransactions: false}, } for _, test := range tests { f, err := os.Open(test.fileName) if err != nil { t.Error(err) } _, useTx, err := ParseSQLMigration(f, DirectionUp, debug) if err != nil { t.Error(err) } if useTx != test.useTransactions { t.Errorf("Failed transaction check. got %v, want %v", useTx, test.useTransactions) } f.Close() } } func TestParsingErrors(t *testing.T) { tt := []string{ statementBeginNoStatementEnd, unfinishedSQL, noUpDownAnnotations, multiUpDown, downFirst, } for i, sql := range tt { _, _, err := ParseSQLMigration(strings.NewReader(sql), DirectionUp, debug) if err == nil { t.Errorf("expected error on tt[%v] %q", i, sql) } } } var multilineSQL = `-- +goose Up CREATE TABLE post ( id int NOT NULL, title text, body text, PRIMARY KEY(id) ); -- 1st stmt -- comment SELECT 2; -- 2nd stmt SELECT 3; SELECT 3; -- 3rd stmt SELECT 4; -- 4th stmt -- +goose Down -- comment DROP TABLE post; -- 1st stmt ` var functxt = `-- +goose Up CREATE TABLE IF NOT EXISTS histories ( id BIGSERIAL PRIMARY KEY, current_value varchar(2000) NOT NULL, created_at timestamp with time zone NOT NULL ); -- +goose StatementBegin CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) returns void AS $$ DECLARE create_query text; BEGIN FOR create_query IN SELECT 'CREATE TABLE IF NOT EXISTS histories_' || TO_CHAR( d, 'YYYY_MM' ) || ' ( CHECK( created_at >= timestamp ''' || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) || ''' AND created_at < timestamp ''' || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) || ''' ) ) inherits ( histories );' FROM generate_series( $1, $2, '1 month' ) AS d LOOP EXECUTE create_query; END LOOP; -- LOOP END END; -- FUNCTION END $$ language plpgsql; -- +goose StatementEnd -- +goose Down drop function histories_partition_creation(DATE, DATE); drop TABLE histories; ` var multiUpDown = `-- +goose Up CREATE TABLE post ( id int NOT NULL, title text, body text, PRIMARY KEY(id) ); -- +goose Down DROP TABLE post; -- +goose Up CREATE TABLE fancier_post ( id int NOT NULL, title text, body text, created_on timestamp without time zone, PRIMARY KEY(id) ); ` var downFirst = `-- +goose Down DROP TABLE fancier_post; ` var statementBeginNoStatementEnd = `-- +goose Up CREATE TABLE IF NOT EXISTS histories ( id BIGSERIAL PRIMARY KEY, current_value varchar(2000) NOT NULL, created_at timestamp with time zone NOT NULL ); -- +goose StatementBegin CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) returns void AS $$ DECLARE create_query text; BEGIN FOR create_query IN SELECT 'CREATE TABLE IF NOT EXISTS histories_' || TO_CHAR( d, 'YYYY_MM' ) || ' ( CHECK( created_at >= timestamp ''' || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) || ''' AND created_at < timestamp ''' || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) || ''' ) ) inherits ( histories );' FROM generate_series( $1, $2, '1 month' ) AS d LOOP EXECUTE create_query; END LOOP; -- LOOP END END; -- FUNCTION END $$ language plpgsql; -- +goose Down drop function histories_partition_creation(DATE, DATE); drop TABLE histories; ` var unfinishedSQL = ` -- +goose Up ALTER TABLE post -- +goose Down ` var emptySQL = `-- +goose Up -- This is just a comment` var emptySQL2 = ` -- comment -- +goose Up -- comment -- +goose Down ` var noUpDownAnnotations = ` CREATE TABLE post ( id int NOT NULL, title text, body text, PRIMARY KEY(id) ); ` var mysqlChangeDelimiter = ` -- +goose Up -- +goose StatementBegin DELIMITER | -- +goose StatementEnd -- +goose StatementBegin CREATE FUNCTION my_func( str CHAR(255) ) RETURNS CHAR(255) DETERMINISTIC BEGIN RETURN "Dummy Body"; END | -- +goose StatementEnd -- +goose StatementBegin DELIMITER ; -- +goose StatementEnd select my_func("123") from dual; -- +goose Down ` var copyFromStdin = ` -- +goose Up -- +goose StatementBegin COPY public.django_content_type (id, app_label, model) FROM stdin; 1 admin logentry 2 auth permission 3 auth group 4 auth user 5 contenttypes contenttype 6 sessions session \. -- +goose StatementEnd ` var plpgsqlSyntax = ` -- +goose Up -- +goose StatementBegin CREATE OR REPLACE FUNCTION update_updated_at_column() RETURNS TRIGGER AS $$ BEGIN NEW.updated_at = now(); RETURN NEW; END; $$ language 'plpgsql'; -- +goose StatementEnd -- +goose StatementBegin CREATE TRIGGER update_properties_updated_at BEFORE UPDATE ON properties FOR EACH ROW EXECUTE PROCEDURE update_updated_at_column(); -- +goose StatementEnd -- +goose Down -- +goose StatementBegin DROP TRIGGER update_properties_updated_at -- +goose StatementEnd -- +goose StatementBegin DROP FUNCTION update_updated_at_column() -- +goose StatementEnd ` var plpgsqlSyntaxMixedStatements = ` -- +goose Up -- +goose StatementBegin CREATE OR REPLACE FUNCTION update_updated_at_column() RETURNS TRIGGER AS $$ BEGIN NEW.updated_at = now(); RETURN NEW; END; $$ language 'plpgsql'; -- +goose StatementEnd CREATE TRIGGER update_properties_updated_at BEFORE UPDATE ON properties FOR EACH ROW EXECUTE PROCEDURE update_updated_at_column(); -- +goose Down DROP TRIGGER update_properties_updated_at; DROP FUNCTION update_updated_at_column(); ` func TestValidUp(t *testing.T) { t.Parallel() // Test valid "up" parser logic. // // This test expects each directory, such as: internal/sqlparser/testdata/valid-up/test01 // // to contain exactly one migration file called "input.sql". We read this file and pass it // to the parser. Then we compare the statements against the golden files. // Each golden file is equivalent to one statement. // // ├── 01.up.golden.sql // ├── 02.up.golden.sql // ├── 03.up.golden.sql // └── input.sql tests := []struct { Name string StatementsCount int }{ {Name: "test01", StatementsCount: 3}, {Name: "test02", StatementsCount: 1}, {Name: "test03", StatementsCount: 1}, {Name: "test04", StatementsCount: 3}, {Name: "test05", StatementsCount: 2}, {Name: "test06", StatementsCount: 5}, {Name: "test07", StatementsCount: 1}, {Name: "test08", StatementsCount: 6}, {Name: "test09", StatementsCount: 1}, } for _, tc := range tests { path := filepath.Join("testdata", "valid-up", tc.Name) t.Run(tc.Name, func(t *testing.T) { testValid(t, path, tc.StatementsCount, DirectionUp) }) } } func testValid(t *testing.T, dir string, count int, direction Direction) { t.Helper() f, err := os.Open(filepath.Join(dir, "input.sql")) check.NoError(t, err) t.Cleanup(func() { f.Close() }) statements, _, err := ParseSQLMigration(f, direction, debug) check.NoError(t, err) check.Number(t, len(statements), count) compareStatements(t, dir, statements, direction) } func compareStatements(t *testing.T, dir string, statements []string, direction Direction) { t.Helper() files, err := filepath.Glob(filepath.Join(dir, fmt.Sprintf("*.%s.golden.sql", direction))) check.NoError(t, err) if len(statements) != len(files) { t.Fatalf("mismatch between parsed statements (%d) and golden files (%d), did you check in NN.{up|down}.golden.sql file in %q?", len(statements), len(files), dir) } for _, goldenFile := range files { goldenFile = filepath.Base(goldenFile) before, _, ok := strings.Cut(goldenFile, ".") if !ok { t.Fatal(`failed to cut on file delimiter ".", must be of the format NN.{up|down}.golden.sql`) } index, err := strconv.Atoi(before) check.NoError(t, err) index-- goldenFilePath := filepath.Join(dir, goldenFile) by, err := os.ReadFile(goldenFilePath) check.NoError(t, err) got, want := statements[index], string(by) if got != want { if isCIEnvironment() { t.Errorf("input does not match expected golden file:\n\ngot:\n%s\n\nwant:\n%s\n", got, want) } else { t.Error("input does not match expected output; diff files with .FAIL to debug") t.Logf("\ndiff %v %v", filepath.Join("internal", "sqlparser", goldenFilePath+".FAIL"), filepath.Join("internal", "sqlparser", goldenFilePath), ) err := os.WriteFile(goldenFilePath+".FAIL", []byte(got), 0644) check.NoError(t, err) } } } } func isCIEnvironment() bool { ok, _ := strconv.ParseBool(os.Getenv("CI")) return ok } func TestEnvsub(t *testing.T) { // Do not run in parallel, as this test sets environment variables. // Test valid migrations with ${var} like statements when on are substituted for the whole // migration. t.Setenv("GOOSE_ENV_REGION", "us_east_") t.Setenv("GOOSE_ENV_SET_BUT_EMPTY_VALUE", "") t.Setenv("GOOSE_ENV_NAME", "foo") tests := []struct { Name string DownCount int UpCount int }{ {Name: "test01", UpCount: 4, DownCount: 1}, {Name: "test02", UpCount: 3, DownCount: 0}, {Name: "test03", UpCount: 1, DownCount: 0}, } for _, tc := range tests { t.Run(tc.Name, func(t *testing.T) { dir := filepath.Join("testdata", "envsub", tc.Name) testValid(t, dir, tc.UpCount, DirectionUp) testValid(t, dir, tc.DownCount, DirectionDown) }) } } func TestEnvsubError(t *testing.T) { t.Parallel() s := ` -- +goose ENVSUB ON -- +goose Up CREATE TABLE post ( id int NOT NULL, title text, ${SOME_UNSET_VAR?required env var not set} text, PRIMARY KEY(id) ); ` _, _, err := ParseSQLMigration(strings.NewReader(s), DirectionUp, debug) check.HasError(t, err) check.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:") }