Skip to content

Commit 1d1c2a6

Browse files
committed
feat: add rdzv_conf to dist.ddp (#1071)
1 parent 24dc0d5 commit 1d1c2a6

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

torchx/components/dist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ def spmd(
132132
j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host
133133
env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
134134
max_retries: the number of scheduler retries allowed
135-
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
136-
Only takes effect when running multi-node. When running single node, this parameter
137-
is ignored and a random free port is chosen.
138135
mounts: (for docker based runs only) mounts to mount into the worker environment/container
139136
(ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
140137
debug: whether to run with preset debug flags enabled
@@ -174,6 +171,7 @@ def ddp(
174171
max_retries: int = 0,
175172
rdzv_port: int = 29500,
176173
rdzv_backend: str = "c10d",
174+
rdzv_conf: Optional[str] = None,
177175
mounts: Optional[List[str]] = None,
178176
debug: bool = False,
179177
tee: int = 3,
@@ -208,6 +206,7 @@ def ddp(
208206
Only takes effect when running multi-node. When running single node, this parameter
209207
is ignored and a random free port is chosen.
210208
rdzv_backend: the rendezvous backend to use. Only takes effect when running multi-node.
209+
rdzv_conf: the additional rendezvous configuration to use (ex. join_timeout=600,close_timeout=600,timeout=600).
211210
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
212211
See scheduler documentation for more info.
213212
debug: whether to run with preset debug flags enabled
@@ -258,6 +257,7 @@ def ddp(
258257
"torchrun",
259258
"--rdzv_backend",
260259
rdzv_backend,
260+
*(["--rdzv_conf", rdzv_conf] if rdzv_conf is not None else []),
261261
"--rdzv_endpoint",
262262
rdzv_endpoint,
263263
"--rdzv_id",

torchx/components/test/dist_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def test_ddp_debug(self) -> None:
4141
self.assertEqual(env[k], v)
4242

4343
def test_ddp_rdzv_backend_static(self) -> None:
44-
app = ddp(script="foo.py", rdzv_backend="static")
44+
rdzv_conf = "join_timeout=600,close_timeout=600,timeout=600"
45+
app = ddp(script="foo.py", rdzv_backend="static", rdzv_conf=rdzv_conf)
4546
cmd = app.roles[0].args[1]
47+
self.assertTrue(f"--rdzv_conf {rdzv_conf}" in cmd)
4648
self.assertTrue("--rdzv_backend static" in cmd)
4749
self.assertTrue("--node_rank" in cmd)
4850

0 commit comments

Comments
 (0)