package pgtype

import (
	"bytes"
	"database/sql/driver"
	"encoding/binary"
	"fmt"
	"math"
	"math/big"
	"strconv"
	"strings"

	"github.com/jackc/pgx/v5/internal/pgio"
)

// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
const nbase = 10000

const (
	pgNumericNaN     = 0x00000000c0000000
	pgNumericNaNSign = 0xc000

	pgNumericPosInf     = 0x00000000d0000000
	pgNumericPosInfSign = 0xd000

	pgNumericNegInf     = 0x00000000f0000000
	pgNumericNegInfSign = 0xf000
)

var big0 *big.Int = big.NewInt(0)
var big1 *big.Int = big.NewInt(1)
var big10 *big.Int = big.NewInt(10)
var big100 *big.Int = big.NewInt(100)
var big1000 *big.Int = big.NewInt(1000)

var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
var bigMinInt *big.Int = big.NewInt(int64(minInt))

var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))

var bigNBase *big.Int = big.NewInt(nbase)
var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)

type NumericScanner interface {
	ScanNumeric(v Numeric) error
}

type NumericValuer interface {
	NumericValue() (Numeric, error)
}

type Numeric struct {
	Int              *big.Int
	Exp              int32
	NaN              bool
	InfinityModifier InfinityModifier
	Valid            bool
}

func (n *Numeric) ScanNumeric(v Numeric) error {
	*n = v
	return nil
}

func (n Numeric) NumericValue() (Numeric, error) {
	return n, nil
}

func (n Numeric) Float64Value() (Float8, error) {
	if !n.Valid {
		return Float8{}, nil
	} else if n.NaN {
		return Float8{Float64: math.NaN(), Valid: true}, nil
	} else if n.InfinityModifier == Infinity {
		return Float8{Float64: math.Inf(1), Valid: true}, nil
	} else if n.InfinityModifier == NegativeInfinity {
		return Float8{Float64: math.Inf(-1), Valid: true}, nil
	}

	buf := make([]byte, 0, 32)

	if n.Int == nil {
		buf = append(buf, '0')
	} else {
		buf = append(buf, n.Int.String()...)
	}
	buf = append(buf, 'e')
	buf = append(buf, strconv.FormatInt(int64(n.Exp), 10)...)

	f, err := strconv.ParseFloat(string(buf), 64)
	if err != nil {
		return Float8{}, err
	}

	return Float8{Float64: f, Valid: true}, nil
}

func (n *Numeric) ScanInt64(v Int8) error {
	if !v.Valid {
		*n = Numeric{}
		return nil
	}

	*n = Numeric{Int: big.NewInt(v.Int64), Valid: true}
	return nil
}

func (n Numeric) Int64Value() (Int8, error) {
	if !n.Valid {
		return Int8{}, nil
	}

	bi, err := n.toBigInt()
	if err != nil {
		return Int8{}, err
	}

	if !bi.IsInt64() {
		return Int8{}, fmt.Errorf("cannot convert %v to int64", n)
	}

	return Int8{Int64: bi.Int64(), Valid: true}, nil
}

func (n *Numeric) toBigInt() (*big.Int, error) {
	if n.Exp == 0 {
		return n.Int, nil
	}

	num := &big.Int{}
	num.Set(n.Int)
	if n.Exp > 0 {
		mul := &big.Int{}
		mul.Exp(big10, big.NewInt(int64(n.Exp)), nil)
		num.Mul(num, mul)
		return num, nil
	}

	div := &big.Int{}
	div.Exp(big10, big.NewInt(int64(-n.Exp)), nil)
	remainder := &big.Int{}
	num.DivMod(num, div, remainder)
	if remainder.Cmp(big0) != 0 {
		return nil, fmt.Errorf("cannot convert %v to integer", n)
	}
	return num, nil
}

func parseNumericString(str string) (n *big.Int, exp int32, err error) {
	parts := strings.SplitN(str, ".", 2)
	digits := strings.Join(parts, "")

	if len(parts) > 1 {
		exp = int32(-len(parts[1]))
	} else {
		for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' {
			digits = digits[:len(digits)-1]
			exp++
		}
	}

	accum := &big.Int{}
	if _, ok := accum.SetString(digits, 10); !ok {
		return nil, 0, fmt.Errorf("%s is not a number", str)
	}

	return accum, exp, nil
}

func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
	digits := len(src) / 2
	if digits > 4 {
		digits = 4
	}

	rp := 0

	for i := 0; i < digits; i++ {
		if i > 0 {
			accum *= nbase
		}
		accum += int64(binary.BigEndian.Uint16(src[rp:]))
		rp += 2
	}

	return accum, rp, digits
}

// Scan implements the database/sql Scanner interface.
func (n *Numeric) Scan(src any) error {
	if src == nil {
		*n = Numeric{}
		return nil
	}

	switch src := src.(type) {
	case string:
		return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n)
	}

	return fmt.Errorf("cannot scan %T", src)
}

// Value implements the database/sql/driver Valuer interface.
func (n Numeric) Value() (driver.Value, error) {
	if !n.Valid {
		return nil, nil
	}

	buf, err := NumericCodec{}.PlanEncode(nil, 0, TextFormatCode, n).Encode(n, nil)
	if err != nil {
		return nil, err
	}
	return string(buf), err
}

func (n Numeric) MarshalJSON() ([]byte, error) {
	if !n.Valid {
		return []byte("null"), nil
	}

	if n.NaN {
		return []byte(`"NaN"`), nil
	}

	return n.numberTextBytes(), nil
}

// numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte {
	intStr := n.Int.String()
	buf := &bytes.Buffer{}
	exp := int(n.Exp)
	if exp > 0 {
		buf.WriteString(intStr)
		for i := 0; i < exp; i++ {
			buf.WriteByte('0')
		}
	} else if exp < 0 {
		if len(intStr) <= -exp {
			buf.WriteString("0.")
			leadingZeros := -exp - len(intStr)
			for i := 0; i < leadingZeros; i++ {
				buf.WriteByte('0')
			}
			buf.WriteString(intStr)
		} else if len(intStr) > -exp {
			dpPos := len(intStr) + exp
			buf.WriteString(intStr[:dpPos])
			buf.WriteByte('.')
			buf.WriteString(intStr[dpPos:])
		}
	} else {
		buf.WriteString(intStr)
	}

	return buf.Bytes()
}

type NumericCodec struct{}

func (NumericCodec) FormatSupported(format int16) bool {
	return format == TextFormatCode || format == BinaryFormatCode
}

func (NumericCodec) PreferredFormat() int16 {
	return BinaryFormatCode
}

func (NumericCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
	switch format {
	case BinaryFormatCode:
		switch value.(type) {
		case NumericValuer:
			return encodePlanNumericCodecBinaryNumericValuer{}
		case Float64Valuer:
			return encodePlanNumericCodecBinaryFloat64Valuer{}
		case Int64Valuer:
			return encodePlanNumericCodecBinaryInt64Valuer{}
		}
	case TextFormatCode:
		switch value.(type) {
		case NumericValuer:
			return encodePlanNumericCodecTextNumericValuer{}
		case Float64Valuer:
			return encodePlanNumericCodecTextFloat64Valuer{}
		case Int64Valuer:
			return encodePlanNumericCodecTextInt64Valuer{}
		}
	}

	return nil
}

type encodePlanNumericCodecBinaryNumericValuer struct{}

func (encodePlanNumericCodecBinaryNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
	n, err := value.(NumericValuer).NumericValue()
	if err != nil {
		return nil, err
	}

	return encodeNumericBinary(n, buf)
}

type encodePlanNumericCodecBinaryFloat64Valuer struct{}

func (encodePlanNumericCodecBinaryFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
	n, err := value.(Float64Valuer).Float64Value()
	if err != nil {
		return nil, err
	}

	if !n.Valid {
		return nil, nil
	}

	if math.IsNaN(n.Float64) {
		return encodeNumericBinary(Numeric{NaN: true, Valid: true}, buf)
	} else if math.IsInf(n.Float64, 1) {
		return encodeNumericBinary(Numeric{InfinityModifier: Infinity, Valid: true}, buf)
	} else if math.IsInf(n.Float64, -1) {
		return encodeNumericBinary(Numeric{InfinityModifier: NegativeInfinity, Valid: true}, buf)
	}
	num, exp, err := parseNumericString(strconv.FormatFloat(n.Float64, 'f', -1, 64))
	if err != nil {
		return nil, err
	}

	return encodeNumericBinary(Numeric{Int: num, Exp: exp, Valid: true}, buf)
}

type encodePlanNumericCodecBinaryInt64Valuer struct{}

func (encodePlanNumericCodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
	n, err := value.(Int64Valuer).Int64Value()
	if err != nil {
		return nil, err
	}

	if !n.Valid {
		return nil, nil
	}

	return encodeNumericBinary(Numeric{Int: big.NewInt(n.Int64), Valid: true}, buf)
}

func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) {
	if !n.Valid {
		return nil, nil
	}

	if n.NaN {
		buf = pgio.AppendUint64(buf, pgNumericNaN)
		return buf, nil
	} else if n.InfinityModifier == Infinity {
		buf = pgio.AppendUint64(buf, pgNumericPosInf)
		return buf, nil
	} else if n.InfinityModifier == NegativeInfinity {
		buf = pgio.AppendUint64(buf, pgNumericNegInf)
		return buf, nil
	}

	var sign int16
	if n.Int.Cmp(big0) < 0 {
		sign = 16384
	}

	absInt := &big.Int{}
	wholePart := &big.Int{}
	fracPart := &big.Int{}
	remainder := &big.Int{}
	absInt.Abs(n.Int)

	// Normalize absInt and exp to where exp is always a multiple of 4. This makes
	// converting to 16-bit base 10,000 digits easier.
	var exp int32
	switch n.Exp % 4 {
	case 1, -3:
		exp = n.Exp - 1
		absInt.Mul(absInt, big10)
	case 2, -2:
		exp = n.Exp - 2
		absInt.Mul(absInt, big100)
	case 3, -1:
		exp = n.Exp - 3
		absInt.Mul(absInt, big1000)
	default:
		exp = n.Exp
	}

	if exp < 0 {
		divisor := &big.Int{}
		divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
		wholePart.DivMod(absInt, divisor, fracPart)
		fracPart.Add(fracPart, divisor)
	} else {
		wholePart = absInt
	}

	var wholeDigits, fracDigits []int16

	for wholePart.Cmp(big0) != 0 {
		wholePart.DivMod(wholePart, bigNBase, remainder)
		wholeDigits = append(wholeDigits, int16(remainder.Int64()))
	}

	if fracPart.Cmp(big0) != 0 {
		for fracPart.Cmp(big1) != 0 {
			fracPart.DivMod(fracPart, bigNBase, remainder)
			fracDigits = append(fracDigits, int16(remainder.Int64()))
		}
	}

	buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits)))

	var weight int16
	if len(wholeDigits) > 0 {
		weight = int16(len(wholeDigits) - 1)
		if exp > 0 {
			weight += int16(exp / 4)
		}
	} else {
		weight = int16(exp/4) - 1 + int16(len(fracDigits))
	}
	buf = pgio.AppendInt16(buf, weight)

	buf = pgio.AppendInt16(buf, sign)

	var dscale int16
	if n.Exp < 0 {
		dscale = int16(-n.Exp)
	}
	buf = pgio.AppendInt16(buf, dscale)

	for i := len(wholeDigits) - 1; i >= 0; i-- {
		buf = pgio.AppendInt16(buf, wholeDigits[i])
	}

	for i := len(fracDigits) - 1; i >= 0; i-- {
		buf = pgio.AppendInt16(buf, fracDigits[i])
	}

	return buf, nil
}

type encodePlanNumericCodecTextNumericValuer struct{}

func (encodePlanNumericCodecTextNumericValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
	n, err := value.(NumericValuer).NumericValue()
	if err != nil {
		return nil, err
	}

	return encodeNumericText(n, buf)
}

type encodePlanNumericCodecTextFloat64Valuer struct{}

func (encodePlanNumericCodecTextFloat64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
	n, err := value.(Float64Valuer).Float64Value()
	if err != nil {
		return nil, err
	}

	if !n.Valid {
		return nil, nil
	}

	if math.IsNaN(n.Float64) {
		buf = append(buf, "NaN"...)
	} else if math.IsInf(n.Float64, 1) {
		buf = append(buf, "Infinity"...)
	} else if math.IsInf(n.Float64, -1) {
		buf = append(buf, "-Infinity"...)
	} else {
		buf = append(buf, strconv.FormatFloat(n.Float64, 'f', -1, 64)...)
	}
	return buf, nil
}

type encodePlanNumericCodecTextInt64Valuer struct{}

func (encodePlanNumericCodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
	n, err := value.(Int64Valuer).Int64Value()
	if err != nil {
		return nil, err
	}

	if !n.Valid {
		return nil, nil
	}

	buf = append(buf, strconv.FormatInt(n.Int64, 10)...)
	return buf, nil
}

func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) {
	if !n.Valid {
		return nil, nil
	}

	if n.NaN {
		buf = append(buf, "NaN"...)
		return buf, nil
	} else if n.InfinityModifier == Infinity {
		buf = append(buf, "Infinity"...)
		return buf, nil
	} else if n.InfinityModifier == NegativeInfinity {
		buf = append(buf, "-Infinity"...)
		return buf, nil
	}

	buf = append(buf, n.numberTextBytes()...)

	return buf, nil
}

func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {

	switch format {
	case BinaryFormatCode:
		switch target.(type) {
		case NumericScanner:
			return scanPlanBinaryNumericToNumericScanner{}
		case Float64Scanner:
			return scanPlanBinaryNumericToFloat64Scanner{}
		case Int64Scanner:
			return scanPlanBinaryNumericToInt64Scanner{}
		case TextScanner:
			return scanPlanBinaryNumericToTextScanner{}
		}
	case TextFormatCode:
		switch target.(type) {
		case NumericScanner:
			return scanPlanTextAnyToNumericScanner{}
		case Float64Scanner:
			return scanPlanTextAnyToFloat64Scanner{}
		case Int64Scanner:
			return scanPlanTextAnyToInt64Scanner{}
		}
	}

	return nil
}

type scanPlanBinaryNumericToNumericScanner struct{}

func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error {
	scanner := (dst).(NumericScanner)

	if src == nil {
		return scanner.ScanNumeric(Numeric{})
	}

	if len(src) < 8 {
		return fmt.Errorf("numeric incomplete %v", src)
	}

	rp := 0
	ndigits := binary.BigEndian.Uint16(src[rp:])
	rp += 2
	weight := int16(binary.BigEndian.Uint16(src[rp:]))
	rp += 2
	sign := binary.BigEndian.Uint16(src[rp:])
	rp += 2
	dscale := int16(binary.BigEndian.Uint16(src[rp:]))
	rp += 2

	if sign == pgNumericNaNSign {
		return scanner.ScanNumeric(Numeric{NaN: true, Valid: true})
	} else if sign == pgNumericPosInfSign {
		return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true})
	} else if sign == pgNumericNegInfSign {
		return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true})
	}

	if ndigits == 0 {
		return scanner.ScanNumeric(Numeric{Int: big.NewInt(0), Valid: true})
	}

	if len(src[rp:]) < int(ndigits)*2 {
		return fmt.Errorf("numeric incomplete %v", src)
	}

	accum := &big.Int{}

	for i := 0; i < int(ndigits+3)/4; i++ {
		int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
		rp += bytesRead

		if i > 0 {
			var mul *big.Int
			switch digitsRead {
			case 1:
				mul = bigNBase
			case 2:
				mul = bigNBaseX2
			case 3:
				mul = bigNBaseX3
			case 4:
				mul = bigNBaseX4
			default:
				return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
			}
			accum.Mul(accum, mul)
		}

		accum.Add(accum, big.NewInt(int64accum))
	}

	exp := (int32(weight) - int32(ndigits) + 1) * 4

	if dscale > 0 {
		fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1)
		fracDecimalDigits := fracNBaseDigits * 4

		if dscale > fracDecimalDigits {
			multCount := int(dscale - fracDecimalDigits)
			for i := 0; i < multCount; i++ {
				accum.Mul(accum, big10)
				exp--
			}
		} else if dscale < fracDecimalDigits {
			divCount := int(fracDecimalDigits - dscale)
			for i := 0; i < divCount; i++ {
				accum.Div(accum, big10)
				exp++
			}
		}
	}

	reduced := &big.Int{}
	remainder := &big.Int{}
	if exp >= 0 {
		for {
			reduced.DivMod(accum, big10, remainder)
			if remainder.Cmp(big0) != 0 {
				break
			}
			accum.Set(reduced)
			exp++
		}
	}

	if sign != 0 {
		accum.Neg(accum)
	}

	return scanner.ScanNumeric(Numeric{Int: accum, Exp: exp, Valid: true})
}

type scanPlanBinaryNumericToFloat64Scanner struct{}

func (scanPlanBinaryNumericToFloat64Scanner) Scan(src []byte, dst any) error {
	scanner := (dst).(Float64Scanner)

	if src == nil {
		return scanner.ScanFloat64(Float8{})
	}

	var n Numeric

	err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n)
	if err != nil {
		return err
	}

	f8, err := n.Float64Value()
	if err != nil {
		return err
	}

	return scanner.ScanFloat64(f8)
}

type scanPlanBinaryNumericToInt64Scanner struct{}

func (scanPlanBinaryNumericToInt64Scanner) Scan(src []byte, dst any) error {
	scanner := (dst).(Int64Scanner)

	if src == nil {
		return scanner.ScanInt64(Int8{})
	}

	var n Numeric

	err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n)
	if err != nil {
		return err
	}

	bigInt, err := n.toBigInt()
	if err != nil {
		return err
	}

	if !bigInt.IsInt64() {
		return fmt.Errorf("%v is out of range for int64", bigInt)
	}

	return scanner.ScanInt64(Int8{Int64: bigInt.Int64(), Valid: true})
}

type scanPlanBinaryNumericToTextScanner struct{}

func (scanPlanBinaryNumericToTextScanner) Scan(src []byte, dst any) error {
	scanner := (dst).(TextScanner)

	if src == nil {
		return scanner.ScanText(Text{})
	}

	var n Numeric

	err := scanPlanBinaryNumericToNumericScanner{}.Scan(src, &n)
	if err != nil {
		return err
	}

	sbuf, err := encodeNumericText(n, nil)
	if err != nil {
		return err
	}

	return scanner.ScanText(Text{String: string(sbuf), Valid: true})
}

type scanPlanTextAnyToNumericScanner struct{}

func (scanPlanTextAnyToNumericScanner) Scan(src []byte, dst any) error {
	scanner := (dst).(NumericScanner)

	if src == nil {
		return scanner.ScanNumeric(Numeric{})
	}

	if string(src) == "NaN" {
		return scanner.ScanNumeric(Numeric{NaN: true, Valid: true})
	} else if string(src) == "Infinity" {
		return scanner.ScanNumeric(Numeric{InfinityModifier: Infinity, Valid: true})
	} else if string(src) == "-Infinity" {
		return scanner.ScanNumeric(Numeric{InfinityModifier: NegativeInfinity, Valid: true})
	}

	num, exp, err := parseNumericString(string(src))
	if err != nil {
		return err
	}

	return scanner.ScanNumeric(Numeric{Int: num, Exp: exp, Valid: true})
}

func (c NumericCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
	if src == nil {
		return nil, nil
	}

	if format == TextFormatCode {
		return string(src), nil
	}

	var n Numeric
	err := codecScan(c, m, oid, format, src, &n)
	if err != nil {
		return nil, err
	}

	buf, err := m.Encode(oid, TextFormatCode, n, nil)
	if err != nil {
		return nil, err
	}
	return string(buf), nil
}

func (c NumericCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
	if src == nil {
		return nil, nil
	}

	var n Numeric
	err := codecScan(c, m, oid, format, src, &n)
	if err != nil {
		return nil, err
	}
	return n, nil
}