|
import os |
|
import os.path as osp |
|
import re |
|
import subprocess |
|
import time |
|
import traceback |
|
from functools import partial |
|
from multiprocessing import Pipe, Pool |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import mmengine |
|
from mmengine.config import ConfigDict |
|
from tqdm import tqdm |
|
|
|
from opencompass.registry import RUNNERS, TASKS |
|
from opencompass.utils import batched, get_logger |
|
|
|
from .base import BaseRunner |
|
|
|
|
|
@RUNNERS.register_module() |
|
class SlurmSequentialRunner(BaseRunner): |
|
"""Distributed runner based on Slurm. It will launch tasks in parallel |
|
using `srun` command. |
|
|
|
This runner launches tasks one by one for execution. A new task will only |
|
be launched when and only when max_num_workers is not met, and the previous |
|
task has been successfully allocated to a machine. Therefore, unlike the |
|
`SlurmRunner`, at most only one task will be in the PENDING status at the |
|
same time during a run, making the random_sleep strategy no longer |
|
necessary. In addition, this runner also includes a feature to |
|
automatically kill all jobs by the job_id on exit. |
|
|
|
The runner will obtain the job_id by reading the srun output similar to |
|
`srun: Job 123456 scheduled successfully!`. If the output of srun does not |
|
match this pattern, the runner will not work properly. |
|
|
|
Args: |
|
task (ConfigDict): Task type config. |
|
max_num_workers (int): Max number of workers to run in parallel. |
|
Defaults to 32. |
|
retry (int): Number of retries if the job failed. Defaults to 2. |
|
partition (str): Slurm partition name. Defaults to None. |
|
quotatype (str): Slurm quota type. Defaults to None. |
|
qos (str): Slurm quality of service. Defaults to None. |
|
debug (bool): Whether to run in debug mode. Defaults to False. |
|
lark_bot_url (str): Lark bot url. Defaults to None. |
|
extra_command (List, optional): Extra slurm command. |
|
For example ['-c 12', '-w node1']. Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
task: ConfigDict, |
|
task_prefix: str = '', |
|
max_num_workers: int = 32, |
|
retry: int = 2, |
|
partition: str = None, |
|
quotatype: str = None, |
|
qos: str = None, |
|
debug: bool = False, |
|
lark_bot_url: str = None, |
|
extra_command: Optional[List[str]] = None): |
|
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) |
|
self.max_num_workers = max_num_workers |
|
self.retry = retry |
|
self.partition = partition |
|
self.quotatype = quotatype |
|
self.qos = qos |
|
self.task_prefix = task_prefix |
|
if not extra_command: |
|
extra_command = [] |
|
assert isinstance(extra_command, list) |
|
self.extra_command = extra_command |
|
|
|
logger = get_logger() |
|
if self.quotatype in ['spot', 'auto']: |
|
logger.warning( |
|
'Quotatype spot or auto may cause stability issues, ' |
|
'reserved is recommended.') |
|
|
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: |
|
if not self.debug: |
|
return self._launch_wo_debug(tasks) |
|
else: |
|
return [self._launch(task) for task in tasks] |
|
|
|
def _launch_wo_debug(self, |
|
tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: |
|
launched_bar = tqdm(total=len(tasks), desc='Launched') |
|
finished_bar = tqdm(total=len(tasks), desc='Finished') |
|
job_ids = [] |
|
status = [] |
|
|
|
def _update(result): |
|
finished_bar.update() |
|
status.append(result) |
|
return result |
|
|
|
def _err_update(err): |
|
finished_bar.update() |
|
traceback.print_exc() |
|
status.append(('', -1)) |
|
|
|
try: |
|
parent_conns = [] |
|
num_workers = max(min(self.max_num_workers, len(tasks)), 1) |
|
with Pool(processes=num_workers) as pool: |
|
for task in tasks: |
|
parent_conn, child_conn = Pipe() |
|
_ = pool.apply_async(self._launch, |
|
kwds={ |
|
'cfg': task, |
|
'child_conn': child_conn |
|
}, |
|
callback=_update, |
|
error_callback=_err_update) |
|
time.sleep(0.5) |
|
|
|
job_id = parent_conn.recv() |
|
launched_bar.update() |
|
parent_conns.append(parent_conn) |
|
job_ids.append(job_id) |
|
|
|
pool.close() |
|
pool.join() |
|
return status |
|
except KeyboardInterrupt: |
|
raise |
|
finally: |
|
launched_bar.close() |
|
finished_bar.close() |
|
for parent_conn in parent_conns: |
|
while parent_conn.poll(): |
|
try: |
|
job_id = parent_conn.recv() |
|
job_ids.append(job_id) |
|
except EOFError: |
|
break |
|
parent_conn.close() |
|
|
|
tbar = tqdm(total=len(job_ids), desc='clear sruns') |
|
for batched_job_ids in batched(job_ids, 4): |
|
while True: |
|
ps = [] |
|
try: |
|
for job_id in batched_job_ids: |
|
tbar.update() |
|
if job_id is None: |
|
continue |
|
cmd = f'scancel {job_id}' |
|
p = subprocess.Popen(cmd, |
|
shell=True, |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.STDOUT) |
|
ps.append(p) |
|
break |
|
except KeyboardInterrupt: |
|
logger = get_logger() |
|
logger.error('Ignoring KeyboardInterrupt...') |
|
for p in ps: |
|
p.wait() |
|
tbar.close() |
|
|
|
def _launch(self, cfg: ConfigDict, child_conn: Pipe = None): |
|
logger = get_logger() |
|
|
|
task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) |
|
num_gpus = task.num_gpus |
|
task_name = task.name |
|
task_name = self.task_prefix + task_name |
|
|
|
|
|
mmengine.mkdir_or_exist('tmp/') |
|
param_file = f'tmp/{os.getpid()}_params.py' |
|
process = None |
|
try: |
|
cfg.dump(param_file) |
|
|
|
|
|
tmpl = 'srun' |
|
if self.partition: |
|
tmpl += f' -p {self.partition}' |
|
if self.quotatype: |
|
tmpl += f' --quotatype={self.quotatype}' |
|
if self.qos: |
|
tmpl += f' --qos={self.qos}' |
|
if num_gpus > 0: |
|
tmpl += f' --gres=gpu:{num_gpus}' |
|
for extra_cmd in self.extra_command: |
|
tmpl += f' {extra_cmd}' |
|
tmpl += f" -N1 -u -J '{task_name[:512]}'" + ' {task_cmd}' |
|
get_cmd = partial(task.get_command, |
|
cfg_path=param_file, |
|
template=tmpl) |
|
cmd = get_cmd() |
|
|
|
logger.debug(f'Running command: {cmd}') |
|
|
|
retry = self.retry |
|
output_paths = task.get_output_paths() |
|
|
|
if self.debug: |
|
while True: |
|
process = subprocess.Popen(cmd, shell=True, text=True) |
|
process.communicate() |
|
process.wait() |
|
if self._job_failed(process.returncode, output_paths): |
|
if retry > 0: |
|
logger.warning( |
|
f'task {task_name} failed, retrying...') |
|
retry -= 1 |
|
cmd = get_cmd() |
|
else: |
|
break |
|
else: |
|
break |
|
else: |
|
out_path = task.get_log_path(file_extension='out') |
|
mmengine.mkdir_or_exist(osp.split(out_path)[0]) |
|
stdout = open(out_path, 'w', encoding='utf-8') |
|
stderr = subprocess.PIPE |
|
while True: |
|
process = subprocess.Popen(cmd, |
|
shell=True, |
|
text=True, |
|
stdout=stdout, |
|
stderr=stderr) |
|
job_id = None |
|
while True: |
|
line = process.stderr.readline() |
|
if not line: |
|
break |
|
match = re.search( |
|
r'srun: Job (\d+) scheduled successfully!', line) |
|
if match and job_id is None: |
|
job_id = match.group(1) |
|
child_conn.send(job_id) |
|
stdout.write(line) |
|
process.wait() |
|
if self._job_failed(process.returncode, output_paths): |
|
if retry > 0: |
|
retry -= 1 |
|
cmd = get_cmd() |
|
else: |
|
logger.warning( |
|
f'task {task_name} fail, see\n{out_path}') |
|
break |
|
else: |
|
break |
|
except KeyboardInterrupt: |
|
raise |
|
finally: |
|
|
|
if child_conn is not None: |
|
child_conn.send(None) |
|
child_conn.close() |
|
if process is not None: |
|
process.kill() |
|
os.remove(param_file) |
|
return task_name, process.returncode |
|
|
|
def _job_failed(self, return_code: int, output_paths: List[str]) -> bool: |
|
return return_code != 0 or not all( |
|
osp.exists(output_path) for output_path in output_paths) |
|
|