danielhajialigol commited on
Commit
adc6c07
1 Parent(s): bc31c45

fixing ICD padding and related summary bank

Browse files
Files changed (2) hide show
  1. app.py +30 -18
  2. utils.py +12 -7
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import numpy as np
2
  import gradio as gr
3
  import pandas as pd
4
  import torch
@@ -27,8 +27,7 @@ mimic.eval()
27
  # disease ner model
28
  pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")
29
 
30
- #
31
-
32
  ex1 = """HEAD CT: Head CT showed no intracranial hemorrhage or mass effect, but old infarction consistent with past medical history."""
33
  ex2 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
34
  ex3 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
@@ -74,19 +73,19 @@ def find_related_summaries(text):
74
  scores = torch.mm(related_tensor, embedding.transpose(1,0))
75
  scores_indices = scores.topk(k=50, dim=0)
76
  indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
77
- summaries = []
78
  score_set = set()
79
  for summary_idx, score in zip(indices, scores):
80
  score = score.item()
81
- if len(summaries) == 5:
82
  break
83
  corresp_summary = all_summaries[summary_idx]
84
  if score in score_set:
85
  continue
86
- summary = f'{round(score,2)}% Similarity Rate for the following Discharge Summary:\n\n{corresp_summary}'
87
- summaries.append([summary])
88
  score_set.add(score)
89
- return summaries
90
 
91
 
92
 
@@ -112,7 +111,8 @@ def run(text, related_discharges=False):
112
  return visualize_attn(model_results=model_results)
113
  return (
114
  visualize_attn(model_results=model_results),
115
- gr.Dataset.update(samples=related_summaries, visible=True, label='Related Discharge Summaries'),
 
116
  gr.ClearButton.update(visible=True),
117
  gr.TextArea.update(visible=True),
118
  gr.Button.update(visible=True),
@@ -149,14 +149,19 @@ def load_example(example_id):
149
  return prettify_text(related_chosen)
150
  # return related_chosen
151
 
 
 
 
 
 
 
152
  def prettify_text(nested_list):
153
- idx = 1
154
  string = ''
155
  for li in nested_list:
 
156
  delimiters = 99 * '='
157
- string += f'({idx})\n{li[0]}\n{delimiters}\n'
158
- idx += 1
159
- return string
160
 
161
  def remove_most_recent():
162
  global related_chosen
@@ -175,7 +180,10 @@ def main():
175
  This interface outlines DRGCoder, an explainable clinical coding for the early prediction of diagnostic-related groups (DRGs). Please note all summaries will be truncated to 512 words if longer.
176
  """)
177
  with gr.Row() as row:
178
- input = gr.Textbox(label="Input Discharge Summary Here", placeholder='sample discharge summary')
 
 
 
179
  with gr.Row() as row:
180
  gr.Examples(examples, [input])
181
  with gr.Row() as row:
@@ -215,18 +223,22 @@ def main():
215
 
216
  # input to related summaries
217
  with gr.Row() as row:
218
- input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
219
  with gr.Row() as row:
220
  rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
221
  sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
222
 
223
  with gr.Row() as row:
224
- related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
225
-
 
 
 
226
  # initial run
227
  btn.click(run, inputs=[input], outputs=[attn_viz, related, attn_clr_btn, input_related, sbm_btn, rmv_related_btn])
228
  # find related summaries
229
- related.click(load_example, inputs=[related], outputs=[input_related])
 
230
  # remove related summaries
231
  rmv_related_btn.click(remove_most_recent, outputs=[input_related])
232
 
 
1
+ import re
2
  import gradio as gr
3
  import pandas as pd
4
  import torch
 
27
  # disease ner model
28
  pipe = pipeline("token-classification", model="alvaroalon2/biobert_diseases_ner")
29
 
30
+ # default DRG summary examples
 
31
  ex1 = """HEAD CT: Head CT showed no intracranial hemorrhage or mass effect, but old infarction consistent with past medical history."""
32
  ex2 = """Radiologic studies also included a chest CT, which confirmed cavitary lesions in the left lung apex consistent with infectious tuberculosis. This also moderate-sized left pleural effusion."""
33
  ex3 = """We have discharged Mrs Smith on regular oral Furosemide (40mg OD) and we have requested an outpatient ultrasound of her renal tract which will be performed in the next few weeks. We will review Mrs Smith in the Cardiology Outpatient Clinic in 6 weeks time."""
 
73
  scores = torch.mm(related_tensor, embedding.transpose(1,0))
74
  scores_indices = scores.topk(k=50, dim=0)
75
  indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
76
+ summary_score_list = []
77
  score_set = set()
78
  for summary_idx, score in zip(indices, scores):
79
  score = score.item()
80
+ if len(summary_score_list) == 5:
81
  break
82
  corresp_summary = all_summaries[summary_idx]
83
  if score in score_set:
84
  continue
85
+ summary_score_list.append(
86
+ [round(score,2), corresp_summary])
87
  score_set.add(score)
88
+ return summary_score_list
89
 
90
 
91
 
 
111
  return visualize_attn(model_results=model_results)
112
  return (
113
  visualize_attn(model_results=model_results),
114
+ # gr.Dataset.update(samples=related_summaries, visible=True, label='Related Discharge Summaries'),
115
+ gr.DataFrame.update(value=related_summaries, visible=True),
116
  gr.ClearButton.update(visible=True),
117
  gr.TextArea.update(visible=True),
118
  gr.Button.update(visible=True),
 
149
  return prettify_text(related_chosen)
150
  # return related_chosen
151
 
152
+ def load_df_example(df, event: gr.SelectData):
153
+ global related_chosen
154
+ discharge_summary = event.value
155
+ related_chosen.append([discharge_summary])
156
+ return prettify_text(related_chosen)
157
+
158
  def prettify_text(nested_list):
 
159
  string = ''
160
  for li in nested_list:
161
+ striped = re.sub(' +', ' ', li[0]).strip()
162
  delimiters = 99 * '='
163
+ string += f'{striped}\n{delimiters}\n'
164
+ return string.strip()
 
165
 
166
  def remove_most_recent():
167
  global related_chosen
 
180
  This interface outlines DRGCoder, an explainable clinical coding for the early prediction of diagnostic-related groups (DRGs). Please note all summaries will be truncated to 512 words if longer.
181
  """)
182
  with gr.Row() as row:
183
+ input = gr.Textbox(
184
+ label="Input Discharge Summary Here", placeholder='sample discharge summary',
185
+ text_align='left', interactive=True
186
+ )
187
  with gr.Row() as row:
188
  gr.Examples(examples, [input])
189
  with gr.Row() as row:
 
223
 
224
  # input to related summaries
225
  with gr.Row() as row:
226
+ input_related = gr.TextArea(label="Input up to 3 Related Discharge Summaries Here", visible=False, text_align='left', min_width=300)
227
  with gr.Row() as row:
228
  rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
229
  sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
230
 
231
  with gr.Row() as row:
232
+ # related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index', headers=['AAAAA', 'BBBB', 'CCCCC', 'DDDDD', 'RRRRR'])
233
+ related = gr.DataFrame(
234
+ value=None, headers=['Similarity Score', 'Related Discharge Summary'], max_rows=5,
235
+ datatype=['number', 'str'], col_count=(2, 'fixed'), visible=False
236
+ )
237
  # initial run
238
  btn.click(run, inputs=[input], outputs=[attn_viz, related, attn_clr_btn, input_related, sbm_btn, rmv_related_btn])
239
  # find related summaries
240
+ # related.click(load_example, inputs=[related], outputs=[input_related])
241
+ related.select(load_df_example, inputs=[related], outputs=[input_related])
242
  # remove related summaries
243
  rmv_related_btn.click(remove_most_recent, outputs=[input_related])
244
 
utils.py CHANGED
@@ -286,27 +286,32 @@ def modify_drg_html(html, drg_link):
286
 
287
  def get_icd_html(icd_list):
288
  if len(icd_list) == 0:
289
- return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
290
  final_html = '<td>'
291
  icd_set = set()
292
- for icd_dict in icd_list:
 
293
  text, link = icd_dict['text'], icd_dict['link']
294
  if text in icd_set:
295
  continue
296
- tmp_html = visualization.format_classname(classname=text)
297
- html = modify_code_html(html=tmp_html, link=link, icd=True)
298
- final_html += html
299
  icd_set.add(text)
 
 
 
 
300
  return final_html + '</td>'
301
 
302
 
303
  def get_disease_html(diseases):
304
  if len(diseases) == 0:
305
- return '<td><text style="padding-right:2em"><b>N/A</b></text></td>'
306
  diseases = list(set(diseases))
307
  diseases_str = ', '.join(diseases)
308
  html = visualization.format_classname(classname=diseases_str)
309
- return html + '</td>'
310
 
311
 
312
 
 
286
 
287
  def get_icd_html(icd_list):
288
  if len(icd_list) == 0:
289
+ return '<td><text style="padding-left:2em"><b>N/A</b></text></td>'
290
  final_html = '<td>'
291
  icd_set = set()
292
+ style="border-style: solid; overflow: visible; min-width: calc(min(0px, 100%)); border-width: var(--block-border-width);"
293
+ for i, icd_dict in enumerate(icd_list):
294
  text, link = icd_dict['text'], icd_dict['link']
295
  if text in icd_set:
296
  continue
297
+ # tmp_html = visualization.format_classname(classname=text)
298
+ # html = modify_code_html(html=tmp_html, link=link, icd=True)
299
+ # style="padding-left:2em; font-weight:bold;"
300
  icd_set.add(text)
301
+ if i+1 < len(icd_list):
302
+ text += ','
303
+ html = f'<a style="{style}" href="{link}">{text}</a><br>'
304
+ final_html += html
305
  return final_html + '</td>'
306
 
307
 
308
  def get_disease_html(diseases):
309
  if len(diseases) == 0:
310
+ return '<td><text style="padding-left:2em"><b>N/A</b></text></td>'
311
  diseases = list(set(diseases))
312
  diseases_str = ', '.join(diseases)
313
  html = visualization.format_classname(classname=diseases_str)
314
+ return html
315
 
316
 
317