Shaltiel commited on
Commit
e6626d0
1 Parent(s): 59a553d

Upload BertForPrefixMarking.py

Browse files
Files changed (1) hide show
  1. BertForPrefixMarking.py +1 -0
BertForPrefixMarking.py CHANGED
@@ -123,6 +123,7 @@ class BertForPrefixMarking(BertPreTrainedModel):
123
  def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
124
  # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
125
  inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
 
126
 
127
  # run through bert
128
  logits = self.forward(**inputs, return_dict=True).logits
 
123
  def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
124
  # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
125
  inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
126
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
127
 
128
  # run through bert
129
  logits = self.forward(**inputs, return_dict=True).logits