File size: 960 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import os
import ujson
import torch
import random
from collections import defaultdict, OrderedDict
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message, load_checkpoint
def load_model(args, do_print=True):
colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
query_maxlen=args.query_maxlen,
doc_maxlen=args.doc_maxlen,
dim=args.dim,
similarity_metric=args.similarity,
mask_punctuation=args.mask_punctuation)
colbert = colbert.to(DEVICE)
print_message("#> Loading model checkpoint.", condition=do_print)
checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
colbert.eval()
return colbert, checkpoint
|