Update inference.py
Browse files- inference.py +2 -2
inference.py
CHANGED
@@ -40,7 +40,7 @@ class DualBertModel: # use this class to encode your document and queries.
|
|
40 |
attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
|
41 |
return input_ids, attention_mask
|
42 |
|
43 |
-
def encode_queries(self, queries, batch_size: int, **kwargs)
|
44 |
"""
|
45 |
Encodes a list of strings (queries) and returns a list of encodings.
|
46 |
"""
|
@@ -54,7 +54,7 @@ class DualBertModel: # use this class to encode your document and queries.
|
|
54 |
return torch.concat(to_return)
|
55 |
|
56 |
|
57 |
-
def encode_corpus(self, corpus, batch_size: int, **kwargs)
|
58 |
"""
|
59 |
Encodes a list of strings (documents) and returns a list of encodings.
|
60 |
"""
|
|
|
40 |
attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
|
41 |
return input_ids, attention_mask
|
42 |
|
43 |
+
def encode_queries(self, queries, batch_size: int, **kwargs):
|
44 |
"""
|
45 |
Encodes a list of strings (queries) and returns a list of encodings.
|
46 |
"""
|
|
|
54 |
return torch.concat(to_return)
|
55 |
|
56 |
|
57 |
+
def encode_corpus(self, corpus, batch_size: int, **kwargs):
|
58 |
"""
|
59 |
Encodes a list of strings (documents) and returns a list of encodings.
|
60 |
"""
|