import argparse import json import os import os.path as osp import random import time from typing import List, Sequence import mmengine import torch import torch.distributed as dist from mmengine.config import Config, ConfigDict from mmengine.device import get_device from mmengine.dist import init_dist from mmengine.evaluator import Evaluator from mmengine.logging import print_log from mmengine.model.wrappers import MMDistributedDataParallel from mmengine.utils import track_iter_progress from opencompass.registry import MM_MODELS, TASKS from opencompass.utils import get_logger def build_model(cfg): model = MM_MODELS.build(cfg['model']) load_from = cfg.get('load_from', None) if load_from is not None: state_dict = torch.load(cfg['load_from'], map_location='cpu') if 'model' in state_dict: state_dict = state_dict['model'] elif 'state_dict' in state_dict: state_dict = state_dict['state_dict'] msg = model.load_state_dict(state_dict, strict=False) print_log(msg) model.to(get_device()) if dist.is_initialized(): model = MMDistributedDataParallel( model, device_ids=[int(os.environ['LOCAL_RANK'])], broadcast_buffers=False) return model @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run class MultimodalInferTask: """Multimodal Inference Task. This task is used to run the inference process. """ def __init__(self, cfg: ConfigDict): self.num_gpus = cfg.get('num_gpus', 0) self.num_procs = cfg.get('num_procs', 1) self.dataloader = cfg.get('dataset') self.model = cfg.get('model') self.evaluator = cfg.get('evaluator') self.cfg = cfg self.logger = get_logger() @property def name(self) -> str: model_name = self.model['type'] dataset_name = self.dataloader['dataset']['type'] evaluator_name = self.evaluator[0]['type'] return f'{model_name}-{dataset_name}-{evaluator_name}' def get_log_path(self, file_extension: str = 'json') -> str: """Get the path to the log file. Args: file_extension (str): The file extension of the log file. Default: 'json'. """ model_name = self.model['type'] dataset_name = self.dataloader['dataset']['type'] evaluator_name = self.evaluator[0]['type'] return osp.join(self.cfg.work_dir, model_name, dataset_name, f'{evaluator_name}.{file_extension}') def get_output_paths(self, file_extension: str = 'json') -> List[str]: """Get the path to the output file. Args: file_extension (str): The file extension of the log file. Default: 'json'. """ model_name = self.model['type'] dataset_name = self.dataloader['dataset']['type'] evaluator_name = self.evaluator[0]['type'] return [ osp.join(self.cfg.work_dir, model_name, dataset_name, f'{evaluator_name}.{file_extension}') ] def get_command(self, cfg_path, template): """Get the command template for the task. Args: cfg_path (str): The path to the config file of the task. template (str): The template which have '{task_cmd}' to format the command. """ script_path = __file__ if self.num_gpus > 0: port = random.randint(12000, 32000) command = (f'torchrun --master_port={port} ' f'--nproc_per_node {self.num_procs} ' f'{script_path} {cfg_path}') else: command = f'python {script_path} {cfg_path}' return template.format(task_cmd=command) def run(self): from mmengine.runner import Runner # only support slurm, pytorch, mpi init_dist(self.cfg.launcher) self.logger.info(f'Task {self.name}') # build dataloader dataloader = Runner.build_dataloader(self.dataloader) # build model model = build_model(self.cfg) model.eval() # build evaluator evaluator = Evaluator(self.evaluator) for batch in track_iter_progress(dataloader): if dist.is_initialized(): data_samples = model.module.forward(batch) else: data_samples = model.forward(batch) if not isinstance(data_samples, Sequence): data_samples = [data_samples] evaluator.process(data_samples) metrics = evaluator.evaluate(len(dataloader.dataset)) metrics_file = self.get_output_paths()[0] mmengine.mkdir_or_exist(osp.split(metrics_file)[0]) with open(metrics_file, 'w') as f: json.dump(metrics, f) def parse_args(): parser = argparse.ArgumentParser(description='Model Inferencer') parser.add_argument('config', help='Config file path') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() cfg = Config.fromfile(args.config) start_time = time.time() inferencer = MultimodalInferTask(cfg) inferencer.run() end_time = time.time() get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')