Update decode_utils.py
Browse files- decode_utils.py +1 -4
decode_utils.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import logging
|
2 |
import math
|
3 |
import re
|
4 |
from typing import (
|
@@ -14,8 +13,6 @@ import torch.nn as nn
|
|
14 |
from tqdm import tqdm
|
15 |
from transformers import PreTrainedTokenizer
|
16 |
|
17 |
-
logger = logging.getLogger(__name__)
|
18 |
-
|
19 |
|
20 |
def get_id_and_prob(spans, offset_map):
|
21 |
prompt_length = 0
|
@@ -445,7 +442,7 @@ class UIEDecoder(nn.Module):
|
|
445 |
}
|
446 |
|
447 |
for k, v in batch.items():
|
448 |
-
batch[k] = torch.
|
449 |
|
450 |
outputs = self(**batch)
|
451 |
start_prob, end_prob = outputs[0], outputs[1]
|
|
|
|
|
1 |
import math
|
2 |
import re
|
3 |
from typing import (
|
|
|
13 |
from tqdm import tqdm
|
14 |
from transformers import PreTrainedTokenizer
|
15 |
|
|
|
|
|
16 |
|
17 |
def get_id_and_prob(spans, offset_map):
|
18 |
prompt_length = 0
|
|
|
442 |
}
|
443 |
|
444 |
for k, v in batch.items():
|
445 |
+
batch[k] = torch.tensor(v, device=self.device)
|
446 |
|
447 |
outputs = self(**batch)
|
448 |
start_prob, end_prob = outputs[0], outputs[1]
|