File size: 2,453 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
import os

from mmengine.logging import MMLogger

_nameToLevel = {
    'CRITICAL': logging.CRITICAL,
    'FATAL': logging.FATAL,
    'ERROR': logging.ERROR,
    'WARN': logging.WARNING,
    'WARNING': logging.WARNING,
    'INFO': logging.INFO,
    'DEBUG': logging.DEBUG,
    'NOTSET': logging.NOTSET,
}


def get_logger(log_level='INFO', filter_duplicate_level=None) -> MMLogger:
    """Get the logger for OpenCompass.

    Args:
        log_level (str): The log level. Default: 'INFO'. Choices are 'DEBUG',
            'INFO', 'WARNING', 'ERROR', 'CRITICAL'.
    """
    if not MMLogger.check_instance_created('OpenCompass'):
        logger = MMLogger.get_instance('OpenCompass',
                                       logger_name='OpenCompass',
                                       log_level=log_level)
    else:
        logger = MMLogger.get_instance('OpenCompass')

    if filter_duplicate_level is None:
        # export OPENCOMPASS_FILTER_DUPLICATE_LEVEL=error
        # export OPENCOMPASS_FILTER_DUPLICATE_LEVEL=error,warning
        filter_duplicate_level = os.getenv(
            'OPENCOMPASS_FILTER_DUPLICATE_LEVEL', None)

    if filter_duplicate_level:
        logger.addFilter(
            FilterDuplicateMessage('OpenCompass', filter_duplicate_level))

    return logger


class FilterDuplicateMessage(logging.Filter):
    """Filter the repeated message.

    Args:
        name (str): name of the filter.
    """

    def __init__(self, name, filter_duplicate_level):
        super().__init__(name)
        self.seen: set = set()

        if isinstance(filter_duplicate_level, str):
            filter_duplicate_level = filter_duplicate_level.split(',')

        self.filter_duplicate_level = []
        for level in filter_duplicate_level:
            _level = level.strip().upper()
            if _level not in _nameToLevel:
                raise ValueError(f'Invalid log level: {_level}')
            self.filter_duplicate_level.append(_nameToLevel[_level])

    def filter(self, record: logging.LogRecord) -> bool:
        """Filter the repeated error message.

        Args:
            record (LogRecord): The log record.

        Returns:
            bool: Whether to output the log record.
        """
        if record.levelno not in self.filter_duplicate_level:
            return True

        if record.msg not in self.seen:
            self.seen.add(record.msg)
            return True
        return False