From 39c030eac9ef49afcb868cd2905d341bed850088 Mon Sep 17 00:00:00 2001 From: TomasBarry Date: Tue, 18 Sep 2018 09:33:47 +0100 Subject: [PATCH] Exit the program when migration can't be parsed --- .gitignore | 6 +++ migration_sql.go | 24 +++++++----- migration_sql_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 104 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 09c6c0f..c155879 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,9 @@ cmd/goose/goose* *.swp *.test + +# Files output by tests +custom-goose +go.db +goose +sql.db diff --git a/migration_sql.go b/migration_sql.go index 5029833..6bea400 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "database/sql" + "fmt" "io" "os" "strings" @@ -39,7 +40,7 @@ func endsWithSemicolon(line string) bool { // within a statement. For these cases, we provide the explicit annotations // 'StatementBegin' and 'StatementEnd' to allow the script to // tell us to ignore semicolons. -func getSQLStatements(r io.Reader, direction bool) (stmts []string, tx bool) { +func getSQLStatements(r io.Reader, direction bool) ([]string, bool, error) { var buf bytes.Buffer scanner := bufio.NewScanner(r) @@ -51,7 +52,8 @@ func getSQLStatements(r io.Reader, direction bool) (stmts []string, tx bool) { statementEnded := false ignoreSemicolons := false directionIsActive := false - tx = true + tx := true + stmts := []string{} for scanner.Scan() { @@ -95,7 +97,7 @@ func getSQLStatements(r io.Reader, direction bool) (stmts []string, tx bool) { } if _, err := buf.WriteString(line + "\n"); err != nil { - log.Fatalf("io err: %v", err) + return nil, false, fmt.Errorf("io err: %v", err) } // Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement @@ -109,24 +111,23 @@ func getSQLStatements(r io.Reader, direction bool) (stmts []string, tx bool) { } if err := scanner.Err(); err != nil { - log.Fatalf("scanning migration: %v", err) + return nil, false, fmt.Errorf("scanning migration: %v", err) } // diagnose likely migration script errors if ignoreSemicolons { - log.Println("WARNING: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'") + return nil, false, fmt.Errorf("parsing migration: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'") } if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 { - log.Printf("WARNING: Unexpected unfinished SQL query: %s. Missing a semicolon?\n", bufferRemaining) + return nil, false, fmt.Errorf("parsing migration: unexpected unfinished SQL query: %s. potential missing semicolon", bufferRemaining) } if upSections == 0 && downSections == 0 { - log.Fatalf(`ERROR: no Up/Down annotations found, so no statements were executed. - See https://bitbucket.org/liamstask/goose/overview for details.`) + return nil, false, fmt.Errorf("parsing migration: no Up/Down annotations found, so no statements were executed. See https://bitbucket.org/liamstask/goose/overview for details") } - return + return stmts, tx, nil } // Run a migration specified in raw SQL. @@ -144,7 +145,10 @@ func runSQLMigration(db *sql.DB, scriptFile string, v int64, direction bool) err } defer f.Close() - statements, useTx := getSQLStatements(f, direction) + statements, useTx, err := getSQLStatements(f, direction) + if err != nil { + return err + } if useTx { // TRANSACTION. diff --git a/migration_sql_test.go b/migration_sql_test.go index e72662a..a12a828 100644 --- a/migration_sql_test.go +++ b/migration_sql_test.go @@ -80,7 +80,10 @@ func TestSplitStatements(t *testing.T) { } for _, test := range tests { - stmts, _ := getSQLStatements(strings.NewReader(test.sql), test.direction) + stmts, _, err := getSQLStatements(strings.NewReader(test.sql), test.direction) + if err != nil { + t.Error(err) + } if len(stmts) != test.count { t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count) } @@ -113,7 +116,10 @@ func TestUseTransactions(t *testing.T) { if err != nil { t.Error(err) } - _, useTx := getSQLStatements(f, true) + _, useTx, err := getSQLStatements(f, true) + if err != nil { + t.Error(err) + } if useTx != test.useTransactions { t.Errorf("Failed transaction check. got %v, want %v", useTx, test.useTransactions) } @@ -121,6 +127,33 @@ func TestUseTransactions(t *testing.T) { } } +func TestParsingErrors(t *testing.T) { + type testData struct { + sql string + error bool + } + tests := []testData{ + { + sql: statementBeginNoStatementEnd, + error: true, + }, + { + sql: unfinishedSQL, + error: true, + }, + { + sql: noUpDownAnnotations, + error: true, + }, + } + for _, test := range tests { + _, _, err := getSQLStatements(strings.NewReader(test.sql), true) + if err == nil { + t.Errorf("Failed transaction check. got %v, want %v", err, test.error) + } + } +} + var functxt = `-- +goose Up CREATE TABLE IF NOT EXISTS histories ( id BIGSERIAL PRIMARY KEY, @@ -180,3 +213,52 @@ CREATE TABLE fancier_post ( -- +goose Down DROP TABLE fancier_post; ` + +var statementBeginNoStatementEnd = `-- +goose Up +CREATE TABLE IF NOT EXISTS histories ( + id BIGSERIAL PRIMARY KEY, + current_value varchar(2000) NOT NULL, + created_at timestamp with time zone NOT NULL +); + +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE ) +returns void AS $$ +DECLARE + create_query text; +BEGIN + FOR create_query IN SELECT + 'CREATE TABLE IF NOT EXISTS histories_' + || TO_CHAR( d, 'YYYY_MM' ) + || ' ( CHECK( created_at >= timestamp ''' + || TO_CHAR( d, 'YYYY-MM-DD 00:00:00' ) + || ''' AND created_at < timestamp ''' + || TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' ) + || ''' ) ) inherits ( histories );' + FROM generate_series( $1, $2, '1 month' ) AS d + LOOP + EXECUTE create_query; + END LOOP; -- LOOP END +END; -- FUNCTION END +$$ +language plpgsql; + +-- +goose Down +drop function histories_partition_creation(DATE, DATE); +drop TABLE histories; +` + +var unfinishedSQL = ` +-- +goose Up +ALTER TABLE post + +-- +goose Down +` +var noUpDownAnnotations = ` +CREATE TABLE post ( + id int NOT NULL, + title text, + body text, + PRIMARY KEY(id) +); +`