Skip to content

Commit 3ade937

Browse files
authored
feat: Run PyExecutor's inference flow to estimate max_num_tokens for kv_cache_manager (#3092)
Signed-off-by: Hui Gao <[email protected]>
1 parent 10d2d16 commit 3ade937

File tree

7 files changed

+411
-295
lines changed

7 files changed

+411
-295
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class CacheTransceiver : public BaseCacheTransceiver
118118
std::unique_ptr<DataRequester> mDataRequester;
119119
std::map<LlmRequest*, std::future<void>> mResponderFutures;
120120
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
121-
mpi::MpiComm const *mMpiGroupComm{}, *mMpiWorldComm{};
121+
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
122122
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
123123
mMpiGroupTPInDPComm;
124124
executor::kv_cache::CommState const* mCommState;

tensorrt_llm/_torch/pyexecutor/_util.py

+297-88
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

+3
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(
6464
self._output = None
6565
self._graph = None
6666

67+
def __del__(self):
68+
self._graph.reset()
69+
6770
def capture(
6871
self,
6972
forward_fn: Callable[[Dict[str, Any]], torch.Tensor],

tensorrt_llm/_torch/pyexecutor/llm_request.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def executor_request_to_llm_request(req_id: int,
141141
guided_decoding_params=executor_request.guided_decoding_params,
142142
encoder_input_tokens=None,
143143
return_encoder_output=False,
144-
client_id=executor_request.client_id,
144+
client_id=executor_request.client_id
145+
if executor_request.client_id is not None else req_id,
145146
priority=0.5,
146147
llm_request_type=llm_request_type,
147148
context_phase_params=executor_request.context_phase_params)

tensorrt_llm/_torch/pyexecutor/model_engine.py

+8
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,14 @@ def _init_model_capacity(self):
824824
self._init_max_seq_len()
825825
self._init_max_num_tokens()
826826

827+
def _release_cuda_graphs(self):
828+
for _, graph in self._cuda_graphs.items():
829+
del graph
830+
self._cuda_graphs.clear()
831+
torch.cuda.empty_cache()
832+
del self._cuda_graph_mem_pool
833+
self._cuda_graph_mem_pool = None
834+
827835
def get_max_num_sequences(self) -> int:
828836
"""
829837
Return the maximum number of sequences that the model supports. PyExecutor need this to compute max_num_active_requests

tensorrt_llm/_torch/pyexecutor/py_executor.py

+42-16
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def __init__(self,
142142
max_batch_size: int = 8,
143143
max_draft_tokens: int = 0,
144144
kv_cache_transceiver: KvCacheTransceiver = None,
145-
draft_model_engine: Optional[ModelEngine] = None):
145+
draft_model_engine: Optional[ModelEngine] = None,
146+
start_worker: bool = True):
146147
super(PyExecutor, self).__init__()
147148
self.device_id = torch.cuda.current_device()
148149
self.global_rank = global_mpi_rank()
@@ -223,20 +224,34 @@ def __init__(self,
223224
self.stats = []
224225
self.start_times = {}
225226
self.new_active_requests_queue_latency_ms = 0
227+
self.gather_all_responses = False
226228

227229
self.kv_cache_transceiver = kv_cache_transceiver
228230
if self.dist.pp_size > 1:
229-
event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp
231+
self.event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp
230232
else:
231-
event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop
233+
self.event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop
232234

233-
if self.draft_model_engine is not None and event_loop.__name__ != self._executor_loop.__name__:
235+
if self.draft_model_engine is not None and self.event_loop.__name__ != self._executor_loop.__name__:
234236
raise NotImplementedError(
235237
"Drafting is not supported for selected executor loop. "
236238
"Please disable disagg/pipeline parallelism/overlap scheduler.")
237239

238-
self.worker_thread = threading.Thread(target=event_loop, daemon=True)
239-
self.worker_thread.start()
240+
self.worker_started = False
241+
self.worker_lock = threading.Lock()
242+
if start_worker:
243+
self.start_worker()
244+
245+
def start_worker(self):
246+
self.worker_lock.acquire()
247+
try:
248+
if self.worker_started == False:
249+
self.worker_thread = threading.Thread(target=self.event_loop,
250+
daemon=True)
251+
self.worker_thread.start()
252+
self.worker_started = True
253+
finally:
254+
self.worker_lock.release()
240255

241256
def __enter__(self):
242257
return self
@@ -306,6 +321,7 @@ def shutdown(self):
306321
self.enqueue_lock.release()
307322
self.shutdown_event.wait()
308323
self.worker_thread.join()
324+
self.worker_started = False
309325
for manager in self.resource_manager.resource_managers.values():
310326
if manager:
311327
manager.shutdown()
@@ -372,6 +388,9 @@ def enqueue_request(self,
372388
self.enqueue_lock.release()
373389
return req_id
374390

391+
def set_gather_responses(self, gather_all_responses):
392+
self.gather_all_responses = gather_all_responses
393+
375394
@contextmanager
376395
def _profiler(self):
377396
it = -1
@@ -474,6 +493,9 @@ def _process_iter_stats(self, finished_requests, scheduled_batch,
474493
iter_start_time, iter_stats):
475494
iter_end_time = time.time()
476495
iter_latency_ms = iter_end_time - iter_start_time
496+
if iter_stats is None:
497+
return
498+
477499
self._append_iter_stats(
478500
self._update_iter_stats(iter_stats, iter_latency_ms,
479501
len(finished_requests), scheduled_batch))
@@ -725,7 +747,6 @@ def _executor_loop(self):
725747
new_requests) or got_finish_signal
726748
if got_finish_signal and len(self.active_requests) == 0:
727749
break
728-
729750
if self.enable_iter_perf_stats:
730751
iter_stats = self._get_init_iter_stats(
731752
len(new_requests),
@@ -1097,6 +1118,7 @@ def _fetch_new_requests(self):
10971118
self.dist.tp_size - 1) // self.dist.tp_size,
10981119
max(self.all_ranks_num_active_requests),
10991120
)
1121+
11001122
self.has_context_request = False
11011123
new_requests_cur_rank = []
11021124
if new_requests != [] and new_requests[
@@ -1783,21 +1805,26 @@ def _handle_cancelled_requests(self):
17831805

17841806
@nvtx_range("_enqueue_responses")
17851807
def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]):
1786-
if 0 not in self.dist.mapping.tp_group:
1808+
if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses:
17871809
return
17881810

17891811
logger.debug(
17901812
f'before gather, rank = {self.dist.rank}, responses = {responses}')
17911813
if self.enable_attention_dp:
1792-
responses_list = self.dist.tp_gather(responses)
1793-
if self.dist.rank == 0:
1814+
if not self.gather_all_responses:
1815+
responses_list = self.dist.tp_gather(responses)
1816+
else:
1817+
responses_list = self.dist.allgather(responses)
1818+
if self.dist.rank == 0 or self.gather_all_responses:
17941819
gather_responses = {}
1795-
for resp in responses_list:
1796-
gather_responses.update(resp)
1797-
responses = gather_responses
1820+
if responses_list is not None:
1821+
for resp in responses_list:
1822+
gather_responses.update(resp)
1823+
responses = gather_responses
17981824
logger.debug(
17991825
f'after gather, rank = {self.dist.rank}, responses = {responses}')
1800-
if self.dist.rank == 0:
1826+
1827+
if self.dist.rank == 0 or self.gather_all_responses:
18011828
with self.response_cv:
18021829
for req_id, resp in responses.items():
18031830
if req_id in self.responses.keys():
@@ -1817,6 +1844,7 @@ def _handle_responses(self):
18171844
for request in self.active_requests:
18181845
req_id = request.py_request_id
18191846
#no responses for dummy request, and finish it
1847+
18201848
if request.is_dummy == True:
18211849
requests_to_terminate.append(request)
18221850
continue
@@ -1825,7 +1853,6 @@ def _handle_responses(self):
18251853
request.decoding_iter = request.py_decoding_iter
18261854
response = request.create_response(False, self.dist.rank)
18271855
request_done = False
1828-
18291856
if response:
18301857
request_done = response.result.is_final
18311858
new_responses.update({req_id: response})
@@ -1840,7 +1867,6 @@ def _handle_responses(self):
18401867
self._enqueue_responses(new_responses)
18411868
for request in requests_to_terminate:
18421869
self._terminate_request(request)
1843-
18441870
return requests_to_terminate
18451871

18461872
@nvtx_range("_terminate_ctx_finished_requests")

0 commit comments

Comments
 (0)