Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use itertools batch to get long jobs lists #3815

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 134 additions & 100 deletions parsl/providers/slurm/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
import time
from itertools import islice
from typing import Any, Dict, Optional

import typeguard
Expand All @@ -19,37 +20,51 @@

# From https://slurm.schedmd.com/sacct.html#SECTION_JOB-STATE-CODES
sacct_translate_table = {
'PENDING': JobState.PENDING,
'RUNNING': JobState.RUNNING,
'CANCELLED': JobState.CANCELLED,
'COMPLETED': JobState.COMPLETED,
'FAILED': JobState.FAILED,
'NODE_FAIL': JobState.FAILED,
'BOOT_FAIL': JobState.FAILED,
'DEADLINE': JobState.TIMEOUT,
'TIMEOUT': JobState.TIMEOUT,
'REVOKED': JobState.FAILED,
'OUT_OF_MEMORY': JobState.FAILED,
'SUSPENDED': JobState.HELD,
'PREEMPTED': JobState.TIMEOUT,
'REQUEUED': JobState.PENDING
"PENDING": JobState.PENDING,
"RUNNING": JobState.RUNNING,
"CANCELLED": JobState.CANCELLED,
"COMPLETED": JobState.COMPLETED,
"FAILED": JobState.FAILED,
"NODE_FAIL": JobState.FAILED,
"BOOT_FAIL": JobState.FAILED,
"DEADLINE": JobState.TIMEOUT,
"TIMEOUT": JobState.TIMEOUT,
"REVOKED": JobState.FAILED,
"OUT_OF_MEMORY": JobState.FAILED,
"SUSPENDED": JobState.HELD,
"PREEMPTED": JobState.TIMEOUT,
"REQUEUED": JobState.PENDING,
}

squeue_translate_table = {
'PD': JobState.PENDING,
'R': JobState.RUNNING,
'CA': JobState.CANCELLED,
'CF': JobState.PENDING, # (configuring),
'CG': JobState.RUNNING, # (completing),
'CD': JobState.COMPLETED,
'F': JobState.FAILED, # (failed),
'TO': JobState.TIMEOUT, # (timeout),
'NF': JobState.FAILED, # (node failure),
'RV': JobState.FAILED, # (revoked) and
'SE': JobState.FAILED # (special exit state)
"PD": JobState.PENDING,
"R": JobState.RUNNING,
"CA": JobState.CANCELLED,
"CF": JobState.PENDING, # (configuring),
"CG": JobState.RUNNING, # (completing),
"CD": JobState.COMPLETED,
"F": JobState.FAILED, # (failed),
"TO": JobState.TIMEOUT, # (timeout),
"NF": JobState.FAILED, # (node failure),
"RV": JobState.FAILED, # (revoked) and
"SE": JobState.FAILED, # (special exit state)
}


def batched(iterable, n):
"""Batched
Turns a list into a batch of size n. This code is from
https://docs.python.org/3.12/library/itertools.html#itertools.batched
and in versions 3.12+ this can be replaced with
itertools.batched
"""
if n < 1:
raise ValueError("n must be at least one")
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
yield batch


class SlurmProvider(ClusterProvider, RepresentationMixin):
"""Slurm Execution Provider

Expand Down Expand Up @@ -99,6 +114,12 @@ class SlurmProvider(ClusterProvider, RepresentationMixin):
symbolic group for the job ID.
worker_init : str
Command to be run before starting a worker, such as 'module load Anaconda; source activate env'.
cmd_timeout : int (Default = 10)
Number of seconds to wait for slurm commands to finish. For schedulers with many this
may need to be increased to wait longer for scheduler information.
status_batch_size: int (Default = 50)
Number of jobs to batch together in calls to the scheduler status. For schedulers
with many jobs this may need to be decreased to get jobs in smaller batches.
exclusive : bool (Default = True)
Requests nodes which are not shared with other running jobs.
launcher : Launcher
Expand All @@ -109,36 +130,41 @@ class SlurmProvider(ClusterProvider, RepresentationMixin):
"""

@typeguard.typechecked
def __init__(self,
partition: Optional[str] = None,
account: Optional[str] = None,
qos: Optional[str] = None,
constraint: Optional[str] = None,
clusters: Optional[str] = None,
nodes_per_block: int = 1,
cores_per_node: Optional[int] = None,
mem_per_node: Optional[int] = None,
init_blocks: int = 1,
min_blocks: int = 0,
max_blocks: int = 1,
parallelism: float = 1,
walltime: str = "00:10:00",
scheduler_options: str = '',
regex_job_id: str = r"Submitted batch job (?P<id>\S*)",
worker_init: str = '',
cmd_timeout: int = 10,
exclusive: bool = True,
launcher: Launcher = SingleNodeLauncher()):
label = 'slurm'
super().__init__(label,
nodes_per_block,
init_blocks,
min_blocks,
max_blocks,
parallelism,
walltime,
cmd_timeout=cmd_timeout,
launcher=launcher)
def __init__(
self,
partition: Optional[str] = None,
account: Optional[str] = None,
qos: Optional[str] = None,
constraint: Optional[str] = None,
clusters: Optional[str] = None,
nodes_per_block: int = 1,
cores_per_node: Optional[int] = None,
mem_per_node: Optional[int] = None,
init_blocks: int = 1,
min_blocks: int = 0,
max_blocks: int = 1,
parallelism: float = 1,
walltime: str = "00:10:00",
scheduler_options: str = "",
regex_job_id: str = r"Submitted batch job (?P<id>\S*)",
worker_init: str = "",
cmd_timeout: int = 10,
status_batch_size: int = 50,
exclusive: bool = True,
launcher: Launcher = SingleNodeLauncher(),
):
label = "slurm"
super().__init__(
label,
nodes_per_block,
init_blocks,
min_blocks,
max_blocks,
parallelism,
walltime,
cmd_timeout=cmd_timeout,
launcher=launcher,
)

self.partition = partition
self.cores_per_node = cores_per_node
Expand All @@ -148,7 +174,9 @@ def __init__(self,
self.qos = qos
self.constraint = constraint
self.clusters = clusters
self.scheduler_options = scheduler_options + '\n'
# Used to batch requests to sacct/squeue for long jobs lists
self.status_batch_size = status_batch_size
self.scheduler_options = scheduler_options + "\n"
if exclusive:
self.scheduler_options += "#SBATCH --exclusive\n"
if partition:
Expand All @@ -163,7 +191,7 @@ def __init__(self,
self.scheduler_options += "#SBATCH --clusters={}\n".format(clusters)

self.regex_job_id = regex_job_id
self.worker_init = worker_init + '\n'
self.worker_init = worker_init + "\n"
# Check if sacct works and if not fall back to squeue
cmd = "sacct -X"
logger.debug("Executing %s", cmd)
Expand Down Expand Up @@ -191,33 +219,35 @@ def __init__(self,
self._translate_table = squeue_translate_table

def _status(self):
'''Returns the status list for a list of job_ids
"""Returns the status list for a list of job_ids

Args:
self

Returns:
[status...] : Status list of all jobs
'''
job_id_list = ','.join(
[jid for jid, job in self.resources.items() if not job['status'].terminal]
)
if not job_id_list:
logger.debug('No active jobs, skipping status update')
return

cmd = self._cmd.format(job_id_list)
logger.debug("Executing %s", cmd)
retcode, stdout, stderr = self.execute_wait(cmd)
logger.debug("sacct/squeue returned %s %s", stdout, stderr)
"""

# Execute_wait failed. Do no update
if retcode != 0:
logger.warning("sacct/squeue failed with non-zero exit code {}".format(retcode))
if len(self.resources.items()) == 0:
logger.debug("No active jobs, skipping status update")
return

job_list_batches = batched(self.resources.items(), self.status_batch_size)
stdout = ""
for job_batch in job_list_batches:
job_id_list = ",".join([jid for jid, job in job_batch if not job["status"].terminal])
cmd = self._cmd.format(job_id_list)
logger.debug("Executing %s", cmd)
retcode, _stdout, stderr = self.execute_wait(cmd)
logger.debug("sacct/squeue returned %s %s", stdout, stderr)
stdout += _stdout
# Execute_wait failed. Do no update
if retcode != 0:
logger.warning("sacct/squeue failed with non-zero exit code {}".format(retcode))
return

jobs_missing = set(self.resources.keys())
for line in stdout.split('\n'):
for line in stdout.split("\n"):
if not line:
# Blank line
continue
Expand All @@ -229,19 +259,23 @@ def _status(self):
logger.warning(f"Slurm status {slurm_state} is not recognized")
status = self._translate_table.get(slurm_state, JobState.UNKNOWN)
logger.debug("Updating job {} with slurm status {} to parsl state {!s}".format(job_id, slurm_state, status))
self.resources[job_id]['status'] = JobStatus(status,
stdout_path=self.resources[job_id]['job_stdout_path'],
stderr_path=self.resources[job_id]['job_stderr_path'])
self.resources[job_id]["status"] = JobStatus(
status,
stdout_path=self.resources[job_id]["job_stdout_path"],
stderr_path=self.resources[job_id]["job_stderr_path"],
)
jobs_missing.remove(job_id)

# sacct can get job info after jobs have completed so this path shouldn't be hit
# squeue does not report on jobs that are not running. So we are filling in the
# blanks for missing jobs, we might lose some information about why the jobs failed.
for missing_job in jobs_missing:
logger.debug("Updating missing job {} to completed status".format(missing_job))
self.resources[missing_job]['status'] = JobStatus(
JobState.COMPLETED, stdout_path=self.resources[missing_job]['job_stdout_path'],
stderr_path=self.resources[missing_job]['job_stderr_path'])
self.resources[missing_job]["status"] = JobStatus(
JobState.COMPLETED,
stdout_path=self.resources[missing_job]["job_stdout_path"],
stderr_path=self.resources[missing_job]["job_stderr_path"],
)

def submit(self, command: str, tasks_per_node: int, job_name="parsl.slurm") -> str:
"""Submit the command as a slurm job.
Expand All @@ -263,12 +297,12 @@ def submit(self, command: str, tasks_per_node: int, job_name="parsl.slurm") -> s
scheduler_options = self.scheduler_options
worker_init = self.worker_init
if self.mem_per_node is not None:
scheduler_options += '#SBATCH --mem={}g\n'.format(self.mem_per_node)
worker_init += 'export PARSL_MEMORY_GB={}\n'.format(self.mem_per_node)
scheduler_options += "#SBATCH --mem={}g\n".format(self.mem_per_node)
worker_init += "export PARSL_MEMORY_GB={}\n".format(self.mem_per_node)
if self.cores_per_node is not None:
cpus_per_task = math.floor(self.cores_per_node / tasks_per_node)
scheduler_options += '#SBATCH --cpus-per-task={}'.format(cpus_per_task)
worker_init += 'export PARSL_CORES={}\n'.format(cpus_per_task)
scheduler_options += "#SBATCH --cpus-per-task={}".format(cpus_per_task)
worker_init += "export PARSL_CORES={}\n".format(cpus_per_task)

job_name = "{0}.{1}".format(job_name, time.time())

Expand All @@ -292,25 +326,24 @@ def submit(self, command: str, tasks_per_node: int, job_name="parsl.slurm") -> s
job_config["job_stderr_path"] = job_stderr_path

# Wrap the command
job_config["user_script"] = self.launcher(command,
tasks_per_node,
self.nodes_per_block)
job_config["user_script"] = self.launcher(command, tasks_per_node, self.nodes_per_block)

logger.debug("Writing submit script")
self._write_submit_script(template_string, script_path, job_name, job_config)

retcode, stdout, stderr = self.execute_wait("sbatch {0}".format(script_path))

if retcode == 0:
for line in stdout.split('\n'):
for line in stdout.split("\n"):
match = re.match(self.regex_job_id, line)
if match:
job_id = match.group("id")
self.resources[job_id] = {'job_id': job_id,
'status': JobStatus(JobState.PENDING),
'job_stdout_path': job_stdout_path,
'job_stderr_path': job_stderr_path,
}
self.resources[job_id] = {
"job_id": job_id,
"status": JobStatus(JobState.PENDING),
"job_stdout_path": job_stdout_path,
"job_stderr_path": job_stderr_path,
}
return job_id
else:
logger.error("Could not read job ID from submit command standard output.")
Expand All @@ -320,29 +353,30 @@ def submit(self, command: str, tasks_per_node: int, job_name="parsl.slurm") -> s
"Could not read job ID from submit command standard output",
stdout=stdout,
stderr=stderr,
retcode=retcode
retcode=retcode,
)
else:
logger.error("Submit command failed")
logger.error("Retcode:%s STDOUT:%s STDERR:%s", retcode, stdout.strip(), stderr.strip())
raise SubmitException(
job_name, "Could not read job ID from submit command standard output",
job_name,
"Could not read job ID from submit command standard output",
stdout=stdout,
stderr=stderr,
retcode=retcode
retcode=retcode,
)

def cancel(self, job_ids):
''' Cancels the jobs specified by a list of job ids
"""Cancels the jobs specified by a list of job ids

Args:
job_ids : [<job_id> ...]

Returns :
[True/False...] : If the cancel operation fails the entire list will be False.
'''
"""

job_id_list = ' '.join(job_ids)
job_id_list = " ".join(job_ids)

# Make the command to cancel jobs
_cmd = "scancel"
Expand All @@ -354,7 +388,7 @@ def cancel(self, job_ids):
rets = None
if retcode == 0:
for jid in job_ids:
self.resources[jid]['status'] = JobStatus(JobState.CANCELLED) # Setting state to cancelled
self.resources[jid]["status"] = JobStatus(JobState.CANCELLED) # Setting state to cancelled
rets = [True for i in job_ids]
else:
rets = [False for i in job_ids]
Expand Down
Loading