Spaces:
Runtime error
Runtime error
import os | |
import torch | |
# from colbert.utils.runs import Run | |
from colbert.utils.utils import print_message, save_checkpoint | |
from colbert.parameters import SAVED_CHECKPOINTS | |
from colbert.infra.run import Run | |
def print_progress(scores): | |
positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2) | |
print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg) | |
def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False): | |
# arguments = dict(args) | |
# TODO: Call provenance() on the values that support it?? | |
checkpoints_path = savepath or os.path.join(Run().path_, 'checkpoints') | |
name = None | |
try: | |
save = colbert.save | |
except: | |
save = colbert.module.save | |
if not os.path.exists(checkpoints_path): | |
os.makedirs(checkpoints_path) | |
path_save = None | |
if consumed_all_triples or (batch_idx % 2000 == 0): | |
# name = os.path.join(path, "colbert.dnn") | |
# save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) | |
path_save = os.path.join(checkpoints_path, "colbert") | |
if batch_idx in SAVED_CHECKPOINTS: | |
# name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) | |
# save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) | |
path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}") | |
if path_save: | |
print(f"#> Saving a checkpoint to {path_save} ..") | |
checkpoint = {} | |
checkpoint['batch'] = batch_idx | |
# checkpoint['epoch'] = 0 | |
# checkpoint['model_state_dict'] = model.state_dict() | |
# checkpoint['optimizer_state_dict'] = optimizer.state_dict() | |
# checkpoint['arguments'] = arguments | |
save(path_save) | |
return path_save | |