20
20
import tempfile
21
21
from dataclasses import dataclass
22
22
from datetime import datetime
23
+ from subprocess import CalledProcessError , PIPE
23
24
from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
24
25
25
26
import torchx
39
40
macros ,
40
41
NONE ,
41
42
ReplicaStatus ,
43
+ Resource ,
42
44
Role ,
43
45
RoleStatus ,
44
46
runopts ,
66
68
"TIMEOUT" : AppState .FAILED ,
67
69
}
68
70
71
+
72
+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
73
+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74
+
75
+
69
76
SBATCH_JOB_OPTIONS = {
70
77
"comment" ,
71
78
"mail-user" ,
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483
490
484
491
def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
485
492
try :
486
- return self ._describe_sacct (app_id )
487
- except subprocess .CalledProcessError :
488
493
return self ._describe_squeue (app_id )
494
+ except CalledProcessError as e :
495
+ # NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
496
+ # if the job does not exist or has finished (e.g. not in PENDING or RUNNING states)
497
+ # in this case, fall back to the less descriptive but more persistent sacct
498
+ # (slurm cluster must have accounting storage enabled for sacct to work)
499
+ log .info (
500
+ "unable to get job info for `{}` with `squeue` ({}), trying `sacct`" .format (
501
+ app_id , e .stderr
502
+ )
503
+ )
504
+ return self ._describe_sacct (app_id )
489
505
490
506
def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491
- p = subprocess .run (
492
- ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
493
- )
494
- output = p .stdout .decode ("utf-8" ).split ("\n " )
507
+ try :
508
+ output = subprocess .check_output (
509
+ ["sacct" , "--parsable2" , "-j" , app_id ],
510
+ stderr = PIPE ,
511
+ encoding = "utf-8" ,
512
+ ).split ("\n " )
513
+ except CalledProcessError as e :
514
+ log .info (
515
+ "unable to get job info for `{}` with `sacct` ({})" .format (
516
+ app_id , e .stderr
517
+ )
518
+ )
519
+ return None
520
+
495
521
if len (output ) <= 1 :
496
522
return None
497
523
@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511
537
512
538
state = row ["State" ]
513
539
msg = state
514
- state_enum = SLURM_STATES .get (state )
515
- assert (
516
- state_enum
517
- ), f"failed to translate slurm state { state } to torchx state"
518
- app_state = state_enum
540
+ app_state = appstate_from_slurm_state (state )
519
541
520
542
role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521
543
if not replica_id or not role :
@@ -540,46 +562,92 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540
562
msg = msg ,
541
563
)
542
564
543
- def _describe_squeue (self , app_id : str ) -> Optional [DescribeAppResponse ]:
544
- p = subprocess .run (
545
- ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
565
+ def _describe_squeue (self , app_id : str ) -> DescribeAppResponse :
566
+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567
+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568
+ output = subprocess .check_output (
569
+ ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8"
546
570
)
547
- output_json = json .loads (p .stdout .decode ("utf-8" ))
548
571
549
- roles = {}
550
- roles_statuses = {}
551
- msg = ""
552
- app_state = AppState .UNKNOWN
553
- for job in output_json ["jobs" ]:
554
- state = job ["job_state" ][0 ]
555
- msg = state
556
- state_enum = SLURM_STATES .get (state )
557
- assert (
558
- state_enum
559
- ), f"failed to translate slurm state { state } to torchx state"
560
- app_state = state_enum
572
+ output_json = json .loads (output )
573
+ jobs = output_json ["jobs" ]
561
574
562
- role , _ , replica_id = job ["name" ].rpartition ("-" )
563
- if not replica_id or not role :
564
- # name should always have at least 3 parts but sometimes sacct
565
- # is slow to update
566
- continue
567
- if role not in roles :
568
- roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
569
- roles_statuses [role ] = RoleStatus (role , [])
570
- roles [role ].num_replicas += 1
571
- roles_statuses [role ].replicas .append (
572
- ReplicaStatus (
573
- id = int (replica_id ), role = role , state = app_state , hostname = ""
575
+ roles : dict [str , Role ] = {}
576
+ roles_statuses : dict [str , RoleStatus ] = {}
577
+ state = AppState .UNKNOWN
578
+
579
+ for job in jobs :
580
+ # job name is of the form "{role_name}-{replica_id}"
581
+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
582
+
583
+ entrypoint = job ["command" ]
584
+ image = job ["current_working_directory" ]
585
+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
586
+
587
+ job_resources = job ["job_resources" ]
588
+
589
+ role = roles .setdefault (
590
+ role_name ,
591
+ Role (
592
+ name = role_name ,
593
+ image = image ,
594
+ entrypoint = entrypoint ,
595
+ num_replicas = 0 ,
574
596
),
575
597
)
598
+ role_status = roles_statuses .setdefault (
599
+ role_name ,
600
+ RoleStatus (role_name , replicas = []),
601
+ )
602
+
603
+ if state == AppState .PENDING :
604
+ # NOTE: torchx launched jobs points to exactly one host
605
+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
606
+ hostname = job_resources ["scheduled_nodes" ]
607
+ role .num_replicas += 1
608
+ role_status .replicas .append (
609
+ ReplicaStatus (
610
+ id = int (replica_id ),
611
+ role = role_name ,
612
+ state = state ,
613
+ hostname = hostname ,
614
+ )
615
+ )
616
+ else : # state == AppState.RUNNING
617
+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
618
+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
619
+ # but we deal with jobs that have not been launched with torchx
620
+ # which can have multiple hosts per sub-job (count them as replicas)
621
+ node_infos = job_resources .get ("allocated_nodes" , [])
622
+ for node_info in node_infos :
623
+ # NOTE: we expect resource specs for all the nodes to be the same
624
+ # NOTE: use allocated (not used/requested) memory since
625
+ # users may only specify --cpu, in which case slurm
626
+ # uses the (system) configured {mem-per-cpu} * {cpus}
627
+ # to allocate memory.
628
+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
629
+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
630
+ cpu = int (node_info ["cpus_used" ])
631
+ memMB = int (node_info ["memory_allocated" ])
632
+
633
+ hostname = node_info ["nodename" ]
634
+
635
+ role .resource = Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
636
+ role .num_replicas += 1
637
+ role_status .replicas .append (
638
+ ReplicaStatus (
639
+ id = int (replica_id ),
640
+ role = role_name ,
641
+ state = state ,
642
+ hostname = hostname ,
643
+ )
644
+ )
576
645
577
646
return DescribeAppResponse (
578
647
app_id = app_id ,
579
648
roles = list (roles .values ()),
580
649
roles_statuses = list (roles_statuses .values ()),
581
- state = app_state ,
582
- msg = msg ,
650
+ state = state ,
583
651
)
584
652
585
653
def log_iter (
0 commit comments