Spaces:
Running
Running
danielhajialigol
commited on
Commit
•
eca4ff8
1
Parent(s):
9901139
fix padding issue
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
from model import MimicTransformer
|
7 |
from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
|
8 |
from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
|
9 |
-
|
10 |
model_path = 'checkpoint_0_9113.bin'
|
11 |
related_tensor = torch.load('discharge_embeddings.pt')
|
12 |
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
@@ -56,6 +56,7 @@ def get_model_results(text):
|
|
56 |
logits = outputs[0][0]
|
57 |
out = logits.detach().cpu()[0]
|
58 |
drg_code = i2d[out.argmax().item()]
|
|
|
59 |
prob = torch.nn.functional.softmax(out).max()
|
60 |
return {
|
61 |
'class': drg_code,
|
|
|
6 |
from model import MimicTransformer
|
7 |
from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
|
8 |
from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
|
9 |
+
|
10 |
model_path = 'checkpoint_0_9113.bin'
|
11 |
related_tensor = torch.load('discharge_embeddings.pt')
|
12 |
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
|
|
56 |
logits = outputs[0][0]
|
57 |
out = logits.detach().cpu()[0]
|
58 |
drg_code = i2d[out.argmax().item()]
|
59 |
+
print(out.topk(5))
|
60 |
prob = torch.nn.functional.softmax(out).max()
|
61 |
return {
|
62 |
'class': drg_code,
|
utils.py
CHANGED
@@ -204,7 +204,9 @@ def tokenize_icds(tokenizer, annotations, token_ids):
|
|
204 |
|
205 |
def get_attribution(text, tokenizer, model_outputs, inputs, k=7):
|
206 |
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
207 |
-
padding_idx =
|
|
|
|
|
208 |
tokens = tokens[:padding_idx][1:-1]
|
209 |
attn = model_outputs[-1][0]
|
210 |
agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn)
|
|
|
204 |
|
205 |
def get_attribution(text, tokenizer, model_outputs, inputs, k=7):
|
206 |
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
207 |
+
padding_idx = 512
|
208 |
+
if '[PAD]' in tokens:
|
209 |
+
padding_idx = tokens.index('[PAD]')
|
210 |
tokens = tokens[:padding_idx][1:-1]
|
211 |
attn = model_outputs[-1][0]
|
212 |
agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn)
|