danielhajialigol commited on
Commit
bc31c45
1 Parent(s): 542f530

removing redundant examples

Browse files
Files changed (2) hide show
  1. app.py +14 -8
  2. utils.py +1 -1
app.py CHANGED
@@ -8,7 +8,7 @@ from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_ic
8
  from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
  set_seed(42)
10
  model_path = 'checkpoint_0_9113.bin'
11
- related_tensor = torch.nn.functional.normalize(torch.load('discharge_embeddings.pt'))
12
  all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
13
 
14
  similarity_tokenizer = AutoTokenizer.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
@@ -72,13 +72,20 @@ def find_related_summaries(text):
72
  embedding = mean_pooling(outputs, attention_mask=inputs.attention_mask)
73
  embedding = torch.nn.functional.normalize(embedding)
74
  scores = torch.mm(related_tensor, embedding.transpose(1,0))
75
- scores_indices = scores.topk(k=5, dim=0)
76
  indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
77
  summaries = []
 
78
  for summary_idx, score in zip(indices, scores):
 
 
 
79
  corresp_summary = all_summaries[summary_idx]
80
- summary = f'{round(score.item(),2)}% Similarity Rate for the following Discharge Summary:\n\n{corresp_summary}'
 
 
81
  summaries.append([summary])
 
82
  return summaries
83
 
84
 
@@ -208,11 +215,10 @@ def main():
208
 
209
  # input to related summaries
210
  with gr.Row() as row:
211
- with gr.Column(scale=5) as col:
212
- input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
213
- with gr.Column(scale=1) as col:
214
- rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
215
- sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
216
 
217
  with gr.Row() as row:
218
  related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
 
8
  from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
  set_seed(42)
10
  model_path = 'checkpoint_0_9113.bin'
11
+ related_tensor = torch.load('discharge_embeddings.pt')
12
  all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
13
 
14
  similarity_tokenizer = AutoTokenizer.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
 
72
  embedding = mean_pooling(outputs, attention_mask=inputs.attention_mask)
73
  embedding = torch.nn.functional.normalize(embedding)
74
  scores = torch.mm(related_tensor, embedding.transpose(1,0))
75
+ scores_indices = scores.topk(k=50, dim=0)
76
  indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
77
  summaries = []
78
+ score_set = set()
79
  for summary_idx, score in zip(indices, scores):
80
+ score = score.item()
81
+ if len(summaries) == 5:
82
+ break
83
  corresp_summary = all_summaries[summary_idx]
84
+ if score in score_set:
85
+ continue
86
+ summary = f'{round(score,2)}% Similarity Rate for the following Discharge Summary:\n\n{corresp_summary}'
87
  summaries.append([summary])
88
+ score_set.add(score)
89
  return summaries
90
 
91
 
 
215
 
216
  # input to related summaries
217
  with gr.Row() as row:
218
+ input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
219
+ with gr.Row() as row:
220
+ rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
221
+ sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
 
222
 
223
  with gr.Row() as row:
224
  related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
utils.py CHANGED
@@ -317,7 +317,7 @@ def visualize_text(datarecord, drg_link, icd_annotations, diseases):
317
  "<th style='text-align: left'>Predicted DRG</th>"
318
  "<th style='text-align: left'>Word Importance</th>"
319
  "<th style='text-align: left'>Diseases</th>"
320
- "<th style='text-align: left'>ICD Codes</th>"
321
  ]
322
  pred_class_html = visualization.format_classname(datarecord.pred_class)
323
  icd_class_html = get_icd_html(icd_annotations)
 
317
  "<th style='text-align: left'>Predicted DRG</th>"
318
  "<th style='text-align: left'>Word Importance</th>"
319
  "<th style='text-align: left'>Diseases</th>"
320
+ "<th style='text-align: left'>ICD Concepts</th>"
321
  ]
322
  pred_class_html = visualization.format_classname(datarecord.pred_class)
323
  icd_class_html = get_icd_html(icd_annotations)