|
import logging |
|
|
|
import torch.distributed as dist |
|
|
|
LOG_LEVEL = logging.INFO |
|
SUBPROCESS_LOG_LEVEL = logging.ERROR |
|
LOG_FORMATTER = '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s' |
|
|
|
|
|
def get_logger(name, level=LOG_LEVEL, log_file=None, file_mode='w'): |
|
formatter = logging.Formatter(LOG_FORMATTER) |
|
|
|
logger = logging.getLogger(name) |
|
|
|
for handler in logger.root.handlers: |
|
if type(handler) is logging.StreamHandler: |
|
handler.setLevel(logging.ERROR) |
|
|
|
if dist.is_available() and dist.is_initialized(): |
|
rank = dist.get_rank() |
|
else: |
|
rank = 0 |
|
|
|
if rank == 0 and log_file is not None: |
|
file_handler = logging.FileHandler(log_file, file_mode) |
|
file_handler.setFormatter(formatter) |
|
file_handler.setLevel(level) |
|
logger.addHandler(file_handler) |
|
|
|
if rank == 0: |
|
logger.setLevel(level) |
|
else: |
|
logger.setLevel(SUBPROCESS_LOG_LEVEL) |
|
|
|
stream_handler = logging.StreamHandler() |
|
stream_handler.setFormatter(formatter) |
|
stream_handler.setLevel(level) |
|
logger.addHandler(stream_handler) |
|
|
|
return logger |
|
|