package sqlparser

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"io"
	"log"
	"os"
	"strings"
	"sync"

	"github.com/mfridman/interpolate"
)

type Direction string

const (
	DirectionUp   Direction = "up"
	DirectionDown Direction = "down"
)

func FromBool(b bool) Direction {
	if b {
		return DirectionUp
	}
	return DirectionDown
}

func (d Direction) String() string {
	return string(d)
}

func (d Direction) ToBool() bool {
	return d == DirectionUp
}

type parserState int

const (
	start                   parserState = iota // 0
	gooseUp                                    // 1
	gooseStatementBeginUp                      // 2
	gooseStatementEndUp                        // 3
	gooseDown                                  // 4
	gooseStatementBeginDown                    // 5
	gooseStatementEndDown                      // 6
)

type stateMachine struct {
	state   parserState
	verbose bool
}

func newStateMachine(begin parserState, verbose bool) *stateMachine {
	return &stateMachine{
		state:   begin,
		verbose: verbose,
	}
}

func (s *stateMachine) get() parserState {
	return s.state
}

func (s *stateMachine) set(new parserState) {
	s.print("set %d => %d", s.state, new)
	s.state = new
}

const (
	grayColor  = "\033[90m"
	resetColor = "\033[00m"
)

func (s *stateMachine) print(msg string, args ...interface{}) {
	msg = "StateMachine: " + msg
	if s.verbose {
		log.Printf(grayColor+msg+resetColor, args...)
	}
}

const scanBufSize = 4 * 1024 * 1024

var bufferPool = sync.Pool{
	New: func() interface{} {
		buf := make([]byte, scanBufSize)
		return &buf
	},
}

// Split given SQL script into individual statements and return
// SQL statements for given direction (up=true, down=false).
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []string, useTx bool, err error) {
	scanBufPtr := bufferPool.Get().(*[]byte)
	scanBuf := *scanBufPtr
	defer bufferPool.Put(scanBufPtr)

	scanner := bufio.NewScanner(r)
	scanner.Buffer(scanBuf, scanBufSize)

	stateMachine := newStateMachine(start, debug)
	useTx = true
	useEnvsub := false

	var buf bytes.Buffer
	for scanner.Scan() {
		line := scanner.Text()
		if debug {
			log.Println(line)
		}
		if stateMachine.get() == start && strings.TrimSpace(line) == "" {
			continue
		}

		// Check for annotations.
		// All annotations must be in format: "-- +goose [annotation]"
		if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") {
			var cmd annotation

			cmd, err = extractAnnotation(line)
			if err != nil {
				return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err)
			}

			switch cmd {
			case annotationUp:
				switch stateMachine.get() {
				case start:
					stateMachine.set(gooseUp)
				default:
					return nil, false, fmt.Errorf("duplicate '-- +goose Up' annotations; stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
				}
				continue

			case annotationDown:
				switch stateMachine.get() {
				case gooseUp, gooseStatementEndUp:
					// If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a
					// previous up annotation. This is an error, because we expect the SQL query to be terminated by a semicolon
					// and the buffer to have been reset.
					if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
						return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
					}
					stateMachine.set(gooseDown)
				default:
					return nil, false, fmt.Errorf("must start with '-- +goose Up' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
				}
				continue

			case annotationStatementBegin:
				switch stateMachine.get() {
				case gooseUp, gooseStatementEndUp:
					stateMachine.set(gooseStatementBeginUp)
				case gooseDown, gooseStatementEndDown:
					stateMachine.set(gooseStatementBeginDown)
				default:
					return nil, false, fmt.Errorf("'-- +goose StatementBegin' must be defined after '-- +goose Up' or '-- +goose Down' annotation, stateMachine=%d, see https://github.com/pressly/goose#sql-migrations", stateMachine.state)
				}
				continue

			case annotationStatementEnd:
				switch stateMachine.get() {
				case gooseStatementBeginUp:
					stateMachine.set(gooseStatementEndUp)
				case gooseStatementBeginDown:
					stateMachine.set(gooseStatementEndDown)
				default:
					return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
				}

			case annotationNoTransaction:
				useTx = false
				continue

			case annotationEnvsubOn:
				useEnvsub = true
				continue

			case annotationEnvsubOff:
				useEnvsub = false
				continue

			default:
				return nil, false, fmt.Errorf("unknown annotation: %q", cmd)
			}
		}
		// Once we've started parsing a statement the buffer is no longer empty,
		// we keep all comments up until the end of the statement (the buffer will be reset).
		// All other comments in the file are ignored.
		if buf.Len() == 0 {
			// This check ensures leading comments and empty lines prior to a statement are ignored.
			if strings.HasPrefix(strings.TrimSpace(line), "--") || line == "" {
				stateMachine.print("ignore comment")
				continue
			}
		}
		switch stateMachine.get() {
		case gooseStatementEndDown, gooseStatementEndUp:
			// Do not include the "+goose StatementEnd" annotation in the final statement.
		default:
			if useEnvsub {
				expanded, err := interpolate.Interpolate(&envWrapper{}, line)
				if err != nil {
					return nil, false, fmt.Errorf("variable substitution failed: %w:\n%s", err, line)
				}
				line = expanded
			}
			// Write SQL line to a buffer.
			if _, err := buf.WriteString(line + "\n"); err != nil {
				return nil, false, fmt.Errorf("failed to write to buf: %w", err)
			}
		}
		// Read SQL body one by line, if we're in the right direction.
		//
		// 1) basic query with semicolon; 2) psql statement
		//
		// Export statement once we hit end of statement.
		switch stateMachine.get() {
		case gooseUp, gooseStatementBeginUp, gooseStatementEndUp:
			if direction == DirectionDown {
				buf.Reset()
				stateMachine.print("ignore down")
				continue
			}
		case gooseDown, gooseStatementBeginDown, gooseStatementEndDown:
			if direction == DirectionUp {
				buf.Reset()
				stateMachine.print("ignore up")
				continue
			}
		default:
			return nil, false, fmt.Errorf("failed to parse migration: unexpected state %d on line %q, see https://github.com/pressly/goose#sql-migrations", stateMachine.state, line)
		}

		switch stateMachine.get() {
		case gooseUp:
			if endsWithSemicolon(line) {
				stmts = append(stmts, cleanupStatement(buf.String()))
				buf.Reset()
				stateMachine.print("store simple Up query")
			}
		case gooseDown:
			if endsWithSemicolon(line) {
				stmts = append(stmts, cleanupStatement(buf.String()))
				buf.Reset()
				stateMachine.print("store simple Down query")
			}
		case gooseStatementEndUp:
			stmts = append(stmts, cleanupStatement(buf.String()))
			buf.Reset()
			stateMachine.print("store Up statement")
			stateMachine.set(gooseUp)
		case gooseStatementEndDown:
			stmts = append(stmts, cleanupStatement(buf.String()))
			buf.Reset()
			stateMachine.print("store Down statement")
			stateMachine.set(gooseDown)
		}
	}
	if err := scanner.Err(); err != nil {
		return nil, false, fmt.Errorf("failed to scan migration: %w", err)
	}
	// EOF

	switch stateMachine.get() {
	case start:
		return nil, false, errors.New("failed to parse migration: must start with '-- +goose Up' annotation, see https://github.com/pressly/goose#sql-migrations")
	case gooseStatementBeginUp, gooseStatementBeginDown:
		return nil, false, errors.New("failed to parse migration: missing '-- +goose StatementEnd' annotation")
	}

	if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
		return nil, false, missingSemicolonError(stateMachine.state, direction, bufferRemaining)
	}

	return stmts, useTx, nil
}

type annotation string

const (
	annotationUp             annotation = "Up"
	annotationDown           annotation = "Down"
	annotationStatementBegin annotation = "StatementBegin"
	annotationStatementEnd   annotation = "StatementEnd"
	annotationNoTransaction  annotation = "NO TRANSACTION"
	annotationEnvsubOn       annotation = "ENVSUB ON"
	annotationEnvsubOff      annotation = "ENVSUB OFF"
)

var supportedAnnotations = map[annotation]struct{}{
	annotationUp:             {},
	annotationDown:           {},
	annotationStatementBegin: {},
	annotationStatementEnd:   {},
	annotationNoTransaction:  {},
	annotationEnvsubOn:       {},
	annotationEnvsubOff:      {},
}

var (
	errEmptyAnnotation   = errors.New("empty annotation")
	errInvalidAnnotation = errors.New("invalid annotation")
)

// extractAnnotation extracts the annotation from the line.
// All annotations must be in format: "-- +goose [annotation]"
// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF
func extractAnnotation(line string) (annotation, error) {
	// If line contains leading whitespace - return error.
	if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") {
		return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation)
	}

	// Extract the annotation from the line, by removing the leading "--"
	cmd := strings.ReplaceAll(line, "--", "")

	// Extract the annotation from the line, by removing the leading "+goose"
	cmd = strings.Replace(cmd, "+goose", "", 1)

	if strings.Contains(cmd, "+goose") {
		return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation)
	}

	// Remove leading and trailing whitespace from the annotation command.
	cmd = strings.TrimSpace(cmd)

	if cmd == "" {
		return "", errEmptyAnnotation
	}

	a := annotation(cmd)

	for s := range supportedAnnotations {
		if strings.EqualFold(string(s), string(a)) {
			return s, nil
		}
	}

	return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation)
}

func missingSemicolonError(state parserState, direction Direction, s string) error {
	return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?",
		state,
		direction,
		s,
	)
}

type envWrapper struct{}

var _ interpolate.Env = (*envWrapper)(nil)

func (e *envWrapper) Get(key string) (string, bool) {
	return os.LookupEnv(key)
}

func cleanupStatement(input string) string {
	return strings.TrimSpace(input)
}

// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
	scanBufPtr := bufferPool.Get().(*[]byte)
	scanBuf := *scanBufPtr
	defer bufferPool.Put(scanBufPtr)

	prev := ""
	scanner := bufio.NewScanner(strings.NewReader(line))
	scanner.Buffer(scanBuf, scanBufSize)
	scanner.Split(bufio.ScanWords)

	for scanner.Scan() {
		word := scanner.Text()
		if strings.HasPrefix(word, "--") {
			break
		}
		prev = word
	}

	return strings.HasSuffix(prev, ";")
}