File size: 5,354 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import argparse
import json
import os
import os.path as osp
import random
import time
from typing import List, Sequence
import mmengine
import torch
import torch.distributed as dist
from mmengine.config import Config, ConfigDict
from mmengine.device import get_device
from mmengine.dist import init_dist
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.model.wrappers import MMDistributedDataParallel
from mmengine.utils import track_iter_progress
from opencompass.registry import MM_MODELS, TASKS
from opencompass.utils import get_logger
def build_model(cfg):
model = MM_MODELS.build(cfg['model'])
load_from = cfg.get('load_from', None)
if load_from is not None:
state_dict = torch.load(cfg['load_from'], map_location='cpu')
if 'model' in state_dict:
state_dict = state_dict['model']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
msg = model.load_state_dict(state_dict, strict=False)
print_log(msg)
model.to(get_device())
if dist.is_initialized():
model = MMDistributedDataParallel(
model,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
return model
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class MultimodalInferTask:
"""Multimodal Inference Task.
This task is used to run the inference process.
"""
def __init__(self, cfg: ConfigDict):
self.num_gpus = cfg.get('num_gpus', 0)
self.num_procs = cfg.get('num_procs', 1)
self.dataloader = cfg.get('dataset')
self.model = cfg.get('model')
self.evaluator = cfg.get('evaluator')
self.cfg = cfg
self.logger = get_logger()
@property
def name(self) -> str:
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return f'{model_name}-{dataset_name}-{evaluator_name}'
def get_log_path(self, file_extension: str = 'json') -> str:
"""Get the path to the log file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return [
osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
]
def get_command(self, cfg_path, template):
"""Get the command template for the task.
Args:
cfg_path (str): The path to the config file of the task.
template (str): The template which have '{task_cmd}' to format
the command.
"""
script_path = __file__
if self.num_gpus > 0:
port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
else:
command = f'python {script_path} {cfg_path}'
return template.format(task_cmd=command)
def run(self):
from mmengine.runner import Runner
# only support slurm, pytorch, mpi
init_dist(self.cfg.launcher)
self.logger.info(f'Task {self.name}')
# build dataloader
dataloader = Runner.build_dataloader(self.dataloader)
# build model
model = build_model(self.cfg)
model.eval()
# build evaluator
evaluator = Evaluator(self.evaluator)
for batch in track_iter_progress(dataloader):
if dist.is_initialized():
data_samples = model.module.forward(batch)
else:
data_samples = model.forward(batch)
if not isinstance(data_samples, Sequence):
data_samples = [data_samples]
evaluator.process(data_samples)
metrics = evaluator.evaluate(len(dataloader.dataset))
metrics_file = self.get_output_paths()[0]
mmengine.mkdir_or_exist(osp.split(metrics_file)[0])
with open(metrics_file, 'w') as f:
json.dump(metrics, f)
def parse_args():
parser = argparse.ArgumentParser(description='Model Inferencer')
parser.add_argument('config', help='Config file path')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
start_time = time.time()
inferencer = MultimodalInferTask(cfg)
inferencer.run()
end_time = time.time()
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')
|