faceplugin's picture
initial commit
901e379
raw
history blame
1.12 kB
import datetime
import torch
def str2bool(s):
return s.lower() in ('true', '1')
class Timer:
def __init__(self):
self.clock = {}
def start(self, key="default"):
self.clock[key] = datetime.datetime.now()
def end(self, key="default"):
if key not in self.clock:
raise Exception(f"{key} is not in the clock.")
interval = datetime.datetime.now() - self.clock[key]
del self.clock[key]
return interval.total_seconds()
def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path):
torch.save({
'epoch': epoch,
'model': net_state_dict,
'optimizer': optimizer_state_dict,
'best_score': best_score
}, checkpoint_path)
torch.save(net_state_dict, model_path)
def load_checkpoint(checkpoint_path):
return torch.load(checkpoint_path)
def freeze_net_layers(net):
for param in net.parameters():
param.requires_grad = False
def store_labels(path, labels):
with open(path, "w") as f:
f.write("\n".join(labels))