diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 59b89cf7..5ff9632c 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1773,9 +1773,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) if batch.err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err) multiResult.closed = true - multiResult.err = batch.err - pgConn.unlock() + pgConn.asyncClose() return multiResult } @@ -1783,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, err) multiResult.closed = true - multiResult.err = err - pgConn.unlock() + pgConn.asyncClose() return multiResult } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 2b582e24..b2d2f7f7 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) } +type mockConnection struct { + net.Conn + writeLatency *time.Duration +} + +func (m mockConnection) Write(b []byte) (n int, err error) { + time.Sleep(*m.writeLatency) + return m.Conn.Write(b) +} + +func TestConnExecBatchWriteError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var mockConn mockConnection + writeLatency := 0 * time.Second + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := net.Dial(network, address) + mockConn = mockConnection{conn, &writeLatency} + return mockConn, err + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + pgConn.Conn() + + ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel2() + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + writeLatency = 2 * time.Second + mrr := pgConn.ExecBatch(ctx2, batch) + err = mrr.Close() + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) +} + func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel()