package pgx_test

import (
	"context"
	"testing"
	"time"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgxtest"
	"github.com/stretchr/testify/require"
)

type testTracer struct {
	traceQueryStart    func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context
	traceQueryEnd      func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData)
	traceBatchStart    func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context
	traceBatchQuery    func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData)
	traceBatchEnd      func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData)
	traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context
	traceCopyFromEnd   func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData)
	tracePrepareStart  func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context
	tracePrepareEnd    func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData)
	traceConnectStart  func(ctx context.Context, data pgx.TraceConnectStartData) context.Context
	traceConnectEnd    func(ctx context.Context, data pgx.TraceConnectEndData)
}

type ctxKey string

func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
	if tt.traceQueryStart != nil {
		return tt.traceQueryStart(ctx, conn, data)
	}
	return ctx
}

func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
	if tt.traceQueryEnd != nil {
		tt.traceQueryEnd(ctx, conn, data)
	}
}

func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
	if tt.traceBatchStart != nil {
		return tt.traceBatchStart(ctx, conn, data)
	}
	return ctx
}

func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
	if tt.traceBatchQuery != nil {
		tt.traceBatchQuery(ctx, conn, data)
	}
}

func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
	if tt.traceBatchEnd != nil {
		tt.traceBatchEnd(ctx, conn, data)
	}
}

func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
	if tt.traceCopyFromStart != nil {
		return tt.traceCopyFromStart(ctx, conn, data)
	}
	return ctx
}

func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
	if tt.traceCopyFromEnd != nil {
		tt.traceCopyFromEnd(ctx, conn, data)
	}
}

func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
	if tt.tracePrepareStart != nil {
		return tt.tracePrepareStart(ctx, conn, data)
	}
	return ctx
}

func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
	if tt.tracePrepareEnd != nil {
		tt.tracePrepareEnd(ctx, conn, data)
	}
}

func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
	if tt.traceConnectStart != nil {
		return tt.traceConnectStart(ctx, data)
	}
	return ctx
}

func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
	if tt.traceConnectEnd != nil {
		tt.traceConnectEnd(ctx, data)
	}
}

func TestTraceExec(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		traceQueryStartCalled := false
		tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
			traceQueryStartCalled = true
			require.Equal(t, `select $1::text`, data.SQL)
			require.Len(t, data.Args, 1)
			require.Equal(t, `testing`, data.Args[0])
			return context.WithValue(ctx, ctxKey(ctxKey("fromTraceQueryStart")), "foo")
		}

		traceQueryEndCalled := false
		tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
			traceQueryEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceQueryStart"))))
			require.Equal(t, `SELECT 1`, data.CommandTag.String())
			require.NoError(t, data.Err)
		}

		_, err := conn.Exec(ctx, `select $1::text`, "testing")
		require.NoError(t, err)
		require.True(t, traceQueryStartCalled)
		require.True(t, traceQueryEndCalled)
	})
}

func TestTraceQuery(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		traceQueryStartCalled := false
		tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
			traceQueryStartCalled = true
			require.Equal(t, `select $1::text`, data.SQL)
			require.Len(t, data.Args, 1)
			require.Equal(t, `testing`, data.Args[0])
			return context.WithValue(ctx, ctxKey("fromTraceQueryStart"), "foo")
		}

		traceQueryEndCalled := false
		tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
			traceQueryEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceQueryStart")))
			require.Equal(t, `SELECT 1`, data.CommandTag.String())
			require.NoError(t, data.Err)
		}

		var s string
		err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s)
		require.NoError(t, err)
		require.Equal(t, "testing", s)
		require.True(t, traceQueryStartCalled)
		require.True(t, traceQueryEndCalled)
	})
}

func TestTraceBatchNormal(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		traceBatchStartCalled := false
		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
			traceBatchStartCalled = true
			require.NotNil(t, data.Batch)
			require.Equal(t, 2, data.Batch.Len())
			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
		}

		traceBatchQueryCalledCount := 0
		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
			traceBatchQueryCalledCount++
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			require.NoError(t, data.Err)
		}

		traceBatchEndCalled := false
		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
			traceBatchEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			require.NoError(t, data.Err)
		}

		batch := &pgx.Batch{}
		batch.Queue(`select 1`)
		batch.Queue(`select 2`)

		br := conn.SendBatch(context.Background(), batch)
		require.True(t, traceBatchStartCalled)

		var n int32
		err := br.QueryRow().Scan(&n)
		require.NoError(t, err)
		require.EqualValues(t, 1, n)
		require.EqualValues(t, 1, traceBatchQueryCalledCount)

		err = br.QueryRow().Scan(&n)
		require.NoError(t, err)
		require.EqualValues(t, 2, n)
		require.EqualValues(t, 2, traceBatchQueryCalledCount)

		err = br.Close()
		require.NoError(t, err)

		require.True(t, traceBatchEndCalled)
	})
}

func TestTraceBatchClose(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		traceBatchStartCalled := false
		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
			traceBatchStartCalled = true
			require.NotNil(t, data.Batch)
			require.Equal(t, 2, data.Batch.Len())
			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
		}

		traceBatchQueryCalledCount := 0
		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
			traceBatchQueryCalledCount++
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			require.NoError(t, data.Err)
		}

		traceBatchEndCalled := false
		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
			traceBatchEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			require.NoError(t, data.Err)
		}

		batch := &pgx.Batch{}
		batch.Queue(`select 1`)
		batch.Queue(`select 2`)

		br := conn.SendBatch(context.Background(), batch)
		require.True(t, traceBatchStartCalled)
		err := br.Close()
		require.NoError(t, err)
		require.EqualValues(t, 2, traceBatchQueryCalledCount)
		require.True(t, traceBatchEndCalled)
	})
}

func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		traceBatchStartCalled := false
		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
			traceBatchStartCalled = true
			require.NotNil(t, data.Batch)
			require.Equal(t, 3, data.Batch.Len())
			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
		}

		traceBatchQueryCalledCount := 0
		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
			traceBatchQueryCalledCount++
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			if traceBatchQueryCalledCount == 2 {
				require.Error(t, data.Err)
			} else {
				require.NoError(t, data.Err)
			}
		}

		traceBatchEndCalled := false
		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
			traceBatchEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			require.Error(t, data.Err)
		}

		batch := &pgx.Batch{}
		batch.Queue(`select 1`)
		batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
		batch.Queue(`select 3`)

		br := conn.SendBatch(context.Background(), batch)
		require.True(t, traceBatchStartCalled)

		commandTag, err := br.Exec()
		require.NoError(t, err)
		require.Equal(t, "SELECT 1", commandTag.String())

		commandTag, err = br.Exec()
		require.Error(t, err)
		require.Equal(t, "", commandTag.String())

		commandTag, err = br.Exec()
		require.Error(t, err)
		require.Equal(t, "", commandTag.String())

		err = br.Close()
		require.Error(t, err)
		require.EqualValues(t, 2, traceBatchQueryCalledCount)
		require.True(t, traceBatchEndCalled)
	})
}

func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		traceBatchStartCalled := false
		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
			traceBatchStartCalled = true
			require.NotNil(t, data.Batch)
			require.Equal(t, 3, data.Batch.Len())
			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
		}

		traceBatchQueryCalledCount := 0
		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
			traceBatchQueryCalledCount++
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			if traceBatchQueryCalledCount == 2 {
				require.Error(t, data.Err)
			} else {
				require.NoError(t, data.Err)
			}
		}

		traceBatchEndCalled := false
		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
			traceBatchEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
			require.Error(t, data.Err)
		}

		batch := &pgx.Batch{}
		batch.Queue(`select 1`)
		batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
		batch.Queue(`select 3`)

		br := conn.SendBatch(context.Background(), batch)
		require.True(t, traceBatchStartCalled)
		err := br.Close()
		require.Error(t, err)
		require.EqualValues(t, 2, traceBatchQueryCalledCount)
		require.True(t, traceBatchEndCalled)
	})
}

func TestTraceCopyFrom(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
		defer cancel()

		traceCopyFromStartCalled := false
		tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
			traceCopyFromStartCalled = true
			require.Equal(t, pgx.Identifier{"foo"}, data.TableName)
			require.Equal(t, []string{"a"}, data.ColumnNames)
			return context.WithValue(ctx, ctxKey("fromTraceCopyFromStart"), "foo")
		}

		traceCopyFromEndCalled := false
		tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
			traceCopyFromEndCalled = true
			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceCopyFromStart")))
			require.Equal(t, `COPY 2`, data.CommandTag.String())
			require.NoError(t, data.Err)
		}

		_, err := conn.Exec(ctx, `create temporary table foo(a int4)`)
		require.NoError(t, err)

		inputRows := [][]any{
			{int32(1)},
			{nil},
		}

		copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
		require.NoError(t, err)
		require.EqualValues(t, len(inputRows), copyCount)
		require.True(t, traceCopyFromStartCalled)
		require.True(t, traceCopyFromEndCalled)
	})
}

func TestTracePrepare(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	ctr := defaultConnTestRunner
	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
		config := defaultConnTestRunner.CreateConfig(ctx, t)
		config.Tracer = tracer
		return config
	}

	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
	defer cancel()

	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
		tracePrepareStartCalled := false
		tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
			tracePrepareStartCalled = true
			require.Equal(t, `ps`, data.Name)
			require.Equal(t, `select $1::text`, data.SQL)
			return context.WithValue(ctx, ctxKey("fromTracePrepareStart"), "foo")
		}

		tracePrepareEndCalled := false
		tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
			tracePrepareEndCalled = true
			require.False(t, data.AlreadyPrepared)
			require.NoError(t, data.Err)
		}

		_, err := conn.Prepare(ctx, "ps", `select $1::text`)
		require.NoError(t, err)
		require.True(t, tracePrepareStartCalled)
		require.True(t, tracePrepareEndCalled)

		tracePrepareStartCalled = false
		tracePrepareEndCalled = false
		tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
			tracePrepareEndCalled = true
			require.True(t, data.AlreadyPrepared)
			require.NoError(t, data.Err)
		}

		_, err = conn.Prepare(ctx, "ps", `select $1::text`)
		require.NoError(t, err)
		require.True(t, tracePrepareStartCalled)
		require.True(t, tracePrepareEndCalled)
	})
}

func TestTraceConnect(t *testing.T) {
	t.Parallel()

	tracer := &testTracer{}

	config := defaultConnTestRunner.CreateConfig(context.Background(), t)
	config.Tracer = tracer

	traceConnectStartCalled := false
	tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
		traceConnectStartCalled = true
		require.NotNil(t, data.ConnConfig)
		return context.WithValue(ctx, ctxKey("fromTraceConnectStart"), "foo")
	}

	traceConnectEndCalled := false
	tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
		traceConnectEndCalled = true
		require.NotNil(t, data.Conn)
		require.NoError(t, data.Err)
	}

	conn1, err := pgx.ConnectConfig(context.Background(), config)
	require.NoError(t, err)
	defer conn1.Close(context.Background())
	require.True(t, traceConnectStartCalled)
	require.True(t, traceConnectEndCalled)

	config, err = pgx.ParseConfig("host=/invalid")
	require.NoError(t, err)
	config.Tracer = tracer

	traceConnectStartCalled = false
	traceConnectEndCalled = false
	tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
		traceConnectEndCalled = true
		require.Nil(t, data.Conn)
		require.Error(t, data.Err)
	}

	conn2, err := pgx.ConnectConfig(context.Background(), config)
	require.Nil(t, conn2)
	require.Error(t, err)
	require.True(t, traceConnectStartCalled)
	require.True(t, traceConnectEndCalled)
}