package structs

import (
	"fmt"
	"reflect"
	"strings"
	"sync"

	"github.com/vingarcia/ksql/internal/modifiers"
	"github.com/vingarcia/ksql/ksqlmodifiers"
)

// StructInfo stores metainformation of the struct
// parser in order to help the ksql library to work
// efectively and efficiently with reflection.
type StructInfo struct {
	IsNestedStruct bool
	byIndex        map[int]*FieldInfo
	byName         map[string]*FieldInfo
}

// FieldInfo contains reflection and tags
// information regarding a specific field
// of a struct.
type FieldInfo struct {
	// AttrName is the name of struct attribute of the struct.
	AttrName string

	// ColumnName is the name of the database column described by the ksql tag.
	ColumnName string

	// Index indexes the position of this attribute on the struct.
	// This field is meant to be used together with the
	// `reflect.Value.Field()` and `reflect.Type.Field()` methods.
	Index int

	// Valid will only be set to false if the instance
	// of this field was not initialized, i.e.
	// it denotes the zero value of a FieldInfo.
	Valid bool

	// Modifier contains the AttrModifier associated with this field.
	Modifier ksqlmodifiers.AttrModifier
}

// ByIndex returns either the *FieldInfo of a valid
// empty struct with Valid set to false
func (s StructInfo) ByIndex(idx int) *FieldInfo {
	field, found := s.byIndex[idx]
	if !found {
		return &FieldInfo{}
	}
	return field
}

// ByName returns either the *FieldInfo of a valid
// empty struct with Valid set to false
func (s StructInfo) ByName(name string) *FieldInfo {
	field, found := s.byName[name]
	if !found {
		return &FieldInfo{}
	}
	return field
}

func (s StructInfo) add(field FieldInfo) {
	field.Valid = true
	s.byIndex[field.Index] = &field
	s.byName[field.ColumnName] = &field

	// Make sure to save a lowercased version because
	// some databases will set these keys to lowercase.
	if _, found := s.byName[strings.ToLower(field.ColumnName)]; !found {
		s.byName[strings.ToLower(field.ColumnName)] = &field
	}
}

// NumFields ...
func (s StructInfo) NumFields() int {
	return len(s.byIndex)
}

// This cache is kept as a pkg variable
// because the total number of types on a program
// should be finite. So keeping a single cache here
// works fine.
var tagInfoCache = &sync.Map{}

// GetTagInfo efficiently returns the type information
// using a global private cache
//
// In the future we might move this cache inside
// a struct, but for now this accessor is the one
// we are using
func GetTagInfo(key reflect.Type) (StructInfo, error) {
	return getCachedTagInfo(tagInfoCache, key)
}

func getCachedTagInfo(tagInfoCache *sync.Map, key reflect.Type) (StructInfo, error) {
	if data, found := tagInfoCache.Load(key); found {
		info, ok := data.(StructInfo)
		if !ok {
			return StructInfo{}, fmt.Errorf("invalid cache entry, expected type StructInfo, found %T", data)
		}
		return info, nil
	}

	info, err := getTagNames(key)
	if err != nil {
		return StructInfo{}, err
	}

	tagInfoCache.Store(key, info)
	return info, nil
}

// StructToMap converts any struct type to a map based on
// the tag named `ksql`, i.e. `ksql:"map_key_name"`
//
// Valid pointers are dereferenced and copied to the map,
// null pointers are ignored.
//
// This function is efficient in the fact that it caches
// the slower steps of the reflection required to perform
// this task.
func StructToMap(obj interface{}) (map[string]interface{}, error) {
	v := reflect.ValueOf(obj)
	t := v.Type()

	if t.Kind() == reflect.Ptr {
		v = v.Elem()
		t = t.Elem()
	}
	if t.Kind() != reflect.Struct {
		return nil, fmt.Errorf("input must be a struct or struct pointer")
	}

	info, err := getCachedTagInfo(tagInfoCache, t)
	if err != nil {
		return nil, err
	}

	m := map[string]interface{}{}
	for i := 0; i < v.NumField(); i++ {
		fieldInfo := info.ByIndex(i)
		if !fieldInfo.Valid {
			continue
		}

		field := v.Field(i)
		ft := field.Type()
		if ft.Kind() == reflect.Ptr {
			if !field.IsNil() {
				field = field.Elem()
			} else {
				if !fieldInfo.Modifier.Nullable {
					continue
				}
			}
		}

		m[fieldInfo.ColumnName] = field.Interface()
	}

	return m, nil
}

// PtrConverter was created to make it easier
// to handle conversion between ptr and non ptr types, e.g.:
//
// - *type to *type
// - type to *type
// - *type to type
// - type to type
type PtrConverter struct {
	BaseType  reflect.Type
	BaseValue reflect.Value
	ElemType  reflect.Type
	ElemValue reflect.Value
}

// NewPtrConverter instantiates a PtrConverter from
// an empty interface.
//
// The input argument can be of any type, but
// if it is a pointer then its Elem() will be
// used as source value for the PtrConverter.Convert()
// method.
func NewPtrConverter(v interface{}) PtrConverter {
	if v == nil {
		// This is necessary so that reflect.ValueOf
		// returns a valid reflect.Value
		v = (*interface{})(nil)
	}

	baseValue := reflect.ValueOf(v)
	baseType := reflect.TypeOf(v)

	elemType := baseType
	elemValue := baseValue
	if baseType.Kind() == reflect.Ptr {
		elemType = elemType.Elem()
		elemValue = elemValue.Elem()
	}
	return PtrConverter{
		BaseType:  baseType,
		BaseValue: baseValue,
		ElemType:  elemType,
		ElemValue: elemValue,
	}
}

// Convert attempts to convert the ElemValue to the destType received
// as argument and then returns the converted reflect.Value or an error
func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) {
	destElemType := destType
	if destType.Kind() == reflect.Ptr {
		destElemType = destType.Elem()
	}

	// Return 0 valued destType instance:
	if p.BaseType.Kind() == reflect.Ptr && p.BaseValue.IsNil() {
		// Note that if destType is a ptr it will return a nil ptr.
		return reflect.New(destType).Elem(), nil
	}

	if !p.ElemType.ConvertibleTo(destElemType) {
		return reflect.Value{}, fmt.Errorf(
			"cannot convert from type %v to type %v", p.BaseType, destType,
		)
	}

	destValue := p.ElemValue.Convert(destElemType)

	// Get the address of destValue if necessary:
	if destType.Kind() == reflect.Ptr {
		if !destValue.CanAddr() {
			tmp := reflect.New(destElemType)
			tmp.Elem().Set(destValue)
			destValue = tmp
		} else {
			destValue = destValue.Addr()
		}
	}

	return destValue, nil
}

// This function collects only the names
// that will be used from the input type.
//
// This should save several calls to `Field(i).Tag.Get("foo")`
// which improves performance by a lot.
func getTagNames(t reflect.Type) (_ StructInfo, err error) {
	info := StructInfo{
		byIndex: map[int]*FieldInfo{},
		byName:  map[string]*FieldInfo{},
	}
	for i := 0; i < t.NumField(); i++ {
		// If this field is private:
		if t.Field(i).PkgPath != "" {
			return StructInfo{}, fmt.Errorf("all fields using the ksql tags must be exported, but %v is unexported", t)
		}

		attrName := t.Field(i).Name
		name := t.Field(i).Tag.Get("ksql")
		if name == "" {
			continue
		}

		tags := strings.Split(name, ",")
		var modifier ksqlmodifiers.AttrModifier
		if len(tags) > 1 {
			name = tags[0]
			modifier, err = modifiers.LoadGlobalModifier(tags[1])
			if err != nil {
				return StructInfo{}, fmt.Errorf("attribute contains invalid modifier name: %w", err)
			}
		}

		if _, found := info.byName[name]; found {
			return StructInfo{}, fmt.Errorf(
				"struct contains multiple attributes with the same ksql tag name: '%s'",
				name,
			)
		}

		info.add(FieldInfo{
			AttrName:   attrName,
			ColumnName: name,
			Index:      i,
			Modifier:   modifier,
		})
	}

	// If there were `ksql` tags present, then we are finished:
	if len(info.byIndex) > 0 {
		return info, nil
	}

	// If there are no `ksql` tags in the struct, lets assume
	// it is a struct tagged with `tablename` for allowing JOINs
	for i := 0; i < t.NumField(); i++ {
		name := t.Field(i).Tag.Get("tablename")
		if name == "" {
			continue
		}

		info.add(FieldInfo{
			AttrName:   t.Field(i).Name,
			ColumnName: name,
			Index:      i,
		})
	}

	if len(info.byIndex) == 0 {
		return StructInfo{}, fmt.Errorf("the struct must contain at least one attribute with the ksql tag")
	}

	info.IsNestedStruct = true

	return info, nil
}

// DecodeAsSliceOfStructs makes several checks
// while decoding an input type and returns
// useful information so that it is easier
// to manipulate the original slice later.
func DecodeAsSliceOfStructs(slice reflect.Type) (
	structType reflect.Type,
	isSliceOfPtrs bool,
	err error,
) {
	if slice.Kind() != reflect.Slice {
		err = fmt.Errorf(
			"expected input kind to be a slice but got %v",
			slice,
		)
		return
	}

	elemType := slice.Elem()
	isPtr := elemType.Kind() == reflect.Ptr

	if isPtr {
		elemType = elemType.Elem()
	}

	if elemType.Kind() != reflect.Struct {
		err = fmt.Errorf(
			"expected input to be a slice of structs but got %v",
			slice,
		)
		return
	}

	return elemType, isPtr, nil
}