File size: 1,108 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import logging
import torch.distributed as dist
LOG_LEVEL = logging.INFO
SUBPROCESS_LOG_LEVEL = logging.ERROR
LOG_FORMATTER = '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s'
def get_logger(name, level=LOG_LEVEL, log_file=None, file_mode='w'):
formatter = logging.Formatter(LOG_FORMATTER)
logger = logging.getLogger(name)
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
if rank == 0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(formatter)
file_handler.setLevel(level)
logger.addHandler(file_handler)
if rank == 0:
logger.setLevel(level)
else:
logger.setLevel(SUBPROCESS_LOG_LEVEL)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
stream_handler.setLevel(level)
logger.addHandler(stream_handler)
return logger
|