package tracelog_test

import (
	"bytes"
	"context"
	"log"
	"os"
	"strings"
	"testing"
	"time"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgxtest"
	"github.com/jackc/pgx/v5/tracelog"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

var defaultConnTestRunner pgxtest.ConnTestRunner

func init() {
	defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
	defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
		require.NoError(t, err)
		return config
	}
}

type testLog struct {
	lvl  tracelog.LogLevel
	msg  string
	data map[string]any
}

type testLogger struct {
	logs []testLog
}

func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
	data["ctxdata"] = ctx.Value("ctxdata")
	l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
}

func (l *testLogger) Clear() {
	l.logs = l.logs[0:0]
}

func (l *testLogger) FilterByMsg(msg string) (res []testLog) {
	for _, log := range l.logs {
		if log.msg == msg {
			res = append(res, log)
		}
	}

	return res
}

func TestContextGetsPassedToLogMethod(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		ctx = context.WithValue(ctx, "ctxdata", "foo")
		_, err := conn.Exec(ctx, `;`)
		require.NoError(t, err)
		require.Len(t, logger.logs, 1)
		require.Equal(t, "foo", logger.logs[0].data["ctxdata"])
	})
}

func TestLoggerFunc(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	const testMsg = "foo"

	buf := bytes.Buffer{}
	logger := log.New(&buf, "", 0)

	createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc {
		return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) {
			logger.Printf("%s", testMsg)
		}
	}

	config := defaultConnTestRunner.CreateConfig(ctx, t)
	config.Tracer = &tracelog.TraceLog{
		Logger:   createAdapterFn(logger),
		LogLevel: tracelog.LogLevelTrace,
	}

	conn, err := pgx.ConnectConfig(ctx, config)
	require.NoError(t, err)
	defer conn.Close(ctx)

	buf.Reset() // Clear logs written when establishing connection

	if _, err := conn.Exec(context.TODO(), ";"); err != nil {
		t.Fatal(err)
	}

	if strings.TrimSpace(buf.String()) != testMsg {
		t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
	}
}

func TestLogQuery(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		_, err := conn.Exec(ctx, `select $1::text`, "testing")
		require.NoError(t, err)

		logs := logger.FilterByMsg("Query")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl)

		logger.Clear()

		_, err = conn.Exec(ctx, `foo`, "testing")
		require.Error(t, err)

		logs = logger.FilterByMsg("Query")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelError, logs[0].lvl)
		require.Equal(t, err, logs[0].data["err"])
	})
}

// https://github.com/jackc/pgx/issues/1365
func TestLogQueryArgsHandlesUTF8(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		var s string
		for i := 0; i < 63; i++ {
			s += "0"
		}
		s += "😊"

		_, err := conn.Exec(ctx, `select $1::text`, s)
		require.NoError(t, err)

		logs := logger.FilterByMsg("Query")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl)
		require.Equal(t, s, logs[0].data["args"].([]any)[0])

		logger.Clear()

		_, err = conn.Exec(ctx, `select $1::text`, s+"000")
		require.NoError(t, err)

		logs = logger.FilterByMsg("Query")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl)
		require.Equal(t, s+" (truncated 3 bytes)", logs[0].data["args"].([]any)[0])
	})
}

func TestLogCopyFrom(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, pgxtest.KnownOIDQueryExecModes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		_, err := conn.Exec(ctx, `create temporary table foo(a int4)`)
		require.NoError(t, err)

		logger.Clear()

		inputRows := [][]any{
			{int32(1)},
			{nil},
		}

		copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
		require.NoError(t, err)
		require.EqualValues(t, len(inputRows), copyCount)

		logs := logger.FilterByMsg("CopyFrom")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl)

		logger.Clear()

		inputRows = [][]any{
			{"not an integer"},
			{nil},
		}

		copyCount, err = conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
		require.Error(t, err)
		require.EqualValues(t, 0, copyCount)

		logs = logger.FilterByMsg("CopyFrom")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelError, logs[0].lvl)
	})
}

func TestLogConnect(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	config := defaultConnTestRunner.CreateConfig(ctx, t)
	config.Tracer = tracer

	conn1, err := pgx.ConnectConfig(ctx, config)
	require.NoError(t, err)
	defer conn1.Close(ctx)
	require.Len(t, logger.logs, 1)
	require.Equal(t, "Connect", logger.logs[0].msg)
	require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)

	logger.Clear()

	config, err = pgx.ParseConfig("host=/invalid")
	require.NoError(t, err)
	config.Tracer = tracer

	conn2, err := pgx.ConnectConfig(ctx, config)
	require.Nil(t, conn2)
	require.Error(t, err)
	require.Len(t, logger.logs, 1)
	require.Equal(t, "Connect", logger.logs[0].msg)
	require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl)
}

func TestLogBatchStatementsOnExec(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		batch := &pgx.Batch{}
		batch.Queue("create table foo (id bigint)")
		batch.Queue("drop table foo")

		br := conn.SendBatch(ctx, batch)

		_, err := br.Exec()
		require.NoError(t, err)

		_, err = br.Exec()
		require.NoError(t, err)

		err = br.Close()
		require.NoError(t, err)

		require.Len(t, logger.logs, 3)
		assert.Equal(t, "BatchQuery", logger.logs[0].msg)
		assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"])
		assert.Equal(t, "BatchQuery", logger.logs[1].msg)
		assert.Equal(t, "drop table foo", logger.logs[1].data["sql"])
		assert.Equal(t, "BatchClose", logger.logs[2].msg)

	})
}

func TestLogBatchStatementsOnBatchResultClose(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		batch := &pgx.Batch{}
		batch.Queue("select generate_series(1,$1)", 100)
		batch.Queue("select 1 = 1;")

		br := conn.SendBatch(ctx, batch)
		err := br.Close()
		require.NoError(t, err)

		require.Len(t, logger.logs, 3)
		assert.Equal(t, "BatchQuery", logger.logs[0].msg)
		assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"])
		assert.Equal(t, "BatchQuery", logger.logs[1].msg)
		assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"])
		assert.Equal(t, "BatchClose", logger.logs[2].msg)
	})
}

func TestLogPrepare(t *testing.T) {
	t.Parallel()

	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	logger := &testLogger{}
	tracer := &tracelog.TraceLog{
		Logger:   logger,
		LogLevel: tracelog.LogLevelTrace,
	}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{
		pgx.QueryExecModeCacheStatement,
		pgx.QueryExecModeCacheDescribe,
		pgx.QueryExecModeDescribeExec,
	}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		_, err := conn.Exec(ctx, `select $1::text`, "testing")
		require.NoError(t, err)

		logs := logger.FilterByMsg("Prepare")
		require.Len(t, logs, 1)
		require.Equal(t, tracelog.LogLevelInfo, logs[0].lvl)

		logger.Clear()

		_, err = conn.Exec(ctx, `foo aaaa`, "testing")
		require.Error(t, err)

		logs = logger.FilterByMsg("Prepare")
		require.Len(t, logs, 1)
		require.Equal(t, err, logs[0].data["err"])
	})

	ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		logger.Clear() // Clear any logs written when establishing connection

		_, err := conn.Prepare(ctx, "test_query_1", `select $1::int`)
		require.NoError(t, err)

		require.Len(t, logger.logs, 1)
		require.Equal(t, "Prepare", logger.logs[0].msg)
		require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)

		logger.Clear()

		_, err = conn.Prepare(ctx, `test_query_2`, "foo aaaa")
		require.Error(t, err)

		require.Len(t, logger.logs, 1)
		require.Equal(t, "Prepare", logger.logs[0].msg)
		require.Equal(t, err, logger.logs[0].data["err"])
	})
}