mColBERT / colbert /evaluation /load_model.py
vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame
960 Bytes
import os
import ujson
import torch
import random
from collections import defaultdict, OrderedDict
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message, load_checkpoint
def load_model(args, do_print=True):
colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
query_maxlen=args.query_maxlen,
doc_maxlen=args.doc_maxlen,
dim=args.dim,
similarity_metric=args.similarity,
mask_punctuation=args.mask_punctuation)
colbert = colbert.to(DEVICE)
print_message("#> Loading model checkpoint.", condition=do_print)
checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
colbert.eval()
return colbert, checkpoint