irow commited on
Commit
d09fda3
1 Parent(s): 4529fd0

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -2
inference.py CHANGED
@@ -40,7 +40,7 @@ class DualBertModel: # use this class to encode your document and queries.
40
  attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
41
  return input_ids, attention_mask
42
 
43
- def encode_queries(self, queries, batch_size: int, **kwargs) -> np.ndarray:
44
  """
45
  Encodes a list of strings (queries) and returns a list of encodings.
46
  """
@@ -54,7 +54,7 @@ class DualBertModel: # use this class to encode your document and queries.
54
  return torch.concat(to_return)
55
 
56
 
57
- def encode_corpus(self, corpus, batch_size: int, **kwargs) -> np.ndarray:
58
  """
59
  Encodes a list of strings (documents) and returns a list of encodings.
60
  """
 
40
  attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
41
  return input_ids, attention_mask
42
 
43
+ def encode_queries(self, queries, batch_size: int, **kwargs):
44
  """
45
  Encodes a list of strings (queries) and returns a list of encodings.
46
  """
 
54
  return torch.concat(to_return)
55
 
56
 
57
+ def encode_corpus(self, corpus, batch_size: int, **kwargs):
58
  """
59
  Encodes a list of strings (documents) and returns a list of encodings.
60
  """