From 1a9b2a53a543a9adbad04b08d3ebafaa4a6b3fdf Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Fri, 28 Jul 2023 18:04:26 -0500
Subject: [PATCH] Fix staticcheck issues

---
 conn.go        |  2 +-
 conn_test.go   |  7 ++++---
 query_test.go  | 14 +++-----------
 rows.go        |  2 +-
 rows_test.go   |  1 +
 tracer_test.go | 42 ++++++++++++++++++++++--------------------
 tx.go          |  1 -
 tx_test.go     |  2 ++
 8 files changed, 34 insertions(+), 37 deletions(-)

diff --git a/conn.go b/conn.go
index a79cab04..7c7081b4 100644
--- a/conn.go
+++ b/conn.go
@@ -507,7 +507,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a
 
 	mrr := c.pgConn.Exec(ctx, sql)
 	for mrr.NextResult() {
-		commandTag, err = mrr.ResultReader().Close()
+		commandTag, _ = mrr.ResultReader().Close()
 	}
 	err = mrr.Close()
 	return commandTag, err
diff --git a/conn_test.go b/conn_test.go
index 4c35b648..f988ccb3 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -315,7 +315,7 @@ func TestExecContextWithoutCancelation(t *testing.T) {
 	defer cancel()
 
 	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
-		ctx, cancelFunc := context.WithCancel(context.Background())
+		ctx, cancelFunc := context.WithCancel(ctx)
 		defer cancelFunc()
 
 		commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
@@ -336,7 +336,7 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
 	defer cancel()
 
 	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
-		ctx, cancelFunc := context.WithCancel(context.Background())
+		ctx, cancelFunc := context.WithCancel(ctx)
 		defer cancelFunc()
 
 		_, err := conn.Exec(ctx, "selct;")
@@ -361,7 +361,7 @@ func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
 	defer cancel()
 
 	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
-		ctx, cancelFunc := context.WithCancel(context.Background())
+		ctx, cancelFunc := context.WithCancel(ctx)
 		defer cancelFunc()
 
 		_, err := conn.Exec(ctx, "selct $1;", 1)
@@ -565,6 +565,7 @@ func TestListenNotify(t *testing.T) {
 	defer cancel()
 	notification, err = listener.WaitForNotification(ctx)
 	assert.True(t, pgconn.Timeout(err))
+	assert.Nil(t, notification)
 
 	// listener can listen again after a timeout
 	mustExec(t, notifier, "notify chat")
diff --git a/query_test.go b/query_test.go
index 5ba4f2a9..6d7e91df 100644
--- a/query_test.go
+++ b/query_test.go
@@ -656,16 +656,9 @@ func TestQueryRowCoreIntegerEncoding(t *testing.T) {
 	defer closeConn(t, conn)
 
 	type allTypes struct {
-		ui   uint
-		ui8  uint8
-		ui16 uint16
-		ui32 uint32
-		ui64 uint64
-		i    int
-		i8   int8
-		i16  int16
-		i32  int32
-		i64  int64
+		i16 int16
+		i32 int32
+		i64 int64
 	}
 
 	var actual, zero allTypes
@@ -983,7 +976,6 @@ func TestQueryRowErrors(t *testing.T) {
 
 	type allTypes struct {
 		i16 int16
-		i   int
 		s   string
 	}
 
diff --git a/rows.go b/rows.go
index 5b823b1e..1b1c8ac9 100644
--- a/rows.go
+++ b/rows.go
@@ -306,7 +306,7 @@ func (rows *baseRows) Values() ([]any, error) {
 				copy(newBuf, buf)
 				values = append(values, newBuf)
 			default:
-				rows.fatal(errors.New("Unknown format code"))
+				rows.fatal(errors.New("unknown format code"))
 			}
 		}
 
diff --git a/rows_test.go b/rows_test.go
index bf3ed986..b2d1137a 100644
--- a/rows_test.go
+++ b/rows_test.go
@@ -270,6 +270,7 @@ func TestCollectOneRowPrefersPostgreSQLErrorOverErrNoRows(t *testing.T) {
 		var pgErr *pgconn.PgError
 		require.ErrorAs(t, err, &pgErr)
 		require.Equal(t, "23505", pgErr.Code)
+		require.Equal(t, "", name)
 	})
 }
 
diff --git a/tracer_test.go b/tracer_test.go
index 0eaf8e38..a0fea71e 100644
--- a/tracer_test.go
+++ b/tracer_test.go
@@ -24,6 +24,8 @@ type testTracer struct {
 	traceConnectEnd    func(ctx context.Context, data pgx.TraceConnectEndData)
 }
 
+type ctxKey string
+
 func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
 	if tt.traceQueryStart != nil {
 		return tt.traceQueryStart(ctx, conn, data)
@@ -117,13 +119,13 @@ func TestTraceExec(t *testing.T) {
 			require.Equal(t, `select $1::text`, data.SQL)
 			require.Len(t, data.Args, 1)
 			require.Equal(t, `testing`, data.Args[0])
-			return context.WithValue(ctx, "fromTraceQueryStart", "foo")
+			return context.WithValue(ctx, ctxKey(ctxKey("fromTraceQueryStart")), "foo")
 		}
 
 		traceQueryEndCalled := false
 		tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
 			traceQueryEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceQueryStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceQueryStart"))))
 			require.Equal(t, `SELECT 1`, data.CommandTag.String())
 			require.NoError(t, data.Err)
 		}
@@ -157,13 +159,13 @@ func TestTraceQuery(t *testing.T) {
 			require.Equal(t, `select $1::text`, data.SQL)
 			require.Len(t, data.Args, 1)
 			require.Equal(t, `testing`, data.Args[0])
-			return context.WithValue(ctx, "fromTraceQueryStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTraceQueryStart"), "foo")
 		}
 
 		traceQueryEndCalled := false
 		tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
 			traceQueryEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceQueryStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceQueryStart")))
 			require.Equal(t, `SELECT 1`, data.CommandTag.String())
 			require.NoError(t, data.Err)
 		}
@@ -198,20 +200,20 @@ func TestTraceBatchNormal(t *testing.T) {
 			traceBatchStartCalled = true
 			require.NotNil(t, data.Batch)
 			require.Equal(t, 2, data.Batch.Len())
-			return context.WithValue(ctx, "fromTraceBatchStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
 		}
 
 		traceBatchQueryCalledCount := 0
 		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
 			traceBatchQueryCalledCount++
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			require.NoError(t, data.Err)
 		}
 
 		traceBatchEndCalled := false
 		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
 			traceBatchEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			require.NoError(t, data.Err)
 		}
 
@@ -261,20 +263,20 @@ func TestTraceBatchClose(t *testing.T) {
 			traceBatchStartCalled = true
 			require.NotNil(t, data.Batch)
 			require.Equal(t, 2, data.Batch.Len())
-			return context.WithValue(ctx, "fromTraceBatchStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
 		}
 
 		traceBatchQueryCalledCount := 0
 		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
 			traceBatchQueryCalledCount++
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			require.NoError(t, data.Err)
 		}
 
 		traceBatchEndCalled := false
 		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
 			traceBatchEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			require.NoError(t, data.Err)
 		}
 
@@ -312,13 +314,13 @@ func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
 			traceBatchStartCalled = true
 			require.NotNil(t, data.Batch)
 			require.Equal(t, 3, data.Batch.Len())
-			return context.WithValue(ctx, "fromTraceBatchStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
 		}
 
 		traceBatchQueryCalledCount := 0
 		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
 			traceBatchQueryCalledCount++
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			if traceBatchQueryCalledCount == 2 {
 				require.Error(t, data.Err)
 			} else {
@@ -329,7 +331,7 @@ func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
 		traceBatchEndCalled := false
 		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
 			traceBatchEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			require.Error(t, data.Err)
 		}
 
@@ -381,13 +383,13 @@ func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
 			traceBatchStartCalled = true
 			require.NotNil(t, data.Batch)
 			require.Equal(t, 3, data.Batch.Len())
-			return context.WithValue(ctx, "fromTraceBatchStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
 		}
 
 		traceBatchQueryCalledCount := 0
 		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
 			traceBatchQueryCalledCount++
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			if traceBatchQueryCalledCount == 2 {
 				require.Error(t, data.Err)
 			} else {
@@ -398,7 +400,7 @@ func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
 		traceBatchEndCalled := false
 		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
 			traceBatchEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
 			require.Error(t, data.Err)
 		}
 
@@ -440,13 +442,13 @@ func TestTraceCopyFrom(t *testing.T) {
 			traceCopyFromStartCalled = true
 			require.Equal(t, pgx.Identifier{"foo"}, data.TableName)
 			require.Equal(t, []string{"a"}, data.ColumnNames)
-			return context.WithValue(ctx, "fromTraceCopyFromStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTraceCopyFromStart"), "foo")
 		}
 
 		traceCopyFromEndCalled := false
 		tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
 			traceCopyFromEndCalled = true
-			require.Equal(t, "foo", ctx.Value("fromTraceCopyFromStart"))
+			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceCopyFromStart")))
 			require.Equal(t, `COPY 2`, data.CommandTag.String())
 			require.NoError(t, data.Err)
 		}
@@ -488,7 +490,7 @@ func TestTracePrepare(t *testing.T) {
 			tracePrepareStartCalled = true
 			require.Equal(t, `ps`, data.Name)
 			require.Equal(t, `select $1::text`, data.SQL)
-			return context.WithValue(ctx, "fromTracePrepareStart", "foo")
+			return context.WithValue(ctx, ctxKey("fromTracePrepareStart"), "foo")
 		}
 
 		tracePrepareEndCalled := false
@@ -530,7 +532,7 @@ func TestTraceConnect(t *testing.T) {
 	tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
 		traceConnectStartCalled = true
 		require.NotNil(t, data.ConnConfig)
-		return context.WithValue(ctx, "fromTraceConnectStart", "foo")
+		return context.WithValue(ctx, ctxKey("fromTraceConnectStart"), "foo")
 	}
 
 	traceConnectEndCalled := false
diff --git a/tx.go b/tx.go
index 575c17a7..8feeb512 100644
--- a/tx.go
+++ b/tx.go
@@ -152,7 +152,6 @@ type Tx interface {
 // called on the dbTx.
 type dbTx struct {
 	conn         *Conn
-	err          error
 	savepointNum int64
 	closed       bool
 }
diff --git a/tx_test.go b/tx_test.go
index 3c0ed285..cd4fb207 100644
--- a/tx_test.go
+++ b/tx_test.go
@@ -555,6 +555,7 @@ func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
 				require.NoError(t, err)
 				return nil
 			})
+			require.NoError(t, err)
 
 			return nil
 		})
@@ -601,6 +602,7 @@ func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
 
 		return nil
 	})
+	require.NoError(t, err)
 
 	var n int64
 	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)