GraCo / isegm /utils /log.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
2.71 kB
import io
import time
import logging
from datetime import datetime
import numpy as np
from torch.utils.tensorboard import SummaryWriter
LOGGER_NAME = 'root'
LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S'
handler = logging.StreamHandler()
logger = logging.getLogger(LOGGER_NAME)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
def add_logging(logs_path, prefix):
log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log'
stdout_log_path = logs_path / log_name
fh = logging.FileHandler(str(stdout_log_path))
formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s',
datefmt=LOGGER_DATEFMT)
fh.setFormatter(formatter)
logger.addHandler(fh)
class TqdmToLogger(io.StringIO):
logger = None
level = None
buf = ''
def __init__(self, logger, level=None, mininterval=5):
super(TqdmToLogger, self).__init__()
self.logger = logger
self.level = level or logging.INFO
self.mininterval = mininterval
self.last_time = 0
def write(self, buf):
self.buf = buf.strip('\r\n\t ')
def flush(self):
if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval:
self.logger.log(self.level, self.buf)
self.last_time = time.time()
class SummaryWriterAvg(SummaryWriter):
def __init__(self, *args, dump_period=20, **kwargs):
super().__init__(*args, **kwargs)
self._dump_period = dump_period
self._avg_scalars = dict()
def add_scalar(self, tag, value, global_step=None, disable_avg=False):
if disable_avg or isinstance(value, (tuple, list, dict)):
super().add_scalar(tag, np.array(value), global_step=global_step)
else:
if tag not in self._avg_scalars:
self._avg_scalars[tag] = ScalarAccumulator(self._dump_period)
avg_scalar = self._avg_scalars[tag]
avg_scalar.add(value)
if avg_scalar.is_full():
super().add_scalar(tag, avg_scalar.value,
global_step=global_step)
avg_scalar.reset()
class ScalarAccumulator(object):
def __init__(self, period):
self.sum = 0
self.cnt = 0
self.period = period
def add(self, value):
self.sum += value
self.cnt += 1
@property
def value(self):
if self.cnt > 0:
return self.sum / self.cnt
else:
return 0
def reset(self):
self.cnt = 0
self.sum = 0
def is_full(self):
return self.cnt >= self.period
def __len__(self):
return self.cnt