mohdelgaar commited on
Commit
55d47b9
1 Parent(s): 59dd739

Add timeline tool

Browse files
Files changed (1) hide show
  1. app.py +310 -212
app.py CHANGED
@@ -2,63 +2,69 @@ import re
2
  import argparse
3
  import torch
4
  import gradio as gr
 
 
 
5
  from data import load_tokenizer
6
  from model import load_model
7
  from datetime import datetime
8
  from dateutil import parser
9
  from demo_assets import *
 
10
 
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument('--data_dir', default='/data/mohamed/data')
13
- parser.add_argument('--aim_repo', default='/data/mohamed/')
14
- parser.add_argument('--ckpt', default='electra-base.pt')
15
- parser.add_argument('--aim_exp', default='mimic-decisions-1215')
16
- parser.add_argument('--label_encoding', default='multiclass')
17
- parser.add_argument('--multiclass', action='store_true')
18
- parser.add_argument('--debug', action='store_true')
19
- parser.add_argument('--save_losses', action='store_true')
20
- parser.add_argument('--task', default='token', choices=['seq', 'token'])
21
- parser.add_argument('--max_len', type=int, default=512)
22
- parser.add_argument('--num_layers', type=int, default=3)
23
- parser.add_argument('--kernels', nargs=3, type=int, default=[1,2,3])
24
- parser.add_argument('--model', default='roberta-base',)
25
- parser.add_argument('--model_name', default='google/electra-base-discriminator',)
26
- parser.add_argument('--gpu', default='0')
27
- parser.add_argument('--grad_accumulation', default=2, type=int)
28
- parser.add_argument('--pheno_id', type=int)
29
- parser.add_argument('--unseen_pheno', type=int)
30
- parser.add_argument('--text_subset')
31
- parser.add_argument('--pheno_n', type=int, default=500)
32
- parser.add_argument('--hidden_size', type=int, default=100)
33
- parser.add_argument('--emb_size', type=int, default=400)
34
- parser.add_argument('--total_steps', type=int, default=5000)
35
- parser.add_argument('--train_log', type=int, default=500)
36
- parser.add_argument('--val_log', type=int, default=1000)
37
- parser.add_argument('--seed', default = '0')
38
- parser.add_argument('--num_phenos', type=int, default=10)
39
- parser.add_argument('--num_decs', type=int, default=9)
40
- parser.add_argument('--num_umls_tags', type=int, default=33)
41
- parser.add_argument('--batch_size', type=int, default=8)
42
- parser.add_argument('--pos_weight', type=float, default=1.25)
43
- parser.add_argument('--alpha_distil', type=float, default=1)
44
- parser.add_argument('--distil', action='store_true')
45
- parser.add_argument('--distil_att', action='store_true')
46
- parser.add_argument('--distil_ckpt')
47
- parser.add_argument('--use_umls', action='store_true')
48
- parser.add_argument('--include_nolabel', action='store_true')
49
- parser.add_argument('--truncate_train', action='store_true')
50
- parser.add_argument('--truncate_eval', action='store_true')
51
- parser.add_argument('--load_ckpt', action='store_true')
52
- parser.add_argument('--gradio', action='store_true')
53
- parser.add_argument('--optuna', action='store_true')
54
- parser.add_argument('--mimic_data', action='store_true')
55
- parser.add_argument('--eval_only', action='store_true')
56
- parser.add_argument('--lr', type=float, default=4e-5)
57
- parser.add_argument('--resample', default='')
58
- parser.add_argument('--verbose', type=bool, default=True)
59
- parser.add_argument('--use_crf', type=bool)
60
- parser.add_argument('--print_spans', action='store_true')
61
- args = parser.parse_args()
 
 
62
 
63
  if args.task == 'seq' and args.pheno_id is not None:
64
  args.num_labels = 1
@@ -150,170 +156,262 @@ def extract_date(text):
150
 
151
 
152
 
153
- def run_gradio(model, tokenizer):
154
- def predict(text):
155
- encoding = tokenizer.encode_plus(text)
156
- x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
157
- mask = torch.ones_like(x)
158
- output = model.generate(x, mask)[0]
159
- return output, encoding.token_to_chars
160
 
161
- def process(text):
162
- if text is not None:
163
- output, t2c = predict(text)
164
- tags = postprocess_labels(text, output, t2c)
165
- with open('log.csv', 'a') as f:
166
- f.write(f'{datetime.now()},{text}\n')
167
- return list(zip(text, tags))
168
- else:
169
- return text
170
-
171
- def process_sum(*inputs):
172
- global sum_c
173
- dates = {}
174
- for i in range(sum_c):
175
- text = inputs[i]
176
- output, t2c = predict(text)
177
- spans = indicators_to_spans(output.argmax(-1), t2c)
178
- date = extract_date(text)
179
- present_decs = set(cat for cat, _, _ in spans)
180
- decs = {k: [] for k in sorted(present_decs)}
181
- for c, s, e in spans:
182
- decs[c].append(text[s:e])
183
- dates[date] = decs
184
-
185
- out = ""
186
- for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
187
- out += f'## **[{date}]**\n\n'
188
- decs = dates[date]
189
- for c in decs:
190
- out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
191
- for dec in decs[c]:
192
- out += f'{dec}\n\n'
193
-
194
- return out
195
 
 
196
  global sum_c
197
- sum_c = 1
198
- SUM_INPUTS = 20
199
- def update_inputs(inputs):
200
- outputs = []
201
- if inputs is None:
202
- c = 0
203
- else:
204
- inputs = [open(f.name).read() for f in inputs]
205
- for i, text in enumerate(inputs):
206
- outputs.append(gr.update(value=text, visible=True))
207
- c = len(inputs)
208
-
209
- n = SUM_INPUTS
210
- for i in range(n - c):
211
- outputs.append(gr.update(value='', visible=False))
212
- global sum_c; sum_c = c
213
- return outputs
214
-
215
- def add_ex(*inputs):
216
- global sum_c
217
- new_idx = sum_c
218
- if new_idx < SUM_INPUTS:
219
- out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
220
- sum_c += 1
221
- else:
222
- out = inputs
223
- return out
224
-
225
- def sub_ex(*inputs):
226
- global sum_c
227
- new_idx = sum_c - 1
228
- if new_idx > 0:
229
- out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
230
- sum_c -= 1
231
- else:
232
- out = inputs
233
- return out
234
-
235
-
236
- device = model.backbone.device
237
- # colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
238
- # 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
239
- colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
240
-
241
- color_map = {cat: colors[i] for i,cat in enumerate(categories)}
242
-
243
- det_desc = ['Admit, discharge, follow-up, referral',
244
- 'Ordering test, consulting colleague, seeking external information',
245
- 'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
246
- 'Quantitative or qualitative',
247
- 'Start, stop, alter, maintain, refrain',
248
- 'Start, stop, alter, maintain, refrain',
249
- 'Positive, negative, ambiguous test results',
250
- 'Transfer responsibility, wait and see, change subject',
251
- 'Advice or precaution',
252
- 'Sick leave, drug refund, insurance, disability']
253
-
254
- desc = '### Zones (categories)\n'
255
- desc += '| | |\n| --- | --- |\n'
256
- for i,cat in enumerate(categories):
257
- desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
258
-
259
- #colors
260
- #markdown labels
261
- #legend and desc
262
- #css font-size
263
- css = '.category-legend {border:1px dashed black;}'\
264
- '.text-sm {font-size: 1.5rem; line-height: 200%;}'\
265
- '.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
266
- '.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
267
- '.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
268
- 'line-height: 1.6;}'\
269
- '#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
270
- title='Clinical Decision Zoning'
271
- with gr.Blocks(title=title, css=css) as demo:
272
- gr.Markdown(f'# {title}')
273
- with gr.Tab("Label a Clinical Note"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  with gr.Row():
275
- with gr.Column():
276
- gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
277
- text_input = gr.Textbox(
278
- # value=examples[0],
279
- label="",
280
- placeholder="Enter text here...")
281
- text_btn = gr.Button('Run')
282
- with gr.Column():
283
- gr.Markdown("## Labeled Summary or Note"),
284
- text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
285
- gr.Examples(text_examples, inputs=text_input)
286
- with gr.Tab("Summarize Patient History"):
287
  with gr.Row():
288
- with gr.Column():
289
- sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
290
- sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
291
- for i in range(2, SUM_INPUTS + 1)])
292
- sum_btn = gr.Button('Run')
293
- with gr.Row():
294
- ex_add = gr.Button("+")
295
- ex_sub = gr.Button("-")
296
- upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
297
- gr.Examples(sum_examples, inputs=upload,
298
- fn = update_inputs, outputs=sum_inputs, run_on_click=True)
299
- with gr.Column():
300
- gr.Markdown("## Summarized Clinical Decision History")
301
- sum_out = gr.Markdown(elem_id='sum-out')
302
- gr.Markdown(desc)
303
-
304
- # Functions
305
- text_input.submit(process, inputs=text_input, outputs=text_out)
306
- text_btn.click(process, inputs=text_input, outputs=text_out)
307
- upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
308
- ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
309
- ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
310
- sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
311
- # demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
312
- demo.launch(share=False)
313
 
314
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
315
- tokenizer = load_tokenizer(args.model_name)
316
- model = load_model(args, device)[0]
317
- model.eval()
318
- torch.set_grad_enabled(False)
319
- run_gradio(model, tokenizer)
 
 
 
 
 
 
 
 
2
  import argparse
3
  import torch
4
  import gradio as gr
5
+ import pandas as pd
6
+ import plotly.express as px
7
+ import numpy as np
8
  from data import load_tokenizer
9
  from model import load_model
10
  from datetime import datetime
11
  from dateutil import parser
12
  from demo_assets import *
13
+ from typing import List, Dict, Any
14
 
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--data_dir', default='/data/mohamed/data')
18
+ parser.add_argument('--aim_repo', default='/data/mohamed/')
19
+ parser.add_argument('--ckpt', default='electra-base.pt')
20
+ parser.add_argument('--aim_exp', default='mimic-decisions-1215')
21
+ parser.add_argument('--label_encoding', default='multiclass')
22
+ parser.add_argument('--multiclass', action='store_true')
23
+ parser.add_argument('--debug', action='store_true')
24
+ parser.add_argument('--save_losses', action='store_true')
25
+ parser.add_argument('--task', default='token', choices=['seq', 'token'])
26
+ parser.add_argument('--max_len', type=int, default=512)
27
+ parser.add_argument('--num_layers', type=int, default=3)
28
+ parser.add_argument('--kernels', nargs=3, type=int, default=[1,2,3])
29
+ parser.add_argument('--model', default='roberta-base',)
30
+ parser.add_argument('--model_name', default='google/electra-base-discriminator',)
31
+ parser.add_argument('--gpu', default='0')
32
+ parser.add_argument('--grad_accumulation', default=2, type=int)
33
+ parser.add_argument('--pheno_id', type=int)
34
+ parser.add_argument('--unseen_pheno', type=int)
35
+ parser.add_argument('--text_subset')
36
+ parser.add_argument('--pheno_n', type=int, default=500)
37
+ parser.add_argument('--hidden_size', type=int, default=100)
38
+ parser.add_argument('--emb_size', type=int, default=400)
39
+ parser.add_argument('--total_steps', type=int, default=5000)
40
+ parser.add_argument('--train_log', type=int, default=500)
41
+ parser.add_argument('--val_log', type=int, default=1000)
42
+ parser.add_argument('--seed', default = '0')
43
+ parser.add_argument('--num_phenos', type=int, default=10)
44
+ parser.add_argument('--num_decs', type=int, default=9)
45
+ parser.add_argument('--num_umls_tags', type=int, default=33)
46
+ parser.add_argument('--batch_size', type=int, default=8)
47
+ parser.add_argument('--pos_weight', type=float, default=1.25)
48
+ parser.add_argument('--alpha_distil', type=float, default=1)
49
+ parser.add_argument('--distil', action='store_true')
50
+ parser.add_argument('--distil_att', action='store_true')
51
+ parser.add_argument('--distil_ckpt')
52
+ parser.add_argument('--use_umls', action='store_true')
53
+ parser.add_argument('--include_nolabel', action='store_true')
54
+ parser.add_argument('--truncate_train', action='store_true')
55
+ parser.add_argument('--truncate_eval', action='store_true')
56
+ parser.add_argument('--load_ckpt', action='store_true')
57
+ parser.add_argument('--gradio', action='store_true')
58
+ parser.add_argument('--optuna', action='store_true')
59
+ parser.add_argument('--mimic_data', action='store_true')
60
+ parser.add_argument('--eval_only', action='store_true')
61
+ parser.add_argument('--lr', type=float, default=4e-5)
62
+ parser.add_argument('--resample', default='')
63
+ parser.add_argument('--verbose', type=bool, default=True)
64
+ parser.add_argument('--use_crf', type=bool)
65
+ parser.add_argument('--print_spans', action='store_true')
66
+ return parser.parse_args()
67
+ args = get_args()
68
 
69
  if args.task == 'seq' and args.pheno_id is not None:
70
  args.num_labels = 1
 
156
 
157
 
158
 
159
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160
+ tokenizer = load_tokenizer(args.model_name)
161
+ model = load_model(args, device)[0]
162
+ model.eval()
163
+ torch.set_grad_enabled(False)
 
 
164
 
165
+ def predict(text):
166
+ encoding = tokenizer.encode_plus(text)
167
+ x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
168
+ mask = torch.ones_like(x)
169
+ output = model.generate(x, mask)[0]
170
+ return output, encoding.token_to_chars
171
+
172
+ def process(text):
173
+ if text is not None:
174
+ output, t2c = predict(text)
175
+ tags = postprocess_labels(text, output, t2c)
176
+ with open('log.csv', 'a') as f:
177
+ f.write(f'{datetime.now()},{text}\n')
178
+ return list(zip(text, tags))
179
+ else:
180
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ def process_sum(*inputs):
183
  global sum_c
184
+ dates = {}
185
+ for i in range(sum_c):
186
+ text = inputs[i]
187
+ output, t2c = predict(text)
188
+ spans = indicators_to_spans(output.argmax(-1), t2c)
189
+ date = extract_date(text)
190
+ present_decs = set(cat for cat, _, _ in spans)
191
+ decs = {k: [] for k in sorted(present_decs)}
192
+ for c, s, e in spans:
193
+ decs[c].append(text[s:e])
194
+ dates[date] = decs
195
+
196
+ out = ""
197
+ for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
198
+ out += f'## **[{date}]**\n\n'
199
+ decs = dates[date]
200
+ for c in decs:
201
+ out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
202
+ for dec in decs[c]:
203
+ out += f'{dec}\n\n'
204
+
205
+ return out
206
+
207
+
208
+ def get_structured_data(*inputs):
209
+ global sum_c
210
+ data = []
211
+ for i in range(sum_c):
212
+ text = inputs[i]
213
+ output, t2c = predict(text)
214
+ spans = indicators_to_spans(output.argmax(-1), t2c)
215
+ date = extract_date(text)
216
+ for c, s, e in spans:
217
+ data.append({
218
+ 'date': date,
219
+ 'timestamp': parser.parse(date),
220
+ 'decision_type': categories[c], 'details': text[s:e]})
221
+ return data
222
+
223
+ def update_inputs(inputs):
224
+ outputs = []
225
+ if inputs is None:
226
+ c = 0
227
+ else:
228
+ inputs = [open(f.name).read() for f in inputs]
229
+ for i, text in enumerate(inputs):
230
+ outputs.append(gr.update(value=text, visible=True))
231
+ c = len(inputs)
232
+
233
+ n = SUM_INPUTS
234
+ for i in range(n - c):
235
+ outputs.append(gr.update(value='', visible=False))
236
+ global sum_c; sum_c = c
237
+ global structured_data
238
+ structured_data = get_structured_data(*inputs) if inputs is not None else []
239
+ return outputs
240
+
241
+ def add_ex(*inputs):
242
+ global sum_c
243
+ new_idx = sum_c
244
+ if new_idx < SUM_INPUTS:
245
+ out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
246
+ sum_c += 1
247
+ else:
248
+ out = inputs
249
+ return out
250
+
251
+ def sub_ex(*inputs):
252
+ global sum_c
253
+ new_idx = sum_c - 1
254
+ if new_idx > 0:
255
+ out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
256
+ sum_c -= 1
257
+ else:
258
+ out = inputs
259
+ return out
260
+
261
+
262
+ def create_timeline_plot(data: List[Dict[str, Any]]):
263
+ df = pd.DataFrame(data)
264
+ # df['int_cat'] = pd.factorize(df['decision_type'])[0]
265
+ # df['int_cat_jittered'] = df['int_cat'] + np.random.uniform(-0.1, 0.1, size=len(df))
266
+ # fig = px.scatter(df, x='date', y='int_cat_jittered', color='decision_type', hover_data=['details'],
267
+ # title='Patient Timeline')
268
+ # fig.update_layout(
269
+ # yaxis=dict(
270
+ # tickmode='array',
271
+ # tickvals=df['int_cat'].unique(),
272
+ # ticktext=df['decision_type'].unique()),
273
+ # xaxis_title='Date',
274
+ # yaxis_title='Category')
275
+ fig = px.strip(df, x='date', y='decision_type', color='decision_type', hover_data=['details'],
276
+ stripmode = "overlay",
277
+ title='Patient Timeline')
278
+ fig.update_traces(jitter=1.0, marker=dict(size=10, opacity=0.6))
279
+ fig.update_layout(height=600)
280
+ return fig
281
+
282
+ def filter_timeline(decision_type: str, start_date: str, end_date: str) -> px.scatter:
283
+ global structured_data
284
+ filtered_data = structured_data
285
+ if 'All' not in decision_types:
286
+ filtered_data = [event for event in filtered_data if event['decision_type'] in decision_types]
287
+
288
+ start = parser.parse(start_date)
289
+ end = parser.parse(end_date)
290
+ filtered_data = [event for event in filtered_data if start <= event['timestamp'] <= end]
291
+
292
+ return create_timeline_plot(filtered_data)
293
+
294
+ def generate_summary(*inputs) -> str:
295
+ global structured_data
296
+ structured_data = get_structured_data(*inputs)
297
+ decision_types = {}
298
+ for event in structured_data:
299
+ decision_type = event['decision_type']
300
+ decision_types[decision_type] = decision_types.get(decision_type, 0) + 1
301
+
302
+ summary = "Decision Type Summary:\n"
303
+ for decision_type, count in decision_types.items():
304
+ summary += f"{decision_type}: {count}\n"
305
+ return summary, create_timeline_plot(structured_data)
306
+
307
+ global sum_c
308
+ sum_c = 1
309
+ SUM_INPUTS = 20
310
+ structured_data = []
311
+
312
+ device = model.backbone.device
313
+ # colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
314
+ # 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
315
+ colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
316
+
317
+ color_map = {cat: colors[i] for i,cat in enumerate(categories)}
318
+
319
+ det_desc = ['Admit, discharge, follow-up, referral',
320
+ 'Ordering test, consulting colleague, seeking external information',
321
+ 'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
322
+ 'Quantitative or qualitative',
323
+ 'Start, stop, alter, maintain, refrain',
324
+ 'Start, stop, alter, maintain, refrain',
325
+ 'Positive, negative, ambiguous test results',
326
+ 'Transfer responsibility, wait and see, change subject',
327
+ 'Advice or precaution',
328
+ 'Sick leave, drug refund, insurance, disability']
329
+
330
+ desc = '### Zones (categories)\n'
331
+ desc += '| | |\n| --- | --- |\n'
332
+ for i,cat in enumerate(categories):
333
+ desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
334
+
335
+ #colors
336
+ #markdown labels
337
+ #legend and desc
338
+ #css font-size
339
+ css = '.category-legend {border:1px dashed black;}'\
340
+ '.text-sm {font-size: 1.5rem; line-height: 200%;}'\
341
+ '.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
342
+ '.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
343
+ '.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
344
+ 'line-height: 1.6;}'\
345
+ '#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
346
+ title='Clinical Decision Zoning'
347
+ with gr.Blocks(title=title, css=css) as demo:
348
+ gr.Markdown(f'# {title}')
349
+ with gr.Tab("Label a Clinical Note"):
350
+ with gr.Row():
351
+ with gr.Column():
352
+ gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
353
+ text_input = gr.Textbox(
354
+ # value=examples[0],
355
+ label="",
356
+ placeholder="Enter text here...")
357
+ text_btn = gr.Button('Run')
358
+ with gr.Column():
359
+ gr.Markdown("## Labeled Summary or Note"),
360
+ text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
361
+ gr.Examples(text_examples, inputs=text_input)
362
+ with gr.Tab("Summarize Patient History"):
363
+ with gr.Row():
364
+ with gr.Column():
365
+ sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
366
+ sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
367
+ for i in range(2, SUM_INPUTS + 1)])
368
+ sum_btn = gr.Button('Run')
369
+ with gr.Row():
370
+ ex_add = gr.Button("+")
371
+ ex_sub = gr.Button("-")
372
+ upload = gr.File(label='Upload clinical notes', file_types=['text'], file_count='multiple')
373
+ gr.Examples(sum_examples, inputs=upload,
374
+ fn = update_inputs, outputs=sum_inputs, run_on_click=True)
375
+ with gr.Column():
376
+ gr.Markdown("## Summarized Clinical Decision History")
377
+ sum_out = gr.Markdown(elem_id='sum-out')
378
+ with gr.Tab("Timeline Visualization Tool"):
379
+ with gr.Column():
380
+ sum_inputs2 = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
381
+ sum_inputs2.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
382
+ for i in range(2, SUM_INPUTS + 1)])
383
  with gr.Row():
384
+ ex_add2 = gr.Button("+")
385
+ ex_sub2 = gr.Button("-")
386
+ upload2 = gr.File(label='Upload clinical notes', file_types=['text'], file_count='multiple')
387
+ gr.Examples(sum_examples, inputs=upload2,
388
+ fn = update_inputs, outputs=sum_inputs2, run_on_click=True)
389
+ with gr.Column():
 
 
 
 
 
 
390
  with gr.Row():
391
+ decision_type = gr.Dropdown(["All"] + categories,
392
+ multiselect=True,
393
+ label="Decision Type", value="All")
394
+ start_date = gr.Textbox(label="Start Date (MM/DD/YYYY)", value="01/01/2006")
395
+ end_date = gr.Textbox(label="End Date (MM/DD/YYYY)", value="12/31/2024")
396
+
397
+ filter_button = gr.Button("Filter Timeline")
398
+
399
+ timeline_plot = gr.Plot()
400
+
401
+ summary_button = gr.Button("Generate Summary")
402
+ summary_output = gr.Textbox(label="Summary")
403
+ gr.Markdown(desc)
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
+ # Functions
406
+ text_input.submit(process, inputs=text_input, outputs=text_out)
407
+ text_btn.click(process, inputs=text_input, outputs=text_out)
408
+ upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
409
+ upload2.change(update_inputs, inputs=upload2, outputs=sum_inputs2)
410
+ ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
411
+ ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
412
+ ex_add2.click(add_ex, inputs=sum_inputs2, outputs=sum_inputs2)
413
+ ex_sub2.click(sub_ex, inputs=sum_inputs2, outputs=sum_inputs2)
414
+ sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
415
+ filter_button.click(filter_timeline, inputs=[decision_type, start_date, end_date], outputs=timeline_plot)
416
+ summary_button.click(generate_summary, inputs=sum_inputs2, outputs=[summary_output, timeline_plot])
417
+ demo.launch(share=True)