Skip to content

Commit 34d6b83

Browse files
authored
(torchx/scheduler) Fill hostnames for each replica in slurm scheduler's describe API
Differential Revision: D76485112 Pull Request resolved: #1080
1 parent 50b8c02 commit 34d6b83

File tree

4 files changed

+1827
-105
lines changed

4 files changed

+1827
-105
lines changed

.github/workflows/slurm-local-integration-tests.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ on:
66
- main
77
pull_request:
88

9+
910
env:
10-
SLURM_VERSION: 21.08.6
11+
# slurm tag should be one of https://github.com/SchedMD/slurm/tags
12+
SLURM_TAG: slurm-23-11-11-1
13+
SLURM_VERSION: 23.11.11
1114

1215
jobs:
1316
slurm:

torchx/schedulers/slurm_scheduler.py

Lines changed: 128 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
from dataclasses import dataclass
2222
from datetime import datetime
23+
from subprocess import CalledProcessError, PIPE
2324
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
2425

2526
import torchx
@@ -39,6 +40,7 @@
3940
macros,
4041
NONE,
4142
ReplicaStatus,
43+
Resource,
4244
Role,
4345
RoleStatus,
4446
runopts,
@@ -66,6 +68,11 @@
6668
"TIMEOUT": AppState.FAILED,
6769
}
6870

71+
72+
def appstate_from_slurm_state(slurm_state: str) -> AppState:
73+
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74+
75+
6976
SBATCH_JOB_OPTIONS = {
7077
"comment",
7178
"mail-user",
@@ -482,16 +489,36 @@ def _cancel_existing(self, app_id: str) -> None:
482489
subprocess.run(["scancel", app_id], check=True)
483490

484491
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
492+
# NOTE: depending on the version of slurm, querying for job info
493+
# with `squeue` for finished (or non-existent) jobs either:
494+
# 1. errors out with 'slurm_load_jobs error: Invalid job id specified'
495+
# 2. -- or -- squeue returns an empty jobs list
496+
# in either case, fall back to the less descriptive but more persistent sacct
497+
# (slurm cluster must have accounting storage enabled for sacct to work)
485498
try:
486-
return self._describe_sacct(app_id)
487-
except subprocess.CalledProcessError:
488-
return self._describe_squeue(app_id)
499+
if desc := self._describe_squeue(app_id):
500+
return desc
501+
except CalledProcessError as e:
502+
log.info(
503+
f"unable to get job info for `{app_id}` with `squeue` ({e.stderr}), trying `sacct`"
504+
)
505+
return self._describe_sacct(app_id)
489506

490507
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
491-
p = subprocess.run(
492-
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
493-
)
494-
output = p.stdout.decode("utf-8").split("\n")
508+
try:
509+
output = subprocess.check_output(
510+
["sacct", "--parsable2", "-j", app_id],
511+
stderr=PIPE,
512+
encoding="utf-8",
513+
).split("\n")
514+
except CalledProcessError as e:
515+
log.info(
516+
"unable to get job info for `{}` with `sacct` ({})".format(
517+
app_id, e.stderr
518+
)
519+
)
520+
return None
521+
495522
if len(output) <= 1:
496523
return None
497524

@@ -511,11 +538,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511538

512539
state = row["State"]
513540
msg = state
514-
state_enum = SLURM_STATES.get(state)
515-
assert (
516-
state_enum
517-
), f"failed to translate slurm state {state} to torchx state"
518-
app_state = state_enum
541+
app_state = appstate_from_slurm_state(state)
519542

520543
role, _, replica_id = row["JobName"].rpartition("-")
521544
if not replica_id or not role:
@@ -541,45 +564,109 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
541564
)
542565

543566
def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
544-
p = subprocess.run(
545-
["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True
567+
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
568+
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
569+
output = subprocess.check_output(
570+
["squeue", "--json", "-j", app_id], stderr=PIPE, encoding="utf-8"
546571
)
547-
output_json = json.loads(p.stdout.decode("utf-8"))
572+
output_json = json.loads(output)
573+
jobs = output_json["jobs"]
574+
if not jobs:
575+
return None
548576

549-
roles = {}
550-
roles_statuses = {}
551-
msg = ""
552-
app_state = AppState.UNKNOWN
553-
for job in output_json["jobs"]:
554-
state = job["job_state"][0]
555-
msg = state
556-
state_enum = SLURM_STATES.get(state)
557-
assert (
558-
state_enum
559-
), f"failed to translate slurm state {state} to torchx state"
560-
app_state = state_enum
577+
roles: dict[str, Role] = {}
578+
roles_statuses: dict[str, RoleStatus] = {}
579+
state = AppState.UNKNOWN
561580

562-
role, _, replica_id = job["name"].rpartition("-")
563-
if not replica_id or not role:
564-
# name should always have at least 3 parts but sometimes sacct
565-
# is slow to update
566-
continue
567-
if role not in roles:
568-
roles[role] = Role(name=role, num_replicas=0, image="")
569-
roles_statuses[role] = RoleStatus(role, [])
570-
roles[role].num_replicas += 1
571-
roles_statuses[role].replicas.append(
572-
ReplicaStatus(
573-
id=int(replica_id), role=role, state=app_state, hostname=""
581+
for job in jobs:
582+
# job name is of the form "{role_name}-{replica_id}"
583+
role_name, _, replica_id = job["name"].rpartition("-")
584+
585+
entrypoint = job["command"]
586+
image = job["current_working_directory"]
587+
state = appstate_from_slurm_state(job["job_state"][0])
588+
589+
job_resources = job["job_resources"]
590+
591+
role = roles.setdefault(
592+
role_name,
593+
Role(
594+
name=role_name,
595+
image=image,
596+
entrypoint=entrypoint,
597+
num_replicas=0,
574598
),
575599
)
600+
role_status = roles_statuses.setdefault(
601+
role_name,
602+
RoleStatus(role_name, replicas=[]),
603+
)
604+
605+
if state == AppState.PENDING:
606+
# NOTE: torchx launched jobs points to exactly one host
607+
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
608+
hostname = job_resources.get("scheduled_nodes", "")
609+
610+
role.num_replicas += 1
611+
role_status.replicas.append(
612+
ReplicaStatus(
613+
id=int(replica_id),
614+
role=role_name,
615+
state=state,
616+
hostname=hostname,
617+
)
618+
)
619+
else: # state == AppState.RUNNING
620+
# NOTE: torchx schedules on slurm with sbatch + heterogenous job
621+
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
622+
# but we deal with jobs that have not been launched with torchx
623+
# which can have multiple hosts per sub-job (count them as replicas)
624+
node_infos = job_resources.get("allocated_nodes", [])
625+
626+
if not isinstance(node_infos, list):
627+
# NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
628+
# is not a list of individual nodes, but a map of the nodelist specs
629+
# in this case just use jobs[].job_resources.nodes
630+
hostname = job_resources.get("nodes")
631+
role.num_replicas += 1
632+
role_status.replicas.append(
633+
ReplicaStatus(
634+
id=int(replica_id),
635+
role=role_name,
636+
state=state,
637+
hostname=hostname,
638+
)
639+
)
640+
else:
641+
for node_info in node_infos:
642+
# NOTE: we expect resource specs for all the nodes to be the same
643+
# NOTE: use allocated (not used/requested) memory since
644+
# users may only specify --cpu, in which case slurm
645+
# uses the (system) configured {mem-per-cpu} * {cpus}
646+
# to allocate memory.
647+
# NOTE: getting gpus is tricky because it modeled as a trackable-resource
648+
# or not configured at all (use total-cpu-on-host as proxy for gpus)
649+
cpu = int(node_info["cpus_used"])
650+
memMB = int(node_info["memory_allocated"])
651+
652+
hostname = node_info["nodename"]
653+
654+
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
655+
role.num_replicas += 1
656+
role_status.replicas.append(
657+
ReplicaStatus(
658+
id=int(replica_id),
659+
role=role_name,
660+
state=state,
661+
hostname=hostname,
662+
)
663+
)
576664

577665
return DescribeAppResponse(
578666
app_id=app_id,
579667
roles=list(roles.values()),
580668
roles_statuses=list(roles_statuses.values()),
581-
state=app_state,
582-
msg=msg,
669+
state=state,
583670
)
584671

585672
def log_iter(

0 commit comments

Comments
 (0)