Skip to content

(torchx/scheduler) Fill hostnames for each replica in slurm scheduler's describe API #1080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/slurm-local-integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ on:
- main
pull_request:


env:
SLURM_VERSION: 21.08.6
# slurm tag should be one of https://github.com/SchedMD/slurm/tags
SLURM_TAG: slurm-23-11-11-1
SLURM_VERSION: 23.11.11

jobs:
slurm:
Expand Down
169 changes: 128 additions & 41 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tempfile
from dataclasses import dataclass
from datetime import datetime
from subprocess import CalledProcessError, PIPE
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

import torchx
Expand All @@ -39,6 +40,7 @@
macros,
NONE,
ReplicaStatus,
Resource,
Role,
RoleStatus,
runopts,
Expand Down Expand Up @@ -66,6 +68,11 @@
"TIMEOUT": AppState.FAILED,
}


def appstate_from_slurm_state(slurm_state: str) -> AppState:
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)


SBATCH_JOB_OPTIONS = {
"comment",
"mail-user",
Expand Down Expand Up @@ -482,16 +489,36 @@ def _cancel_existing(self, app_id: str) -> None:
subprocess.run(["scancel", app_id], check=True)

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
# NOTE: depending on the version of slurm, querying for job info
# with `squeue` for finished (or non-existent) jobs either:
# 1. errors out with 'slurm_load_jobs error: Invalid job id specified'
# 2. -- or -- squeue returns an empty jobs list
# in either case, fall back to the less descriptive but more persistent sacct
# (slurm cluster must have accounting storage enabled for sacct to work)
try:
return self._describe_sacct(app_id)
except subprocess.CalledProcessError:
return self._describe_squeue(app_id)
if desc := self._describe_squeue(app_id):
return desc
except CalledProcessError as e:
log.info(
f"unable to get job info for `{app_id}` with `squeue` ({e.stderr}), trying `sacct`"
)
return self._describe_sacct(app_id)

def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
p = subprocess.run(
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
)
output = p.stdout.decode("utf-8").split("\n")
try:
output = subprocess.check_output(
["sacct", "--parsable2", "-j", app_id],
stderr=PIPE,
encoding="utf-8",
).split("\n")
except CalledProcessError as e:
log.info(
"unable to get job info for `{}` with `sacct` ({})".format(
app_id, e.stderr
)
)
return None

if len(output) <= 1:
return None

Expand All @@ -511,11 +538,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:

state = row["State"]
msg = state
state_enum = SLURM_STATES.get(state)
assert (
state_enum
), f"failed to translate slurm state {state} to torchx state"
app_state = state_enum
app_state = appstate_from_slurm_state(state)

role, _, replica_id = row["JobName"].rpartition("-")
if not replica_id or not role:
Expand All @@ -541,45 +564,109 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
)

def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
p = subprocess.run(
["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
output = subprocess.check_output(
["squeue", "--json", "-j", app_id], stderr=PIPE, encoding="utf-8"
)
output_json = json.loads(p.stdout.decode("utf-8"))
output_json = json.loads(output)
jobs = output_json["jobs"]
if not jobs:
return None

roles = {}
roles_statuses = {}
msg = ""
app_state = AppState.UNKNOWN
for job in output_json["jobs"]:
state = job["job_state"][0]
msg = state
state_enum = SLURM_STATES.get(state)
assert (
state_enum
), f"failed to translate slurm state {state} to torchx state"
app_state = state_enum
roles: dict[str, Role] = {}
roles_statuses: dict[str, RoleStatus] = {}
state = AppState.UNKNOWN

role, _, replica_id = job["name"].rpartition("-")
if not replica_id or not role:
# name should always have at least 3 parts but sometimes sacct
# is slow to update
continue
if role not in roles:
roles[role] = Role(name=role, num_replicas=0, image="")
roles_statuses[role] = RoleStatus(role, [])
roles[role].num_replicas += 1
roles_statuses[role].replicas.append(
ReplicaStatus(
id=int(replica_id), role=role, state=app_state, hostname=""
for job in jobs:
# job name is of the form "{role_name}-{replica_id}"
role_name, _, replica_id = job["name"].rpartition("-")

entrypoint = job["command"]
image = job["current_working_directory"]
state = appstate_from_slurm_state(job["job_state"][0])

job_resources = job["job_resources"]

role = roles.setdefault(
role_name,
Role(
name=role_name,
image=image,
entrypoint=entrypoint,
num_replicas=0,
),
)
role_status = roles_statuses.setdefault(
role_name,
RoleStatus(role_name, replicas=[]),
)

if state == AppState.PENDING:
# NOTE: torchx launched jobs points to exactly one host
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
hostname = job_resources.get("scheduled_nodes", "")

role.num_replicas += 1
role_status.replicas.append(
ReplicaStatus(
id=int(replica_id),
role=role_name,
state=state,
hostname=hostname,
)
)
else: # state == AppState.RUNNING
# NOTE: torchx schedules on slurm with sbatch + heterogenous job
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
# but we deal with jobs that have not been launched with torchx
# which can have multiple hosts per sub-job (count them as replicas)
node_infos = job_resources.get("allocated_nodes", [])

if not isinstance(node_infos, list):
# NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
# is not a list of individual nodes, but a map of the nodelist specs
# in this case just use jobs[].job_resources.nodes
hostname = job_resources.get("nodes")
role.num_replicas += 1
role_status.replicas.append(
ReplicaStatus(
id=int(replica_id),
role=role_name,
state=state,
hostname=hostname,
)
)
else:
for node_info in node_infos:
# NOTE: we expect resource specs for all the nodes to be the same
# NOTE: use allocated (not used/requested) memory since
# users may only specify --cpu, in which case slurm
# uses the (system) configured {mem-per-cpu} * {cpus}
# to allocate memory.
# NOTE: getting gpus is tricky because it modeled as a trackable-resource
# or not configured at all (use total-cpu-on-host as proxy for gpus)
cpu = int(node_info["cpus_used"])
memMB = int(node_info["memory_allocated"])

hostname = node_info["nodename"]

role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
role.num_replicas += 1
role_status.replicas.append(
ReplicaStatus(
id=int(replica_id),
role=role_name,
state=state,
hostname=hostname,
)
)

return DescribeAppResponse(
app_id=app_id,
roles=list(roles.values()),
roles_statuses=list(roles_statuses.values()),
state=app_state,
msg=msg,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg isn't needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea msg defaults to an empty string if not specified. We were just setting msg=state so no real functional value added + describe_sacct didn't set msg

state=state,
)

def log_iter(
Expand Down
Loading
Loading