File size: 2,779 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import copy
import os
from abc import abstractmethod
from typing import List

from mmengine.config import ConfigDict

from opencompass.utils import get_infer_output_path, task_abbr_from_cfg


class BaseTask:
    """Base class for all tasks. There are two ways to run the task:
    1. Directly by calling the `run` method.
    2. Calling the `get_command` method to get the command,
        and then run the command in the shell.

    Args:
        cfg (ConfigDict): Config dict.
    """

    # The prefix of the task name.
    name_prefix: str = None
    # The subdirectory of the work directory to store the log files.
    log_subdir: str = None
    # The subdirectory of the work directory to store the output files.
    output_subdir: str = None

    def __init__(self, cfg: ConfigDict):
        cfg = copy.deepcopy(cfg)
        self.cfg = cfg
        self.model_cfgs = cfg['models']
        self.dataset_cfgs = cfg['datasets']
        self.work_dir = cfg['work_dir']

    @abstractmethod
    def run(self):
        """Run the task."""

    @abstractmethod
    def get_command(self, cfg_path, template) -> str:
        """Get the command template for the task.

        Args:
            cfg_path (str): The path to the config file of the task.
            template (str): The template which have '{task_cmd}' to format
                the command.
        """

    @property
    def name(self) -> str:
        return self.name_prefix + task_abbr_from_cfg(
            {
                'models': self.model_cfgs,
                'datasets': self.dataset_cfgs
            })

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.cfg})'

    def get_log_path(self, file_extension: str = 'json') -> str:
        """Get the path to the log file.

        Args:
            file_extension (str): The file extension of the log file.
                Default: 'json'.
        """
        return get_infer_output_path(
            self.model_cfgs[0], self.dataset_cfgs[0][0],
            os.path.join(self.work_dir, self.log_subdir), file_extension)

    def get_output_paths(self, file_extension: str = 'json') -> List[str]:
        """Get the paths to the output files. Every file should exist if the
        task succeeds.

        Args:
            file_extension (str): The file extension of the output files.
                Default: 'json'.
        """
        output_paths = []
        for model, datasets in zip(self.model_cfgs, self.dataset_cfgs):
            for dataset in datasets:
                output_paths.append(
                    get_infer_output_path(
                        model, dataset,
                        os.path.join(self.work_dir, self.output_subdir),
                        file_extension))
        return output_paths