Skip to content

Commit 4be469b

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(monarch_hyperactor) Create python binding for a RemoteAllocator that takes a list of remote channel addresses (#170)
Summary: Pull Request resolved: #170 To support multi-node actor meshes in OSS without having to write a custom allocator for each scheduler (e.g. `SlurmAllocator`, `KubernetesAllocator`) we take advantage of the infrastructure we already have in TorchX and TorchElastic. This Diff creates Python bindings for `RemoteAllocatorBase` that takes a list of server addresses (in channel_addr format - e.g. `metatls!devgpu032.nha1.facebook.com:26600` or `tcp!devgpu032.nha1.facebook.com:26601`) of remote-process-allocator server and connects to it. The internals reuse existing `RemoteProcessAlloc` with a custom `PyRemoteProcessAllocInitializer` that simply returns a `Vec<RemoteProcessAllocHost>` given the user provided list of server addresses. Recommended to start the review at `monarch‎/python‎/tests‎/test_allocator.py‎` to get a sense of what the API/Usage looks like. The next diff will provide a function that gets the list of server addresses given a job-id (more specifically a monarch server handle of the form `{scheduler}://{namespace}/{job_id}` e.g. `slurm://default/monarch-kiuk-123`) and returns an Allocator that can be used to create a `ProcMesh` as usual. NOTE: WIP fixing type-checking failures so ignore those... Differential Revision: D75928565
1 parent 545467d commit 4be469b

File tree

11 files changed

+472
-14
lines changed

11 files changed

+472
-14
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ jobs:
5454
# Install test dependencies
5555
pip install -r python/tests/requirements.txt
5656
57+
# Install remote process_allocator binary (some tests use it)
58+
cargo install --path monarch_hyperactor
59+
5760
# Build and install monarch
5861
# NB: monarch is currently can't be built in isolated builds (e.g not PEP519 compatible)
5962
# because 'torch-sys' needs to be compiled against 'torch' in the main python environment

monarch_extension/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
9797
module,
9898
"monarch_hyperactor.alloc",
9999
)?)?;
100+
monarch_hyperactor::channel::register_python_bindings(&get_or_add_new_module(
101+
module,
102+
"monarch_hyperactor.channel",
103+
)?)?;
100104
monarch_hyperactor::actor_mesh::register_python_bindings(&get_or_add_new_module(
101105
module,
102106
"monarch_hyperactor.actor_mesh",

monarch_hyperactor/src/alloc.rs

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77
*/
88

99
use std::collections::HashMap;
10+
use std::str::FromStr;
1011
use std::sync::Arc;
12+
use std::time::Duration;
1113

14+
use async_trait::async_trait;
15+
use hyperactor::WorldId;
16+
use hyperactor::channel::ChannelAddr;
17+
use hyperactor::channel::ChannelTransport;
1218
use hyperactor_extension::alloc::PyAlloc;
1319
use hyperactor_extension::alloc::PyAllocSpec;
1420
use hyperactor_mesh::alloc::Allocator;
1521
use hyperactor_mesh::alloc::LocalAllocator;
1622
use hyperactor_mesh::alloc::ProcessAllocator;
23+
use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAlloc;
24+
use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocHost;
25+
use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocInitializer;
1726
use pyo3::exceptions::PyRuntimeError;
1827
use pyo3::prelude::*;
1928
use tokio::process::Command;
@@ -132,9 +141,159 @@ impl PyProcessAllocator {
132141
}
133142
}
134143

144+
struct PyRemoteProcessAllocInitializer {
145+
addrs: Vec<String>,
146+
}
147+
148+
#[async_trait]
149+
impl RemoteProcessAllocInitializer for PyRemoteProcessAllocInitializer {
150+
async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error> {
151+
self.addrs
152+
.iter()
153+
.map(|channel_addr| {
154+
let addr = ChannelAddr::from_str(channel_addr)?;
155+
let remote_host = match addr {
156+
ChannelAddr::Tcp(socket_addr) => RemoteProcessAllocHost {
157+
id: socket_addr.ip().to_string(),
158+
hostname: socket_addr.ip().to_string(),
159+
},
160+
ChannelAddr::MetaTls(hostname, _) => RemoteProcessAllocHost {
161+
id: hostname.clone(),
162+
hostname: hostname.clone(),
163+
},
164+
ChannelAddr::Unix(_) => RemoteProcessAllocHost {
165+
id: addr.to_string(),
166+
hostname: addr.to_string(),
167+
},
168+
_ => {
169+
anyhow::bail!("Unsupported transport for channel address: `{addr:?}`")
170+
}
171+
};
172+
Ok(remote_host)
173+
})
174+
.collect()
175+
}
176+
}
177+
178+
#[pyclass(
179+
name = "RemoteAllocatorBase",
180+
module = "monarch._rust_bindings.monarch_hyperactor.alloc",
181+
subclass
182+
)]
183+
pub struct PyRemoteAllocator {
184+
world_id: String,
185+
addrs: Vec<String>,
186+
heartbeat_interval_millis: u64,
187+
}
188+
189+
const DEFAULT_REMOTE_ALLOCATOR_PORT: u16 = 26600;
190+
const DEFAULT_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_MILLIS: u64 = 5000;
191+
192+
#[pymethods]
193+
impl PyRemoteAllocator {
194+
#[classattr]
195+
const DEFAULT_PORT: u16 = DEFAULT_REMOTE_ALLOCATOR_PORT;
196+
197+
#[classattr]
198+
const DEFAULT_HEARTBEAT_INTERVAL_MILLIS: u64 =
199+
DEFAULT_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_MILLIS;
200+
201+
#[new]
202+
#[pyo3(signature = (
203+
world_id,
204+
addrs,
205+
heartbeat_interval_millis = DEFAULT_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_MILLIS,
206+
))]
207+
fn new(world_id: String, addrs: Vec<String>, heartbeat_interval_millis: u64) -> PyResult<Self> {
208+
Ok(Self {
209+
world_id,
210+
addrs,
211+
heartbeat_interval_millis,
212+
})
213+
}
214+
215+
fn allocate_nonblocking<'py>(
216+
&self,
217+
py: Python<'py>,
218+
spec: &PyAllocSpec,
219+
) -> PyResult<Bound<'py, PyAny>> {
220+
let addrs = self.addrs.clone();
221+
let world_id = self.world_id.clone();
222+
let spec_inner = spec.inner.clone();
223+
let heartbeat_interval_millis = self.heartbeat_interval_millis;
224+
225+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
226+
// all addrs expected to have the same transport; use the first one
227+
let first_addr = addrs.first().expect("addrs should not be empty");
228+
let first_addr = ChannelAddr::from_str(first_addr)?;
229+
let transport = first_addr.transport();
230+
let port = match first_addr {
231+
ChannelAddr::Tcp(socket_addr) => socket_addr.port(),
232+
ChannelAddr::MetaTls(_, port) => port,
233+
ChannelAddr::Unix(_) => 0,
234+
ChannelAddr::Local(_) => 0,
235+
ChannelAddr::Sim(_) => {
236+
return Err(PyRuntimeError::new_err(format!(
237+
"Unsupported channel_addr: {first_addr:?}"
238+
)));
239+
}
240+
};
241+
242+
let alloc = RemoteProcessAlloc::new(
243+
spec_inner,
244+
WorldId(world_id),
245+
transport,
246+
port,
247+
Duration::from_millis(heartbeat_interval_millis),
248+
PyRemoteProcessAllocInitializer { addrs },
249+
)
250+
.await?;
251+
252+
Ok(PyAlloc::new(Box::new(alloc)))
253+
})
254+
}
255+
fn allocate_blocking<'py>(&self, py: Python<'py>, spec: &PyAllocSpec) -> PyResult<PyAlloc> {
256+
let addrs = self.addrs.clone();
257+
let world_id = self.world_id.clone();
258+
let spec_inner = spec.inner.clone();
259+
let heartbeat_interval_millis = self.heartbeat_interval_millis;
260+
261+
signal_safe_block_on(py, async move {
262+
// all addrs expected to have the same transport; use the first one
263+
let first_addr = addrs.first().expect("addrs should not be empty");
264+
let first_addr = ChannelAddr::from_str(first_addr)?;
265+
let transport = first_addr.transport();
266+
let port = match first_addr {
267+
ChannelAddr::Tcp(socket_addr) => socket_addr.port(),
268+
ChannelAddr::MetaTls(_, port) => port,
269+
ChannelAddr::Unix(_) => 0,
270+
ChannelAddr::Local(_) => 0,
271+
ChannelAddr::Sim(_) => {
272+
return Err(PyRuntimeError::new_err(format!(
273+
"Unsupported channel_addr: {first_addr:?}"
274+
)));
275+
}
276+
};
277+
278+
let alloc = RemoteProcessAlloc::new(
279+
spec_inner,
280+
WorldId(world_id),
281+
transport,
282+
port,
283+
Duration::from_millis(heartbeat_interval_millis),
284+
PyRemoteProcessAllocInitializer { addrs },
285+
)
286+
.await?;
287+
288+
Ok(PyAlloc::new(Box::new(alloc)))
289+
})?
290+
}
291+
}
292+
135293
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
136294
hyperactor_mod.add_class::<PyProcessAllocator>()?;
137295
hyperactor_mod.add_class::<PyLocalAllocator>()?;
296+
hyperactor_mod.add_class::<PyRemoteAllocator>()?;
138297

139298
Ok(())
140299
}

monarch_hyperactor/src/bin/process_allocator/common.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,24 @@ use clap::command;
1313
use hyperactor::channel::ChannelAddr;
1414
use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocator;
1515
use tokio::process::Command;
16+
1617
#[derive(Parser, Debug)]
1718
#[command(about = "Runs hyperactor's process allocator")]
1819
pub struct Args {
1920
#[arg(
2021
long,
21-
default_value = "[::]",
22-
help = "The address bind to. The process allocator runs on `bind_addr:port`"
22+
default_value_t = 26600,
23+
help = "The port to bind to on [::] (all network interfaces on this host). Same as specifying `--addr=[::]:{port}`"
2324
)]
24-
pub addr: String,
25+
pub port: u16,
2526

2627
#[arg(
2728
long,
28-
default_value_t = 26600,
29-
help = "Port to bind to. The process allocator runs on `bind_addr:port`"
29+
help = "The address to bind to in the form: \
30+
`{transport}!{address}:{port}` (e.g. `tcp!127.0.0.1:26600`). \
31+
If specified, `--port` argument is ignored"
3032
)]
31-
pub port: u16,
33+
pub addr: Option<String>,
3234

3335
#[arg(
3436
long,
@@ -72,8 +74,8 @@ mod tests {
7274

7375
let parsed_args = Args::parse_from(args);
7476

75-
assert_eq!(parsed_args.addr, "[::]");
7677
assert_eq!(parsed_args.port, 26600);
78+
assert_eq!(parsed_args.addr, None);
7779
assert_eq!(parsed_args.program, "monarch_bootstrap");
7880
Ok(())
7981
}
@@ -82,15 +84,13 @@ mod tests {
8284
async fn test_args() -> Result<(), anyhow::Error> {
8385
let args = vec![
8486
"process_allocator",
85-
"--addr=127.0.0.1",
86-
"--port=29500",
87+
"--addr=tcp!127.0.0.1:29501",
8788
"--program=/bin/echo",
8889
];
8990

9091
let parsed_args = Args::parse_from(args);
9192

92-
assert_eq!(parsed_args.addr, "127.0.0.1");
93-
assert_eq!(parsed_args.port, 29500);
93+
assert_eq!(parsed_args.addr, Some("tcp!127.0.0.1:29501".to_string()));
9494
assert_eq!(parsed_args.program, "/bin/echo");
9595
Ok(())
9696
}

monarch_hyperactor/src/bin/process_allocator/main.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
mod common;
1010

11+
use std::str::FromStr;
12+
1113
use clap::Parser;
1214
use common::Args;
1315
use common::main_impl;
@@ -18,9 +20,11 @@ async fn main() {
1820
let args = Args::parse();
1921
hyperactor::initialize();
2022

21-
let bind = format!("{}:{}", args.addr, args.port);
22-
let socket_addr: std::net::SocketAddr = bind.parse().unwrap();
23-
let serve_address = ChannelAddr::Tcp(socket_addr);
23+
let bind = args
24+
.addr
25+
.unwrap_or_else(|| format!("tcp![::]:{}", args.port));
26+
27+
let serve_address = ChannelAddr::from_str(&bind).unwrap();
2428

2529
let _ = main_impl(serve_address, args.program).await.unwrap();
2630
}

monarch_hyperactor/src/channel.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use hyperactor::channel::ChannelAddr;
10+
use hyperactor::channel::ChannelTransport;
11+
use pyo3::prelude::*;
12+
13+
/// Python binding for [`hyperactor::channel::ChannelTransport`]
14+
#[pyclass(
15+
name = "ChannelTransport",
16+
module = "monarch._rust_bindings.monarch_hyperactor.channel",
17+
eq
18+
)]
19+
#[derive(PartialEq, Clone, Copy)]
20+
pub enum PyChannelTransport {
21+
Tcp,
22+
MetaTls,
23+
Local,
24+
Unix,
25+
// Sim(/*proxy address:*/ ChannelAddr), TODO kiuk@ add support
26+
}
27+
28+
#[pyclass(
29+
name = "ChannelAddr",
30+
module = "monarch._rust_bindings.monarch_hyperactor.channel"
31+
)]
32+
pub struct PyChannelAddr;
33+
34+
#[pymethods]
35+
impl PyChannelAddr {
36+
/// Returns an "any" address for the given transport type.
37+
/// Primarily used to bind servers
38+
#[staticmethod]
39+
fn any(transport: PyChannelTransport) -> PyResult<String> {
40+
Ok(ChannelAddr::any(transport.into()).to_string())
41+
}
42+
}
43+
44+
impl From<PyChannelTransport> for ChannelTransport {
45+
fn from(val: PyChannelTransport) -> Self {
46+
match val {
47+
PyChannelTransport::Tcp => ChannelTransport::Tcp,
48+
PyChannelTransport::MetaTls => ChannelTransport::MetaTls,
49+
PyChannelTransport::Local => ChannelTransport::Local,
50+
PyChannelTransport::Unix => ChannelTransport::Unix,
51+
}
52+
}
53+
}
54+
55+
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
56+
hyperactor_mod.add_class::<PyChannelTransport>()?;
57+
hyperactor_mod.add_class::<PyChannelAddr>()?;
58+
Ok(())
59+
}

monarch_hyperactor/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub mod actor;
1212
pub mod actor_mesh;
1313
pub mod alloc;
1414
pub mod bootstrap;
15+
pub mod channel;
1516
pub mod mailbox;
1617
pub mod ndslice;
1718
pub mod proc;

0 commit comments

Comments
 (0)