Skip to content

Commit f55f386

Browse files
zdevitofacebook-github-bot
authored andcommitted
testing for tensor engine (#199)
Summary: Pull Request resolved: #199 hook in the actor mesh based controller to our test suite as an additional backend to suss out bugs ghstack-source-id: 289874844 Reviewed By: mariusae Differential Revision: D76171866 fbshipit-source-id: cae6846c4af24f735f1c0b293f2982735abebf84
1 parent 93c58be commit f55f386

File tree

14 files changed

+219
-62
lines changed

14 files changed

+219
-62
lines changed

monarch_extension/src/mesh_controller.rs

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use std::collections::VecDeque;
910
use std::iter::repeat_n;
1011
use std::sync::Arc;
1112
use std::sync::atomic::AtomicUsize;
@@ -49,7 +50,7 @@ use crate::convert::convert;
4950
struct _Controller {
5051
controller_instance: Arc<Mutex<InstanceWrapper<ControllerMessage>>>,
5152
workers: RootActorMesh<'static, WorkerActor>,
52-
pending_messages: Vec<PyObject>,
53+
pending_messages: VecDeque<PyObject>,
5354
history: history::History,
5455
}
5556

@@ -64,7 +65,7 @@ impl _Controller {
6465
) -> PyResult<()> {
6566
for (seq, response) in responses {
6667
let message = crate::client::WorkerResponse::new(seq, response);
67-
self.pending_messages.push(message.into_py(py));
68+
self.pending_messages.push_back(message.into_py(py));
6869
}
6970
Ok(())
7071
}
@@ -86,7 +87,7 @@ impl _Controller {
8687
} => {
8788
let dm = crate::client::DebuggerMessage::new(debugger_actor_id.into(), action)?
8889
.into_py(py);
89-
self.pending_messages.push(dm);
90+
self.pending_messages.push_back(dm);
9091
}
9192
ControllerMessage::Status {
9293
seq,
@@ -112,15 +113,19 @@ impl _Controller {
112113
})
113114
}
114115
fn send_slice(&mut self, slice: Slice, message: WorkerMessage) -> PyResult<()> {
115-
let shape = Shape::new(
116-
(0..slice.sizes().len()).map(|i| format!("d{i}")).collect(),
117-
slice,
118-
)
119-
.unwrap();
120-
let worker_slice = SlicedActorMesh::new(&self.workers, shape);
121-
worker_slice
122-
.cast(ndslice::Selection::True, message)
116+
self.workers
117+
.cast_slices(vec![slice], message)
123118
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
119+
// let shape = Shape::new(
120+
// (0..slice.sizes().len()).map(|i| format!("d{i}")).collect(),
121+
// slice,
122+
// )
123+
// .unwrap();
124+
// println!("SENDING TO {:?} {:?}", &shape, &message);
125+
// let worker_slice = SlicedActorMesh::new(&self.workers, shape);
126+
// worker_slice
127+
// .cast(ndslice::Selection::True, message)
128+
// .map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
124129
}
125130
}
126131

@@ -161,13 +166,17 @@ impl _Controller {
161166
let workers = py_proc_mesh
162167
.spawn(&format!("tensor_engine_workers_{}", id), &param)
163168
.await?;
164-
workers.cast(ndslice::Selection::True, AssignRankMessage::AssignRank())?;
169+
//workers.cast(ndslice::Selection::True, )?;
170+
workers.cast_slices(
171+
vec![py_proc_mesh.shape().slice().clone()],
172+
AssignRankMessage::AssignRank(),
173+
)?;
165174
Ok(workers)
166175
})?;
167176
Ok(Self {
168177
workers: workers?,
169178
controller_instance: Arc::new(Mutex::new(controller_instance)),
170-
pending_messages: Vec::new(),
179+
pending_messages: VecDeque::new(),
171180
history: history::History::new(world_size),
172181
})
173182
}
@@ -218,7 +227,7 @@ impl _Controller {
218227
if self.pending_messages.is_empty() {
219228
self.fill_messages(py, timeout_msec)?;
220229
}
221-
Ok(self.pending_messages.pop())
230+
Ok(self.pending_messages.pop_front())
222231
}
223232

224233
fn _debugger_attach(&mut self, pdb_actor: PyActorId) -> PyResult<()> {
@@ -246,14 +255,14 @@ impl _Controller {
246255
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))?;
247256
Ok(())
248257
}
249-
fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<Vec<PyObject>> {
258+
fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<()> {
259+
self.send_slice(
260+
self.workers.proc_mesh().shape().slice().clone(),
261+
WorkerMessage::Exit { error: None },
262+
)?;
250263
let instance = self.controller_instance.clone();
251-
let result =
252-
signal_safe_block_on(py, async move { instance.lock().await.drain_and_stop() })??;
253-
for r in result {
254-
self.add_message(r)?;
255-
}
256-
Ok(std::mem::take(&mut self.pending_messages))
264+
let _ = signal_safe_block_on(py, async move { instance.lock().await.drain_and_stop() })??;
265+
Ok(())
257266
}
258267
}
259268

monarch_hyperactor/src/shape.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ impl From<Shape> for PyShape {
123123
frozen
124124
)]
125125

126-
struct PyPoint {
126+
pub struct PyPoint {
127127
rank: usize,
128128
shape: Py<PyShape>,
129129
}
130130

131131
#[pymethods]
132132
impl PyPoint {
133133
#[new]
134-
fn new(rank: usize, shape: Py<PyShape>) -> Self {
134+
pub fn new(rank: usize, shape: Py<PyShape>) -> Self {
135135
PyPoint { rank, shape }
136136
}
137137
fn __getitem__(&self, py: Python, label: &str) -> PyResult<usize> {
@@ -150,6 +150,19 @@ impl PyPoint {
150150
)))
151151
}
152152
}
153+
154+
fn size(&self, py: Python<'_>, label: &str) -> PyResult<usize> {
155+
let shape = &self.shape.bind(py).get().inner;
156+
if let Some(index) = shape.labels().iter().position(|l| l == label) {
157+
Ok(shape.slice().sizes()[index])
158+
} else {
159+
Err(PyErr::new::<PyValueError, _>(format!(
160+
"Dimension '{}' not found",
161+
label
162+
)))
163+
}
164+
}
165+
153166
fn __len__(&self, py: Python) -> usize {
154167
self.shape.bind(py).get().__len__()
155168
}

monarch_tensor_worker/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ hyperactor = { version = "0.0.0", path = "../hyperactor" }
1919
hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
2020
hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" }
2121
itertools = "0.14.0"
22+
monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" }
2223
monarch_messages = { version = "0.0.0", path = "../monarch_messages" }
2324
monarch_types = { version = "0.0.0", path = "../monarch_types" }
2425
ndslice = { version = "0.0.0", path = "../ndslice" }

monarch_tensor_worker/src/lib.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ use hyperactor::message::Unbind;
6969
use hyperactor::reference::ActorId;
7070
use hyperactor_mesh::actor_mesh::Cast;
7171
use itertools::Itertools;
72+
use monarch_hyperactor::shape::PyPoint;
73+
use monarch_hyperactor::shape::PyShape;
7274
use monarch_messages::controller::ControllerActor;
7375
use monarch_messages::controller::ControllerMessageClient;
7476
use monarch_messages::controller::Seq;
@@ -89,6 +91,9 @@ use monarch_types::PyTree;
8991
use ndslice::Slice;
9092
use pipe::PipeActor;
9193
use pipe::PipeParams;
94+
use pyo3::Py;
95+
use pyo3::Python;
96+
use pyo3::types::PyAnyMethods;
9297
use serde::Deserialize;
9398
use serde::Serialize;
9499
use sorted_vec::SortedVec;
@@ -253,10 +258,19 @@ impl Actor for WorkerActor {
253258
impl Handler<Cast<AssignRankMessage>> for WorkerActor {
254259
async fn handle(
255260
&mut self,
256-
_this: &Instance<Self>,
261+
this: &Instance<Self>,
257262
message: Cast<AssignRankMessage>,
258263
) -> anyhow::Result<()> {
259264
self.rank = message.rank.0;
265+
Python::with_gil(|py| {
266+
let mesh_controller = py.import_bound("monarch.mesh_controller").unwrap();
267+
let shape: PyShape = message.shape.into();
268+
let shape: Py<PyShape> = Py::new(py, shape).unwrap();
269+
let p: PyPoint = PyPoint::new(message.rank.0, shape);
270+
mesh_controller
271+
.call_method1("_initialize_env", (p, this.proc().proc_id().to_string()))
272+
.unwrap();
273+
});
260274
Ok(())
261275
}
262276
}

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class Point(collections.abc.Mapping):
151151
def __new__(cls, rank: int, shape: "Shape") -> "Point": ...
152152
def __getitem__(self, label: str) -> int: ...
153153
def __len__(self) -> int: ...
154+
def size(self, label: str) -> int: ...
154155
@property
155156
def rank(self) -> int: ...
156157
@property

python/monarch/_testing.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
import tempfile
1111
import time
1212
from contextlib import contextmanager, ExitStack
13-
from typing import Callable, Generator, Optional
13+
from typing import Any, Callable, Dict, Generator, Literal, Optional
1414

1515
import monarch_supervisor
1616
from monarch.common.client import Client
1717
from monarch.common.device_mesh import DeviceMesh
1818
from monarch.common.invocation import DeviceException, RemoteException
1919
from monarch.common.shape import NDSlice
2020
from monarch.controller.backend import ProcessBackend
21+
from monarch.mesh_controller import spawn_tensor_engine
22+
from monarch.proc_mesh import proc_mesh, ProcMesh
2123
from monarch.python_local_mesh import PythonLocalContext
2224
from monarch.rust_local_mesh import (
2325
local_mesh,
@@ -50,6 +52,7 @@ def __init__(self):
5052
self.cleanup = ExitStack()
5153
self._py_process_cache = {}
5254
self._rust_process_cache = None
55+
self._proc_mesh_cache: Dict[Any, ProcMesh] = {}
5356

5457
@contextmanager
5558
def _get_context(self, num_hosts, gpu_per_host):
@@ -75,16 +78,14 @@ def _processes(self, num_hosts, gpu_per_host):
7578

7679
@contextmanager
7780
def local_py_device_mesh(
78-
self, num_hosts, gpu_per_host, activate=True
81+
self,
82+
num_hosts,
83+
gpu_per_host,
7984
) -> Generator[DeviceMesh, None, None]:
8085
ctx, hosts, processes = self._processes(num_hosts, gpu_per_host)
8186
dm = world_mesh(ctx, hosts, gpu_per_host, _processes=processes)
8287
try:
83-
if activate:
84-
with dm.activate():
85-
yield dm
86-
else:
87-
yield dm
88+
yield dm
8889
dm.client.shutdown(destroy_pg=False)
8990
except Exception:
9091
# abnormal exit, so we just make sure we do not try to communicate in destructors,
@@ -97,7 +98,6 @@ def local_rust_device_mesh(
9798
self,
9899
num_hosts,
99100
gpu_per_host,
100-
activate: bool = True,
101101
controller_params=None,
102102
) -> Generator[DeviceMesh, None, None]:
103103
# Create a new system and mesh for test.
@@ -115,11 +115,7 @@ def local_rust_device_mesh(
115115
controller_params=controller_params,
116116
) as dm:
117117
try:
118-
if activate:
119-
with dm.activate():
120-
yield dm
121-
else:
122-
yield dm
118+
yield dm
123119
dm.exit()
124120
except Exception:
125121
dm.client._shutdown = True
@@ -129,21 +125,57 @@ def local_rust_device_mesh(
129125
# pyre-ignore: Undefined attribute
130126
dm.client.inner._actor.stop()
131127

128+
@contextmanager
129+
def local_engine_on_proc_mesh(
130+
self,
131+
num_hosts,
132+
gpu_per_host,
133+
) -> Generator[DeviceMesh, None, None]:
134+
key = (num_hosts, gpu_per_host)
135+
if key not in self._proc_mesh_cache:
136+
self._proc_mesh_cache[key] = proc_mesh(
137+
hosts=num_hosts, gpus=gpu_per_host
138+
).get()
139+
140+
dm = spawn_tensor_engine(self._proc_mesh_cache[key])
141+
dm = dm.rename(hosts="host", gpus="gpu")
142+
try:
143+
yield dm
144+
dm.exit()
145+
except Exception as e:
146+
# abnormal exit, so we just make sure we do not try to communicate in destructors,
147+
# but we do notn wait for workers to exit since we do not know what state they are in.
148+
dm.client._shutdown = True
149+
raise
150+
132151
@contextmanager
133152
def local_device_mesh(
134-
self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None
153+
self,
154+
num_hosts,
155+
gpu_per_host,
156+
activate=True,
157+
backend: Literal["py", "rs", "mesh"] = "py",
158+
controller_params=None,
135159
) -> Generator[DeviceMesh, None, None]:
136160
start = time.time()
137-
if rust:
161+
if backend == "rs":
138162
generator = self.local_rust_device_mesh(
139-
num_hosts, gpu_per_host, activate, controller_params=controller_params
163+
num_hosts, gpu_per_host, controller_params=controller_params
140164
)
165+
elif backend == "py":
166+
generator = self.local_py_device_mesh(num_hosts, gpu_per_host)
167+
elif backend == "mesh":
168+
generator = self.local_engine_on_proc_mesh(num_hosts, gpu_per_host)
141169
else:
142-
generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate)
170+
raise ValueError(f"invalid backend: {backend}")
143171
with generator as dm:
144172
end = time.time()
145173
logging.info("initialized mesh in {:.2f}s".format(end - start))
146-
yield dm
174+
if activate:
175+
with dm.activate():
176+
yield dm
177+
else:
178+
yield dm
147179
start = time.time()
148180
end = time.time()
149181
logging.info("shutdown mesh in {:.2f}s".format(end - start))

python/monarch/common/client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def __init__(
103103
# workers.
104104
self.last_processed_seq = -1
105105

106+
# an error that we have received but know for certain has not
107+
# been propagated to a future. This will be reported on shutdown
108+
# to avoid hiding the error. This is best effort: we only keep
109+
# the error until the point the a future is dependent on
110+
# _any_ error, not particularly the tracked one.
111+
self._pending_shutdown_error = None
112+
106113
self.recorder = Recorder()
107114

108115
self.pending_results: Dict[
@@ -174,6 +181,8 @@ def shutdown(
174181
destroy_pg: bool = True,
175182
error_reason: Optional[RemoteException | DeviceException | Exception] = None,
176183
) -> None:
184+
if self.has_shutdown:
185+
return
177186
logger.info("shutting down the client gracefully")
178187

179188
atexit.unregister(self._atexit)
@@ -303,6 +312,7 @@ def _handle_pending_result(self, output: MessageResult) -> None:
303312

304313
if error is not None:
305314
logging.info("Received error for seq %s: %s", seq, error)
315+
self._pending_shutdown_error = error
306316
# We should not have set result if we have an error.
307317
assert result is None
308318
if not isinstance(error, RemoteException):
@@ -326,7 +336,11 @@ def _handle_pending_result(self, output: MessageResult) -> None:
326336

327337
fut, _ = self.pending_results[seq]
328338
if fut is not None:
329-
fut._set_result(result if error is None else error)
339+
if error is None:
340+
fut._set_result(result)
341+
else:
342+
fut._set_result(error)
343+
self._pending_shutdown_error = None
330344
elif result is not None:
331345
logger.debug(f"{seq}: unused result {result}")
332346
elif error is not None:

0 commit comments

Comments
 (0)