Skip to content

Commit 545467d

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Introduce non-advancing events (#151)
Summary: Pull Request resolved: #151 TLDR; We go from simulating 50 steps on a toy model D72348663 in ``` Simulated MNIST training completed in 341.3135597705841s ``` to ``` Simulated MNIST training completed in 7.598558664321899s ``` and can run fast enough to simulate llama3 and nanoGPT training in reasonable time ___________________________ Consider the following scenario with `self_message_with_delay`: 1. Event queue: [Sleep(10)], T = 0 2. Simnet dequeues Sleep(10) and advances to T=10 3. `self_message_with_delay` task wakes up and sends ControllerMessage::CheckSupervision 4. ControllerActor handles ControllerMessage::CheckSupervision and calls `self_message_with_delay(ControllerMessage::CheckSupervision)` 5. Event queue: [Sleep(10)], T = 20 This loops endlessly and we advance the time far further than we should. Before we were remedying this in the Simnet event loop by adding a short debounce interval that would cause us to only move on to handling the events once the debounce duration has elapsed and no more events have arrived. In the above scenario this would mean that the events between T=10 and T=20 would arrive and be handled instead of instantly advancing to T=20, T=30, T=40... The problem with this is this actually dramatically slows down the wall time of running the simulator such that examples like nanoGPT and llama3 would take very long to simulate a single step. The solution to this is to have 2 separate queues: queue 1: events that are allowed to advance the sim time on it's own queue 2: events that are allowed to advance the sim time only when the earliest event in queue 1 comes after the earliest event in queue 2, or if some debounce period such that only queue 2 has events has elapsed. The second scenario occurs when there is a supervision failure Doing this allows us to safely remove the debounce, dramatically speeding up the wall time it takes to run the simulator while also ensuring that events from `self_message_with_delay` are scheduled at the correct time Reviewed By: kaiyuan-li Differential Revision: D75900568 fbshipit-source-id: 4a0e1786953216992b87e22013958e66751387b3
1 parent 5f1cd75 commit 545467d

File tree

3 files changed

+151
-55
lines changed

3 files changed

+151
-55
lines changed

hyperactor/src/clock.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ pub trait Clock {
4444
&self,
4545
duration: tokio::time::Duration,
4646
) -> impl std::future::Future<Output = ()> + Send + Sync;
47+
/// Initiates a sleep for the specified duration
48+
fn non_advancing_sleep(
49+
&self,
50+
duration: tokio::time::Duration,
51+
) -> impl std::future::Future<Output = ()> + Send + Sync;
4752
/// Get the current time according to the clock
4853
fn now(&self) -> tokio::time::Instant;
4954
/// Sleep until the specified deadline.
@@ -72,6 +77,12 @@ impl Clock for ClockKind {
7277
Self::Real(clock) => clock.sleep(duration).await,
7378
}
7479
}
80+
async fn non_advancing_sleep(&self, duration: tokio::time::Duration) {
81+
match self {
82+
Self::Sim(clock) => clock.non_advancing_sleep(duration).await,
83+
Self::Real(clock) => clock.non_advancing_sleep(duration).await,
84+
}
85+
}
7586
async fn sleep_until(&self, deadline: tokio::time::Instant) {
7687
match self {
7788
Self::Sim(clock) => clock.sleep_until(deadline).await,
@@ -131,6 +142,22 @@ impl Clock for SimClock {
131142
.unwrap();
132143
rx.recv().await.unwrap();
133144
}
145+
146+
async fn non_advancing_sleep(&self, duration: tokio::time::Duration) {
147+
let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone());
148+
let (tx, rx) = mailbox.open_once_port::<()>();
149+
150+
simnet_handle()
151+
.unwrap()
152+
.send_nonadvanceable_event(SleepEvent::new(
153+
tx.bind(),
154+
mailbox,
155+
duration.as_millis() as u64,
156+
))
157+
.unwrap();
158+
rx.recv().await.unwrap();
159+
}
160+
134161
async fn sleep_until(&self, deadline: tokio::time::Instant) {
135162
let now = self.now();
136163
if deadline <= now {
@@ -170,6 +197,9 @@ impl Clock for RealClock {
170197
async fn sleep(&self, duration: tokio::time::Duration) {
171198
tokio::time::sleep(duration).await;
172199
}
200+
async fn non_advancing_sleep(&self, duration: tokio::time::Duration) {
201+
Self::sleep(self, duration).await;
202+
}
173203
#[allow(clippy::disallowed_methods)]
174204
async fn sleep_until(&self, deadline: tokio::time::Instant) {
175205
tokio::time::sleep_until(deadline).await;

hyperactor/src/proc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ impl<A: Actor> Instance<A> {
847847
let self_id = self.self_id().clone();
848848
let clock = self.proc.state().clock.clone();
849849
tokio::spawn(async move {
850-
clock.sleep(delay).await;
850+
clock.non_advancing_sleep(delay).await;
851851
if let Err(e) = port.send(message) {
852852
// TODO: this is a fire-n-forget thread. We need to
853853
// handle errors in a better way.

hyperactor/src/simnet.rs

Lines changed: 120 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,12 @@ pub enum SimNetError {
334334
}
335335

336336
struct State {
337+
// The simnet is allowed to advance to the time of the earliest event in this queue at any time
337338
scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
339+
// The simnet is allowed to advance to the time of the earliest event in this queue at any time
340+
// only if the earliest event in `scheduled_events` occurs after the earliest event in this queue
341+
// or some debounce period has passed where there are only events in this queue.
342+
unadvanceable_scheduled_events: BTreeMap<SimulatorTimeInstant, Vec<ScheduledEvent>>,
338343
}
339344

340345
/// The state of the python training script.
@@ -349,8 +354,7 @@ pub enum TrainingScriptState {
349354
/// A handle to a running [`SimNet`] instance.
350355
pub struct SimNetHandle {
351356
join_handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
352-
event_tx: UnboundedSender<Box<dyn Event>>,
353-
scheduled_event_tx: UnboundedSender<ScheduledEvent>,
357+
event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
354358
config: Arc<Mutex<SimNetConfig>>,
355359
records: Option<Arc<Mutex<Vec<SimulatorEventRecord>>>>,
356360
pending_event_count: Arc<AtomicUsize>,
@@ -370,23 +374,38 @@ impl SimNetHandle {
370374
/// Sends an event to be scheduled onto the simnet's event loop
371375
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
372376
pub fn send_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
377+
self.send_event_impl(event, true)
378+
}
379+
380+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
381+
fn send_event_impl(&self, event: Box<dyn Event>, advanceable: bool) -> Result<(), SimNetError> {
373382
self.pending_event_count
374383
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
375384
self.event_tx
376-
.send(event)
385+
.send((event, advanceable, None))
377386
.map_err(|err| SimNetError::Closed(err.to_string()))
378387
}
379388

389+
/// Sends an non-advanceable event to be scheduled onto the simnet's event loop
390+
/// A non-advanceable event is an event that cannot advance the simnet's time unless
391+
/// the earliest event in the simnet's advancing event queue occurs after the earliest
392+
/// event in the simnet's non-advancing event queue, or some debounce period has passed
393+
/// where there are only events in the simnet's non-advancing event queue.
394+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
395+
pub fn send_nonadvanceable_event(&self, event: Box<dyn Event>) -> Result<(), SimNetError> {
396+
self.send_event_impl(event, false)
397+
}
398+
380399
/// Sends an event that already has a scheduled time onto the simnet's event loop
381400
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`.
382401
pub(crate) fn send_scheduled_event(
383402
&self,
384-
scheduled_event: ScheduledEvent,
403+
ScheduledEvent { event, time }: ScheduledEvent,
385404
) -> Result<(), SimNetError> {
386405
self.pending_event_count
387406
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
388-
self.scheduled_event_tx
389-
.send(scheduled_event)
407+
self.event_tx
408+
.send((event, true, Some(time)))
390409
.map_err(|err| SimNetError::Closed(err.to_string()))
391410
}
392411

@@ -402,9 +421,13 @@ impl SimNetHandle {
402421
self.pending_event_count
403422
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404423
self.event_tx
405-
.send(Box::new(NodeJoinEvent {
406-
channel_addr: address,
407-
}))
424+
.send((
425+
Box::new(NodeJoinEvent {
426+
channel_addr: address,
427+
}),
428+
true,
429+
None,
430+
))
408431
.map_err(|err| SimNetError::Closed(err.to_string()))
409432
}
410433

@@ -652,7 +675,7 @@ impl ProxyHandle {
652675
/// to_event: a function that specifies how to generate an Event from a forward message
653676
async fn start(
654677
proxy_addr: ChannelAddr,
655-
event_tx: UnboundedSender<Box<dyn Event>>,
678+
event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
656679
pending_event_count: Arc<AtomicUsize>,
657680
operational_message_tx: UnboundedSender<OperationalMessage>,
658681
) -> anyhow::Result<Self> {
@@ -683,7 +706,7 @@ impl ProxyHandle {
683706
}
684707
};
685708

686-
if let Err(e) = event_tx.send(event) {
709+
if let Err(e) = event_tx.send((event, true, None)) {
687710
tracing::error!("error sending message to simnet: {:?}", e);
688711
} else {
689712
pending_event_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
@@ -744,8 +767,8 @@ pub fn start(
744767

745768
let (training_script_state_tx, training_script_state_rx) =
746769
tokio::sync::watch::channel(TrainingScriptState::Running);
747-
let (event_tx, event_rx) = mpsc::unbounded_channel::<Box<dyn Event>>();
748-
let (scheduled_event_tx, scheduled_event_rx) = mpsc::unbounded_channel::<ScheduledEvent>();
770+
let (event_tx, event_rx) =
771+
mpsc::unbounded_channel::<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>();
749772
let records = Some(Arc::new(Mutex::new(vec![]))); // TODO remove optional
750773
let pending_event_count = Arc::new(AtomicUsize::new(0));
751774
let stop_signal = Arc::new(AtomicBool::new(false));
@@ -754,24 +777,19 @@ pub fn start(
754777
let config = config.clone();
755778
let pending_event_count = pending_event_count.clone();
756779
let stop_signal = stop_signal.clone();
757-
tokio::spawn(async move {
780+
tokio::task::spawn_blocking(move || {
758781
let mut net = SimNet {
759782
config,
760783
address_book,
761784
state: State {
762785
scheduled_events: BTreeMap::new(),
786+
unadvanceable_scheduled_events: BTreeMap::new(),
763787
},
764788
max_latency: Duration::from_millis(max_duration_ms),
765789
records,
766790
pending_event_count,
767791
};
768-
net.run(
769-
event_rx,
770-
scheduled_event_rx,
771-
training_script_state_rx,
772-
stop_signal,
773-
)
774-
.await;
792+
block_on(net.run(event_rx, training_script_state_rx, stop_signal));
775793
})
776794
};
777795
let join_handles = Arc::new(Mutex::new(vec![simnet_join_handle]));
@@ -789,7 +807,6 @@ pub fn start(
789807
HANDLE.get_or_init(|| SimNetHandle {
790808
join_handles,
791809
event_tx,
792-
scheduled_event_tx,
793810
config,
794811
records,
795812
pending_event_count,
@@ -850,60 +867,57 @@ impl SimNet {
850867
}
851868

852869
/// Schedule the event into the network.
853-
async fn schedule_event(&mut self, scheduled_event: ScheduledEvent) {
870+
async fn schedule_event(&mut self, scheduled_event: ScheduledEvent, advanceable: bool) {
854871
if let Some(records) = &self.records {
855872
records.lock().await.push(SimulatorEventRecord {
856873
summary: scheduled_event.event.summary(),
857874
start_at: SimClock.millis_since_start(SimClock.now()),
858875
end_at: scheduled_event.time,
859876
});
860877
}
861-
self.state
862-
.scheduled_events
863-
.entry(scheduled_event.time)
864-
.or_insert_with(Vec::new)
865-
.push(scheduled_event);
878+
if advanceable {
879+
self.state
880+
.scheduled_events
881+
.entry(scheduled_event.time)
882+
.or_insert_with(Vec::new)
883+
.push(scheduled_event);
884+
} else {
885+
self.state
886+
.unadvanceable_scheduled_events
887+
.entry(scheduled_event.time)
888+
.or_insert_with(Vec::new)
889+
.push(scheduled_event);
890+
}
866891
}
867892

868893
/// Run the simulation. This will dispatch all the messages in the network.
869894
/// And wait for new ones.
870895
async fn run(
871896
&mut self,
872-
mut event_rx: UnboundedReceiver<Box<dyn Event>>,
873-
mut scheduled_event_rx: UnboundedReceiver<ScheduledEvent>,
897+
mut event_rx: UnboundedReceiver<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
874898
training_script_state_rx: tokio::sync::watch::Receiver<TrainingScriptState>,
875899
stop_signal: Arc<AtomicBool>,
876900
) {
877901
// The simulated number of milliseconds the training script
878902
// has spent waiting for the backend to resolve a future
879903
let mut training_script_waiting_time: u64 = 0;
904+
// Duration elapsed while only non_advanceable_events has events
905+
let mut debounce_timer: Option<tokio::time::Instant> = None;
880906
'outer: loop {
881907
// Check if we should stop
882908
if stop_signal.load(Ordering::SeqCst) {
883909
break 'outer;
884910
}
885911

886-
// TODO: Find a way to drain all needed messages with better guarantees.
887-
//
888-
// Allow tiny grace period for messages to stop coming in before we actually advance time
889-
// to handle inflight events
890-
while let Ok(event) =
891-
tokio::time::timeout(tokio::time::Duration::from_millis(10), event_rx.recv()).await
892-
{
893-
if let Some(event) = event {
894-
let scheduled_event = self.create_scheduled_event(event).await;
895-
self.schedule_event(scheduled_event).await;
896-
} else {
897-
break 'outer;
898-
}
899-
}
900-
901-
while let Ok(ScheduledEvent { time, event }) = scheduled_event_rx.try_recv() {
902-
self.schedule_event(ScheduledEvent {
903-
time: time + training_script_waiting_time,
904-
event,
905-
})
906-
.await;
912+
while let Ok((event, advanceable, time)) = event_rx.try_recv() {
913+
let scheduled_event = match time {
914+
Some(time) => ScheduledEvent {
915+
time: time + training_script_waiting_time,
916+
event,
917+
},
918+
None => self.create_scheduled_event(event).await,
919+
};
920+
self.schedule_event(scheduled_event, advanceable).await;
907921
}
908922

909923
{
@@ -924,10 +938,62 @@ impl SimNet {
924938
{
925939
continue;
926940
}
941+
match (
942+
self.state.scheduled_events.first_key_value(),
943+
self.state.unadvanceable_scheduled_events.first_key_value(),
944+
) {
945+
(None, Some(_)) if debounce_timer.is_none() => {
946+
// Start debounce timer when only the non-advancedable
947+
// queue has events and the timer has not already started
948+
debounce_timer = Some(RealClock.now());
949+
}
950+
// Timer already active
951+
(None, Some(_)) => {}
952+
// Reset timer when non-advanceable queue is not the only queue with events
953+
_ => {
954+
debounce_timer = None;
955+
}
956+
}
927957
// process for next delivery time.
928-
let Some((scheduled_time, scheduled_events)) =
929-
self.state.scheduled_events.pop_first()
930-
else {
958+
let Some((scheduled_time, scheduled_events)) = (match (
959+
self.state.scheduled_events.first_key_value(),
960+
self.state.unadvanceable_scheduled_events.first_key_value(),
961+
) {
962+
(Some((advanceable_time, _)), Some((unadvanceable_time, _))) => {
963+
if unadvanceable_time < advanceable_time {
964+
self.state.unadvanceable_scheduled_events.pop_first()
965+
} else {
966+
self.state.scheduled_events.pop_first()
967+
}
968+
}
969+
(Some(_), None) => self.state.scheduled_events.pop_first(),
970+
(None, Some(_)) => match debounce_timer {
971+
Some(time) => {
972+
if time.elapsed() > tokio::time::Duration::from_millis(1000) {
973+
// debounce interval has elapsed, reset timer
974+
debounce_timer = None;
975+
self.state.unadvanceable_scheduled_events.pop_first()
976+
} else {
977+
None
978+
}
979+
}
980+
None => None,
981+
},
982+
(None, None) => None,
983+
}) else {
984+
tokio::select! {
985+
Some((event, advanceable, time)) = event_rx.recv() => {
986+
let scheduled_event = match time {
987+
Some(time) => ScheduledEvent {
988+
time: time + training_script_waiting_time,
989+
event,
990+
},
991+
None => self.create_scheduled_event(event).await,
992+
};
993+
self.schedule_event(scheduled_event, advanceable).await;
994+
},
995+
_ = RealClock.sleep(Duration::from_millis(10)) => {}
996+
}
931997
continue;
932998
};
933999
if training_script_state_rx.borrow().is_waiting() {

0 commit comments

Comments
 (0)