@@ -142,7 +142,8 @@ def __init__(self,
142
142
max_batch_size : int = 8 ,
143
143
max_draft_tokens : int = 0 ,
144
144
kv_cache_transceiver : KvCacheTransceiver = None ,
145
- draft_model_engine : Optional [ModelEngine ] = None ):
145
+ draft_model_engine : Optional [ModelEngine ] = None ,
146
+ start_worker : bool = True ):
146
147
super (PyExecutor , self ).__init__ ()
147
148
self .device_id = torch .cuda .current_device ()
148
149
self .global_rank = global_mpi_rank ()
@@ -223,20 +224,34 @@ def __init__(self,
223
224
self .stats = []
224
225
self .start_times = {}
225
226
self .new_active_requests_queue_latency_ms = 0
227
+ self .gather_all_responses = False
226
228
227
229
self .kv_cache_transceiver = kv_cache_transceiver
228
230
if self .dist .pp_size > 1 :
229
- event_loop = self ._executor_loop_pp_overlap if enable_overlap_scheduler else self ._executor_loop_pp
231
+ self . event_loop = self ._executor_loop_pp_overlap if enable_overlap_scheduler else self ._executor_loop_pp
230
232
else :
231
- event_loop = self ._executor_loop_overlap if enable_overlap_scheduler else self ._executor_loop
233
+ self . event_loop = self ._executor_loop_overlap if enable_overlap_scheduler else self ._executor_loop
232
234
233
- if self .draft_model_engine is not None and event_loop .__name__ != self ._executor_loop .__name__ :
235
+ if self .draft_model_engine is not None and self . event_loop .__name__ != self ._executor_loop .__name__ :
234
236
raise NotImplementedError (
235
237
"Drafting is not supported for selected executor loop. "
236
238
"Please disable disagg/pipeline parallelism/overlap scheduler." )
237
239
238
- self .worker_thread = threading .Thread (target = event_loop , daemon = True )
239
- self .worker_thread .start ()
240
+ self .worker_started = False
241
+ self .worker_lock = threading .Lock ()
242
+ if start_worker :
243
+ self .start_worker ()
244
+
245
+ def start_worker (self ):
246
+ self .worker_lock .acquire ()
247
+ try :
248
+ if self .worker_started == False :
249
+ self .worker_thread = threading .Thread (target = self .event_loop ,
250
+ daemon = True )
251
+ self .worker_thread .start ()
252
+ self .worker_started = True
253
+ finally :
254
+ self .worker_lock .release ()
240
255
241
256
def __enter__ (self ):
242
257
return self
@@ -306,6 +321,7 @@ def shutdown(self):
306
321
self .enqueue_lock .release ()
307
322
self .shutdown_event .wait ()
308
323
self .worker_thread .join ()
324
+ self .worker_started = False
309
325
for manager in self .resource_manager .resource_managers .values ():
310
326
if manager :
311
327
manager .shutdown ()
@@ -372,6 +388,9 @@ def enqueue_request(self,
372
388
self .enqueue_lock .release ()
373
389
return req_id
374
390
391
+ def set_gather_responses (self , gather_all_responses ):
392
+ self .gather_all_responses = gather_all_responses
393
+
375
394
@contextmanager
376
395
def _profiler (self ):
377
396
it = - 1
@@ -474,6 +493,9 @@ def _process_iter_stats(self, finished_requests, scheduled_batch,
474
493
iter_start_time , iter_stats ):
475
494
iter_end_time = time .time ()
476
495
iter_latency_ms = iter_end_time - iter_start_time
496
+ if iter_stats is None :
497
+ return
498
+
477
499
self ._append_iter_stats (
478
500
self ._update_iter_stats (iter_stats , iter_latency_ms ,
479
501
len (finished_requests ), scheduled_batch ))
@@ -725,7 +747,6 @@ def _executor_loop(self):
725
747
new_requests ) or got_finish_signal
726
748
if got_finish_signal and len (self .active_requests ) == 0 :
727
749
break
728
-
729
750
if self .enable_iter_perf_stats :
730
751
iter_stats = self ._get_init_iter_stats (
731
752
len (new_requests ),
@@ -1097,6 +1118,7 @@ def _fetch_new_requests(self):
1097
1118
self .dist .tp_size - 1 ) // self .dist .tp_size ,
1098
1119
max (self .all_ranks_num_active_requests ),
1099
1120
)
1121
+
1100
1122
self .has_context_request = False
1101
1123
new_requests_cur_rank = []
1102
1124
if new_requests != [] and new_requests [
@@ -1783,21 +1805,26 @@ def _handle_cancelled_requests(self):
1783
1805
1784
1806
@nvtx_range ("_enqueue_responses" )
1785
1807
def _enqueue_responses (self , responses : Dict [int , ExecutorResponse ]):
1786
- if 0 not in self .dist .mapping .tp_group :
1808
+ if 0 not in self .dist .mapping .tp_group and not self . gather_all_responses :
1787
1809
return
1788
1810
1789
1811
logger .debug (
1790
1812
f'before gather, rank = { self .dist .rank } , responses = { responses } ' )
1791
1813
if self .enable_attention_dp :
1792
- responses_list = self .dist .tp_gather (responses )
1793
- if self .dist .rank == 0 :
1814
+ if not self .gather_all_responses :
1815
+ responses_list = self .dist .tp_gather (responses )
1816
+ else :
1817
+ responses_list = self .dist .allgather (responses )
1818
+ if self .dist .rank == 0 or self .gather_all_responses :
1794
1819
gather_responses = {}
1795
- for resp in responses_list :
1796
- gather_responses .update (resp )
1797
- responses = gather_responses
1820
+ if responses_list is not None :
1821
+ for resp in responses_list :
1822
+ gather_responses .update (resp )
1823
+ responses = gather_responses
1798
1824
logger .debug (
1799
1825
f'after gather, rank = { self .dist .rank } , responses = { responses } ' )
1800
- if self .dist .rank == 0 :
1826
+
1827
+ if self .dist .rank == 0 or self .gather_all_responses :
1801
1828
with self .response_cv :
1802
1829
for req_id , resp in responses .items ():
1803
1830
if req_id in self .responses .keys ():
@@ -1817,6 +1844,7 @@ def _handle_responses(self):
1817
1844
for request in self .active_requests :
1818
1845
req_id = request .py_request_id
1819
1846
#no responses for dummy request, and finish it
1847
+
1820
1848
if request .is_dummy == True :
1821
1849
requests_to_terminate .append (request )
1822
1850
continue
@@ -1825,7 +1853,6 @@ def _handle_responses(self):
1825
1853
request .decoding_iter = request .py_decoding_iter
1826
1854
response = request .create_response (False , self .dist .rank )
1827
1855
request_done = False
1828
-
1829
1856
if response :
1830
1857
request_done = response .result .is_final
1831
1858
new_responses .update ({req_id : response })
@@ -1840,7 +1867,6 @@ def _handle_responses(self):
1840
1867
self ._enqueue_responses (new_responses )
1841
1868
for request in requests_to_terminate :
1842
1869
self ._terminate_request (request )
1843
-
1844
1870
return requests_to_terminate
1845
1871
1846
1872
@nvtx_range ("_terminate_ctx_finished_requests" )
0 commit comments