Upload BertForPrefixMarking.py
Browse files- BertForPrefixMarking.py +1 -0
BertForPrefixMarking.py
CHANGED
@@ -123,6 +123,7 @@ class BertForPrefixMarking(BertPreTrainedModel):
|
|
123 |
def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
|
124 |
# step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
|
125 |
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
|
|
|
126 |
|
127 |
# run through bert
|
128 |
logits = self.forward(**inputs, return_dict=True).logits
|
|
|
123 |
def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
|
124 |
# step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
|
125 |
inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
|
126 |
+
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
127 |
|
128 |
# run through bert
|
129 |
logits = self.forward(**inputs, return_dict=True).logits
|