Spaces:
Running
Running
danielhajialigol
commited on
Commit
•
adc6c07
1
Parent(s):
bc31c45
fixing ICD padding and related summary bank
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import torch
|
@@ -27,8 +27,7 @@ mimic.eval()
|
|
27 |
# disease ner model
|
28 |
pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
ex1 = """HEAD CT: Head CT showed no intracranial hemorrhage or mass effect, but old infarction consistent with past medical history."""
|
33 |
ex2 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
|
34 |
ex3 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
|
@@ -74,19 +73,19 @@ def find_related_summaries(text):
|
|
74 |
scores = torch.mm(related_tensor, embedding.transpose(1,0))
|
75 |
scores_indices = scores.topk(k=50, dim=0)
|
76 |
indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
|
77 |
-
|
78 |
score_set = set()
|
79 |
for summary_idx, score in zip(indices, scores):
|
80 |
score = score.item()
|
81 |
-
if len(
|
82 |
break
|
83 |
corresp_summary = all_summaries[summary_idx]
|
84 |
if score in score_set:
|
85 |
continue
|
86 |
-
|
87 |
-
|
88 |
score_set.add(score)
|
89 |
-
return
|
90 |
|
91 |
|
92 |
|
@@ -112,7 +111,8 @@ def run(text, related_discharges=False):
|
|
112 |
return visualize_attn(model_results=model_results)
|
113 |
return (
|
114 |
visualize_attn(model_results=model_results),
|
115 |
-
gr.Dataset.update(samples=related_summaries, visible=True, label='Related Discharge Summaries'),
|
|
|
116 |
gr.ClearButton.update(visible=True),
|
117 |
gr.TextArea.update(visible=True),
|
118 |
gr.Button.update(visible=True),
|
@@ -149,14 +149,19 @@ def load_example(example_id):
|
|
149 |
return prettify_text(related_chosen)
|
150 |
# return related_chosen
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def prettify_text(nested_list):
|
153 |
-
idx = 1
|
154 |
string = ''
|
155 |
for li in nested_list:
|
|
|
156 |
delimiters = 99 * '='
|
157 |
-
string += f'
|
158 |
-
|
159 |
-
return string
|
160 |
|
161 |
def remove_most_recent():
|
162 |
global related_chosen
|
@@ -175,7 +180,10 @@ def main():
|
|
175 |
This interface outlines DRGCoder, an explainable clinical coding for the early prediction of diagnostic-related groups (DRGs). Please note all summaries will be truncated to 512 words if longer.
|
176 |
""")
|
177 |
with gr.Row() as row:
|
178 |
-
input = gr.Textbox(
|
|
|
|
|
|
|
179 |
with gr.Row() as row:
|
180 |
gr.Examples(examples, [input])
|
181 |
with gr.Row() as row:
|
@@ -215,18 +223,22 @@ def main():
|
|
215 |
|
216 |
# input to related summaries
|
217 |
with gr.Row() as row:
|
218 |
-
input_related = gr.TextArea(label="Input up to 3 Related Discharge
|
219 |
with gr.Row() as row:
|
220 |
rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
|
221 |
sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
|
222 |
|
223 |
with gr.Row() as row:
|
224 |
-
related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
|
225 |
-
|
|
|
|
|
|
|
226 |
# initial run
|
227 |
btn.click(run, inputs=[input], outputs=[attn_viz, related, attn_clr_btn, input_related, sbm_btn, rmv_related_btn])
|
228 |
# find related summaries
|
229 |
-
related.click(load_example, inputs=[related], outputs=[input_related])
|
|
|
230 |
# remove related summaries
|
231 |
rmv_related_btn.click(remove_most_recent, outputs=[input_related])
|
232 |
|
|
|
1 |
+
import re
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import torch
|
|
|
27 |
# disease ner model
|
28 |
pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")
|
29 |
|
30 |
+
# default DRG summary examples
|
|
|
31 |
ex1 = """HEAD CT: Head CT showed no intracranial hemorrhage or mass effect, but old infarction consistent with past medical history."""
|
32 |
ex2 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
|
33 |
ex3 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
|
|
|
73 |
scores = torch.mm(related_tensor, embedding.transpose(1,0))
|
74 |
scores_indices = scores.topk(k=50, dim=0)
|
75 |
indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
|
76 |
+
summary_score_list = []
|
77 |
score_set = set()
|
78 |
for summary_idx, score in zip(indices, scores):
|
79 |
score = score.item()
|
80 |
+
if len(summary_score_list) == 5:
|
81 |
break
|
82 |
corresp_summary = all_summaries[summary_idx]
|
83 |
if score in score_set:
|
84 |
continue
|
85 |
+
summary_score_list.append(
|
86 |
+
[round(score,2), corresp_summary])
|
87 |
score_set.add(score)
|
88 |
+
return summary_score_list
|
89 |
|
90 |
|
91 |
|
|
|
111 |
return visualize_attn(model_results=model_results)
|
112 |
return (
|
113 |
visualize_attn(model_results=model_results),
|
114 |
+
# gr.Dataset.update(samples=related_summaries, visible=True, label='Related Discharge Summaries'),
|
115 |
+
gr.DataFrame.update(value=related_summaries, visible=True),
|
116 |
gr.ClearButton.update(visible=True),
|
117 |
gr.TextArea.update(visible=True),
|
118 |
gr.Button.update(visible=True),
|
|
|
149 |
return prettify_text(related_chosen)
|
150 |
# return related_chosen
|
151 |
|
152 |
+
def load_df_example(df, event: gr.SelectData):
|
153 |
+
global related_chosen
|
154 |
+
discharge_summary = event.value
|
155 |
+
related_chosen.append([discharge_summary])
|
156 |
+
return prettify_text(related_chosen)
|
157 |
+
|
158 |
def prettify_text(nested_list):
|
|
|
159 |
string = ''
|
160 |
for li in nested_list:
|
161 |
+
striped = re.sub(' +', ' ', li[0]).strip()
|
162 |
delimiters = 99 * '='
|
163 |
+
string += f'{striped}\n{delimiters}\n'
|
164 |
+
return string.strip()
|
|
|
165 |
|
166 |
def remove_most_recent():
|
167 |
global related_chosen
|
|
|
180 |
This interface outlines DRGCoder, an explainable clinical coding for the early prediction of diagnostic-related groups (DRGs). Please note all summaries will be truncated to 512 words if longer.
|
181 |
""")
|
182 |
with gr.Row() as row:
|
183 |
+
input = gr.Textbox(
|
184 |
+
label="Input Discharge Summary Here", placeholder='sample discharge summary',
|
185 |
+
text_align='left', interactive=True
|
186 |
+
)
|
187 |
with gr.Row() as row:
|
188 |
gr.Examples(examples, [input])
|
189 |
with gr.Row() as row:
|
|
|
223 |
|
224 |
# input to related summaries
|
225 |
with gr.Row() as row:
|
226 |
+
input_related = gr.TextArea(label="Input up to 3 Related Discharge Summaries Here", visible=False, text_align='left', min_width=300)
|
227 |
with gr.Row() as row:
|
228 |
rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
|
229 |
sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
|
230 |
|
231 |
with gr.Row() as row:
|
232 |
+
# related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index', headers=['AAAAA', 'BBBB', 'CCCCC', 'DDDDD', 'RRRRR'])
|
233 |
+
related = gr.DataFrame(
|
234 |
+
value=None, headers=['Similarity Score', 'Related Discharge Summary'], max_rows=5,
|
235 |
+
datatype=['number', 'str'], col_count=(2, 'fixed'), visible=False
|
236 |
+
)
|
237 |
# initial run
|
238 |
btn.click(run, inputs=[input], outputs=[attn_viz, related, attn_clr_btn, input_related, sbm_btn, rmv_related_btn])
|
239 |
# find related summaries
|
240 |
+
# related.click(load_example, inputs=[related], outputs=[input_related])
|
241 |
+
related.select(load_df_example, inputs=[related], outputs=[input_related])
|
242 |
# remove related summaries
|
243 |
rmv_related_btn.click(remove_most_recent, outputs=[input_related])
|
244 |
|
utils.py
CHANGED
@@ -286,27 +286,32 @@ def modify_drg_html(html, drg_link):
|
|
286 |
|
287 |
def get_icd_html(icd_list):
|
288 |
if len(icd_list) == 0:
|
289 |
-
return '<td><text style="padding-
|
290 |
final_html = '<td>'
|
291 |
icd_set = set()
|
292 |
-
|
|
|
293 |
text, link = icd_dict['text'], icd_dict['link']
|
294 |
if text in icd_set:
|
295 |
continue
|
296 |
-
tmp_html = visualization.format_classname(classname=text)
|
297 |
-
html = modify_code_html(html=tmp_html, link=link, icd=True)
|
298 |
-
|
299 |
icd_set.add(text)
|
|
|
|
|
|
|
|
|
300 |
return final_html + '</td>'
|
301 |
|
302 |
|
303 |
def get_disease_html(diseases):
|
304 |
if len(diseases) == 0:
|
305 |
-
return '<td><text style="padding-
|
306 |
diseases = list(set(diseases))
|
307 |
diseases_str = ', '.join(diseases)
|
308 |
html = visualization.format_classname(classname=diseases_str)
|
309 |
-
return html
|
310 |
|
311 |
|
312 |
|
|
|
286 |
|
287 |
def get_icd_html(icd_list):
|
288 |
if len(icd_list) == 0:
|
289 |
+
return '<td><text style="padding-left:2em"><b>N/A</b></text></td>'
|
290 |
final_html = '<td>'
|
291 |
icd_set = set()
|
292 |
+
style="border-style: solid; overflow: visible; min-width: calc(min(0px, 100%)); border-width: var(--block-border-width);"
|
293 |
+
for i, icd_dict in enumerate(icd_list):
|
294 |
text, link = icd_dict['text'], icd_dict['link']
|
295 |
if text in icd_set:
|
296 |
continue
|
297 |
+
# tmp_html = visualization.format_classname(classname=text)
|
298 |
+
# html = modify_code_html(html=tmp_html, link=link, icd=True)
|
299 |
+
# style="padding-left:2em; font-weight:bold;"
|
300 |
icd_set.add(text)
|
301 |
+
if i+1 < len(icd_list):
|
302 |
+
text += ','
|
303 |
+
html = f'<a style="{style}" href="{link}">{text}</a><br>'
|
304 |
+
final_html += html
|
305 |
return final_html + '</td>'
|
306 |
|
307 |
|
308 |
def get_disease_html(diseases):
|
309 |
if len(diseases) == 0:
|
310 |
+
return '<td><text style="padding-left:2em"><b>N/A</b></text></td>'
|
311 |
diseases = list(set(diseases))
|
312 |
diseases_str = ', '.join(diseases)
|
313 |
html = visualization.format_classname(classname=diseases_str)
|
314 |
+
return html
|
315 |
|
316 |
|
317 |
|