// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package dbtest

import (
	"database/sql"
	"fmt"
	"os"
	"path/filepath"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"gorm.io/gorm"
	"gorm.io/gorm/schema"

	"gogs.io/gogs/internal/conf"
	"gogs.io/gogs/internal/dbutil"
)

// NewDB creates a new test database and initializes the given list of tables
// for the suite. The test database is dropped after testing is completed unless
// failed.
func NewDB(t *testing.T, suite string, tables ...any) *gorm.DB {
	dbType := os.Getenv("GOGS_DATABASE_TYPE")

	var dbName string
	var dbOpts conf.DatabaseOpts
	var cleanup func(db *gorm.DB)
	switch dbType {
	case "mysql":
		dbOpts = conf.DatabaseOpts{
			Type:     "mysql",
			Host:     os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
			Name:     dbName,
			User:     os.Getenv("MYSQL_USER"),
			Password: os.Getenv("MYSQL_PASSWORD"),
		}

		dsn, err := dbutil.NewDSN(dbOpts)
		require.NoError(t, err)

		sqlDB, err := sql.Open("mysql", dsn)
		require.NoError(t, err)

		// Set up test database
		dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
		_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
		require.NoError(t, err)

		_, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
		require.NoError(t, err)

		dbOpts.Name = dbName

		cleanup = func(db *gorm.DB) {
			testDB, err := db.DB()
			if err == nil {
				_ = testDB.Close()
			}

			_, _ = sqlDB.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
			_ = sqlDB.Close()
		}
	case "postgres":
		dbOpts = conf.DatabaseOpts{
			Type:     "postgres",
			Host:     os.ExpandEnv("$PGHOST:$PGPORT"),
			Name:     dbName,
			Schema:   "public",
			User:     os.Getenv("PGUSER"),
			Password: os.Getenv("PGPASSWORD"),
			SSLMode:  os.Getenv("PGSSLMODE"),
		}

		dsn, err := dbutil.NewDSN(dbOpts)
		require.NoError(t, err)

		sqlDB, err := sql.Open("pgx", dsn)
		require.NoError(t, err)

		// Set up test database
		dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
		_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
		require.NoError(t, err)

		_, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
		require.NoError(t, err)

		dbOpts.Name = dbName

		cleanup = func(db *gorm.DB) {
			testDB, err := db.DB()
			if err == nil {
				_ = testDB.Close()
			}

			_, _ = sqlDB.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
			_ = sqlDB.Close()
		}
	case "sqlite":
		dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
		dbOpts = conf.DatabaseOpts{
			Type: "sqlite",
			Path: dbName,
		}
		cleanup = func(db *gorm.DB) {
			sqlDB, err := db.DB()
			if err == nil {
				_ = sqlDB.Close()
			}
			_ = os.Remove(dbName)
		}
	default:
		dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
		dbOpts = conf.DatabaseOpts{
			Type: "sqlite3",
			Path: dbName,
		}
		cleanup = func(db *gorm.DB) {
			sqlDB, err := db.DB()
			if err == nil {
				_ = sqlDB.Close()
			}
			_ = os.Remove(dbName)
		}
	}

	now := time.Now().UTC().Truncate(time.Second)
	db, err := dbutil.OpenDB(
		dbOpts,
		&gorm.Config{
			SkipDefaultTransaction: true,
			NamingStrategy: schema.NamingStrategy{
				SingularTable: true,
			},
			NowFunc: func() time.Time {
				return now
			},
		},
	)
	require.NoError(t, err)

	t.Cleanup(func() {
		if t.Failed() {
			t.Logf("Database %q left intact for inspection", dbName)
			return
		}

		cleanup(db)
	})

	err = db.Migrator().AutoMigrate(tables...)
	require.NoError(t, err)

	return db
}