20
20
import tempfile
21
21
from dataclasses import dataclass
22
22
from datetime import datetime
23
+ from subprocess import CalledProcessError , PIPE
23
24
from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
24
25
25
26
import torchx
39
40
macros ,
40
41
NONE ,
41
42
ReplicaStatus ,
43
+ Resource ,
42
44
Role ,
43
45
RoleStatus ,
44
46
runopts ,
66
68
"TIMEOUT" : AppState .FAILED ,
67
69
}
68
70
71
+
72
+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
73
+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74
+
75
+
69
76
SBATCH_JOB_OPTIONS = {
70
77
"comment" ,
71
78
"mail-user" ,
@@ -482,16 +489,36 @@ def _cancel_existing(self, app_id: str) -> None:
482
489
subprocess .run (["scancel" , app_id ], check = True )
483
490
484
491
def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
492
+ # NOTE: depending on the version of slurm, querying for job info
493
+ # with `squeue` for finished (or non-existent) jobs either:
494
+ # 1. errors out with 'slurm_load_jobs error: Invalid job id specified'
495
+ # 2. -- or -- squeue returns an empty jobs list
496
+ # in either case, fall back to the less descriptive but more persistent sacct
497
+ # (slurm cluster must have accounting storage enabled for sacct to work)
485
498
try :
486
- return self ._describe_sacct (app_id )
487
- except subprocess .CalledProcessError :
488
- return self ._describe_squeue (app_id )
499
+ if desc := self ._describe_squeue (app_id ):
500
+ return desc
501
+ except CalledProcessError as e :
502
+ log .info (
503
+ f"unable to get job info for `{ app_id } ` with `squeue` ({ e .stderr } ), trying `sacct`"
504
+ )
505
+ return self ._describe_sacct (app_id )
489
506
490
507
def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491
- p = subprocess .run (
492
- ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
493
- )
494
- output = p .stdout .decode ("utf-8" ).split ("\n " )
508
+ try :
509
+ output = subprocess .check_output (
510
+ ["sacct" , "--parsable2" , "-j" , app_id ],
511
+ stderr = PIPE ,
512
+ encoding = "utf-8" ,
513
+ ).split ("\n " )
514
+ except CalledProcessError as e :
515
+ log .info (
516
+ "unable to get job info for `{}` with `sacct` ({})" .format (
517
+ app_id , e .stderr
518
+ )
519
+ )
520
+ return None
521
+
495
522
if len (output ) <= 1 :
496
523
return None
497
524
@@ -511,11 +538,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511
538
512
539
state = row ["State" ]
513
540
msg = state
514
- state_enum = SLURM_STATES .get (state )
515
- assert (
516
- state_enum
517
- ), f"failed to translate slurm state { state } to torchx state"
518
- app_state = state_enum
541
+ app_state = appstate_from_slurm_state (state )
519
542
520
543
role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521
544
if not replica_id or not role :
@@ -541,45 +564,109 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
541
564
)
542
565
543
566
def _describe_squeue (self , app_id : str ) -> Optional [DescribeAppResponse ]:
544
- p = subprocess .run (
545
- ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
567
+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
568
+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
569
+ output = subprocess .check_output (
570
+ ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8"
546
571
)
547
- output_json = json .loads (p .stdout .decode ("utf-8" ))
572
+ output_json = json .loads (output )
573
+ jobs = output_json ["jobs" ]
574
+ if not jobs :
575
+ return None
548
576
549
- roles = {}
550
- roles_statuses = {}
551
- msg = ""
552
- app_state = AppState .UNKNOWN
553
- for job in output_json ["jobs" ]:
554
- state = job ["job_state" ][0 ]
555
- msg = state
556
- state_enum = SLURM_STATES .get (state )
557
- assert (
558
- state_enum
559
- ), f"failed to translate slurm state { state } to torchx state"
560
- app_state = state_enum
577
+ roles : dict [str , Role ] = {}
578
+ roles_statuses : dict [str , RoleStatus ] = {}
579
+ state = AppState .UNKNOWN
561
580
562
- role , _ , replica_id = job ["name" ].rpartition ("-" )
563
- if not replica_id or not role :
564
- # name should always have at least 3 parts but sometimes sacct
565
- # is slow to update
566
- continue
567
- if role not in roles :
568
- roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
569
- roles_statuses [role ] = RoleStatus (role , [])
570
- roles [role ].num_replicas += 1
571
- roles_statuses [role ].replicas .append (
572
- ReplicaStatus (
573
- id = int (replica_id ), role = role , state = app_state , hostname = ""
581
+ for job in jobs :
582
+ # job name is of the form "{role_name}-{replica_id}"
583
+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
584
+
585
+ entrypoint = job ["command" ]
586
+ image = job ["current_working_directory" ]
587
+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
588
+
589
+ job_resources = job ["job_resources" ]
590
+
591
+ role = roles .setdefault (
592
+ role_name ,
593
+ Role (
594
+ name = role_name ,
595
+ image = image ,
596
+ entrypoint = entrypoint ,
597
+ num_replicas = 0 ,
574
598
),
575
599
)
600
+ role_status = roles_statuses .setdefault (
601
+ role_name ,
602
+ RoleStatus (role_name , replicas = []),
603
+ )
604
+
605
+ if state == AppState .PENDING :
606
+ # NOTE: torchx launched jobs points to exactly one host
607
+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
608
+ hostname = job_resources .get ("scheduled_nodes" , "" )
609
+
610
+ role .num_replicas += 1
611
+ role_status .replicas .append (
612
+ ReplicaStatus (
613
+ id = int (replica_id ),
614
+ role = role_name ,
615
+ state = state ,
616
+ hostname = hostname ,
617
+ )
618
+ )
619
+ else : # state == AppState.RUNNING
620
+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
621
+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
622
+ # but we deal with jobs that have not been launched with torchx
623
+ # which can have multiple hosts per sub-job (count them as replicas)
624
+ node_infos = job_resources .get ("allocated_nodes" , [])
625
+
626
+ if not isinstance (node_infos , list ):
627
+ # NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
628
+ # is not a list of individual nodes, but a map of the nodelist specs
629
+ # in this case just use jobs[].job_resources.nodes
630
+ hostname = job_resources .get ("nodes" )
631
+ role .num_replicas += 1
632
+ role_status .replicas .append (
633
+ ReplicaStatus (
634
+ id = int (replica_id ),
635
+ role = role_name ,
636
+ state = state ,
637
+ hostname = hostname ,
638
+ )
639
+ )
640
+ else :
641
+ for node_info in node_infos :
642
+ # NOTE: we expect resource specs for all the nodes to be the same
643
+ # NOTE: use allocated (not used/requested) memory since
644
+ # users may only specify --cpu, in which case slurm
645
+ # uses the (system) configured {mem-per-cpu} * {cpus}
646
+ # to allocate memory.
647
+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
648
+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
649
+ cpu = int (node_info ["cpus_used" ])
650
+ memMB = int (node_info ["memory_allocated" ])
651
+
652
+ hostname = node_info ["nodename" ]
653
+
654
+ role .resource = Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
655
+ role .num_replicas += 1
656
+ role_status .replicas .append (
657
+ ReplicaStatus (
658
+ id = int (replica_id ),
659
+ role = role_name ,
660
+ state = state ,
661
+ hostname = hostname ,
662
+ )
663
+ )
576
664
577
665
return DescribeAppResponse (
578
666
app_id = app_id ,
579
667
roles = list (roles .values ()),
580
668
roles_statuses = list (roles_statuses .values ()),
581
- state = app_state ,
582
- msg = msg ,
669
+ state = state ,
583
670
)
584
671
585
672
def log_iter (
0 commit comments