package dialect

import (
	"context"
	"database/sql"
	"fmt"
	"time"

	"github.com/pressly/goose/v3/internal/dialect/dialectquery"
)

// Store is the interface that wraps the basic methods for a database dialect.
//
// A dialect is a set of SQL statements that are specific to a database.
//
// By defining a store interface, we can support multiple databases
// with a single codebase.
//
// The underlying implementation does not modify the error. It is the callers
// responsibility to assert for the correct error, such as sql.ErrNoRows.
type Store interface {
	// CreateVersionTable creates the version table within a transaction.
	// This table is used to store goose migrations.
	CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error

	// InsertVersion inserts a version id into the version table within a transaction.
	InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
	// InsertVersionNoTx inserts a version id into the version table without a transaction.
	InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error

	// DeleteVersion deletes a version id from the version table within a transaction.
	DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error
	// DeleteVersionNoTx deletes a version id from the version table without a transaction.
	DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error

	// GetMigrationRow retrieves a single migration by version id.
	//
	// Returns the raw sql error if the query fails. It is the callers responsibility
	// to assert for the correct error, such as sql.ErrNoRows.
	GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error)

	// ListMigrations retrieves all migrations sorted in descending order by id.
	//
	// If there are no migrations, an empty slice is returned with no error.
	ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error)
}

// NewStore returns a new Store for the given dialect.
func NewStore(d Dialect) (Store, error) {
	var querier dialectquery.Querier
	switch d {
	case Postgres:
		querier = &dialectquery.Postgres{}
	case Mysql:
		querier = &dialectquery.Mysql{}
	case Sqlite3:
		querier = &dialectquery.Sqlite3{}
	case Sqlserver:
		querier = &dialectquery.Sqlserver{}
	case Redshift:
		querier = &dialectquery.Redshift{}
	case Tidb:
		querier = &dialectquery.Tidb{}
	case Clickhouse:
		querier = &dialectquery.Clickhouse{}
	case Vertica:
		querier = &dialectquery.Vertica{}
	default:
		return nil, fmt.Errorf("unknown querier dialect: %v", d)
	}
	return &store{querier: querier}, nil
}

type GetMigrationResult struct {
	IsApplied bool
	Timestamp time.Time
}

type ListMigrationsResult struct {
	VersionID int64
	IsApplied bool
}

type store struct {
	querier dialectquery.Querier
}

var _ Store = (*store)(nil)

func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error {
	q := s.querier.CreateTable(tableName)
	_, err := tx.ExecContext(ctx, q)
	return err
}

func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
	q := s.querier.InsertVersion(tableName)
	_, err := tx.ExecContext(ctx, q, version, true)
	return err
}

func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
	q := s.querier.InsertVersion(tableName)
	_, err := db.ExecContext(ctx, q, version, true)
	return err
}

func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error {
	q := s.querier.DeleteVersion(tableName)
	_, err := tx.ExecContext(ctx, q, version)
	return err
}

func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error {
	q := s.querier.DeleteVersion(tableName)
	_, err := db.ExecContext(ctx, q, version)
	return err
}

func (s *store) GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error) {
	q := s.querier.GetMigrationByVersion(tableName)
	var timestamp time.Time
	var isApplied bool
	err := db.QueryRowContext(ctx, q, version).Scan(&timestamp, &isApplied)
	if err != nil {
		return nil, err
	}
	return &GetMigrationResult{
		IsApplied: isApplied,
		Timestamp: timestamp,
	}, nil
}

func (s *store) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) {
	q := s.querier.ListMigrations(tableName)
	rows, err := db.QueryContext(ctx, q)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var migrations []*ListMigrationsResult
	for rows.Next() {
		var version int64
		var isApplied bool
		if err := rows.Scan(&version, &isApplied); err != nil {
			return nil, err
		}
		migrations = append(migrations, &ListMigrationsResult{
			VersionID: version,
			IsApplied: isApplied,
		})
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}
	return migrations, nil
}