Spaces:
Running
Running
mohdelgaar
commited on
Commit
•
59dd739
1
Parent(s):
c47c7dc
refactor
Browse files
app.py
CHANGED
@@ -1,8 +1,12 @@
|
|
|
|
1 |
import argparse
|
2 |
import torch
|
|
|
3 |
from data import load_tokenizer
|
4 |
from model import load_model
|
5 |
-
from
|
|
|
|
|
6 |
|
7 |
parser = argparse.ArgumentParser()
|
8 |
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
@@ -72,6 +76,240 @@ elif args.task == 'token':
|
|
72 |
elif args.label_encoding == 'boe':
|
73 |
args.num_labels *= 3
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
77 |
tokenizer = load_tokenizer(args.model_name)
|
|
|
1 |
+
import re
|
2 |
import argparse
|
3 |
import torch
|
4 |
+
import gradio as gr
|
5 |
from data import load_tokenizer
|
6 |
from model import load_model
|
7 |
+
from datetime import datetime
|
8 |
+
from dateutil import parser
|
9 |
+
from demo_assets import *
|
10 |
|
11 |
parser = argparse.ArgumentParser()
|
12 |
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
|
|
76 |
elif args.label_encoding == 'boe':
|
77 |
args.num_labels *= 3
|
78 |
|
79 |
+
categories = ['Contact related', 'Gathering additional information', 'Defining problem',
|
80 |
+
'Treatment goal', 'Drug related', 'Therapeutic procedure related', 'Evaluating test result',
|
81 |
+
'Deferment', 'Advice and precaution', 'Legal and insurance related']
|
82 |
+
unicode_symbols = [
|
83 |
+
"\U0001F91D", # Handshake
|
84 |
+
"\U0001F50D", # Magnifying glass
|
85 |
+
"\U0001F9E9", # Puzzle piece
|
86 |
+
"\U0001F3AF", # Target
|
87 |
+
"\U0001F48A", # Pill
|
88 |
+
"\U00002702", # Surgical scissors
|
89 |
+
"\U0001F9EA", # Test tube
|
90 |
+
"\U000023F0", # Alarm clock
|
91 |
+
"\U000026A0", # Warning sign
|
92 |
+
"\U0001F4C4" # Document
|
93 |
+
]
|
94 |
+
|
95 |
+
OTHERS_ID = 18
|
96 |
+
def postprocess_labels(text, logits, t2c):
|
97 |
+
tags = [None for _ in text]
|
98 |
+
labels = logits.argmax(-1)
|
99 |
+
for i,cat in enumerate(labels):
|
100 |
+
if cat != OTHERS_ID:
|
101 |
+
char_ids = t2c(i)
|
102 |
+
if char_ids is None:
|
103 |
+
continue
|
104 |
+
for idx in range(char_ids.start, char_ids.end):
|
105 |
+
if tags[idx] is None and idx < len(text):
|
106 |
+
tags[idx] = categories[cat // 2]
|
107 |
+
for i in range(len(text)-1):
|
108 |
+
if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]):
|
109 |
+
tags[i] = tags[i-1]
|
110 |
+
return tags
|
111 |
+
|
112 |
+
def indicators_to_spans(labels, t2c = None):
|
113 |
+
def add_span(c, start, end):
|
114 |
+
if t2c(start) is None or t2c(end) is None:
|
115 |
+
start, end = -1, -1
|
116 |
+
else:
|
117 |
+
start = t2c(start).start
|
118 |
+
end = t2c(end).end
|
119 |
+
span = (c, start, end)
|
120 |
+
spans.add(span)
|
121 |
+
|
122 |
+
spans = set()
|
123 |
+
num_tokens = len(labels)
|
124 |
+
num_classes = OTHERS_ID // 2
|
125 |
+
start = None
|
126 |
+
cls = None
|
127 |
+
for t in range(num_tokens):
|
128 |
+
if start and labels[t] == cls + 1:
|
129 |
+
continue
|
130 |
+
elif start:
|
131 |
+
add_span(cls // 2, start, t - 1)
|
132 |
+
start = None
|
133 |
+
# if not start and labels[t] in [2*x for x in range(num_classes)]:
|
134 |
+
if not start and labels[t] != OTHERS_ID:
|
135 |
+
start = t
|
136 |
+
cls = int(labels[t]) // 2 * 2
|
137 |
+
return spans
|
138 |
+
|
139 |
+
def extract_date(text):
|
140 |
+
pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})'
|
141 |
+
match = re.search(pattern, text).group(1)
|
142 |
+
start, end = None, None
|
143 |
+
for i, c in enumerate(match):
|
144 |
+
if start is None and c.isnumeric():
|
145 |
+
start = i
|
146 |
+
elif c.isnumeric():
|
147 |
+
end = i + 1
|
148 |
+
match = match[start:end]
|
149 |
+
return match
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
def run_gradio(model, tokenizer):
|
154 |
+
def predict(text):
|
155 |
+
encoding = tokenizer.encode_plus(text)
|
156 |
+
x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
|
157 |
+
mask = torch.ones_like(x)
|
158 |
+
output = model.generate(x, mask)[0]
|
159 |
+
return output, encoding.token_to_chars
|
160 |
+
|
161 |
+
def process(text):
|
162 |
+
if text is not None:
|
163 |
+
output, t2c = predict(text)
|
164 |
+
tags = postprocess_labels(text, output, t2c)
|
165 |
+
with open('log.csv', 'a') as f:
|
166 |
+
f.write(f'{datetime.now()},{text}\n')
|
167 |
+
return list(zip(text, tags))
|
168 |
+
else:
|
169 |
+
return text
|
170 |
+
|
171 |
+
def process_sum(*inputs):
|
172 |
+
global sum_c
|
173 |
+
dates = {}
|
174 |
+
for i in range(sum_c):
|
175 |
+
text = inputs[i]
|
176 |
+
output, t2c = predict(text)
|
177 |
+
spans = indicators_to_spans(output.argmax(-1), t2c)
|
178 |
+
date = extract_date(text)
|
179 |
+
present_decs = set(cat for cat, _, _ in spans)
|
180 |
+
decs = {k: [] for k in sorted(present_decs)}
|
181 |
+
for c, s, e in spans:
|
182 |
+
decs[c].append(text[s:e])
|
183 |
+
dates[date] = decs
|
184 |
+
|
185 |
+
out = ""
|
186 |
+
for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
|
187 |
+
out += f'## **[{date}]**\n\n'
|
188 |
+
decs = dates[date]
|
189 |
+
for c in decs:
|
190 |
+
out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
|
191 |
+
for dec in decs[c]:
|
192 |
+
out += f'{dec}\n\n'
|
193 |
+
|
194 |
+
return out
|
195 |
+
|
196 |
+
global sum_c
|
197 |
+
sum_c = 1
|
198 |
+
SUM_INPUTS = 20
|
199 |
+
def update_inputs(inputs):
|
200 |
+
outputs = []
|
201 |
+
if inputs is None:
|
202 |
+
c = 0
|
203 |
+
else:
|
204 |
+
inputs = [open(f.name).read() for f in inputs]
|
205 |
+
for i, text in enumerate(inputs):
|
206 |
+
outputs.append(gr.update(value=text, visible=True))
|
207 |
+
c = len(inputs)
|
208 |
+
|
209 |
+
n = SUM_INPUTS
|
210 |
+
for i in range(n - c):
|
211 |
+
outputs.append(gr.update(value='', visible=False))
|
212 |
+
global sum_c; sum_c = c
|
213 |
+
return outputs
|
214 |
+
|
215 |
+
def add_ex(*inputs):
|
216 |
+
global sum_c
|
217 |
+
new_idx = sum_c
|
218 |
+
if new_idx < SUM_INPUTS:
|
219 |
+
out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
|
220 |
+
sum_c += 1
|
221 |
+
else:
|
222 |
+
out = inputs
|
223 |
+
return out
|
224 |
+
|
225 |
+
def sub_ex(*inputs):
|
226 |
+
global sum_c
|
227 |
+
new_idx = sum_c - 1
|
228 |
+
if new_idx > 0:
|
229 |
+
out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
|
230 |
+
sum_c -= 1
|
231 |
+
else:
|
232 |
+
out = inputs
|
233 |
+
return out
|
234 |
+
|
235 |
+
|
236 |
+
device = model.backbone.device
|
237 |
+
# colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
|
238 |
+
# 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
|
239 |
+
colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
|
240 |
+
|
241 |
+
color_map = {cat: colors[i] for i,cat in enumerate(categories)}
|
242 |
+
|
243 |
+
det_desc = ['Admit, discharge, follow-up, referral',
|
244 |
+
'Ordering test, consulting colleague, seeking external information',
|
245 |
+
'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
|
246 |
+
'Quantitative or qualitative',
|
247 |
+
'Start, stop, alter, maintain, refrain',
|
248 |
+
'Start, stop, alter, maintain, refrain',
|
249 |
+
'Positive, negative, ambiguous test results',
|
250 |
+
'Transfer responsibility, wait and see, change subject',
|
251 |
+
'Advice or precaution',
|
252 |
+
'Sick leave, drug refund, insurance, disability']
|
253 |
+
|
254 |
+
desc = '### Zones (categories)\n'
|
255 |
+
desc += '| | |\n| --- | --- |\n'
|
256 |
+
for i,cat in enumerate(categories):
|
257 |
+
desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
|
258 |
+
|
259 |
+
#colors
|
260 |
+
#markdown labels
|
261 |
+
#legend and desc
|
262 |
+
#css font-size
|
263 |
+
css = '.category-legend {border:1px dashed black;}'\
|
264 |
+
'.text-sm {font-size: 1.5rem; line-height: 200%;}'\
|
265 |
+
'.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
|
266 |
+
'.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
|
267 |
+
'.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
|
268 |
+
'line-height: 1.6;}'\
|
269 |
+
'#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
|
270 |
+
title='Clinical Decision Zoning'
|
271 |
+
with gr.Blocks(title=title, css=css) as demo:
|
272 |
+
gr.Markdown(f'# {title}')
|
273 |
+
with gr.Tab("Label a Clinical Note"):
|
274 |
+
with gr.Row():
|
275 |
+
with gr.Column():
|
276 |
+
gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
|
277 |
+
text_input = gr.Textbox(
|
278 |
+
# value=examples[0],
|
279 |
+
label="",
|
280 |
+
placeholder="Enter text here...")
|
281 |
+
text_btn = gr.Button('Run')
|
282 |
+
with gr.Column():
|
283 |
+
gr.Markdown("## Labeled Summary or Note"),
|
284 |
+
text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
|
285 |
+
gr.Examples(text_examples, inputs=text_input)
|
286 |
+
with gr.Tab("Summarize Patient History"):
|
287 |
+
with gr.Row():
|
288 |
+
with gr.Column():
|
289 |
+
sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
|
290 |
+
sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
|
291 |
+
for i in range(2, SUM_INPUTS + 1)])
|
292 |
+
sum_btn = gr.Button('Run')
|
293 |
+
with gr.Row():
|
294 |
+
ex_add = gr.Button("+")
|
295 |
+
ex_sub = gr.Button("-")
|
296 |
+
upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
|
297 |
+
gr.Examples(sum_examples, inputs=upload,
|
298 |
+
fn = update_inputs, outputs=sum_inputs, run_on_click=True)
|
299 |
+
with gr.Column():
|
300 |
+
gr.Markdown("## Summarized Clinical Decision History")
|
301 |
+
sum_out = gr.Markdown(elem_id='sum-out')
|
302 |
+
gr.Markdown(desc)
|
303 |
+
|
304 |
+
# Functions
|
305 |
+
text_input.submit(process, inputs=text_input, outputs=text_out)
|
306 |
+
text_btn.click(process, inputs=text_input, outputs=text_out)
|
307 |
+
upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
|
308 |
+
ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
|
309 |
+
ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
|
310 |
+
sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
|
311 |
+
# demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
|
312 |
+
demo.launch(share=False)
|
313 |
|
314 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
315 |
tokenizer = load_tokenizer(args.model_name)
|
demo.py
DELETED
@@ -1,241 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
-
from datetime import datetime
|
4 |
-
from dateutil import parser
|
5 |
-
from demo_assets import *
|
6 |
-
import re
|
7 |
-
|
8 |
-
categories = ['Contact related', 'Gathering additional information', 'Defining problem',
|
9 |
-
'Treatment goal', 'Drug related', 'Therapeutic procedure related', 'Evaluating test result',
|
10 |
-
'Deferment', 'Advice and precaution', 'Legal and insurance related']
|
11 |
-
unicode_symbols = [
|
12 |
-
"\U0001F91D", # Handshake
|
13 |
-
"\U0001F50D", # Magnifying glass
|
14 |
-
"\U0001F9E9", # Puzzle piece
|
15 |
-
"\U0001F3AF", # Target
|
16 |
-
"\U0001F48A", # Pill
|
17 |
-
"\U00002702", # Surgical scissors
|
18 |
-
"\U0001F9EA", # Test tube
|
19 |
-
"\U000023F0", # Alarm clock
|
20 |
-
"\U000026A0", # Warning sign
|
21 |
-
"\U0001F4C4" # Document
|
22 |
-
]
|
23 |
-
|
24 |
-
OTHERS_ID = 18
|
25 |
-
def postprocess_labels(text, logits, t2c):
|
26 |
-
tags = [None for _ in text]
|
27 |
-
labels = logits.argmax(-1)
|
28 |
-
for i,cat in enumerate(labels):
|
29 |
-
if cat != OTHERS_ID:
|
30 |
-
char_ids = t2c(i)
|
31 |
-
if char_ids is None:
|
32 |
-
continue
|
33 |
-
for idx in range(char_ids.start, char_ids.end):
|
34 |
-
if tags[idx] is None and idx < len(text):
|
35 |
-
tags[idx] = categories[cat // 2]
|
36 |
-
for i in range(len(text)-1):
|
37 |
-
if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]):
|
38 |
-
tags[i] = tags[i-1]
|
39 |
-
return tags
|
40 |
-
|
41 |
-
def indicators_to_spans(labels, t2c = None):
|
42 |
-
def add_span(c, start, end):
|
43 |
-
if t2c(start) is None or t2c(end) is None:
|
44 |
-
start, end = -1, -1
|
45 |
-
else:
|
46 |
-
start = t2c(start).start
|
47 |
-
end = t2c(end).end
|
48 |
-
span = (c, start, end)
|
49 |
-
spans.add(span)
|
50 |
-
|
51 |
-
spans = set()
|
52 |
-
num_tokens = len(labels)
|
53 |
-
num_classes = OTHERS_ID // 2
|
54 |
-
start = None
|
55 |
-
cls = None
|
56 |
-
for t in range(num_tokens):
|
57 |
-
if start and labels[t] == cls + 1:
|
58 |
-
continue
|
59 |
-
elif start:
|
60 |
-
add_span(cls // 2, start, t - 1)
|
61 |
-
start = None
|
62 |
-
# if not start and labels[t] in [2*x for x in range(num_classes)]:
|
63 |
-
if not start and labels[t] != OTHERS_ID:
|
64 |
-
start = t
|
65 |
-
cls = int(labels[t]) // 2 * 2
|
66 |
-
return spans
|
67 |
-
|
68 |
-
def extract_date(text):
|
69 |
-
pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})'
|
70 |
-
match = re.search(pattern, text).group(1)
|
71 |
-
start, end = None, None
|
72 |
-
for i, c in enumerate(match):
|
73 |
-
if start is None and c.isnumeric():
|
74 |
-
start = i
|
75 |
-
elif c.isnumeric():
|
76 |
-
end = i + 1
|
77 |
-
match = match[start:end]
|
78 |
-
return match
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
def run_gradio(model, tokenizer):
|
83 |
-
def predict(text):
|
84 |
-
encoding = tokenizer.encode_plus(text)
|
85 |
-
x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
|
86 |
-
mask = torch.ones_like(x)
|
87 |
-
output = model.generate(x, mask)[0]
|
88 |
-
return output, encoding.token_to_chars
|
89 |
-
|
90 |
-
def process(text):
|
91 |
-
if text is not None:
|
92 |
-
output, t2c = predict(text)
|
93 |
-
tags = postprocess_labels(text, output, t2c)
|
94 |
-
with open('log.csv', 'a') as f:
|
95 |
-
f.write(f'{datetime.now()},{text}\n')
|
96 |
-
return list(zip(text, tags))
|
97 |
-
else:
|
98 |
-
return text
|
99 |
-
|
100 |
-
def process_sum(*inputs):
|
101 |
-
global sum_c
|
102 |
-
dates = {}
|
103 |
-
for i in range(sum_c):
|
104 |
-
text = inputs[i]
|
105 |
-
output, t2c = predict(text)
|
106 |
-
spans = indicators_to_spans(output.argmax(-1), t2c)
|
107 |
-
date = extract_date(text)
|
108 |
-
present_decs = set(cat for cat, _, _ in spans)
|
109 |
-
decs = {k: [] for k in sorted(present_decs)}
|
110 |
-
for c, s, e in spans:
|
111 |
-
decs[c].append(text[s:e])
|
112 |
-
dates[date] = decs
|
113 |
-
|
114 |
-
out = ""
|
115 |
-
for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
|
116 |
-
out += f'## **[{date}]**\n\n'
|
117 |
-
decs = dates[date]
|
118 |
-
for c in decs:
|
119 |
-
out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
|
120 |
-
for dec in decs[c]:
|
121 |
-
out += f'{dec}\n\n'
|
122 |
-
|
123 |
-
return out
|
124 |
-
|
125 |
-
global sum_c
|
126 |
-
sum_c = 1
|
127 |
-
SUM_INPUTS = 20
|
128 |
-
def update_inputs(inputs):
|
129 |
-
outputs = []
|
130 |
-
if inputs is None:
|
131 |
-
c = 0
|
132 |
-
else:
|
133 |
-
inputs = [open(f.name).read() for f in inputs]
|
134 |
-
for i, text in enumerate(inputs):
|
135 |
-
outputs.append(gr.update(value=text, visible=True))
|
136 |
-
c = len(inputs)
|
137 |
-
|
138 |
-
n = SUM_INPUTS
|
139 |
-
for i in range(n - c):
|
140 |
-
outputs.append(gr.update(value='', visible=False))
|
141 |
-
global sum_c; sum_c = c
|
142 |
-
return outputs
|
143 |
-
|
144 |
-
def add_ex(*inputs):
|
145 |
-
global sum_c
|
146 |
-
new_idx = sum_c
|
147 |
-
if new_idx < SUM_INPUTS:
|
148 |
-
out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
|
149 |
-
sum_c += 1
|
150 |
-
else:
|
151 |
-
out = inputs
|
152 |
-
return out
|
153 |
-
|
154 |
-
def sub_ex(*inputs):
|
155 |
-
global sum_c
|
156 |
-
new_idx = sum_c - 1
|
157 |
-
if new_idx > 0:
|
158 |
-
out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
|
159 |
-
sum_c -= 1
|
160 |
-
else:
|
161 |
-
out = inputs
|
162 |
-
return out
|
163 |
-
|
164 |
-
|
165 |
-
device = model.backbone.device
|
166 |
-
# colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
|
167 |
-
# 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
|
168 |
-
colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
|
169 |
-
|
170 |
-
color_map = {cat: colors[i] for i,cat in enumerate(categories)}
|
171 |
-
|
172 |
-
det_desc = ['Admit, discharge, follow-up, referral',
|
173 |
-
'Ordering test, consulting colleague, seeking external information',
|
174 |
-
'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
|
175 |
-
'Quantitative or qualitative',
|
176 |
-
'Start, stop, alter, maintain, refrain',
|
177 |
-
'Start, stop, alter, maintain, refrain',
|
178 |
-
'Positive, negative, ambiguous test results',
|
179 |
-
'Transfer responsibility, wait and see, change subject',
|
180 |
-
'Advice or precaution',
|
181 |
-
'Sick leave, drug refund, insurance, disability']
|
182 |
-
|
183 |
-
desc = '### Zones (categories)\n'
|
184 |
-
desc += '| | |\n| --- | --- |\n'
|
185 |
-
for i,cat in enumerate(categories):
|
186 |
-
desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
|
187 |
-
|
188 |
-
#colors
|
189 |
-
#markdown labels
|
190 |
-
#legend and desc
|
191 |
-
#css font-size
|
192 |
-
css = '.category-legend {border:1px dashed black;}'\
|
193 |
-
'.text-sm {font-size: 1.5rem; line-height: 200%;}'\
|
194 |
-
'.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
|
195 |
-
'.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
|
196 |
-
'.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
|
197 |
-
'line-height: 1.6;}'\
|
198 |
-
'#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
|
199 |
-
title='Clinical Decision Zoning'
|
200 |
-
with gr.Blocks(title=title, css=css) as demo:
|
201 |
-
gr.Markdown(f'# {title}')
|
202 |
-
with gr.Tab("Label a Clinical Note"):
|
203 |
-
with gr.Row():
|
204 |
-
with gr.Column():
|
205 |
-
gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
|
206 |
-
text_input = gr.Textbox(
|
207 |
-
# value=examples[0],
|
208 |
-
label="",
|
209 |
-
placeholder="Enter text here...")
|
210 |
-
text_btn = gr.Button('Run')
|
211 |
-
with gr.Column():
|
212 |
-
gr.Markdown("## Labeled Summary or Note"),
|
213 |
-
text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
|
214 |
-
gr.Examples(text_examples, inputs=text_input)
|
215 |
-
with gr.Tab("Summarize Patient History"):
|
216 |
-
with gr.Row():
|
217 |
-
with gr.Column():
|
218 |
-
sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
|
219 |
-
sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
|
220 |
-
for i in range(2, SUM_INPUTS + 1)])
|
221 |
-
sum_btn = gr.Button('Run')
|
222 |
-
with gr.Row():
|
223 |
-
ex_add = gr.Button("+")
|
224 |
-
ex_sub = gr.Button("-")
|
225 |
-
upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
|
226 |
-
gr.Examples(sum_examples, inputs=upload,
|
227 |
-
fn = update_inputs, outputs=sum_inputs, run_on_click=True)
|
228 |
-
with gr.Column():
|
229 |
-
gr.Markdown("## Summarized Clinical Decision History")
|
230 |
-
sum_out = gr.Markdown(elem_id='sum-out')
|
231 |
-
gr.Markdown(desc)
|
232 |
-
|
233 |
-
# Functions
|
234 |
-
text_input.submit(process, inputs=text_input, outputs=text_out)
|
235 |
-
text_btn.click(process, inputs=text_input, outputs=text_out)
|
236 |
-
upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
|
237 |
-
ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
|
238 |
-
ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
|
239 |
-
sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
|
240 |
-
# demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
|
241 |
-
demo.launch(share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|