Spaces:
Running
Running
danielhajialigol
commited on
Commit
•
6a4a8e0
1
Parent(s):
b3e501a
added drg and icd external link functionality
Browse files
app.py
CHANGED
@@ -2,10 +2,9 @@ import numpy as np
|
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import torch
|
5 |
-
import random
|
6 |
|
7 |
from model import MimicTransformer
|
8 |
-
from utils import load_rule, get_attribution, get_drg_link, visualize_attn
|
9 |
from transformers import set_seed
|
10 |
|
11 |
set_seed(42)
|
@@ -21,7 +20,7 @@ related_tensor = torch.load('discharge_embeddings.pt')
|
|
21 |
|
22 |
# get model and results
|
23 |
mimic = read_model(model=mimic, path=model_path)
|
24 |
-
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES']
|
25 |
|
26 |
tokenizer = mimic.tokenizer
|
27 |
mimic.eval()
|
@@ -78,9 +77,12 @@ def run(text, related_discharges=False):
|
|
78 |
model_results = get_model_results(text=text)
|
79 |
drg_code = model_results['class']
|
80 |
drg_link = get_drg_link(drg_code=drg_code)
|
|
|
81 |
row = rule_df[rule_df['DRG_CODE'] == drg_code]
|
82 |
drg_description = row['DESCRIPTION'].values[0]
|
83 |
model_results['class_dsc'] = drg_description
|
|
|
|
|
84 |
global related_summaries
|
85 |
# related_summaries = generate_similar_summeries()
|
86 |
related_summaries = find_related_summaries(model_results['logits'])
|
@@ -129,7 +131,8 @@ def prettify_text(nested_list):
|
|
129 |
idx = 1
|
130 |
string = ''
|
131 |
for li in nested_list:
|
132 |
-
|
|
|
133 |
idx += 1
|
134 |
return string
|
135 |
|
|
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import torch
|
|
|
5 |
|
6 |
from model import MimicTransformer
|
7 |
+
from utils import load_rule, get_attribution, get_drg_link, get_icd_annotations, visualize_attn
|
8 |
from transformers import set_seed
|
9 |
|
10 |
set_seed(42)
|
|
|
20 |
|
21 |
# get model and results
|
22 |
mimic = read_model(model=mimic, path=model_path)
|
23 |
+
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
24 |
|
25 |
tokenizer = mimic.tokenizer
|
26 |
mimic.eval()
|
|
|
77 |
model_results = get_model_results(text=text)
|
78 |
drg_code = model_results['class']
|
79 |
drg_link = get_drg_link(drg_code=drg_code)
|
80 |
+
icd_results = get_icd_annotations(text=text)
|
81 |
row = rule_df[rule_df['DRG_CODE'] == drg_code]
|
82 |
drg_description = row['DESCRIPTION'].values[0]
|
83 |
model_results['class_dsc'] = drg_description
|
84 |
+
model_results['drg_link'] = drg_link
|
85 |
+
model_results['icd_results'] = icd_results
|
86 |
global related_summaries
|
87 |
# related_summaries = generate_similar_summeries()
|
88 |
related_summaries = find_related_summaries(model_results['logits'])
|
|
|
131 |
idx = 1
|
132 |
string = ''
|
133 |
for li in nested_list:
|
134 |
+
delimiters = 99 * '='
|
135 |
+
string += f'({idx})\n{li[0]}\n{delimiters}\n'
|
136 |
idx += 1
|
137 |
return string
|
138 |
|
utils.py
CHANGED
@@ -66,7 +66,12 @@ def clean_text(text):
|
|
66 |
return new_text
|
67 |
|
68 |
def get_drg_link(drg_code):
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def prettify(dict_list, k):
|
72 |
li = [di[k] for di in dict_list]
|
@@ -179,7 +184,7 @@ def reconstruct_text(tokenizer, tokens, attn):
|
|
179 |
# final representation of text
|
180 |
final_text = ' '.join(reconstructed_tokens).replace(' .', '.')
|
181 |
final_text = final_text.replace(' ,', ',')
|
182 |
-
|
183 |
return aggregated_attn, reconstructed_tokens
|
184 |
|
185 |
def load_rule(path):
|
@@ -225,7 +230,7 @@ def visualize_attn(model_results):
|
|
225 |
raw_input_ids=tokens,
|
226 |
convergence_score=1
|
227 |
)
|
228 |
-
return visualize_text(viz_record)
|
229 |
|
230 |
|
231 |
def modify_attn_html(attn_html):
|
@@ -233,20 +238,46 @@ def modify_attn_html(attn_html):
|
|
233 |
htmls = [attn_split[0]]
|
234 |
for html in attn_split[1:]:
|
235 |
# wrap around href tag
|
236 |
-
href_html = f'<a href="espn.com" \
|
237 |
<mark{html} \
|
238 |
</a>'
|
239 |
htmls.append(href_html)
|
240 |
return "".join(htmls)
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
# copied out of captum because we need raw html instead of a jupyter widget
|
243 |
-
def visualize_text(datarecord):
|
244 |
dom = ["<table width: 100%>"]
|
245 |
rows = [
|
246 |
"<th style='text-align: left'>Predicted DRG</th>"
|
247 |
"<th style='text-align: left'>Word Importance</th>"
|
|
|
248 |
]
|
249 |
pred_class_html = visualization.format_classname(datarecord.pred_class)
|
|
|
|
|
250 |
word_attn_html = visualization.format_word_importances(
|
251 |
datarecord.raw_input_ids, datarecord.word_attributions
|
252 |
)
|
@@ -257,6 +288,7 @@ def visualize_text(datarecord):
|
|
257 |
"<tr>",
|
258 |
pred_class_html,
|
259 |
word_attn_html,
|
|
|
260 |
"<tr>",
|
261 |
]
|
262 |
)
|
|
|
66 |
return new_text
|
67 |
|
68 |
def get_drg_link(drg_code):
|
69 |
+
drg_code = str(drg_code)
|
70 |
+
if len(drg_code) == 1:
|
71 |
+
drg_code = '00' + drg_code
|
72 |
+
elif len(drg_code) == 2:
|
73 |
+
drg_code = '0' + drg_code
|
74 |
+
return f'https://www.findacode.com/code.php?set=DRG&c={drg_code}'
|
75 |
|
76 |
def prettify(dict_list, k):
|
77 |
li = [di[k] for di in dict_list]
|
|
|
184 |
# final representation of text
|
185 |
final_text = ' '.join(reconstructed_tokens).replace(' .', '.')
|
186 |
final_text = final_text.replace(' ,', ',')
|
187 |
+
# final_text == reconstructed_text
|
188 |
return aggregated_attn, reconstructed_tokens
|
189 |
|
190 |
def load_rule(path):
|
|
|
230 |
raw_input_ids=tokens,
|
231 |
convergence_score=1
|
232 |
)
|
233 |
+
return visualize_text(viz_record, drg_link=model_results['drg_link'], icd_annotations=model_results['icd_results'])
|
234 |
|
235 |
|
236 |
def modify_attn_html(attn_html):
|
|
|
238 |
htmls = [attn_split[0]]
|
239 |
for html in attn_split[1:]:
|
240 |
# wrap around href tag
|
241 |
+
href_html = f'<a href="https://espn.com" \
|
242 |
<mark{html} \
|
243 |
</a>'
|
244 |
htmls.append(href_html)
|
245 |
return "".join(htmls)
|
246 |
|
247 |
+
def modify_code_html(html, link, icd=False):
|
248 |
+
html = html.split('<td>')[1].split('</td>')[0]
|
249 |
+
href_html = f'<td><a href="{link}"{html}</a></td>'
|
250 |
+
if icd:
|
251 |
+
href_html = href_html.replace('<td>', '').replace('</td>', '')
|
252 |
+
return href_html
|
253 |
+
|
254 |
+
def modify_drg_html(html, drg_link):
|
255 |
+
return modify_code_html(html=html, link=drg_link, icd=False)
|
256 |
+
|
257 |
+
def get_icd_html(icd_list):
|
258 |
+
if len(icd_list) == 0:
|
259 |
+
return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
|
260 |
+
final_html = '<td>'
|
261 |
+
for icd_dict in icd_list:
|
262 |
+
text, link = icd_dict['text'], icd_dict['link']
|
263 |
+
tmp_html = visualization.format_classname(classname=text)
|
264 |
+
html = modify_code_html(html=tmp_html, link=link, icd=True)
|
265 |
+
final_html += html
|
266 |
+
return final_html + '</td>'
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
# copied out of captum because we need raw html instead of a jupyter widget
|
271 |
+
def visualize_text(datarecord, drg_link, icd_annotations):
|
272 |
dom = ["<table width: 100%>"]
|
273 |
rows = [
|
274 |
"<th style='text-align: left'>Predicted DRG</th>"
|
275 |
"<th style='text-align: left'>Word Importance</th>"
|
276 |
+
"<th style='text-align: left'>ICD Codes</th>"
|
277 |
]
|
278 |
pred_class_html = visualization.format_classname(datarecord.pred_class)
|
279 |
+
icd_class_html = get_icd_html(icd_annotations)
|
280 |
+
pred_class_html = modify_drg_html(html=pred_class_html, drg_link=drg_link)
|
281 |
word_attn_html = visualization.format_word_importances(
|
282 |
datarecord.raw_input_ids, datarecord.word_attributions
|
283 |
)
|
|
|
288 |
"<tr>",
|
289 |
pred_class_html,
|
290 |
word_attn_html,
|
291 |
+
icd_class_html,
|
292 |
"<tr>",
|
293 |
]
|
294 |
)
|