package pgx

import (
	"errors"
	"fmt"
	"time"
)

type Row Rows

func (r *Row) Scan(dest ...interface{}) (err error) {
	rows := (*Rows)(r)

	if rows.Err() != nil {
		return rows.Err()
	}

	if !rows.Next() {
		if rows.Err() == nil {
			return ErrNoRows
		} else {
			return rows.Err()
		}
	}

	rows.Scan(dest...)
	rows.Close()
	return rows.Err()
}

type Rows struct {
	pool      *ConnPool
	conn      *Conn
	mr        *MsgReader
	fields    []FieldDescription
	rowCount  int
	columnIdx int
	err       error
	closed    bool
}

func (rows *Rows) FieldDescriptions() []FieldDescription {
	return rows.fields
}

func (rows *Rows) MsgReader() *MsgReader {
	return rows.mr
}

func (rows *Rows) close() {
	if rows.pool != nil {
		rows.pool.Release(rows.conn)
		rows.pool = nil
	}

	rows.closed = true
}

func (rows *Rows) readUntilReadyForQuery() {
	for {
		t, r, err := rows.conn.rxMsg()
		if err != nil {
			rows.close()
			return
		}

		switch t {
		case readyForQuery:
			rows.conn.rxReadyForQuery(r)
			rows.close()
			return
		case rowDescription:
		case dataRow:
		case commandComplete:
		case bindComplete:
		default:
			err = rows.conn.processContextFreeMsg(t, r)
			if err != nil {
				rows.close()
				return
			}
		}
	}
}

func (rows *Rows) Close() {
	if rows.closed {
		return
	}
	rows.readUntilReadyForQuery()
	rows.close()
}

func (rows *Rows) Err() error {
	return rows.err
}

// abort signals that the query was not successfully sent to the server.
// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
func (rows *Rows) abort(err error) {
	if rows.err != nil {
		return
	}

	rows.err = err
	rows.close()
}

// Fatal signals an error occurred after the query was sent to the server
func (rows *Rows) Fatal(err error) {
	if rows.err != nil {
		return
	}

	rows.err = err
	rows.Close()
}

func (rows *Rows) Next() bool {
	if rows.closed {
		return false
	}

	rows.rowCount++
	rows.columnIdx = 0

	for {
		t, r, err := rows.conn.rxMsg()
		if err != nil {
			rows.Fatal(err)
			return false
		}

		switch t {
		case readyForQuery:
			rows.conn.rxReadyForQuery(r)
			rows.close()
			return false
		case dataRow:
			fieldCount := r.ReadInt16()
			if int(fieldCount) != len(rows.fields) {
				rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount)))
				return false
			}

			rows.mr = r
			return true
		case commandComplete:
		case bindComplete:
		default:
			err = rows.conn.processContextFreeMsg(t, r)
			if err != nil {
				rows.Fatal(err)
				return false
			}
		}
	}
}

func (rows *Rows) nextColumn() (*FieldDescription, int32, bool) {
	if rows.closed {
		return nil, 0, false
	}
	if len(rows.fields) <= rows.columnIdx {
		rows.Fatal(ProtocolError("No next column available"))
		return nil, 0, false
	}

	fd := &rows.fields[rows.columnIdx]
	rows.columnIdx++
	size := rows.mr.ReadInt32()
	return fd, size, true
}

func (rows *Rows) Scan(dest ...interface{}) (err error) {
	if len(rows.fields) != len(dest) {
		err = errors.New("Scan received wrong number of arguments")
		rows.Fatal(err)
		return err
	}

	for _, d := range dest {
		fd, size, _ := rows.nextColumn()
		switch d := d.(type) {
		case *bool:
			*d = decodeBool(rows, fd, size)
		case *[]byte:
			*d = decodeBytea(rows, fd, size)
		case *int64:
			*d = decodeInt8(rows, fd, size)
		case *int16:
			*d = decodeInt2(rows, fd, size)
		case *int32:
			*d = decodeInt4(rows, fd, size)
		case *string:
			*d = decodeText(rows, fd, size)
		case *float32:
			*d = decodeFloat4(rows, fd, size)
		case *float64:
			*d = decodeFloat8(rows, fd, size)
		case *time.Time:
			switch fd.DataType {
			case DateOid:
				*d = decodeDate(rows, fd, size)
			case TimestampTzOid:
				*d = decodeTimestampTz(rows, fd, size)
			case TimestampOid:
				*d = decodeTimestamp(rows, fd, size)
			default:
				err = fmt.Errorf("Can't convert OID %v to time.Time", fd.DataType)
				rows.Fatal(err)
				return err
			}

		case Scanner:
			err = d.Scan(rows, fd, size)
			if err != nil {
				return err
			}
		default:
			return errors.New("Unknown type")
		}
	}

	return nil
}

// Values returns an array of the row values
func (rows *Rows) Values() ([]interface{}, error) {
	if rows.closed {
		return nil, errors.New("rows is closed")
	}

	values := make([]interface{}, 0, len(rows.fields))

	for _, _ = range rows.fields {
		if rows.Err() != nil {
			return nil, rows.Err()
		}

		fd, size, _ := rows.nextColumn()

		switch fd.DataType {
		case BoolOid:
			values = append(values, decodeBool(rows, fd, size))
		case ByteaOid:
			values = append(values, decodeBytea(rows, fd, size))
		case Int8Oid:
			values = append(values, decodeInt8(rows, fd, size))
		case Int2Oid:
			values = append(values, decodeInt2(rows, fd, size))
		case Int4Oid:
			values = append(values, decodeInt4(rows, fd, size))
		case VarcharOid, TextOid:
			values = append(values, decodeText(rows, fd, size))
		case Float4Oid:
			values = append(values, decodeFloat4(rows, fd, size))
		case Float8Oid:
			values = append(values, decodeFloat8(rows, fd, size))
		case DateOid:
			values = append(values, decodeDate(rows, fd, size))
		case TimestampTzOid:
			values = append(values, decodeTimestampTz(rows, fd, size))
		case TimestampOid:
			values = append(values, decodeTimestamp(rows, fd, size))
		default:
			// if it is not an intrinsic type then return the text
			switch fd.FormatCode {
			case TextFormatCode:
				values = append(values, rows.MsgReader().ReadString(size))
			case BinaryFormatCode:
				return nil, errors.New("Values cannot handle binary format non-intrinsic types")
			default:
				return nil, errors.New("Unknown format code")
			}
		}
	}

	return values, rows.Err()
}

// TODO - document
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
	c.rows = Rows{conn: c}
	rows := &c.rows

	if ps, present := c.preparedStatements[sql]; present {
		rows.fields = ps.FieldDescriptions
		err := c.sendPreparedQuery(ps, args...)
		if err != nil {
			rows.abort(err)
		}
		return rows, rows.err
	}

	err := c.sendSimpleQuery(sql, args...)
	if err != nil {
		rows.abort(err)
		return rows, rows.err
	}

	// Simple queries don't know the field descriptions of the result.
	// Read until that is known before returning
	for {
		t, r, err := c.rxMsg()
		if err != nil {
			rows.Fatal(err)
			return rows, rows.err
		}

		switch t {
		case rowDescription:
			rows.fields = rows.conn.rxRowDescription(r)
			return rows, nil
		default:
			err = rows.conn.processContextFreeMsg(t, r)
			if err != nil {
				rows.Fatal(err)
				return rows, rows.err
			}
		}
	}
}

func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
	rows, _ := c.Query(sql, args...)
	return (*Row)(rows)
}