|
import logging |
|
import os |
|
import os.path as osp |
|
import subprocess |
|
import sys |
|
import time |
|
import traceback |
|
from multiprocessing import Manager, Pool |
|
from multiprocessing.managers import SyncManager |
|
from typing import Any, Dict, List, Tuple |
|
|
|
import mmengine |
|
from mmengine.config import ConfigDict |
|
from tqdm import tqdm |
|
|
|
from opencompass.registry import RUNNERS, TASKS |
|
from opencompass.tasks import OpenICLInferTask |
|
from opencompass.tasks.base import BaseTask |
|
from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg, |
|
get_infer_output_path, get_logger, |
|
task_abbr_from_cfg) |
|
|
|
from .base import BaseRunner |
|
|
|
|
|
def monkey_run(self, tokens: SyncManager.Semaphore): |
|
"""Hack for infer task run, add tokens for multiprocess.""" |
|
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}') |
|
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): |
|
self.max_out_len = model_cfg.get('max_out_len', None) |
|
self.min_out_len = model_cfg.get('min_out_len', None) |
|
self.batch_size = model_cfg.get('batch_size', None) |
|
self.model = build_model_from_cfg(model_cfg) |
|
|
|
assert self.model.is_api, 'Only API model is supported.' |
|
self.model.tokens = tokens |
|
|
|
for dataset_cfg in dataset_cfgs: |
|
self.model_cfg = model_cfg |
|
self.dataset_cfg = dataset_cfg |
|
self.infer_cfg = self.dataset_cfg['infer_cfg'] |
|
self.dataset = build_dataset_from_cfg(self.dataset_cfg) |
|
self.sub_cfg = { |
|
'models': [self.model_cfg], |
|
'datasets': [[self.dataset_cfg]], |
|
} |
|
out_path = get_infer_output_path( |
|
self.model_cfg, self.dataset_cfg, |
|
osp.join(self.work_dir, 'predictions')) |
|
if osp.exists(out_path): |
|
continue |
|
self._inference() |
|
|
|
|
|
old_stdout = sys.stdout |
|
old_stderr = sys.stderr |
|
|
|
|
|
def redirect_std_to_file(filename: str): |
|
"""Redirect stdout and stderr, also change logger stream handler.""" |
|
f = open(filename, 'w', encoding='utf-8') |
|
sys.stdout = f |
|
sys.stderr = f |
|
|
|
logger = get_logger() |
|
for h in logger.handlers: |
|
if isinstance(h, logging.StreamHandler): |
|
h.stream = sys.stdout |
|
|
|
gen_logger = logging.getLogger( |
|
'opencompass.openicl.icl_inferencer.icl_gen_inferencer') |
|
for h in gen_logger.handlers: |
|
if isinstance(h, logging.StreamHandler): |
|
h.stream = sys.stdout |
|
|
|
|
|
def reset_std(): |
|
"""Reset stdout and stderr, also change logger stream handler.""" |
|
sys.stdout.close() |
|
sys.stdout = old_stdout |
|
sys.stderr = old_stderr |
|
|
|
logger = get_logger() |
|
for h in logger.handlers: |
|
if isinstance(h, logging.StreamHandler): |
|
h.stream = sys.stdout |
|
|
|
gen_logger = logging.getLogger( |
|
'opencompass.openicl.icl_inferencer.icl_gen_inferencer') |
|
for h in gen_logger.handlers: |
|
if isinstance(h, logging.StreamHandler): |
|
h.stream = sys.stdout |
|
|
|
|
|
def launch(task: BaseTask, tokens: SyncManager.Semaphore): |
|
"""Launch a single task. |
|
|
|
Args: |
|
task (BaseTask): Task to launch. |
|
tokens (SyncManager.Semaphore): Multiprocessing semaphore |
|
for every subprocess to follow. |
|
|
|
Returns: |
|
tuple[str, int]: Task name and exit code. |
|
""" |
|
|
|
task_name = task.name |
|
returncode = 0 |
|
logger = get_logger() |
|
|
|
try: |
|
|
|
out_path = task.get_log_path(file_extension='out') |
|
mmengine.mkdir_or_exist(osp.split(out_path)[0]) |
|
redirect_std_to_file(out_path) |
|
|
|
|
|
start_time = time.time() |
|
inferencer = OpenICLInferTask(task.cfg) |
|
origin_run = inferencer.run |
|
inferencer.run = monkey_run |
|
inferencer.run(inferencer, tokens) |
|
inferencer.run = origin_run |
|
end_time = time.time() |
|
logger.info(f'time elapsed: {end_time - start_time:.2f}s') |
|
except Exception: |
|
|
|
traceback.print_exc() |
|
|
|
reset_std() |
|
logger.warning(f'task {task_name} fail, see\n{out_path}') |
|
returncode = 1 |
|
else: |
|
|
|
reset_std() |
|
return task_name, returncode |
|
|
|
|
|
def submit(task, type, tokens): |
|
"""Helper for launch the task.""" |
|
task = TASKS.build(dict(cfg=task, type=type)) |
|
tqdm.write(f'Launch {task.name} on CPU ') |
|
|
|
res = launch(task, tokens) |
|
return res |
|
|
|
|
|
@RUNNERS.register_module() |
|
class LocalAPIRunner(BaseRunner): |
|
"""Local API Runner. Start tasks by local python. |
|
|
|
The query per second cannot guarantee the number of concurrents, therefore |
|
Supported concurrent users with multiple tasks. Applied for those apis |
|
which has a restriction on concurrent numbers. |
|
|
|
Args: |
|
task (ConfigDict): Task type config. |
|
concurrent_users (int): Max number of concurrent workers to request |
|
the resources. |
|
max_num_workers (int): Max number of workers to run in parallel. |
|
Defaults to 16. |
|
debug (bool): Whether to run in debug mode. |
|
lark_bot_url (str): Lark bot url. |
|
""" |
|
|
|
def __init__(self, |
|
task: ConfigDict, |
|
concurrent_users: int, |
|
max_num_workers: int = 16, |
|
debug: bool = False, |
|
lark_bot_url: str = None): |
|
super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) |
|
self.max_num_workers = max_num_workers |
|
self.concurrent_users = concurrent_users |
|
assert task['type'] in [ |
|
'OpenICLInferTask', |
|
'opencompass.tasks.OpenICLInferTask', |
|
], 'Only supported for api infer task.' |
|
|
|
def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: |
|
"""Launch multiple tasks. |
|
|
|
Args: |
|
tasks (list[dict]): A list of task configs, usually generated by |
|
Partitioner. |
|
|
|
Returns: |
|
list[tuple[str, int]]: A list of (task name, exit code). |
|
""" |
|
status = [] |
|
if self.debug: |
|
|
|
for task in tasks: |
|
task = TASKS.build(dict(cfg=task, type=self.task_cfg['type'])) |
|
task_name = task.name |
|
|
|
mmengine.mkdir_or_exist('tmp/') |
|
param_file = f'tmp/{os.getpid()}_params.py' |
|
try: |
|
task.cfg.dump(param_file) |
|
cmd = task.get_command(cfg_path=param_file, |
|
template='{task_cmd}') |
|
|
|
if cmd.startswith('python'): |
|
task.run() |
|
else: |
|
subprocess.run(cmd, shell=True, text=True) |
|
finally: |
|
os.remove(param_file) |
|
status.append((task_name, 0)) |
|
else: |
|
|
|
pbar = tqdm(total=len(tasks)) |
|
|
|
get_logger().info('All the logs and processes for each task' |
|
' should be checked in each infer/.out file.') |
|
with Manager() as manager: |
|
tokens = manager.Semaphore(self.concurrent_users) |
|
|
|
|
|
pbar_counter = manager.Value('i', 0) |
|
status = [] |
|
|
|
def update(args): |
|
"""Update pbar counter when callback.""" |
|
pbar_counter.value += 1 |
|
status.append(args) |
|
|
|
with Pool(processes=self.max_num_workers) as pool: |
|
for task in tasks: |
|
pool.apply_async(submit, |
|
(task, self.task_cfg['type'], tokens), |
|
callback=update) |
|
pool.close() |
|
|
|
|
|
while True: |
|
cur_count = pbar_counter.value |
|
if cur_count > pbar.n: |
|
pbar.update(cur_count - pbar.n) |
|
|
|
if cur_count >= pbar.total: |
|
pbar.close() |
|
break |
|
|
|
time.sleep(1) |
|
|
|
pool.join() |
|
return status |
|
|