Skip to content

Support overlapped srun commands in Slurm Ray #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ def merge(
)
)

main_executor.env_vars = {}
return main_executor

def __post_init__(self):
Expand Down
63 changes: 63 additions & 0 deletions nemo_run/run/ray/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class SlurmRayRequest:
command: Optional[str] = None
workdir: Optional[str] = None
nemo_run_dir: Optional[str] = None
command_groups: Optional[list[list[str]]] = None
launch_cmd: list[str]

@staticmethod
Expand Down Expand Up @@ -234,6 +235,60 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
"gres_specification": get_gres_specification(),
}

if self.command_groups:
srun_commands: list[str] = []
group_env_vars: list[list[str]] = []

for idx, group in enumerate(self.command_groups):
if idx == 0:
continue

if self.executor.run_as_group and len(self.executor.resource_group) == len(
self.command_groups
):
req = self.executor.resource_group[idx]
env_list = [f"export {k.upper()}={v}" for k, v in req.env_vars.items()]
group_env_vars.append(env_list)
container_flags = get_srun_flags(req.container_mounts, req.container_image)
srun_args = ["--wait=60", "--kill-on-bad-exit=1", "--overlap"]
srun_args.extend(req.srun_args or [])
else:
container_flags = get_srun_flags(
self.executor.container_mounts, self.executor.container_image
)
srun_args = ["--wait=60", "--kill-on-bad-exit=1", "--overlap"]
srun_args.extend(self.executor.srun_args or [])
group_env_vars.append([])

stdout_path = os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.out")
stderr_flags = []
if not self.executor.stderr_to_stdout:
stderr_flags = [
"--error",
os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.err"),
]

srun_cmd = " ".join(
list(
map(
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
[
"srun",
"--output",
noquote(stdout_path),
*stderr_flags,
container_flags,
*srun_args,
],
)
)
)
command = " ".join(group)
srun_commands.append(f"{srun_cmd} {command} &")

vars_to_fill["srun_commands"] = srun_commands
vars_to_fill["group_env_vars"] = group_env_vars

if self.pre_ray_start_commands:
vars_to_fill["pre_ray_start_commands"] = "\n".join(self.pre_ray_start_commands)

Expand Down Expand Up @@ -398,6 +453,7 @@ def create(
dryrun: bool = False,
command: Optional[str] = None,
workdir: Optional[str] = None,
command_groups: Optional[list[list[str]]] = None,
) -> Any:
"""Create (or reuse) a Slurm-backed Ray cluster and return its job-id.

Expand All @@ -416,6 +472,9 @@ def create(
Optional command executed after the Ray head node is ready (e.g. ``ray job submit``).
workdir : str | None
Remote working directory that becomes the CWD inside the container.
command_groups : list[list[str]] | None
Additional commands (one per group) executed via ``srun`` with ``--overlap``
after the cluster is started.

Returns
-------
Expand All @@ -433,6 +492,7 @@ def create(
pre_ray_start_commands=pre_ray_start_commands,
command=command,
workdir=workdir,
command_groups=command_groups,
launch_cmd=["sbatch", "--requeue", "--parsable", "--dependency=singleton"],
).materialize()

Expand Down Expand Up @@ -1094,6 +1154,7 @@ def start(
runtime_env_yaml: Optional[str] | None = None,
pre_ray_start_commands: Optional[list[str]] = None,
dryrun: bool = False,
command_groups: Optional[list[list[str]]] = None,
):
"""Submit a Ray job via Slurm and return a *live* SlurmRayJob helper.

Expand All @@ -1106,6 +1167,7 @@ def start(
executor=my_slurm_executor,
command="python train.py",
workdir="./src",
command_groups=[["echo", "hello"]],
)
"""
# ------------------------------------------------------------------
Expand Down Expand Up @@ -1212,6 +1274,7 @@ def start(
dryrun=dryrun,
command=command,
workdir=remote_workdir,
command_groups=command_groups,
)

self.job_id = job_id
Expand Down
14 changes: 14 additions & 0 deletions nemo_run/run/ray/templates/ray.sub.j2
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,20 @@ echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json

########################################################

{% if srun_commands %}
# Run extra commands
{% for srun_command in srun_commands %}
{%- if loop.index <= group_env_vars|length %}
{%- for env_var in group_env_vars[loop.index - 1] %}
{{env_var}}
{%- endfor %}
{%- endif %}

{{srun_command}}
{% endfor %}
########################################################
{% endif -%}

# We can now launch a job on this cluster
# We do so by launching a driver process on the physical node that the head node is on
# This driver process is responsible for launching a job on the Ray cluster
Expand Down
9 changes: 7 additions & 2 deletions nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,20 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
assert isinstance(executor, SlurmExecutor), (
f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor"
)
assert len(app_def.roles) == 1, "Only one command is supported for Ray jobs."

app_def.metadata = metadata
return app_def


def merge_executables(app_defs: Iterator[specs.AppDef], name: str) -> specs.AppDef:
result = specs.AppDef(name=name, roles=[])
for app_def in app_defs:
result.metadata = {}
for idx, app_def in enumerate(app_defs):
metadata = app_def.metadata or {}
if USE_WITH_RAY_CLUSTER_KEY in metadata:
assert idx == 0, f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for the first command"

result.metadata.update(metadata)
result.roles.extend(app_def.roles)
return result

Expand Down
13 changes: 11 additions & 2 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t

executor.package(packager=executor.packager, job_name=Path(job_dir).name)

values = executor.macro_values()

if app.metadata and app.metadata.get(USE_WITH_RAY_CLUSTER_KEY, False):
assert len(app.roles) == 1, "Only one command is supported for Ray jobs."
srun_cmds: list[list[str]] = []

for role in app.roles:
if values:
role = values.apply(role)
srun_cmd = [role.entrypoint] + role.args
srun_cmds.append([" ".join(srun_cmd)])

command = [app.roles[0].entrypoint] + app.roles[0].args
req = SlurmRayRequest(
name=app.roles[0].name,
Expand All @@ -114,12 +123,12 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
executor=executor,
workdir=f"/{RUNDIR_NAME}/code",
nemo_run_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name),
command_groups=srun_cmds,
)
else:
srun_cmds: list[list[str]] = []
jobs = []
envs = {}
values = executor.macro_values()

if values:
executor.env_vars = {
Expand Down
1 change: 1 addition & 0 deletions test/core/execution/artifacts/group_resource_req_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

export CUSTOM_ENV_1=some_value_1
export ENV_VAR=value


Expand Down
176 changes: 176 additions & 0 deletions test/run/ray/test_slurm_ray_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,179 @@ def test_array_assertion(self):

with pytest.raises(AssertionError, match="array is not supported"):
request.materialize()

def test_command_groups_env_vars(self):
"""Test environment variables are properly set for each command group."""
# Create executor with environment variables
executor = SlurmExecutor(
account="test_account",
env_vars={"GLOBAL_ENV": "global_value"},
)
executor.run_as_group = True

# Create resource groups with different env vars
resource_group = [
SlurmExecutor.ResourceRequest(
packager=Mock(),
nodes=1,
ntasks_per_node=1,
container_image="image1",
env_vars={"GROUP1_ENV": "group1_value"},
container_mounts=["/mount1"],
),
SlurmExecutor.ResourceRequest(
packager=Mock(),
nodes=1,
ntasks_per_node=1,
container_image="image2",
env_vars={"GROUP2_ENV": "group2_value"},
container_mounts=["/mount2"],
),
]
executor.resource_group = resource_group
executor.tunnel = Mock(spec=SSHTunnel)
executor.tunnel.job_dir = "/tmp/test_jobs"

request = SlurmRayRequest(
name="test-ray-cluster",
cluster_dir="/tmp/test_jobs/test-ray-cluster",
template_name="ray.sub.j2",
executor=executor,
command_groups=[["cmd0"], ["cmd1"], ["cmd2"]],
launch_cmd=["sbatch", "--parsable"],
)

script = request.materialize()

# Check global env vars are set in setup section
assert "export GLOBAL_ENV=global_value" in script

# Check that command groups generate srun commands (excluding the first one)
# The template should have a section for srun_commands
assert "# Run extra commands" in script
assert "srun" in script
assert "cmd1" in script # First command group after skipping index 0
assert "cmd2" in script # Second command group

def test_command_groups_without_resource_group(self):
"""Test command groups work without resource groups."""
executor = SlurmExecutor(
account="test_account",
env_vars={"GLOBAL_ENV": "global_value"},
)
executor.tunnel = Mock(spec=SSHTunnel)
executor.tunnel.job_dir = "/tmp/test_jobs"

request = SlurmRayRequest(
name="test-ray-cluster",
cluster_dir="/tmp/test_jobs/test-ray-cluster",
template_name="ray.sub.j2",
executor=executor,
command_groups=[["cmd0"], ["cmd1"]],
launch_cmd=["sbatch", "--parsable"],
)

script = request.materialize()

# Should have global env vars
assert "export GLOBAL_ENV=global_value" in script

# Should have srun commands for overlapping groups (skipping first)
assert "srun" in script
assert "--overlap" in script
assert "cmd1" in script # Second command in the list (index 1)

def test_env_vars_formatting(self):
"""Test that environment variables are properly formatted as export statements."""
executor = SlurmExecutor(
account="test_account",
env_vars={
"VAR_WITH_SPACES": "value with spaces",
"PATH_VAR": "/usr/bin:/usr/local/bin",
"EMPTY_VAR": "",
"NUMBER_VAR": "123",
},
)
executor.tunnel = Mock(spec=SSHTunnel)
executor.tunnel.job_dir = "/tmp/test_jobs"

request = SlurmRayRequest(
name="test-ray-cluster",
cluster_dir="/tmp/test_jobs/test-ray-cluster",
template_name="ray.sub.j2",
executor=executor,
launch_cmd=["sbatch", "--parsable"],
)

script = request.materialize()

# Check all environment variables are properly exported
assert "export VAR_WITH_SPACES=value with spaces" in script
assert "export PATH_VAR=/usr/bin:/usr/local/bin" in script
assert "export EMPTY_VAR=" in script
assert "export NUMBER_VAR=123" in script

def test_group_env_vars_integration(self):
"""Test full integration of group environment variables matching the artifact pattern."""
# This test verifies the behavior seen in group_resource_req_slurm.sh
executor = SlurmExecutor(
account="your_account",
partition="your_partition",
time="00:30:00",
nodes=1,
ntasks_per_node=8,
gpus_per_node=8,
container_image="some-image",
container_mounts=["/some/job/dir/sample_job:/nemo_run"],
env_vars={"ENV_VAR": "value"},
)
executor.run_as_group = True

# Set up resource groups with specific env vars
resource_group = [
# First group (index 0) - for the head/main command
SlurmExecutor.ResourceRequest(
packager=Mock(),
nodes=1,
ntasks_per_node=8,
container_image="some-image",
env_vars={"CUSTOM_ENV_1": "some_value_1"},
container_mounts=["/some/job/dir/sample_job:/nemo_run"],
),
# Second group (index 1)
SlurmExecutor.ResourceRequest(
packager=Mock(),
nodes=1,
ntasks_per_node=8,
container_image="different_container_image",
env_vars={"CUSTOM_ENV_1": "some_value_1"},
container_mounts=["/some/job/dir/sample_job:/nemo_run"],
),
]
executor.resource_group = resource_group

# Mock tunnel
tunnel_mock = Mock(spec=SSHTunnel)
tunnel_mock.job_dir = "/some/job/dir"
executor.tunnel = tunnel_mock

request = SlurmRayRequest(
name="sample_job",
cluster_dir="/some/job/dir/sample_job",
template_name="ray.sub.j2",
executor=executor,
command_groups=[
["bash ./scripts/start_server.sh"],
["bash ./scripts/echo.sh server_host=$het_group_host_0"],
],
launch_cmd=["sbatch", "--parsable"],
)

script = request.materialize()

# Verify the pattern matches the artifact:
# 1. Global env vars should be exported in setup
assert "export ENV_VAR=value" in script

# The template should include group_env_vars for proper env var handling per command
# (The actual env var exports per command happen in the template rendering)
Loading