Skip to content

Commit 256b6fe

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/scheduler) Fill hostnames for each replica in slurm scheduler's describe API (#1080)
Summary: Pull Request resolved: #1080 Additionally fill hostname, resource (cpu, memMB), image, entrypoint in `describe_squeue` for each role/replica. Reviewed By: d4l3k Differential Revision: D76485112
1 parent 50b8c02 commit 256b6fe

File tree

4 files changed

+1808
-108
lines changed

4 files changed

+1808
-108
lines changed

.github/workflows/slurm-local-integration-tests.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ on:
66
- main
77
pull_request:
88

9+
910
env:
10-
SLURM_VERSION: 21.08.6
11+
# slurm tag should be one of https://github.com/SchedMD/slurm/tags
12+
SLURM_TAG: slurm-23-11-11-1
13+
SLURM_VERSION: 23.11.11
1114

1215
jobs:
1316
slurm:
@@ -27,8 +30,7 @@ jobs:
2730
run: |
2831
set -ex
2932
30-
# TODO: switch to trunk once https://github.com/giovtorres/slurm-docker-cluster/pull/29 lands
31-
git clone https://github.com/d4l3k/slurm-docker-cluster.git
33+
git clone https://github.com/giovtorres/slurm-docker-cluster.git
3234
- name: Pull docker containers
3335
run: |
3436
set -ex
@@ -43,7 +45,7 @@ jobs:
4345
run: |
4446
set -ex
4547
cd slurm-docker-cluster
46-
docker build -t slurm-docker-cluster:$SLURM_VERSION .
48+
docker build --build-arg SLURM_TAG=$SLURM_TAG -t slurm-docker-cluster:$SLURM_VERSION .
4749
- name: Start slurm
4850
run: |
4951
set -ex

torchx/schedulers/slurm_scheduler.py

Lines changed: 109 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
from dataclasses import dataclass
2222
from datetime import datetime
23+
from subprocess import CalledProcessError, PIPE
2324
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
2425

2526
import torchx
@@ -39,6 +40,7 @@
3940
macros,
4041
NONE,
4142
ReplicaStatus,
43+
Resource,
4244
Role,
4345
RoleStatus,
4446
runopts,
@@ -66,6 +68,11 @@
6668
"TIMEOUT": AppState.FAILED,
6769
}
6870

71+
72+
def appstate_from_slurm_state(slurm_state: str) -> AppState:
73+
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74+
75+
6976
SBATCH_JOB_OPTIONS = {
7077
"comment",
7178
"mail-user",
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483490

484491
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
485492
try:
486-
return self._describe_sacct(app_id)
487-
except subprocess.CalledProcessError:
488493
return self._describe_squeue(app_id)
494+
except CalledProcessError as e:
495+
# NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
496+
# if the job does not exist or has finished (e.g. not in PENDING or RUNNING states)
497+
# in this case, fall back to the less descriptive but more persistent sacct
498+
# (slurm cluster must have accounting storage enabled for sacct to work)
499+
log.info(
500+
"unable to get job info for `{}` with `squeue` ({}), trying `sacct`".format(
501+
app_id, e.stderr
502+
)
503+
)
504+
return self._describe_sacct(app_id)
489505

490506
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
491-
p = subprocess.run(
492-
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
493-
)
494-
output = p.stdout.decode("utf-8").split("\n")
507+
try:
508+
output = subprocess.check_output(
509+
["sacct", "--parsable2", "-j", app_id],
510+
stderr=PIPE,
511+
encoding="utf-8",
512+
).split("\n")
513+
except CalledProcessError as e:
514+
log.info(
515+
"unable to get job info for `{}` with `sacct` ({})".format(
516+
app_id, e.stderr
517+
)
518+
)
519+
return None
520+
495521
if len(output) <= 1:
496522
return None
497523

@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511537

512538
state = row["State"]
513539
msg = state
514-
state_enum = SLURM_STATES.get(state)
515-
assert (
516-
state_enum
517-
), f"failed to translate slurm state {state} to torchx state"
518-
app_state = state_enum
540+
app_state = appstate_from_slurm_state(state)
519541

520542
role, _, replica_id = row["JobName"].rpartition("-")
521543
if not replica_id or not role:
@@ -540,46 +562,92 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540562
msg=msg,
541563
)
542564

543-
def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
544-
p = subprocess.run(
545-
["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True
565+
def _describe_squeue(self, app_id: str) -> DescribeAppResponse:
566+
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567+
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568+
output = subprocess.check_output(
569+
["squeue", "--json", "-j", app_id], stderr=PIPE, encoding="utf-8"
546570
)
547-
output_json = json.loads(p.stdout.decode("utf-8"))
548571

549-
roles = {}
550-
roles_statuses = {}
551-
msg = ""
552-
app_state = AppState.UNKNOWN
553-
for job in output_json["jobs"]:
554-
state = job["job_state"][0]
555-
msg = state
556-
state_enum = SLURM_STATES.get(state)
557-
assert (
558-
state_enum
559-
), f"failed to translate slurm state {state} to torchx state"
560-
app_state = state_enum
572+
output_json = json.loads(output)
573+
jobs = output_json["jobs"]
561574

562-
role, _, replica_id = job["name"].rpartition("-")
563-
if not replica_id or not role:
564-
# name should always have at least 3 parts but sometimes sacct
565-
# is slow to update
566-
continue
567-
if role not in roles:
568-
roles[role] = Role(name=role, num_replicas=0, image="")
569-
roles_statuses[role] = RoleStatus(role, [])
570-
roles[role].num_replicas += 1
571-
roles_statuses[role].replicas.append(
572-
ReplicaStatus(
573-
id=int(replica_id), role=role, state=app_state, hostname=""
575+
roles: dict[str, Role] = {}
576+
roles_statuses: dict[str, RoleStatus] = {}
577+
state = AppState.UNKNOWN
578+
579+
for job in jobs:
580+
# job name is of the form "{role_name}-{replica_id}"
581+
role_name, _, replica_id = job["name"].rpartition("-")
582+
583+
entrypoint = job["command"]
584+
image = job["current_working_directory"]
585+
state = appstate_from_slurm_state(job["job_state"][0])
586+
587+
job_resources = job["job_resources"]
588+
589+
role = roles.setdefault(
590+
role_name,
591+
Role(
592+
name=role_name,
593+
image=image,
594+
entrypoint=entrypoint,
595+
num_replicas=0,
574596
),
575597
)
598+
role_status = roles_statuses.setdefault(
599+
role_name,
600+
RoleStatus(role_name, replicas=[]),
601+
)
602+
603+
if state == AppState.PENDING:
604+
# NOTE: torchx launched jobs points to exactly one host
605+
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
606+
hostname = job_resources["scheduled_nodes"]
607+
role.num_replicas += 1
608+
role_status.replicas.append(
609+
ReplicaStatus(
610+
id=int(replica_id),
611+
role=role_name,
612+
state=state,
613+
hostname=hostname,
614+
)
615+
)
616+
else: # state == AppState.RUNNING
617+
# NOTE: torchx schedules on slurm with sbatch + heterogenous job
618+
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
619+
# but we deal with jobs that have not been launched with torchx
620+
# which can have multiple hosts per sub-job (count them as replicas)
621+
node_infos = job_resources.get("allocated_nodes", [])
622+
for node_info in node_infos:
623+
# NOTE: we expect resource specs for all the nodes to be the same
624+
# NOTE: use allocated (not used/requested) memory since
625+
# users may only specify --cpu, in which case slurm
626+
# uses the (system) configured {mem-per-cpu} * {cpus}
627+
# to allocate memory.
628+
# NOTE: getting gpus is tricky because it modeled as a trackable-resource
629+
# or not configured at all (use total-cpu-on-host as proxy for gpus)
630+
cpu = int(node_info["cpus_used"])
631+
memMB = int(node_info["memory_allocated"])
632+
633+
hostname = node_info["nodename"]
634+
635+
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
636+
role.num_replicas += 1
637+
role_status.replicas.append(
638+
ReplicaStatus(
639+
id=int(replica_id),
640+
role=role_name,
641+
state=state,
642+
hostname=hostname,
643+
)
644+
)
576645

577646
return DescribeAppResponse(
578647
app_id=app_id,
579648
roles=list(roles.values()),
580649
roles_statuses=list(roles_statuses.values()),
581-
state=app_state,
582-
msg=msg,
650+
state=state,
583651
)
584652

585653
def log_iter(

0 commit comments

Comments
 (0)