package ksql

import (
	"context"
	"database/sql"
	"fmt"
	"reflect"
	"strings"
	"unicode"

	"github.com/pkg/errors"
	"github.com/vingarcia/ksql/structs"
)

var selectQueryCache = map[string]map[reflect.Type]string{}

func init() {
	for dname := range supportedDialects {
		selectQueryCache[dname] = map[reflect.Type]string{}
	}
}

// DB represents the ksql client responsible for
// interfacing with the "database/sql" package implementing
// the KissSQL interface `SQLProvider`.
type DB struct {
	driver    string
	dialect   dialect
	tableName string
	db        sqlProvider

	// Most dbs have a single primary key,
	// But in future ksql should work with compound keys as well
	idCols []string

	insertMethod insertMethod
}

type sqlProvider interface {
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}

// Config describes the optional arguments accepted
// by the ksql.New() function.
type Config struct {
	// MaxOpenCons defaults to 1 if not set
	MaxOpenConns int

	// TableName must be set in order to use the Insert, Delete and Update helper
	// functions. If you only intend to make queries or to use the Exec function
	// it is safe to leave this field unset.
	TableName string

	// IDColumns defaults to []string{"id"} if unset
	IDColumns []string
}

// New instantiates a new KissSQL client
func New(
	dbDriver string,
	connectionString string,
	config Config,
) (DB, error) {
	db, err := sql.Open(dbDriver, connectionString)
	if err != nil {
		return DB{}, err
	}
	if err = db.Ping(); err != nil {
		return DB{}, err
	}

	if config.MaxOpenConns == 0 {
		config.MaxOpenConns = 1
	}

	db.SetMaxOpenConns(config.MaxOpenConns)

	dialect := supportedDialects[dbDriver]
	if dialect == nil {
		return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
	}

	if len(config.IDColumns) == 0 {
		config.IDColumns = []string{"id"}
	}

	insertMethod := dialect.InsertMethod()
	if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID {
		insertMethod = insertWithNoIDRetrieval
	}

	return DB{
		dialect:   dialect,
		driver:    dbDriver,
		db:        db,
		tableName: config.TableName,

		idCols:       config.IDColumns,
		insertMethod: insertMethod,
	}, nil
}

// Query queries several rows from the database,
// the input should be a slice of structs (or *struct) passed
// by reference and it will be filled with all the results.
//
// Note: it is very important to make sure the query will
// return a small known number of results, otherwise you risk
// of overloading the available memory.
func (c DB) Query(
	ctx context.Context,
	records interface{},
	query string,
	params ...interface{},
) error {
	slicePtr := reflect.ValueOf(records)
	slicePtrType := slicePtr.Type()
	if slicePtrType.Kind() != reflect.Ptr {
		return fmt.Errorf("ksql: expected to receive a pointer to slice of structs, but got: %T", records)
	}
	sliceType := slicePtrType.Elem()
	slice := slicePtr.Elem()
	structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType)
	if err != nil {
		return err
	}

	if isSliceOfPtrs {
		// Truncate the slice so there is no risk
		// of overwritting records that were already saved
		// on the slice:
		slice = slice.Slice(0, 0)
	}

	if strings.ToUpper(getFirstToken(query)) == "FROM" {
		selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()])
		if err != nil {
			return err
		}
		query = selectPrefix + query
	}

	rows, err := c.db.QueryContext(ctx, query, params...)
	if err != nil {
		return fmt.Errorf("error running query: %s", err.Error())
	}
	defer rows.Close()

	for idx := 0; rows.Next(); idx++ {
		// Allocate new slice elements
		// only if they are not already allocated:
		if slice.Len() <= idx {
			var elemValue reflect.Value
			elemValue = reflect.New(structType)
			if !isSliceOfPtrs {
				elemValue = elemValue.Elem()
			}
			slice = reflect.Append(slice, elemValue)
		}

		elemPtr := slice.Index(idx).Addr()
		if isSliceOfPtrs {
			// This is necessary since scanRows expects a *record not a **record
			elemPtr = elemPtr.Elem()
		}

		err = scanRows(c.dialect, rows, elemPtr.Interface())
		if err != nil {
			return err
		}
	}

	if err := rows.Close(); err != nil {
		return err
	}

	if rows.Err() != nil {
		return rows.Err()
	}

	// Update the original slice passed by reference:
	slicePtr.Elem().Set(slice)

	return nil
}

// QueryOne queries one instance from the database,
// the input struct must be passed by reference
// and the query should return only one result.
//
// QueryOne returns a ErrRecordNotFound if
// the query returns no results.
func (c DB) QueryOne(
	ctx context.Context,
	record interface{},
	query string,
	params ...interface{},
) error {
	t := reflect.TypeOf(record)
	if t.Kind() != reflect.Ptr {
		return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
	}
	t = t.Elem()
	if t.Kind() != reflect.Struct {
		return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
	}

	if strings.ToUpper(getFirstToken(query)) == "FROM" {
		selectPrefix, err := buildSelectQuery(c.dialect, t, selectQueryCache[c.dialect.DriverName()])
		if err != nil {
			return err
		}
		query = selectPrefix + query
	}

	rows, err := c.db.QueryContext(ctx, query, params...)
	if err != nil {
		return err
	}
	defer rows.Close()

	if !rows.Next() {
		if rows.Err() != nil {
			return rows.Err()
		}
		return ErrRecordNotFound
	}

	err = scanRows(c.dialect, rows, record)
	if err != nil {
		return err
	}

	return rows.Close()
}

// QueryChunks is meant to perform queries that returns
// more results than would normally fit on memory,
// for others cases the Query and QueryOne functions are indicated.
//
// The ChunkParser argument has 4 attributes:
// (1) The Query;
// (2) The query args;
// (3) The chunk size;
// (4) A callback function called ForEachChunk, that will be called
// to process each chunk loaded from the database.
//
// Note that the signature of the ForEachChunk callback can be
// any function that receives a slice of structs or a slice of
// pointers to struct as its only argument and that reflection
// will be used to instantiate this argument and to fill it
// with the database rows.
func (c DB) QueryChunks(
	ctx context.Context,
	parser ChunkParser,
) error {
	fnValue := reflect.ValueOf(parser.ForEachChunk)
	chunkType, err := parseInputFunc(parser.ForEachChunk)
	if err != nil {
		return err
	}

	chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)

	structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType)
	if err != nil {
		return err
	}

	if strings.ToUpper(getFirstToken(parser.Query)) == "FROM" {
		selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()])
		if err != nil {
			return err
		}
		parser.Query = selectPrefix + parser.Query
	}

	rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...)
	if err != nil {
		return err
	}
	defer rows.Close()

	var idx = 0
	for rows.Next() {
		// Allocate new slice elements
		// only if they are not already allocated:
		if chunk.Len() <= idx {
			var elemValue reflect.Value
			elemValue = reflect.New(structType)
			if !isSliceOfPtrs {
				elemValue = elemValue.Elem()
			}
			chunk = reflect.Append(chunk, elemValue)
		}

		err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
		if err != nil {
			return err
		}

		if idx < parser.ChunkSize-1 {
			idx++
			continue
		}

		idx = 0
		err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
		if err != nil {
			if err == ErrAbortIteration {
				return nil
			}
			return err
		}
	}

	if err := rows.Close(); err != nil {
		return err
	}

	// If Next() returned false because of an error:
	if rows.Err() != nil {
		return rows.Err()
	}

	// If no rows were found or idx was reset to 0
	// on the last iteration skip this last call to ForEachChunk:
	if idx > 0 {
		chunk = chunk.Slice(0, idx)

		err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
		if err != nil {
			if err == ErrAbortIteration {
				return nil
			}
			return err
		}
	}

	return nil
}

// Insert one or more instances on the database
//
// If the original instances have been passed by reference
// the ID is automatically updated after insertion is completed.
func (c DB) Insert(
	ctx context.Context,
	record interface{},
) error {
	if c.tableName == "" {
		return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method")
	}

	query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
	if err != nil {
		return err
	}

	switch c.insertMethod {
	case insertWithReturning, insertWithOutput:
		err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols)
	case insertWithLastInsertID:
		err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
	case insertWithNoIDRetrieval:
		err = c.insertWithNoIDRetrieval(ctx, record, query, params)
	default:
		// Unsupported drivers should be detected on the New() function,
		// So we don't expect the code to ever get into this default case.
		err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
	}

	return err
}

func (c DB) insertReturningIDs(
	ctx context.Context,
	record interface{},
	query string,
	params []interface{},
	scanValues []interface{},
	idNames []string,
) error {
	rows, err := c.db.QueryContext(ctx, query, params...)
	if err != nil {
		return err
	}
	defer rows.Close()

	if !rows.Next() {
		err := fmt.Errorf("unexpected error retrieving the id columns from the database")
		if rows.Err() != nil {
			err = rows.Err()
		}

		return err
	}

	err = rows.Scan(scanValues...)
	if err != nil {
		return err
	}

	return rows.Close()
}

func (c DB) insertWithLastInsertID(
	ctx context.Context,
	record interface{},
	query string,
	params []interface{},
	idName string,
) error {
	result, err := c.db.ExecContext(ctx, query, params...)
	if err != nil {
		return err
	}

	v := reflect.ValueOf(record)
	t := v.Type()
	if err = assertStructPtr(t); err != nil {
		return errors.Wrap(err, "can't write to `"+idName+"` field")
	}

	info := structs.GetTagInfo(t.Elem())

	id, err := result.LastInsertId()
	if err != nil {
		return err
	}

	vID := reflect.ValueOf(id)
	tID := vID.Type()

	fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr()
	fieldType := fieldAddr.Type().Elem()

	if !tID.ConvertibleTo(fieldType) {
		return fmt.Errorf(
			"Can't convert last insert id of type int64 into field `%s` of type %s",
			idName,
			fieldType,
		)
	}

	fieldAddr.Elem().Set(vID.Convert(fieldType))
	return nil
}

func (c DB) insertWithNoIDRetrieval(
	ctx context.Context,
	record interface{},
	query string,
	params []interface{},
) error {
	_, err := c.db.ExecContext(ctx, query, params...)
	return err
}

func assertStructPtr(t reflect.Type) error {
	if t.Kind() != reflect.Ptr {
		return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
	}
	if t.Elem().Kind() != reflect.Struct {
		return fmt.Errorf("expected a Kind of Ptr to Struct but got: %s", t)
	}
	return nil
}

// Delete deletes one or more instances from the database by id
func (c DB) Delete(
	ctx context.Context,
	ids ...interface{},
) error {
	if c.tableName == "" {
		return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Delete method")
	}

	if len(ids) == 0 {
		return nil
	}

	idMaps, err := normalizeIDsAsMaps(c.idCols, ids)
	if err != nil {
		return err
	}

	var query string
	var params []interface{}
	if len(c.idCols) == 1 {
		query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps)
	} else {
		query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps)
	}

	_, err = c.db.ExecContext(ctx, query, params...)

	return err
}

func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]interface{}, error) {
	if len(idNames) == 0 {
		return nil, fmt.Errorf("internal ksql error: missing idNames")
	}

	idMaps := []map[string]interface{}{}
	for i := range ids {
		t := reflect.TypeOf(ids[i])
		switch t.Kind() {
		case reflect.Struct:
			m, err := structs.StructToMap(ids[i])
			if err != nil {
				return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i)
			}
			idMaps = append(idMaps, m)
		case reflect.Map:
			m, ok := ids[i].(map[string]interface{})
			if !ok {
				return nil, fmt.Errorf("expected map[string]interface{} but got %T", ids[i])
			}
			idMaps = append(idMaps, m)
		default:
			idMaps = append(idMaps, map[string]interface{}{
				idNames[0]: ids[i],
			})
		}
	}

	for i, m := range idMaps {
		for _, id := range idNames {
			if _, found := m[id]; !found {
				return nil, fmt.Errorf("missing required id field `%s` on record with idx %d", id, i)
			}
		}
	}

	return idMaps, nil
}

// Update updates the given instances on the database by id.
//
// Partial updates are supported, i.e. it will ignore nil pointer attributes
func (c DB) Update(
	ctx context.Context,
	record interface{},
) error {
	if c.tableName == "" {
		return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Update method")
	}

	query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...)
	if err != nil {
		return err
	}

	_, err = c.db.ExecContext(ctx, query, params...)

	return err
}

func buildInsertQuery(
	dialect dialect,
	tableName string,
	record interface{},
	idNames ...string,
) (query string, params []interface{}, scanValues []interface{}, err error) {
	v := reflect.ValueOf(record)
	t := v.Type()
	if err = assertStructPtr(t); err != nil {
		return "", nil, nil, fmt.Errorf(
			"ksql: expected record to be a pointer to struct, but got: %T",
			record,
		)
	}

	info := structs.GetTagInfo(t.Elem())

	recordMap, err := structs.StructToMap(record)
	if err != nil {
		return "", nil, nil, err
	}

	for _, fieldName := range idNames {
		// Remove any ID field that was not set:
		if reflect.ValueOf(recordMap[fieldName]).IsZero() {
			delete(recordMap, fieldName)
		}
	}

	columnNames := []string{}
	for col := range recordMap {
		columnNames = append(columnNames, col)
	}

	params = make([]interface{}, len(recordMap))
	valuesQuery := make([]string, len(recordMap))
	for i, col := range columnNames {
		recordValue := recordMap[col]
		params[i] = recordValue
		if info.ByName(col).SerializeAsJSON {
			params[i] = jsonSerializable{
				DriverName: dialect.DriverName(),
				Attr:       recordValue,
			}
		}

		valuesQuery[i] = dialect.Placeholder(i)
	}

	// Escape all cols to be sure they will be interpreted as column names:
	escapedColumnNames := []string{}
	for _, col := range columnNames {
		escapedColumnNames = append(escapedColumnNames, dialect.Escape(col))
	}

	var returningQuery, outputQuery string
	switch dialect.InsertMethod() {
	case insertWithReturning:
		escapedIDNames := []string{}
		for _, id := range idNames {
			escapedIDNames = append(escapedIDNames, dialect.Escape(id))
		}
		returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")

		for _, id := range idNames {
			scanValues = append(
				scanValues,
				v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
			)
		}
	case insertWithOutput:
		escapedIDNames := []string{}
		for _, id := range idNames {
			escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
		}
		outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")

		for _, id := range idNames {
			scanValues = append(
				scanValues,
				v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
			)
		}
	}

	// Note that the outputQuery and the returningQuery depend
	// on the selected driver, thus, they might be empty strings.
	query = fmt.Sprintf(
		"INSERT INTO %s (%s)%s VALUES (%s)%s",
		dialect.Escape(tableName),
		strings.Join(escapedColumnNames, ", "),
		outputQuery,
		strings.Join(valuesQuery, ", "),
		returningQuery,
	)

	return query, params, scanValues, nil
}

func buildUpdateQuery(
	dialect dialect,
	tableName string,
	record interface{},
	idFieldNames ...string,
) (query string, args []interface{}, err error) {
	recordMap, err := structs.StructToMap(record)
	if err != nil {
		return "", nil, err
	}
	numAttrs := len(recordMap)
	args = make([]interface{}, numAttrs)
	numNonIDArgs := numAttrs - len(idFieldNames)
	whereArgs := args[numNonIDArgs:]

	whereQuery := make([]string, len(idFieldNames))
	for i, fieldName := range idFieldNames {
		whereArgs[i] = recordMap[fieldName]
		whereQuery[i] = fmt.Sprintf(
			"%s = %s",
			dialect.Escape(fieldName),
			dialect.Placeholder(i+numNonIDArgs),
		)

		delete(recordMap, fieldName)
	}

	keys := []string{}
	for key := range recordMap {
		keys = append(keys, key)
	}

	t := reflect.TypeOf(record)
	if t.Kind() == reflect.Ptr {
		t = t.Elem()
	}
	info := structs.GetTagInfo(t)

	var setQuery []string
	for i, k := range keys {
		recordValue := recordMap[k]
		if info.ByName(k).SerializeAsJSON {
			recordValue = jsonSerializable{
				DriverName: dialect.DriverName(),
				Attr:       recordValue,
			}
		}
		args[i] = recordValue
		setQuery = append(setQuery, fmt.Sprintf(
			"%s = %s",
			dialect.Escape(k),
			dialect.Placeholder(i),
		))
	}

	query = fmt.Sprintf(
		"UPDATE %s SET %s WHERE %s",
		dialect.Escape(tableName),
		strings.Join(setQuery, ", "),
		strings.Join(whereQuery, ", "),
	)

	return query, args, nil
}

// Exec just runs an SQL command on the database returning no rows.
func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error {
	_, err := c.db.ExecContext(ctx, query, params...)
	return err
}

// Transaction just runs an SQL command on the database returning no rows.
func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error {
	switch db := c.db.(type) {
	case *sql.Tx:
		return fn(c)
	case *sql.DB:
		tx, err := db.BeginTx(ctx, nil)
		if err != nil {
			return err
		}
		defer func() {
			if r := recover(); r != nil {
				rollbackErr := tx.Rollback()
				if rollbackErr != nil {
					r = errors.Wrap(rollbackErr,
						fmt.Sprintf("unable to rollback after panic with value: %v", r),
					)
				}
				panic(r)
			}
		}()

		ormCopy := c
		ormCopy.db = tx

		err = fn(ormCopy)
		if err != nil {
			rollbackErr := tx.Rollback()
			if rollbackErr != nil {
				err = errors.Wrap(rollbackErr,
					fmt.Sprintf("unable to rollback after error: %s", err.Error()),
				)
			}
			return err
		}

		return tx.Commit()

	default:
		return fmt.Errorf("unexpected error on ksql: db attribute has an invalid type")
	}
}

var errType = reflect.TypeOf(new(error)).Elem()

func parseInputFunc(fn interface{}) (reflect.Type, error) {
	if fn == nil {
		return nil, fmt.Errorf("the ForEachChunk attribute is required and cannot be nil")
	}

	t := reflect.TypeOf(fn)

	if t.Kind() != reflect.Func {
		return nil, fmt.Errorf("the ForEachChunk callback must be a function")
	}
	if t.NumIn() != 1 {
		return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument")
	}

	if t.NumOut() != 1 {
		return nil, fmt.Errorf("the ForEachChunk callback must have a single return value")
	}

	if t.Out(0) != errType {
		return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error")
	}

	argsType := t.In(0)
	if argsType.Kind() != reflect.Slice {
		return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs")
	}

	return argsType, nil
}

type nopScanner struct{}

var nopScannerValue = reflect.ValueOf(&nopScanner{}).Interface()

func (nopScanner) Scan(value interface{}) error {
	return nil
}

func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
	names, err := rows.Columns()
	if err != nil {
		return err
	}

	v := reflect.ValueOf(record)
	t := v.Type()
	if t.Kind() != reflect.Ptr {
		return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
	}

	v = v.Elem()
	t = t.Elem()

	if t.Kind() != reflect.Struct {
		return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
	}

	info := structs.GetTagInfo(t)

	scanArgs := []interface{}{}
	for _, name := range names {
		fieldInfo := info.ByName(name)

		valueScanner := nopScannerValue
		if fieldInfo.Valid {
			valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
			if fieldInfo.SerializeAsJSON {
				valueScanner = &jsonSerializable{
					DriverName: dialect.DriverName(),
					Attr:       valueScanner,
				}
			}
		}

		scanArgs = append(scanArgs, valueScanner)
	}

	return rows.Scan(scanArgs...)
}

func buildSingleKeyDeleteQuery(
	dialect dialect,
	table string,
	idName string,
	idMaps []map[string]interface{},
) (query string, params []interface{}) {
	values := []string{}
	for i, m := range idMaps {
		values = append(values, dialect.Placeholder(i))
		params = append(params, m[idName])
	}

	return fmt.Sprintf(
		"DELETE FROM %s WHERE %s IN (%s)",
		dialect.Escape(table),
		dialect.Escape(idName),
		strings.Join(values, ","),
	), params
}

func buildCompositeKeyDeleteQuery(
	dialect dialect,
	table string,
	idNames []string,
	idMaps []map[string]interface{},
) (query string, params []interface{}) {
	escapedNames := []string{}
	for _, name := range idNames {
		escapedNames = append(escapedNames, dialect.Escape(name))
	}

	values := []string{}
	for _, m := range idMaps {
		tuple := []string{}
		for _, name := range idNames {
			params = append(params, m[name])
			tuple = append(tuple, dialect.Placeholder(len(values)))
		}
		values = append(values, "("+strings.Join(tuple, ",")+")")
	}

	return fmt.Sprintf(
		"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
		dialect.Escape(table),
		strings.Join(escapedNames, ","),
		strings.Join(values, ","),
	), params
}

// We implemented this function instead of using
// a regex or strings.Fields because we wanted
// to preserve the performance of the package.
func getFirstToken(s string) string {
	s = strings.TrimLeftFunc(s, unicode.IsSpace)

	var token strings.Builder
	for _, c := range s {
		if unicode.IsSpace(c) {
			break
		}
		token.WriteRune(c)
	}
	return token.String()
}

func buildSelectQuery(
	dialect dialect,
	structType reflect.Type,
	selectQueryCache map[reflect.Type]string,
) (string, error) {
	if selectQuery, found := selectQueryCache[structType]; found {
		return selectQuery, nil
	}

	info := structs.GetTagInfo(structType)
	var fields []string
	for _, field := range info.Fields() {
		fields = append(fields, dialect.Escape(field.Name))
	}

	query := "SELECT " + strings.Join(fields, ", ") + " "
	selectQueryCache[structType] = query
	return query, nil
}