File size: 5,354 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import json
import os
import os.path as osp
import random
import time
from typing import List, Sequence

import mmengine
import torch
import torch.distributed as dist
from mmengine.config import Config, ConfigDict
from mmengine.device import get_device
from mmengine.dist import init_dist
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.model.wrappers import MMDistributedDataParallel
from mmengine.utils import track_iter_progress

from opencompass.registry import MM_MODELS, TASKS
from opencompass.utils import get_logger


def build_model(cfg):
    model = MM_MODELS.build(cfg['model'])
    load_from = cfg.get('load_from', None)
    if load_from is not None:
        state_dict = torch.load(cfg['load_from'], map_location='cpu')
        if 'model' in state_dict:
            state_dict = state_dict['model']
        elif 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        msg = model.load_state_dict(state_dict, strict=False)
        print_log(msg)
    model.to(get_device())
    if dist.is_initialized():
        model = MMDistributedDataParallel(
            model,
            device_ids=[int(os.environ['LOCAL_RANK'])],
            broadcast_buffers=False)
    return model


@TASKS.register_module(force=(__name__ == '__main__'))  # A hack for script run
class MultimodalInferTask:
    """Multimodal Inference Task.

    This task is used to run the inference process.
    """

    def __init__(self, cfg: ConfigDict):
        self.num_gpus = cfg.get('num_gpus', 0)
        self.num_procs = cfg.get('num_procs', 1)
        self.dataloader = cfg.get('dataset')
        self.model = cfg.get('model')
        self.evaluator = cfg.get('evaluator')
        self.cfg = cfg
        self.logger = get_logger()

    @property
    def name(self) -> str:
        model_name = self.model['type']
        dataset_name = self.dataloader['dataset']['type']
        evaluator_name = self.evaluator[0]['type']
        return f'{model_name}-{dataset_name}-{evaluator_name}'

    def get_log_path(self, file_extension: str = 'json') -> str:
        """Get the path to the log file.

        Args:
            file_extension (str): The file extension of the log file.
                Default: 'json'.
        """
        model_name = self.model['type']
        dataset_name = self.dataloader['dataset']['type']
        evaluator_name = self.evaluator[0]['type']

        return osp.join(self.cfg.work_dir, model_name, dataset_name,
                        f'{evaluator_name}.{file_extension}')

    def get_output_paths(self, file_extension: str = 'json') -> List[str]:
        """Get the path to the output file.

        Args:
            file_extension (str): The file extension of the log file.
                Default: 'json'.
        """
        model_name = self.model['type']
        dataset_name = self.dataloader['dataset']['type']
        evaluator_name = self.evaluator[0]['type']

        return [
            osp.join(self.cfg.work_dir, model_name, dataset_name,
                     f'{evaluator_name}.{file_extension}')
        ]

    def get_command(self, cfg_path, template):
        """Get the command template for the task.

        Args:
            cfg_path (str): The path to the config file of the task.
            template (str): The template which have '{task_cmd}' to format
                the command.
        """
        script_path = __file__
        if self.num_gpus > 0:
            port = random.randint(12000, 32000)
            command = (f'torchrun --master_port={port} '
                       f'--nproc_per_node {self.num_procs} '
                       f'{script_path} {cfg_path}')
        else:
            command = f'python {script_path} {cfg_path}'

        return template.format(task_cmd=command)

    def run(self):
        from mmengine.runner import Runner

        # only support slurm, pytorch, mpi
        init_dist(self.cfg.launcher)
        self.logger.info(f'Task {self.name}')
        # build dataloader
        dataloader = Runner.build_dataloader(self.dataloader)
        # build model
        model = build_model(self.cfg)
        model.eval()
        # build evaluator
        evaluator = Evaluator(self.evaluator)

        for batch in track_iter_progress(dataloader):
            if dist.is_initialized():
                data_samples = model.module.forward(batch)
            else:
                data_samples = model.forward(batch)
            if not isinstance(data_samples, Sequence):
                data_samples = [data_samples]
            evaluator.process(data_samples)

        metrics = evaluator.evaluate(len(dataloader.dataset))
        metrics_file = self.get_output_paths()[0]
        mmengine.mkdir_or_exist(osp.split(metrics_file)[0])
        with open(metrics_file, 'w') as f:
            json.dump(metrics, f)


def parse_args():
    parser = argparse.ArgumentParser(description='Model Inferencer')
    parser.add_argument('config', help='Config file path')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    cfg = Config.fromfile(args.config)
    start_time = time.time()
    inferencer = MultimodalInferTask(cfg)
    inferencer.run()
    end_time = time.time()
    get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')