Skip to content

Commit 293e5c6

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(3/n torchx-allocator)(monarch/tools) add commands.server_ready function and hostnames to mesh_spec (#296)
Summary: TorchX's `status` API returns a struct that has `replica.hostname` field. However it is not always filled for all schedulers. pytorch/torchx#1080 makes it such that the slurm scheduler in TorchX fills out the hostname information. This PR adds a `hostnames` field to `monarch.tools.mesh_sepc.MeshSpec` and fills it up with the hostnames returned by TorchX. This information will be used in PR (5/n) to implement a `TorchXAllocator` Reviewed By: suo Differential Revision: D76847192
1 parent bab6a91 commit 293e5c6

File tree

5 files changed

+164
-8
lines changed

5 files changed

+164
-8
lines changed

python/monarch/tools/commands.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import argparse
1010
import functools
1111
import inspect
12+
import logging
1213
import os
14+
import time
15+
from datetime import timedelta
1316
from typing import Any, Callable, Mapping, Optional, Union
1417

1518
from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
@@ -18,12 +21,13 @@
1821
)
1922

2023
from monarch.tools.mesh_spec import mesh_spec_from_metadata, ServerSpec
21-
2224
from torchx.runner import Runner
23-
from torchx.specs import AppDef, AppDryRunInfo, CfgVal
25+
from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal
2426
from torchx.specs.builders import parse_args
2527
from torchx.util.types import decode, decode_optional
2628

29+
logger: logging.Logger = logging.getLogger(__name__)
30+
2731

2832
def torchx_runner() -> Runner:
2933
# namespace is currently unused so make it empty str
@@ -165,15 +169,73 @@ def info(server_handle: str) -> Optional[ServerSpec]:
165169
if appdef is None:
166170
return None
167171

172+
# host status grouped by mesh (role) names
173+
replica_status = {r.role: r.replicas for r in status.roles}
174+
168175
mesh_specs = []
169176
for role in appdef.roles:
170177
spec = mesh_spec_from_metadata(appdef, role.name)
171178
assert spec is not None, "cannot be 'None' since we iterate over appdef's roles"
179+
180+
# null-guard since some schedulers do not fill replica_status
181+
if host_status := replica_status.get(role.name):
182+
spec.hostnames = [h.hostname for h in host_status]
183+
172184
mesh_specs.append(spec)
173185

174186
return ServerSpec(name=appdef.name, state=status.state, meshes=mesh_specs)
175187

176188

189+
_5_SECONDS = timedelta(seconds=5)
190+
191+
192+
async def server_ready(
193+
server_handle: str, check_interval: timedelta = _5_SECONDS
194+
) -> Optional[ServerSpec]:
195+
"""Waits until the server's job is in RUNNING state to returns the server spec.
196+
Returns `None` if the server does not exist.
197+
198+
NOTE: Certain fields such as `hostnames` is only filled (and valid) when the server is RUNNING.
199+
200+
Usage:
201+
202+
.. code-block:: python
203+
204+
server_info = await server_ready("slurm:///123")
205+
if not server_info:
206+
print(f"Job does not exist")
207+
else:
208+
if server_info.is_running:
209+
for mesh in server_info.meshes:
210+
connect_to(mesh.hostnames)
211+
else:
212+
print(f"Job in {server_info.state} state. Hostnames are not available")
213+
214+
"""
215+
216+
while True:
217+
server_spec = info(server_handle)
218+
219+
if not server_spec: # server not found
220+
return None
221+
222+
if server_spec.state <= AppState.PENDING: # UNSUBMITTED or SUBMITTED or PENDING
223+
# NOTE: TorchX currently does not have async APIs so need to loop-on-interval
224+
# TODO maybe inverse exponential backoff instead of constant interval?
225+
check_interval_seconds = check_interval.total_seconds()
226+
logger.info(
227+
"waiting for %s to be %s (current: %s), will check again in %g seconds...",
228+
server_handle,
229+
AppState.RUNNING,
230+
server_spec.state,
231+
check_interval_seconds,
232+
)
233+
time.sleep(check_interval_seconds)
234+
continue
235+
else:
236+
return server_spec
237+
238+
177239
def kill(server_handle: str) -> None:
178240
with torchx_runner() as runner:
179241
runner.cancel(server_handle)

python/monarch/tools/mesh_spec.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88
import string
9-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
1010
from typing import Any, Optional
1111

1212
from torchx import specs
@@ -29,6 +29,7 @@ class MeshSpec:
2929
host_type: str
3030
gpus: int
3131
port: int = DEFAULT_REMOTE_ALLOCATOR_PORT
32+
hostnames: list[str] = field(default_factory=list)
3233

3334

3435
def _tag(mesh_name: str, tag_template: str) -> str:
@@ -84,6 +85,10 @@ class ServerSpec:
8485
state: specs.AppState
8586
meshes: list[MeshSpec]
8687

88+
@property
89+
def is_running(self) -> bool:
90+
return self.state == specs.AppState.RUNNING
91+
8792
def get_mesh_spec(self, mesh_name: str) -> MeshSpec:
8893
for mesh_spec in self.meshes:
8994
if mesh_spec.name == mesh_name:
@@ -115,6 +120,7 @@ def to_json(self) -> dict[str, Any]:
115120
"host_type": mesh.host_type,
116121
"hosts": mesh.num_hosts,
117122
"gpus": mesh.gpus,
123+
"hostnames": mesh.hostnames,
118124
}
119125
for mesh in self.meshes
120126
},

python/tests/tools/test_cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,14 @@ def test_info(self, mock_cmd_info: mock.MagicMock) -> None:
6868
"trainer": {
6969
"host_type": "gpu.medium",
7070
"hosts": 4,
71-
"gpus": 2
71+
"gpus": 2,
72+
"hostnames": []
7273
},
7374
"generator": {
7475
"host_type": "gpu.small",
7576
"hosts": 16,
76-
"gpus": 1
77+
"gpus": 1,
78+
"hostnames": []
7779
}
7880
}
7981
}

python/tests/tools/test_commands.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
# pyre-strict
88

99
import unittest
10+
from datetime import timedelta
1011
from unittest import mock
1112

1213
from monarch.tools import commands
13-
from monarch.tools.commands import component_args_from_cli
14+
from monarch.tools.commands import component_args_from_cli, server_ready
1415

1516
from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
1617
defaults,
@@ -101,3 +102,78 @@ def test_info(
101102
),
102103
commands.info("slurm:///job-id"),
103104
)
105+
106+
107+
UNUSED = "__UNUSED__"
108+
_5_MS = timedelta(milliseconds=5)
109+
110+
111+
def server(state: AppState) -> ServerSpec:
112+
mesh_x = MeshSpec(name="x", num_hosts=2, host_type=UNUSED, gpus=-1)
113+
mesh_y = MeshSpec(name="y", num_hosts=4, host_type=UNUSED, gpus=-1)
114+
meshes = [mesh_x, mesh_y]
115+
116+
if state == AppState.RUNNING:
117+
for mesh in meshes:
118+
mesh.hostnames = [f"node{i}" for i in range(mesh.num_hosts)]
119+
120+
return ServerSpec(name=UNUSED, state=state, meshes=meshes)
121+
122+
123+
class TestCommandsAsync(unittest.IsolatedAsyncioTestCase):
124+
async def test_server_ready_server_does_not_exist(self) -> None:
125+
with mock.patch(
126+
"monarch.tools.commands.info",
127+
return_value=None,
128+
):
129+
server_info = await server_ready("slurm:///123", check_interval=_5_MS)
130+
self.assertIsNone(server_info)
131+
132+
async def test_server_ready_pending_to_running(self) -> None:
133+
with mock.patch(
134+
"monarch.tools.commands.info",
135+
side_effect=[
136+
server(AppState.UNSUBMITTED),
137+
server(AppState.SUBMITTED),
138+
server(AppState.PENDING),
139+
server(AppState.PENDING),
140+
server(AppState.RUNNING),
141+
server(AppState.CANCELLED),
142+
],
143+
) as mock_info:
144+
server_info = await server_ready("slurm:///123", check_interval=_5_MS)
145+
146+
self.assertIsNotNone(server_info)
147+
self.assertTrue(server_info.is_running)
148+
self.assertEqual(server_info.state, AppState.RUNNING)
149+
150+
mesh_x = server_info.get_mesh_spec("x")
151+
mesh_y = server_info.get_mesh_spec("y")
152+
self.assertListEqual(mesh_x.hostnames, ["node0", "node1"])
153+
self.assertListEqual(mesh_y.hostnames, ["node0", "node1", "node2", "node3"])
154+
155+
mock_info.assert_called()
156+
# called 5 times, once for UNSUBMITTED, SUBMITTED, PENDING, PENDING, and RUNNING
157+
self.assertEqual(mock_info.call_count, 5)
158+
159+
async def test_server_ready_pending_to_terminal(self) -> None:
160+
for terminal_state in [AppState.SUCCEEDED, AppState.FAILED, AppState.CANCELLED]:
161+
with self.subTest(terminal_state=terminal_state):
162+
with mock.patch(
163+
"monarch.tools.commands.info",
164+
side_effect=[
165+
server(AppState.SUBMITTED),
166+
server(AppState.PENDING),
167+
server(AppState.PENDING),
168+
server(terminal_state),
169+
],
170+
) as mock_info:
171+
server_info = await server_ready(
172+
"slurm:///123",
173+
check_interval=_5_MS,
174+
)
175+
176+
self.assertIsNotNone(server_info)
177+
self.assertEqual(server_info.state, terminal_state)
178+
mock_info.assert_called()
179+
self.assertEqual(mock_info.call_count, 4)

python/tests/tools/test_mesh_spec.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,25 @@ def test_mesh_spec_from_metadata(self) -> None:
8282

8383
def test_mesh_spec_can_dump_as_json(self) -> None:
8484
mesh_spec = MeshSpec(
85-
name="trainer", num_hosts=4, host_type="gpu.medium", gpus=2
85+
name="trainer",
86+
num_hosts=4,
87+
host_type="gpu.medium",
88+
gpus=2,
89+
hostnames=["n0", "n1", "n2", "n3"],
8690
)
8791
expected = """
8892
{
8993
"name": "trainer",
9094
"num_hosts": 4,
9195
"host_type": "gpu.medium",
9296
"gpus": 2,
93-
"port": 26600
97+
"port": 26600,
98+
"hostnames": [
99+
"n0",
100+
"n1",
101+
"n2",
102+
"n3"
103+
]
94104
}
95105
"""
96106
self.assertEqual(expected.strip("\n"), json.dumps(asdict(mesh_spec), indent=2))

0 commit comments

Comments
 (0)