package pgx

import (
	"database/sql/driver"
	"fmt"
	"math"
	"reflect"
	"time"

	"github.com/jackc/pgio"
	"github.com/jackc/pgtype"
	errors "golang.org/x/xerrors"
)

// PostgreSQL format codes
const (
	TextFormatCode   = 0
	BinaryFormatCode = 1
)

// SerializationError occurs on failure to encode or decode a value
type SerializationError string

func (e SerializationError) Error() string {
	return string(e)
}

func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
	if arg == nil {
		return nil, nil
	}

	refVal := reflect.ValueOf(arg)
	if refVal.Kind() == reflect.Ptr && refVal.IsNil() {
		return nil, nil
	}

	switch arg := arg.(type) {

	// https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface
	// []byte to database/sql instead of string. But that caused problems with the
	// simple protocol because the driver.Valuer case got taken before the
	// pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual
	// case because of https://github.com/jackc/pgx/issues/339. So instead we
	// special case JSON and JSONB.
	case *pgtype.JSON:
		buf, err := arg.EncodeText(ci, nil)
		if err != nil {
			return nil, err
		}
		if buf == nil {
			return nil, nil
		}
		return string(buf), nil
	case *pgtype.JSONB:
		buf, err := arg.EncodeText(ci, nil)
		if err != nil {
			return nil, err
		}
		if buf == nil {
			return nil, nil
		}
		return string(buf), nil

	case driver.Valuer:
		return callValuerValue(arg)
	case pgtype.TextEncoder:
		buf, err := arg.EncodeText(ci, nil)
		if err != nil {
			return nil, err
		}
		if buf == nil {
			return nil, nil
		}
		return string(buf), nil
	case float32:
		return float64(arg), nil
	case float64:
		return arg, nil
	case bool:
		return arg, nil
	case time.Time:
		return arg, nil
	case string:
		return arg, nil
	case []byte:
		return arg, nil
	case int8:
		return int64(arg), nil
	case int16:
		return int64(arg), nil
	case int32:
		return int64(arg), nil
	case int64:
		return arg, nil
	case int:
		return int64(arg), nil
	case uint8:
		return int64(arg), nil
	case uint16:
		return int64(arg), nil
	case uint32:
		return int64(arg), nil
	case uint64:
		if arg > math.MaxInt64 {
			return nil, errors.Errorf("arg too big for int64: %v", arg)
		}
		return int64(arg), nil
	case uint:
		if uint64(arg) > math.MaxInt64 {
			return nil, errors.Errorf("arg too big for int64: %v", arg)
		}
		return int64(arg), nil
	}

	if dt, found := ci.DataTypeForValue(arg); found {
		v := dt.Value
		err := v.Set(arg)
		if err != nil {
			return nil, err
		}
		buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil)
		if err != nil {
			return nil, err
		}
		if buf == nil {
			return nil, nil
		}
		return string(buf), nil
	}

	if refVal.Kind() == reflect.Ptr {
		arg = refVal.Elem().Interface()
		return convertSimpleArgument(ci, arg)
	}

	if strippedArg, ok := stripNamedType(&refVal); ok {
		return convertSimpleArgument(ci, strippedArg)
	}
	return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
}

func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
	if arg == nil {
		return pgio.AppendInt32(buf, -1), nil
	}

	switch arg := arg.(type) {
	case pgtype.BinaryEncoder:
		sp := len(buf)
		buf = pgio.AppendInt32(buf, -1)
		argBuf, err := arg.EncodeBinary(ci, buf)
		if err != nil {
			return nil, err
		}
		if argBuf != nil {
			buf = argBuf
			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
		}
		return buf, nil
	case pgtype.TextEncoder:
		sp := len(buf)
		buf = pgio.AppendInt32(buf, -1)
		argBuf, err := arg.EncodeText(ci, buf)
		if err != nil {
			return nil, err
		}
		if argBuf != nil {
			buf = argBuf
			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
		}
		return buf, nil
	case string:
		buf = pgio.AppendInt32(buf, int32(len(arg)))
		buf = append(buf, arg...)
		return buf, nil
	}

	refVal := reflect.ValueOf(arg)

	if refVal.Kind() == reflect.Ptr {
		if refVal.IsNil() {
			return pgio.AppendInt32(buf, -1), nil
		}
		arg = refVal.Elem().Interface()
		return encodePreparedStatementArgument(ci, buf, oid, arg)
	}

	if dt, ok := ci.DataTypeForOID(oid); ok {
		value := dt.Value
		err := value.Set(arg)
		if err != nil {
			{
				if arg, ok := arg.(driver.Valuer); ok {
					v, err := callValuerValue(arg)
					if err != nil {
						return nil, err
					}
					return encodePreparedStatementArgument(ci, buf, oid, v)
				}
			}

			return nil, err
		}

		sp := len(buf)
		buf = pgio.AppendInt32(buf, -1)
		argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
		if err != nil {
			return nil, err
		}
		if argBuf != nil {
			buf = argBuf
			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
		}
		return buf, nil
	}

	if strippedArg, ok := stripNamedType(&refVal); ok {
		return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
	}
	return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}

// chooseParameterFormatCode determines the correct format code for an
// argument to a prepared statement. It defaults to TextFormatCode if no
// determination can be made.
func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 {
	switch arg := arg.(type) {
	case pgtype.ParamFormatPreferrer:
		return arg.PreferredParamFormat()
	case pgtype.BinaryEncoder:
		return BinaryFormatCode
	case string, *string, pgtype.TextEncoder:
		return TextFormatCode
	}

	return ci.ParamFormatCodeForOID(oid)
}

func stripNamedType(val *reflect.Value) (interface{}, bool) {
	switch val.Kind() {
	case reflect.Int:
		convVal := int(val.Int())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Int8:
		convVal := int8(val.Int())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Int16:
		convVal := int16(val.Int())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Int32:
		convVal := int32(val.Int())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Int64:
		convVal := int64(val.Int())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Uint:
		convVal := uint(val.Uint())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Uint8:
		convVal := uint8(val.Uint())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Uint16:
		convVal := uint16(val.Uint())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Uint32:
		convVal := uint32(val.Uint())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.Uint64:
		convVal := uint64(val.Uint())
		return convVal, reflect.TypeOf(convVal) != val.Type()
	case reflect.String:
		convVal := val.String()
		return convVal, reflect.TypeOf(convVal) != val.Type()
	}

	return nil, false
}