package pgtype

import (
	"bytes"
	"database/sql/driver"
	"encoding/binary"
	"fmt"
	"reflect"

	"github.com/jackc/pgx/v5/internal/pgio"
)

// MultirangeGetter is a type that can be converted into a PostgreSQL multirange.
type MultirangeGetter interface {
	// IsNull returns true if the value is SQL NULL.
	IsNull() bool

	// Len returns the number of elements in the multirange.
	Len() int

	// Index returns the element at i.
	Index(i int) any

	// IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode.
	IndexType() any
}

// MultirangeSetter is a type can be set from a PostgreSQL multirange.
type MultirangeSetter interface {
	// ScanNull sets the value to SQL NULL.
	ScanNull() error

	// SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing
	// elements.
	SetLen(n int) error

	// ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex.
	ScanIndex(i int) any

	// ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by
	// MultirangeCodec.PlanScan.
	ScanIndexType() any
}

// MultirangeCodec is a codec for any multirange type.
type MultirangeCodec struct {
	ElementType *Type
}

func (c *MultirangeCodec) FormatSupported(format int16) bool {
	return c.ElementType.Codec.FormatSupported(format)
}

func (c *MultirangeCodec) PreferredFormat() int16 {
	return c.ElementType.Codec.PreferredFormat()
}

func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
	multirangeValuer, ok := value.(MultirangeGetter)
	if !ok {
		return nil
	}

	elementType := multirangeValuer.IndexType()

	elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
	if elementEncodePlan == nil {
		return nil
	}

	switch format {
	case BinaryFormatCode:
		return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid}
	case TextFormatCode:
		return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid}
	}

	return nil
}

type encodePlanMultirangeCodecText struct {
	ac  *MultirangeCodec
	m   *Map
	oid uint32
}

func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
	multirange := value.(MultirangeGetter)

	if multirange.IsNull() {
		return nil, nil
	}

	elementCount := multirange.Len()

	buf = append(buf, '{')

	var encodePlan EncodePlan
	var lastElemType reflect.Type
	inElemBuf := make([]byte, 0, 32)
	for i := 0; i < elementCount; i++ {
		if i > 0 {
			buf = append(buf, ',')
		}

		elem := multirange.Index(i)
		var elemBuf []byte
		if elem != nil {
			elemType := reflect.TypeOf(elem)
			if lastElemType != elemType {
				lastElemType = elemType
				encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
				if encodePlan == nil {
					return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
				}
			}
			elemBuf, err = encodePlan.Encode(elem, inElemBuf)
			if err != nil {
				return nil, err
			}
		}

		if elemBuf == nil {
			return nil, fmt.Errorf("multirange cannot contain NULL element")
		} else {
			buf = append(buf, elemBuf...)
		}
	}

	buf = append(buf, '}')

	return buf, nil
}

type encodePlanMultirangeCodecBinary struct {
	ac  *MultirangeCodec
	m   *Map
	oid uint32
}

func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
	multirange := value.(MultirangeGetter)

	if multirange.IsNull() {
		return nil, nil
	}

	elementCount := multirange.Len()

	buf = pgio.AppendInt32(buf, int32(elementCount))

	var encodePlan EncodePlan
	var lastElemType reflect.Type
	for i := 0; i < elementCount; i++ {
		sp := len(buf)
		buf = pgio.AppendInt32(buf, -1)

		elem := multirange.Index(i)
		var elemBuf []byte
		if elem != nil {
			elemType := reflect.TypeOf(elem)
			if lastElemType != elemType {
				lastElemType = elemType
				encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
				if encodePlan == nil {
					return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
				}
			}
			elemBuf, err = encodePlan.Encode(elem, buf)
			if err != nil {
				return nil, err
			}
		}

		if elemBuf == nil {
			return nil, fmt.Errorf("multirange cannot contain NULL element")
		} else {
			buf = elemBuf
			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
		}
	}

	return buf, nil
}

func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
	multirangeScanner, ok := target.(MultirangeSetter)
	if !ok {
		return nil
	}

	elementType := multirangeScanner.ScanIndexType()

	elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
	if _, ok := elementScanPlan.(*scanPlanFail); ok {
		return nil
	}

	return &scanPlanMultirangeCodec{
		multirangeCodec: c,
		m:               m,
		oid:             oid,
		formatCode:      format,
	}
}

func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
	rp := 0

	elementCount := int(binary.BigEndian.Uint32(src[rp:]))
	rp += 4

	err := multirange.SetLen(elementCount)
	if err != nil {
		return err
	}

	if elementCount == 0 {
		return nil
	}

	elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
	if elementScanPlan == nil {
		elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
	}

	for i := 0; i < elementCount; i++ {
		elem := multirange.ScanIndex(i)
		elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
		rp += 4
		var elemSrc []byte
		if elemLen >= 0 {
			elemSrc = src[rp : rp+elemLen]
			rp += elemLen
		}
		err = elementScanPlan.Scan(elemSrc, elem)
		if err != nil {
			return fmt.Errorf("failed to scan multirange element %d: %w", i, err)
		}
	}

	return nil
}

func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
	elements, err := parseUntypedTextMultirange(src)
	if err != nil {
		return err
	}

	err = multirange.SetLen(len(elements))
	if err != nil {
		return err
	}

	if len(elements) == 0 {
		return nil
	}

	elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
	if elementScanPlan == nil {
		elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
	}

	for i, s := range elements {
		elem := multirange.ScanIndex(i)
		err = elementScanPlan.Scan([]byte(s), elem)
		if err != nil {
			return err
		}
	}

	return nil
}

type scanPlanMultirangeCodec struct {
	multirangeCodec *MultirangeCodec
	m               *Map
	oid             uint32
	formatCode      int16
	elementScanPlan ScanPlan
}

func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error {
	c := spac.multirangeCodec
	m := spac.m
	oid := spac.oid
	formatCode := spac.formatCode

	multirange := dst.(MultirangeSetter)

	if src == nil {
		return multirange.ScanNull()
	}

	switch formatCode {
	case BinaryFormatCode:
		return c.decodeBinary(m, oid, src, multirange)
	case TextFormatCode:
		return c.decodeText(m, oid, src, multirange)
	default:
		return fmt.Errorf("unknown format code %d", formatCode)
	}
}

func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
	if src == nil {
		return nil, nil
	}

	switch format {
	case TextFormatCode:
		return string(src), nil
	case BinaryFormatCode:
		buf := make([]byte, len(src))
		copy(buf, src)
		return buf, nil
	default:
		return nil, fmt.Errorf("unknown format code %d", format)
	}
}

func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
	if src == nil {
		return nil, nil
	}

	var multirange Multirange[Range[any]]
	err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange)
	return multirange, err
}

func parseUntypedTextMultirange(src []byte) ([]string, error) {
	elements := make([]string, 0)

	buf := bytes.NewBuffer(src)

	skipWhitespace(buf)

	r, _, err := buf.ReadRune()
	if err != nil {
		return nil, fmt.Errorf("invalid array: %w", err)
	}

	if r != '{' {
		return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r)
	}

parseValueLoop:
	for {
		r, _, err = buf.ReadRune()
		if err != nil {
			return nil, fmt.Errorf("invalid multirange: %w", err)
		}

		switch r {
		case ',': // skip range separator
		case '}':
			break parseValueLoop
		default:
			buf.UnreadRune()
			value, err := parseRange(buf)
			if err != nil {
				return nil, fmt.Errorf("invalid multirange value: %w", err)
			}
			elements = append(elements, value)
		}
	}

	skipWhitespace(buf)

	if buf.Len() > 0 {
		return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
	}

	return elements, nil

}

func parseRange(buf *bytes.Buffer) (string, error) {
	s := &bytes.Buffer{}

	boundSepRead := false
	for {
		r, _, err := buf.ReadRune()
		if err != nil {
			return "", err
		}

		switch r {
		case ',', '}':
			if r == ',' && !boundSepRead {
				boundSepRead = true
				break
			}
			buf.UnreadRune()
			return s.String(), nil
		}

		s.WriteRune(r)
	}
}

// Multirange is a generic multirange type.
//
// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
// enforce the RangeScanner constraint.
type Multirange[T RangeValuer] []T

func (r Multirange[T]) IsNull() bool {
	return r == nil
}

func (r Multirange[T]) Len() int {
	return len(r)
}

func (r Multirange[T]) Index(i int) any {
	return r[i]
}

func (r Multirange[T]) IndexType() any {
	var zero T
	return zero
}

func (r *Multirange[T]) ScanNull() error {
	*r = nil
	return nil
}

func (r *Multirange[T]) SetLen(n int) error {
	*r = make([]T, n)
	return nil
}

func (r Multirange[T]) ScanIndex(i int) any {
	return &r[i]
}

func (r Multirange[T]) ScanIndexType() any {
	return new(T)
}