package pgtype

import (
	"database/sql/driver"
	"encoding/binary"
	"errors"
	"fmt"
	"strings"

	"github.com/jackc/pgx/v5/internal/pgio"
)

// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite.
type CompositeIndexGetter interface {
	// IsNull returns true if the value is SQL NULL.
	IsNull() bool

	// Index returns the element at i.
	Index(i int) any
}

// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite.
type CompositeIndexScanner interface {
	// ScanNull sets the value to SQL NULL.
	ScanNull() error

	// ScanIndex returns a value usable as a scan target for i.
	ScanIndex(i int) any
}

type CompositeCodecField struct {
	Name string
	Type *Type
}

type CompositeCodec struct {
	Fields []CompositeCodecField
}

func (c *CompositeCodec) FormatSupported(format int16) bool {
	for _, f := range c.Fields {
		if !f.Type.Codec.FormatSupported(format) {
			return false
		}
	}

	return true
}

func (c *CompositeCodec) PreferredFormat() int16 {
	if c.FormatSupported(BinaryFormatCode) {
		return BinaryFormatCode
	}
	return TextFormatCode
}

func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
	if _, ok := value.(CompositeIndexGetter); !ok {
		return nil
	}

	switch format {
	case BinaryFormatCode:
		return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m}
	case TextFormatCode:
		return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m}
	}

	return nil
}

type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
	cc *CompositeCodec
	m  *Map
}

func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
	getter := value.(CompositeIndexGetter)

	if getter.IsNull() {
		return nil, nil
	}

	builder := NewCompositeBinaryBuilder(plan.m, buf)
	for i, field := range plan.cc.Fields {
		builder.AppendValue(field.Type.OID, getter.Index(i))
	}

	return builder.Finish()
}

type encodePlanCompositeCodecCompositeIndexGetterToText struct {
	cc *CompositeCodec
	m  *Map
}

func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
	getter := value.(CompositeIndexGetter)

	if getter.IsNull() {
		return nil, nil
	}

	b := NewCompositeTextBuilder(plan.m, buf)
	for i, field := range plan.cc.Fields {
		b.AppendValue(field.Type.OID, getter.Index(i))
	}

	return b.Finish()
}

func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
	switch format {
	case BinaryFormatCode:
		switch target.(type) {
		case CompositeIndexScanner:
			return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m}
		}
	case TextFormatCode:
		switch target.(type) {
		case CompositeIndexScanner:
			return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m}
		}
	}

	return nil
}

type scanPlanBinaryCompositeToCompositeIndexScanner struct {
	cc *CompositeCodec
	m  *Map
}

func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
	targetScanner := (target).(CompositeIndexScanner)

	if src == nil {
		return targetScanner.ScanNull()
	}

	scanner := NewCompositeBinaryScanner(plan.m, src)
	for i, field := range plan.cc.Fields {
		if scanner.Next() {
			fieldTarget := targetScanner.ScanIndex(i)
			if fieldTarget != nil {
				fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget)
				if fieldPlan == nil {
					return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID)
				}

				err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
				if err != nil {
					return err
				}
			}
		} else {
			return errors.New("read past end of composite")
		}
	}

	if err := scanner.Err(); err != nil {
		return err
	}

	return nil
}

type scanPlanTextCompositeToCompositeIndexScanner struct {
	cc *CompositeCodec
	m  *Map
}

func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
	targetScanner := (target).(CompositeIndexScanner)

	if src == nil {
		return targetScanner.ScanNull()
	}

	scanner := NewCompositeTextScanner(plan.m, src)
	for i, field := range plan.cc.Fields {
		if scanner.Next() {
			fieldTarget := targetScanner.ScanIndex(i)
			if fieldTarget != nil {
				fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget)
				if fieldPlan == nil {
					return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID)
				}

				err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
				if err != nil {
					return err
				}
			}
		} else {
			return errors.New("read past end of composite")
		}
	}

	if err := scanner.Err(); err != nil {
		return err
	}

	return nil
}

func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
	if src == nil {
		return nil, nil
	}

	switch format {
	case TextFormatCode:
		return string(src), nil
	case BinaryFormatCode:
		buf := make([]byte, len(src))
		copy(buf, src)
		return buf, nil
	default:
		return nil, fmt.Errorf("unknown format code %d", format)
	}
}

func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
	if src == nil {
		return nil, nil
	}

	switch format {
	case TextFormatCode:
		scanner := NewCompositeTextScanner(m, src)
		values := make(map[string]any, len(c.Fields))
		for i := 0; scanner.Next() && i < len(c.Fields); i++ {
			var v any
			fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v)
			if fieldPlan == nil {
				return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v)
			}

			err := fieldPlan.Scan(scanner.Bytes(), &v)
			if err != nil {
				return nil, err
			}

			values[c.Fields[i].Name] = v
		}

		if err := scanner.Err(); err != nil {
			return nil, err
		}

		return values, nil
	case BinaryFormatCode:
		scanner := NewCompositeBinaryScanner(m, src)
		values := make(map[string]any, len(c.Fields))
		for i := 0; scanner.Next() && i < len(c.Fields); i++ {
			var v any
			fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v)
			if fieldPlan == nil {
				return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
			}

			err := fieldPlan.Scan(scanner.Bytes(), &v)
			if err != nil {
				return nil, err
			}

			values[c.Fields[i].Name] = v
		}

		if err := scanner.Err(); err != nil {
			return nil, err
		}

		return values, nil
	default:
		return nil, fmt.Errorf("unknown format code %d", format)
	}

}

type CompositeBinaryScanner struct {
	m   *Map
	rp  int
	src []byte

	fieldCount int32
	fieldBytes []byte
	fieldOID   uint32
	err        error
}

// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner {
	rp := 0
	if len(src[rp:]) < 4 {
		return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
	}

	fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
	rp += 4

	return &CompositeBinaryScanner{
		m:          m,
		rp:         rp,
		src:        src,
		fieldCount: fieldCount,
	}
}

// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeBinaryScanner) Next() bool {
	if cfs.err != nil {
		return false
	}

	if cfs.rp == len(cfs.src) {
		return false
	}

	if len(cfs.src[cfs.rp:]) < 8 {
		cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
		return false
	}
	cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
	cfs.rp += 4

	fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
	cfs.rp += 4

	if fieldLen >= 0 {
		if len(cfs.src[cfs.rp:]) < fieldLen {
			cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
			return false
		}
		cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
		cfs.rp += fieldLen
	} else {
		cfs.fieldBytes = nil
	}

	return true
}

func (cfs *CompositeBinaryScanner) FieldCount() int {
	return int(cfs.fieldCount)
}

// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) Bytes() []byte {
	return cfs.fieldBytes
}

// OID returns the OID of the field most recently read by Scan().
func (cfs *CompositeBinaryScanner) OID() uint32 {
	return cfs.fieldOID
}

// Err returns any error encountered by the scanner.
func (cfs *CompositeBinaryScanner) Err() error {
	return cfs.err
}

type CompositeTextScanner struct {
	m   *Map
	rp  int
	src []byte

	fieldBytes []byte
	err        error
}

// NewCompositeTextScanner a scanner over a text encoded composite value.
func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner {
	if len(src) < 2 {
		return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
	}

	if src[0] != '(' {
		return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
	}

	if src[len(src)-1] != ')' {
		return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
	}

	return &CompositeTextScanner{
		m:   m,
		rp:  1,
		src: src,
	}
}

// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
// Next returns false, the Err method can be called to check if any errors occurred.
func (cfs *CompositeTextScanner) Next() bool {
	if cfs.err != nil {
		return false
	}

	if cfs.rp == len(cfs.src) {
		return false
	}

	switch cfs.src[cfs.rp] {
	case ',', ')': // null
		cfs.rp++
		cfs.fieldBytes = nil
		return true
	case '"': // quoted value
		cfs.rp++
		cfs.fieldBytes = make([]byte, 0, 16)
		for {
			ch := cfs.src[cfs.rp]

			if ch == '"' {
				cfs.rp++
				if cfs.src[cfs.rp] == '"' {
					cfs.fieldBytes = append(cfs.fieldBytes, '"')
					cfs.rp++
				} else {
					break
				}
			} else if ch == '\\' {
				cfs.rp++
				cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
				cfs.rp++
			} else {
				cfs.fieldBytes = append(cfs.fieldBytes, ch)
				cfs.rp++
			}
		}
		cfs.rp++
		return true
	default: // unquoted value
		start := cfs.rp
		for {
			ch := cfs.src[cfs.rp]
			if ch == ',' || ch == ')' {
				break
			}
			cfs.rp++
		}
		cfs.fieldBytes = cfs.src[start:cfs.rp]
		cfs.rp++
		return true
	}
}

// Bytes returns the bytes of the field most recently read by Scan().
func (cfs *CompositeTextScanner) Bytes() []byte {
	return cfs.fieldBytes
}

// Err returns any error encountered by the scanner.
func (cfs *CompositeTextScanner) Err() error {
	return cfs.err
}

type CompositeBinaryBuilder struct {
	m          *Map
	buf        []byte
	startIdx   int
	fieldCount uint32
	err        error
}

func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder {
	startIdx := len(buf)
	buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
	return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx}
}

func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) {
	if b.err != nil {
		return
	}

	if field == nil {
		b.buf = pgio.AppendUint32(b.buf, oid)
		b.buf = pgio.AppendInt32(b.buf, -1)
		b.fieldCount++
		return
	}

	plan := b.m.PlanEncode(oid, BinaryFormatCode, field)
	if plan == nil {
		b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
		return
	}

	b.buf = pgio.AppendUint32(b.buf, oid)
	lengthPos := len(b.buf)
	b.buf = pgio.AppendInt32(b.buf, -1)
	fieldBuf, err := plan.Encode(field, b.buf)
	if err != nil {
		b.err = err
		return
	}
	if fieldBuf != nil {
		binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
		b.buf = fieldBuf
	}

	b.fieldCount++
}

func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
	if b.err != nil {
		return nil, b.err
	}

	binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
	return b.buf, nil
}

type CompositeTextBuilder struct {
	m          *Map
	buf        []byte
	startIdx   int
	fieldCount uint32
	err        error
	fieldBuf   [32]byte
}

func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder {
	buf = append(buf, '(') // allocate room for number of fields
	return &CompositeTextBuilder{m: m, buf: buf}
}

func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) {
	if b.err != nil {
		return
	}

	if field == nil {
		b.buf = append(b.buf, ',')
		return
	}

	plan := b.m.PlanEncode(oid, TextFormatCode, field)
	if plan == nil {
		b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
		return
	}

	fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
	if err != nil {
		b.err = err
		return
	}
	if fieldBuf != nil {
		b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
	}

	b.buf = append(b.buf, ',')
}

func (b *CompositeTextBuilder) Finish() ([]byte, error) {
	if b.err != nil {
		return nil, b.err
	}

	b.buf[len(b.buf)-1] = ')'
	return b.buf, nil
}

var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)

func quoteCompositeField(src string) string {
	return `"` + quoteCompositeReplacer.Replace(src) + `"`
}

func quoteCompositeFieldIfNeeded(src string) string {
	if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
		return quoteCompositeField(src)
	}
	return src
}

// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target.
// It cannot scan a NULL, but the composite fields can be NULL.
type CompositeFields []any

func (cf CompositeFields) SkipUnderlyingTypePlan() {}

func (cf CompositeFields) IsNull() bool {
	return cf == nil
}

func (cf CompositeFields) Index(i int) any {
	return cf[i]
}

func (cf CompositeFields) ScanNull() error {
	return fmt.Errorf("cannot scan NULL into CompositeFields")
}

func (cf CompositeFields) ScanIndex(i int) any {
	return cf[i]
}