mohdelgaar commited on
Commit
c47c7dc
0 Parent(s):

Initial commit

Browse files
Files changed (13) hide show
  1. .gitattributes +34 -0
  2. README.md +12 -0
  3. app.py +81 -0
  4. data.py +487 -0
  5. demo.py +241 -0
  6. demo_assets.py +20 -0
  7. electra-base.pt +3 -0
  8. examples/note1.txt +17 -0
  9. examples/note2.txt +17 -0
  10. examples/note3.txt +17 -0
  11. examples/note4.txt +17 -0
  12. model.py +206 -0
  13. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Clinical Decisions
3
+ emoji: ⚕️
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.40.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from data import load_tokenizer
4
+ from model import load_model
5
+ from demo import run_gradio
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--data_dir', default='/data/mohamed/data')
9
+ parser.add_argument('--aim_repo', default='/data/mohamed/')
10
+ parser.add_argument('--ckpt', default='electra-base.pt')
11
+ parser.add_argument('--aim_exp', default='mimic-decisions-1215')
12
+ parser.add_argument('--label_encoding', default='multiclass')
13
+ parser.add_argument('--multiclass', action='store_true')
14
+ parser.add_argument('--debug', action='store_true')
15
+ parser.add_argument('--save_losses', action='store_true')
16
+ parser.add_argument('--task', default='token', choices=['seq', 'token'])
17
+ parser.add_argument('--max_len', type=int, default=512)
18
+ parser.add_argument('--num_layers', type=int, default=3)
19
+ parser.add_argument('--kernels', nargs=3, type=int, default=[1,2,3])
20
+ parser.add_argument('--model', default='roberta-base',)
21
+ parser.add_argument('--model_name', default='google/electra-base-discriminator',)
22
+ parser.add_argument('--gpu', default='0')
23
+ parser.add_argument('--grad_accumulation', default=2, type=int)
24
+ parser.add_argument('--pheno_id', type=int)
25
+ parser.add_argument('--unseen_pheno', type=int)
26
+ parser.add_argument('--text_subset')
27
+ parser.add_argument('--pheno_n', type=int, default=500)
28
+ parser.add_argument('--hidden_size', type=int, default=100)
29
+ parser.add_argument('--emb_size', type=int, default=400)
30
+ parser.add_argument('--total_steps', type=int, default=5000)
31
+ parser.add_argument('--train_log', type=int, default=500)
32
+ parser.add_argument('--val_log', type=int, default=1000)
33
+ parser.add_argument('--seed', default = '0')
34
+ parser.add_argument('--num_phenos', type=int, default=10)
35
+ parser.add_argument('--num_decs', type=int, default=9)
36
+ parser.add_argument('--num_umls_tags', type=int, default=33)
37
+ parser.add_argument('--batch_size', type=int, default=8)
38
+ parser.add_argument('--pos_weight', type=float, default=1.25)
39
+ parser.add_argument('--alpha_distil', type=float, default=1)
40
+ parser.add_argument('--distil', action='store_true')
41
+ parser.add_argument('--distil_att', action='store_true')
42
+ parser.add_argument('--distil_ckpt')
43
+ parser.add_argument('--use_umls', action='store_true')
44
+ parser.add_argument('--include_nolabel', action='store_true')
45
+ parser.add_argument('--truncate_train', action='store_true')
46
+ parser.add_argument('--truncate_eval', action='store_true')
47
+ parser.add_argument('--load_ckpt', action='store_true')
48
+ parser.add_argument('--gradio', action='store_true')
49
+ parser.add_argument('--optuna', action='store_true')
50
+ parser.add_argument('--mimic_data', action='store_true')
51
+ parser.add_argument('--eval_only', action='store_true')
52
+ parser.add_argument('--lr', type=float, default=4e-5)
53
+ parser.add_argument('--resample', default='')
54
+ parser.add_argument('--verbose', type=bool, default=True)
55
+ parser.add_argument('--use_crf', type=bool)
56
+ parser.add_argument('--print_spans', action='store_true')
57
+ args = parser.parse_args()
58
+
59
+ if args.task == 'seq' and args.pheno_id is not None:
60
+ args.num_labels = 1
61
+ elif args.task == 'seq':
62
+ args.num_labels = args.num_phenos
63
+ elif args.task == 'token':
64
+ if args.use_umls:
65
+ args.num_labels = args.num_umls_tags
66
+ else:
67
+ args.num_labels = args.num_decs
68
+ if args.label_encoding == 'multiclass':
69
+ args.num_labels = args.num_labels * 2 + 1
70
+ elif args.label_encoding == 'bo':
71
+ args.num_labels *= 2
72
+ elif args.label_encoding == 'boe':
73
+ args.num_labels *= 3
74
+
75
+
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+ tokenizer = load_tokenizer(args.model_name)
78
+ model = load_model(args, device)[0]
79
+ model.eval()
80
+ torch.set_grad_enabled(False)
81
+ run_gradio(model, tokenizer)
data.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from transformers import AutoTokenizer
8
+ from glob import glob
9
+ from collections.abc import Iterable
10
+ from collections import defaultdict
11
+
12
+
13
+ pheno_map = {'alcohol.abuse': 0,
14
+ 'advanced.lung.disease': 1,
15
+ 'advanced.heart.disease': 2,
16
+ 'chronic.pain.fibromyalgia': 3,
17
+ 'other.substance.abuse': 4,
18
+ 'psychiatric.disorders': 5,
19
+ 'obesity': 6,
20
+ 'depression': 7,
21
+ 'advanced.cancer': 8,
22
+ 'chronic.neurological.dystrophies': 9,
23
+ 'none': -1}
24
+ rev_pheno_map = {v: k for k,v in pheno_map.items()}
25
+ valid_cats = range(0,9)
26
+
27
+ umls_cats = ['T114', 'T029', 'T073', 'T058', 'T191', 'T200', 'T048', 'T019', 'T046', 'T023', 'T041', 'T059', 'T184', 'T034', 'T116', 'T039', 'T127', 'T201', 'T129', 'T067', 'T109', 'T197', 'T131', 'T130', 'T126', 'T061', 'T203', 'T047', 'T037', 'T074', 'T031', 'T195', 'T168']
28
+ umls_map = {s: i for i,s in enumerate(umls_cats)}
29
+
30
+ def gen_splits(args, phenos):
31
+ np.random.seed(0)
32
+ if args.task == 'token':
33
+ files = glob(os.path.join(args.data_dir, 'mimic_decisions/data/**/*'))
34
+ if args.use_umls:
35
+ files = ["/".join(x.split('/')[-1:]) for x in files]
36
+ else:
37
+ files = ["/".join(x.split('/')[-2:]) for x in files]
38
+ subjects = np.unique([os.path.basename(x).split('_')[0] for x in files])
39
+ elif phenos is not None:
40
+ subjects = phenos['subject_id'].unique()
41
+ else:
42
+ raise ValueError
43
+
44
+ phenos['phenotype_label'] = phenos['phenotype_label'].apply(lambda x: x.lower())
45
+
46
+ n = len(subjects)
47
+ train_count = int(0.8*n)
48
+ val_count = int(0.9*n) - int(0.8*n)
49
+ test_count = n - int(0.9*n)
50
+
51
+ train, val, test = [], [], []
52
+ np.random.shuffle(subjects)
53
+ subjects = list(subjects)
54
+ pheno_list = set(pheno_map.keys())
55
+ if args.unseen_pheno is not None:
56
+ test_phenos = {rev_pheno_map[args.unseen_pheno]}
57
+ unseen_pheno = rev_pheno_map[args.unseen_pheno]
58
+ train_phenos = pheno_list - test_phenos
59
+ else:
60
+ test_phenos = pheno_list
61
+ train_phenos = pheno_list
62
+ unseen_pheno = 'null'
63
+ while len(subjects) > 0:
64
+ if len(pheno_list) > 0:
65
+ for pheno in pheno_list:
66
+ if len(train) < train_count and pheno in train_phenos:
67
+ el = None
68
+ for i, subj in enumerate(subjects):
69
+ row = phenos[phenos.subject_id == subj]
70
+ if row['phenotype_label'].apply(lambda x: pheno in x and not unseen_pheno in x).any():
71
+ el = subjects.pop(i)
72
+ break
73
+ if el is not None:
74
+ train.append(el)
75
+ elif el is None:
76
+ pheno_list.remove(pheno)
77
+ break
78
+ if len(val) < val_count and (not args.pheno_id or len(val) <= (0.5*val_count)):
79
+ el = None
80
+ for i, subj in enumerate(subjects):
81
+ row = phenos[phenos.subject_id == subj]
82
+ if row['phenotype_label'].apply(lambda x: pheno in x).any():
83
+ el = subjects.pop(i)
84
+ break
85
+ if el is not None:
86
+ val.append(el)
87
+ elif el is None:
88
+ pheno_list.remove(pheno)
89
+ break
90
+ if len(test) < test_count or (args.unseen_pheno is not None and pheno in test_phenos):
91
+ el = None
92
+ for i, subj in enumerate(subjects):
93
+ row = phenos[phenos.subject_id == subj]
94
+ if row['phenotype_label'].apply(lambda x: pheno in x).any():
95
+ el = subjects.pop(i)
96
+ break
97
+ if el is not None:
98
+ test.append(el)
99
+ elif el is None:
100
+ pheno_list.remove(pheno)
101
+ break
102
+ else:
103
+ if len(train) < train_count:
104
+ el = subjects.pop()
105
+ if el is not None:
106
+ train.append(el)
107
+ if len(val) < val_count:
108
+ el = subjects.pop()
109
+ if el is not None:
110
+ val.append(el)
111
+ if len(test) < test_count:
112
+ el = subjects.pop()
113
+ if el is not None:
114
+ test.append(el)
115
+
116
+ if args.task == 'token':
117
+ train = [x for x in files if os.path.basename(x).split('_')[0] in train]
118
+ val = [x for x in files if os.path.basename(x).split('_')[0] in val]
119
+ test = [x for x in files if os.path.basename(x).split('_')[0] in test]
120
+ elif phenos is not None:
121
+ train = phenos[phenos.subject_id.isin(train)]
122
+ val = phenos[phenos.subject_id.isin(val)]
123
+ test = phenos[phenos.subject_id.isin(test)]
124
+ return train, val, test
125
+
126
+ class MyDataset(Dataset):
127
+ def __init__(self, args, tokenizer, data_source, phenos, train = False):
128
+ super().__init__()
129
+ self.tokenizer = tokenizer
130
+ self.data = []
131
+ self.train = train
132
+ self.pheno_ids = defaultdict(list)
133
+ self.dec_ids = {k: [] for k in pheno_map.keys()}
134
+
135
+ if args.task == 'seq':
136
+ for i, row in data_source.iterrows():
137
+ sample = self.load_phenos(args, row, i)
138
+ self.data.append(sample)
139
+ else:
140
+ for i, fn in enumerate(data_source):
141
+ sample = self.load_decisions(args, fn, i, phenos)
142
+ self.data.append(sample)
143
+
144
+ def load_phenos(self, args, row, idx):
145
+ txt_candidates = glob(os.path.join(args.data_dir,
146
+ f'mimic_decisions/raw_text/{row["subject_id"]}_{row["hadm_id"]}*.txt'))
147
+ text = open(txt_candidates[0]).read()
148
+ if args.pheno_n == 500:
149
+
150
+ file_dir = glob(os.path.join(args.data_dir,
151
+ f'mimic_decisions/data/*/{row["subject_id"]}_{row["hadm_id"]}*.json'))[0]
152
+ with open(file_dir) as f:
153
+ data = json.load(f, strict=False)
154
+ annots = data[0]['annotations']
155
+
156
+ if args.text_subset:
157
+ unlabeled_text = np.ones(len(text), dtype=bool)
158
+ labeled_text = np.zeros(len(text), dtype=bool)
159
+ for annot in annots:
160
+ cat = parse_cat(annot['category'])
161
+ start, end = map(int, (annot['start_offset'], annot['end_offset']))
162
+ if cat is not None:
163
+ unlabeled_text[start:end] = 0
164
+ if cat in args.text_subset:
165
+ labeled_text[start:end] = 1
166
+
167
+ combined_text = unlabeled_text | labeled_text if args.include_nolabel else labeled_text
168
+ text = "".join([c for i,c in enumerate(text) if combined_text[i]])
169
+
170
+ encoding = self.tokenizer.encode_plus(text,
171
+ truncation=args.truncate_train if self.train else args.truncate_eval)
172
+
173
+ ids = np.zeros((args.num_decs, len(encoding['input_ids'])))
174
+ for annot in annots:
175
+ start = int(annot['start_offset'])
176
+
177
+ enc_start = encoding.char_to_token(start)
178
+ i = 1
179
+ while enc_start is None:
180
+ enc_start = encoding.char_to_token(start+i)
181
+ i += 1
182
+
183
+ end = int(annot['end_offset'])
184
+ enc_end = encoding.char_to_token(end)
185
+ j = 1
186
+ while enc_end is None:
187
+ enc_end = encoding.char_to_token(end-j)
188
+ j += 1
189
+
190
+ if enc_start is None or enc_end is None:
191
+ raise ValueError
192
+
193
+ cat = parse_cat(annot['category'])
194
+ if not cat or cat not in valid_cats:
195
+ continue
196
+ ids[cat-1, enc_start:enc_end] = 1
197
+ else:
198
+ encoding = self.tokenizer.encode_plus(text,
199
+ truncation=args.truncate_train if self.train else args.truncate_eval)
200
+ ids = None
201
+
202
+ labels = np.zeros(args.num_phenos)
203
+
204
+ if args.pheno_n in (500, 800):
205
+ sample_phenos = row['phenotype_label']
206
+ if sample_phenos != 'none':
207
+ for pheno in sample_phenos.split(','):
208
+ labels[pheno_map[pheno.lower()]] = 1
209
+
210
+ elif args.pheno_n == 1500:
211
+ for k,v in pheno_map.items():
212
+ if row[k] == 1:
213
+ labels[v] = 1
214
+
215
+ if args.pheno_id is not None:
216
+ if args.pheno_id == -1:
217
+ labels = [0.0 if any(labels) else 1.0]
218
+ else:
219
+ labels = [labels[args.pheno_id]]
220
+
221
+ return encoding['input_ids'], labels, ids
222
+
223
+ def load_decisions(self, args, fn, idx, phenos):
224
+ basename = os.path.basename(fn).split("-")[0]
225
+ if args.use_umls:
226
+ file_dir = os.path.join(args.data_dir, 'mimic_decisions/umls', basename)
227
+ else:
228
+ file_dir = os.path.join(args.data_dir, 'mimic_decisions/data', fn)
229
+
230
+ pheno_id = "_".join(basename.split("_")[:3]) + '.txt'
231
+ txt_candidates = glob(os.path.join(args.data_dir,
232
+ f'mimic_decisions/raw_text/{basename}*.txt'))
233
+ text = open(txt_candidates[0]).read()
234
+ encoding = self.tokenizer.encode_plus(text,
235
+ max_length=args.max_len,
236
+ truncation=args.truncate_train if self.train else args.truncate_eval,
237
+ padding = 'max_length',
238
+ )
239
+ if pheno_id in phenos.index:
240
+ sample_phenos = phenos.loc[pheno_id]['phenotype_label']
241
+ for pheno in sample_phenos.split(','):
242
+ self.pheno_ids[pheno].append(idx)
243
+
244
+
245
+ with open(file_dir) as f:
246
+ data = json.load(f, strict=False)
247
+ if args.use_umls:
248
+ annots = data
249
+ else:
250
+ annots = data[0]['annotations']
251
+
252
+ if args.label_encoding == 'multiclass':
253
+ labels = np.full(len(encoding['input_ids']), args.num_labels-1, dtype=int)
254
+ else:
255
+ labels = np.zeros((len(encoding['input_ids']), args.num_labels))
256
+ for annot in annots:
257
+ start = int(annot['start_offset'])
258
+
259
+ enc_start = encoding.char_to_token(start)
260
+ i = 1
261
+ while enc_start is None and i < 10:
262
+ enc_start = encoding.char_to_token(start+i)
263
+ i += 1
264
+ if i == 10:
265
+ break
266
+
267
+ end = int(annot['end_offset'])
268
+ enc_end = encoding.char_to_token(end)
269
+ j = 1
270
+ while enc_end is None and j < 10:
271
+ enc_end = encoding.char_to_token(end-j)
272
+ j += 1
273
+ if j == 10:
274
+ enc_end = len(encoding.input_ids)
275
+
276
+ if enc_start is None or enc_end is None:
277
+ raise ValueError
278
+
279
+ if args.label_encoding == 'multiclass' and any([x in [2*y for y in range(args.num_labels//2)] for x in labels[enc_start:enc_end]]):
280
+ continue
281
+
282
+ if args.use_umls:
283
+ cat = umls_map.get(annot['category'], None)
284
+ else:
285
+ cat = parse_cat(annot['category'])
286
+ if cat:
287
+ cat -= 1
288
+ if cat is None or (not args.use_umls and cat not in valid_cats):
289
+ continue
290
+ if args.label_encoding == 'multiclass':
291
+ cat1 = cat * 2
292
+ cat2 = cat * 2 + 1
293
+ labels[enc_start] = cat1
294
+ labels[enc_start+1:enc_end] = cat2
295
+ elif args.label_encoding == 'bo':
296
+ cat1 = cat * 2
297
+ cat2 = cat * 2 + 1
298
+ labels[enc_start, cat1] = 1
299
+ labels[enc_start+1:enc_end, cat2] = 1
300
+ elif args.label_encoding == 'boe':
301
+ cat1 = cat * 3
302
+ cat2 = cat * 3 + 1
303
+ cat3 = cat * 3 + 2
304
+ labels[enc_start, cat1] = 1
305
+ labels[enc_start+1:enc_end-1, cat2] = 1
306
+ labels[enc_end-1, cat3] = 1
307
+ else:
308
+ labels[enc_start:enc_end, cat] = 1
309
+
310
+ return {'input_ids': encoding['input_ids'], 'labels': labels, 't2c': encoding.token_to_chars}
311
+
312
+
313
+ def __getitem__(self, idx):
314
+ return self.data[idx]
315
+
316
+ def __len__(self):
317
+ return len(self.data)
318
+
319
+ def parse_cat(cat):
320
+ for i,c in enumerate(cat):
321
+ if c.isnumeric():
322
+ if cat[i+1].isnumeric():
323
+ return int(cat[i:i+2])
324
+ return int(c)
325
+ return None
326
+
327
+
328
+ def load_phenos(args):
329
+ if args.pheno_n == 500:
330
+ phenos = pd.read_csv(os.path.join(args.data_dir,
331
+ 'mimic_decisions/phenos500'),
332
+ sep='\t').rename(lambda x: x.strip(), axis=1)
333
+ phenos['raw_text'] = phenos['raw_text'].apply(lambda x: os.path.basename(x))
334
+ phenos[['SUBJECT_ID', 'HADM_ID', 'ROW_ID']] = \
335
+ [os.path.splitext(x)[0].split('_')[:3] for x in phenos['raw_text']]
336
+ phenos = phenos[phenos['phenotype_label'] != '?']
337
+ elif args.pheno_n == 800:
338
+ phenos = pd.read_csv(os.path.join(args.data_dir, 'mimic_decisions/phenos800.csv'))
339
+ phenos.rename({'Ham_ID': 'HADM_ID'}, inplace=True, axis=1)
340
+ phenos = phenos[phenos.phenotype_label != '?']
341
+ elif args.pheno_n == 1500:
342
+ phenos = pd.read_csv(os.path.join(args.data_dir, 'mimic_decisions/phenos1500.csv'))
343
+ phenos.rename({'Hospital.Admission.ID': 'HADM_ID',
344
+ 'subject.id': 'SUBJECT_ID'}, inplace=True, axis=1)
345
+ phenos = phenos[phenos.Unsure != 1]
346
+ phenos['psychiatric.disorders'] = phenos['Dementia']\
347
+ | phenos['Developmental.Delay.Retardation']\
348
+ | phenos['Schizophrenia.and.other.Psychiatric.Disorders']
349
+ else:
350
+ raise ValueError
351
+ phenos.rename(lambda k: k.lower(), inplace=True, axis = 1)
352
+ return phenos
353
+
354
+ def downsample(dataset):
355
+ data = dataset.data
356
+ class0 = [x for x in data if x[1][0] == 0]
357
+ class1 = [x for x in data if x[1][0] == 1]
358
+
359
+ if len(class0) > len(class1):
360
+ class0 = resample(class0, replace=False, n_samples=len(class1), random_state=0)
361
+ else:
362
+ class1 = resample(class1, replace=False, n_samples=len(class0), random_state=0)
363
+ dataset.data = class0 + class1
364
+
365
+ def upsample(dataset):
366
+ data = dataset.data
367
+ class0 = [x for x in data if x[1][0] == 0]
368
+ class1 = [x for x in data if x[1][0] == 1]
369
+
370
+ if len(class0) > len(class1):
371
+ class1 = resample(class1, replace=True, n_samples=len(class0), random_state=0)
372
+ else:
373
+ class0 = resample(class0, replace=True, n_samples=len(class1), random_state=0)
374
+ dataset.data = class0 + class1
375
+
376
+ def load_tokenizer(name):
377
+ return AutoTokenizer.from_pretrained(name)
378
+
379
+ def load_data(args):
380
+ from sklearn.utils import resample
381
+ def collate_segment(batch):
382
+ xs = []
383
+ ys = []
384
+ t2cs = []
385
+ has_ids = 'ids' in batch[0]
386
+ if has_ids:
387
+ idss = []
388
+ else:
389
+ ids = None
390
+ masks = []
391
+ for i in range(len(batch)):
392
+ x = batch[i]['input_ids']
393
+ y = batch[i]['labels']
394
+ if has_ids:
395
+ ids = batch[i]['ids']
396
+ n = len(x)
397
+ if n > args.max_len:
398
+ start = np.random.randint(0, n - args.max_len + 1)
399
+ x = x[start:start + args.max_len]
400
+ if args.task == 'token':
401
+ y = y[start:start + args.max_len]
402
+ if has_ids:
403
+ new_ids = []
404
+ ids = [x[start:start + args.max_len] for x in ids]
405
+ for subids in ids:
406
+ subids = [idx for idx, x in enumerate(subids) if x]
407
+ new_ids.append(subids)
408
+ all_ids = set([y for x in new_ids for y in x])
409
+ nones = set(range(args.max_len)) - all_ids
410
+ new_ids.append(list(nones))
411
+ mask = [1] * args.max_len
412
+ elif n < args.max_len:
413
+ x = np.pad(x, (0, args.max_len - n))
414
+ if args.task == 'token':
415
+ y = np.pad(y, ((0, args.max_len - n), (0, 0)))
416
+ mask = [1] * n + [0] * (args.max_len - n)
417
+ else:
418
+ mask = [1] * n
419
+ xs.append(x)
420
+ ys.append(y)
421
+ t2cs.append(batch[i]['t2c'])
422
+ if has_ids:
423
+ idss.append(new_ids)
424
+ masks.append(mask)
425
+
426
+ xs = torch.tensor(xs)
427
+ ys = torch.tensor(ys)
428
+ masks = torch.tensor(masks)
429
+ return {'input_ids': xs, 'labels': ys, 'ids': ids, 'mask': masks, 't2c': t2cs}
430
+
431
+ def collate_full(batch):
432
+ lens = [len(x['input_ids']) for x in batch]
433
+ max_len = max(args.max_len, max(lens))
434
+ for i in range(len(batch)):
435
+ batch[i]['input_ids'] = np.pad(batch[i]['input_ids'], (0, max_len - lens[i]))
436
+ if args.task == 'token':
437
+ if args.label_encoding == 'multiclass':
438
+ batch[i]['labels'] = np.pad(batch[i]['labels'], (0, max_len - lens[i]), constant_values=-100)
439
+ else:
440
+ batch[i]['labels'] = np.pad(batch[i]['labels'], ((0, max_len - lens[i]), (0, 0)))
441
+ mask = [1] * lens[i] + [0] * (max_len - lens[i])
442
+ batch[i]['mask'] = mask
443
+
444
+ batch = {k: torch.tensor(np.array([sample[k] for sample in batch])) if isinstance(batch[0][k], Iterable) else
445
+ [sample[k] for sample in batch]
446
+ for k in batch[0].keys()}
447
+ return batch
448
+
449
+ tokenizer = load_tokenizer(args.model_name)
450
+ args.vocab_size = tokenizer.vocab_size
451
+ args.max_length = min(tokenizer.model_max_length, 512)
452
+
453
+ if args.mimic_data:
454
+ from datasets import Dataset
455
+ df = pd.read_csv('/data/mohamed/data/mimiciii/NOTEEVENTS.csv.gz',
456
+ usecols=['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'TEXT'])
457
+ data = Dataset.from_pandas(df)
458
+ return data, tokenizer
459
+ else:
460
+ phenos = load_phenos(args)
461
+ train_files, val_files, test_files = gen_splits(args, phenos)
462
+ phenos.set_index('raw_text', inplace=True)
463
+ train_dataset = MyDataset(args, tokenizer, train_files, phenos, train=True)
464
+
465
+ if args.resample == 'down':
466
+ downsample(train_dataset)
467
+ elif args.resample == 'up':
468
+ upsample(train_dataset)
469
+ val_dataset = MyDataset(args, tokenizer, val_files, phenos)
470
+ test_dataset = MyDataset(args, tokenizer, test_files, phenos)
471
+ print('Train dataset:', len(train_dataset))
472
+ print('Val dataset:', len(val_dataset))
473
+ print('Test dataset:', len(test_dataset))
474
+ train_ns = DataLoader(train_dataset, 1, False,
475
+ collate_fn=collate_full,
476
+ )
477
+ train_dataloader = DataLoader(train_dataset, args.batch_size, True,
478
+ collate_fn=collate_segment,
479
+ )
480
+ val_dataloader = DataLoader(val_dataset, 1, False, collate_fn=collate_full)
481
+ test_dataloader = DataLoader(test_dataset, 1, False, collate_fn=collate_full)
482
+
483
+ train_files = [os.path.basename(x).split('-')[0] for x in train_files]
484
+ val_files = [os.path.basename(x).split('-')[0] for x in val_files]
485
+ test_files = [os.path.basename(x).split('-')[0] for x in test_files]
486
+
487
+ return train_dataloader, val_dataloader, test_dataloader, train_ns, [train_files, val_files, test_files]
demo.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from datetime import datetime
4
+ from dateutil import parser
5
+ from demo_assets import *
6
+ import re
7
+
8
+ categories = ['Contact related', 'Gathering additional information', 'Defining problem',
9
+ 'Treatment goal', 'Drug related', 'Therapeutic procedure related', 'Evaluating test result',
10
+ 'Deferment', 'Advice and precaution', 'Legal and insurance related']
11
+ unicode_symbols = [
12
+ "\U0001F91D", # Handshake
13
+ "\U0001F50D", # Magnifying glass
14
+ "\U0001F9E9", # Puzzle piece
15
+ "\U0001F3AF", # Target
16
+ "\U0001F48A", # Pill
17
+ "\U00002702", # Surgical scissors
18
+ "\U0001F9EA", # Test tube
19
+ "\U000023F0", # Alarm clock
20
+ "\U000026A0", # Warning sign
21
+ "\U0001F4C4" # Document
22
+ ]
23
+
24
+ OTHERS_ID = 18
25
+ def postprocess_labels(text, logits, t2c):
26
+ tags = [None for _ in text]
27
+ labels = logits.argmax(-1)
28
+ for i,cat in enumerate(labels):
29
+ if cat != OTHERS_ID:
30
+ char_ids = t2c(i)
31
+ if char_ids is None:
32
+ continue
33
+ for idx in range(char_ids.start, char_ids.end):
34
+ if tags[idx] is None and idx < len(text):
35
+ tags[idx] = categories[cat // 2]
36
+ for i in range(len(text)-1):
37
+ if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]):
38
+ tags[i] = tags[i-1]
39
+ return tags
40
+
41
+ def indicators_to_spans(labels, t2c = None):
42
+ def add_span(c, start, end):
43
+ if t2c(start) is None or t2c(end) is None:
44
+ start, end = -1, -1
45
+ else:
46
+ start = t2c(start).start
47
+ end = t2c(end).end
48
+ span = (c, start, end)
49
+ spans.add(span)
50
+
51
+ spans = set()
52
+ num_tokens = len(labels)
53
+ num_classes = OTHERS_ID // 2
54
+ start = None
55
+ cls = None
56
+ for t in range(num_tokens):
57
+ if start and labels[t] == cls + 1:
58
+ continue
59
+ elif start:
60
+ add_span(cls // 2, start, t - 1)
61
+ start = None
62
+ # if not start and labels[t] in [2*x for x in range(num_classes)]:
63
+ if not start and labels[t] != OTHERS_ID:
64
+ start = t
65
+ cls = int(labels[t]) // 2 * 2
66
+ return spans
67
+
68
+ def extract_date(text):
69
+ pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})'
70
+ match = re.search(pattern, text).group(1)
71
+ start, end = None, None
72
+ for i, c in enumerate(match):
73
+ if start is None and c.isnumeric():
74
+ start = i
75
+ elif c.isnumeric():
76
+ end = i + 1
77
+ match = match[start:end]
78
+ return match
79
+
80
+
81
+
82
+ def run_gradio(model, tokenizer):
83
+ def predict(text):
84
+ encoding = tokenizer.encode_plus(text)
85
+ x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
86
+ mask = torch.ones_like(x)
87
+ output = model.generate(x, mask)[0]
88
+ return output, encoding.token_to_chars
89
+
90
+ def process(text):
91
+ if text is not None:
92
+ output, t2c = predict(text)
93
+ tags = postprocess_labels(text, output, t2c)
94
+ with open('log.csv', 'a') as f:
95
+ f.write(f'{datetime.now()},{text}\n')
96
+ return list(zip(text, tags))
97
+ else:
98
+ return text
99
+
100
+ def process_sum(*inputs):
101
+ global sum_c
102
+ dates = {}
103
+ for i in range(sum_c):
104
+ text = inputs[i]
105
+ output, t2c = predict(text)
106
+ spans = indicators_to_spans(output.argmax(-1), t2c)
107
+ date = extract_date(text)
108
+ present_decs = set(cat for cat, _, _ in spans)
109
+ decs = {k: [] for k in sorted(present_decs)}
110
+ for c, s, e in spans:
111
+ decs[c].append(text[s:e])
112
+ dates[date] = decs
113
+
114
+ out = ""
115
+ for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
116
+ out += f'## **[{date}]**\n\n'
117
+ decs = dates[date]
118
+ for c in decs:
119
+ out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
120
+ for dec in decs[c]:
121
+ out += f'{dec}\n\n'
122
+
123
+ return out
124
+
125
+ global sum_c
126
+ sum_c = 1
127
+ SUM_INPUTS = 20
128
+ def update_inputs(inputs):
129
+ outputs = []
130
+ if inputs is None:
131
+ c = 0
132
+ else:
133
+ inputs = [open(f.name).read() for f in inputs]
134
+ for i, text in enumerate(inputs):
135
+ outputs.append(gr.update(value=text, visible=True))
136
+ c = len(inputs)
137
+
138
+ n = SUM_INPUTS
139
+ for i in range(n - c):
140
+ outputs.append(gr.update(value='', visible=False))
141
+ global sum_c; sum_c = c
142
+ return outputs
143
+
144
+ def add_ex(*inputs):
145
+ global sum_c
146
+ new_idx = sum_c
147
+ if new_idx < SUM_INPUTS:
148
+ out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
149
+ sum_c += 1
150
+ else:
151
+ out = inputs
152
+ return out
153
+
154
+ def sub_ex(*inputs):
155
+ global sum_c
156
+ new_idx = sum_c - 1
157
+ if new_idx > 0:
158
+ out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
159
+ sum_c -= 1
160
+ else:
161
+ out = inputs
162
+ return out
163
+
164
+
165
+ device = model.backbone.device
166
+ # colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
167
+ # 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
168
+ colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
169
+
170
+ color_map = {cat: colors[i] for i,cat in enumerate(categories)}
171
+
172
+ det_desc = ['Admit, discharge, follow-up, referral',
173
+ 'Ordering test, consulting colleague, seeking external information',
174
+ 'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
175
+ 'Quantitative or qualitative',
176
+ 'Start, stop, alter, maintain, refrain',
177
+ 'Start, stop, alter, maintain, refrain',
178
+ 'Positive, negative, ambiguous test results',
179
+ 'Transfer responsibility, wait and see, change subject',
180
+ 'Advice or precaution',
181
+ 'Sick leave, drug refund, insurance, disability']
182
+
183
+ desc = '### Zones (categories)\n'
184
+ desc += '| | |\n| --- | --- |\n'
185
+ for i,cat in enumerate(categories):
186
+ desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
187
+
188
+ #colors
189
+ #markdown labels
190
+ #legend and desc
191
+ #css font-size
192
+ css = '.category-legend {border:1px dashed black;}'\
193
+ '.text-sm {font-size: 1.5rem; line-height: 200%;}'\
194
+ '.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
195
+ '.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
196
+ '.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
197
+ 'line-height: 1.6;}'\
198
+ '#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
199
+ title='Clinical Decision Zoning'
200
+ with gr.Blocks(title=title, css=css) as demo:
201
+ gr.Markdown(f'# {title}')
202
+ with gr.Tab("Label a Clinical Note"):
203
+ with gr.Row():
204
+ with gr.Column():
205
+ gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
206
+ text_input = gr.Textbox(
207
+ # value=examples[0],
208
+ label="",
209
+ placeholder="Enter text here...")
210
+ text_btn = gr.Button('Run')
211
+ with gr.Column():
212
+ gr.Markdown("## Labeled Summary or Note"),
213
+ text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
214
+ gr.Examples(text_examples, inputs=text_input)
215
+ with gr.Tab("Summarize Patient History"):
216
+ with gr.Row():
217
+ with gr.Column():
218
+ sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
219
+ sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
220
+ for i in range(2, SUM_INPUTS + 1)])
221
+ sum_btn = gr.Button('Run')
222
+ with gr.Row():
223
+ ex_add = gr.Button("+")
224
+ ex_sub = gr.Button("-")
225
+ upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
226
+ gr.Examples(sum_examples, inputs=upload,
227
+ fn = update_inputs, outputs=sum_inputs, run_on_click=True)
228
+ with gr.Column():
229
+ gr.Markdown("## Summarized Clinical Decision History")
230
+ sum_out = gr.Markdown(elem_id='sum-out')
231
+ gr.Markdown(desc)
232
+
233
+ # Functions
234
+ text_input.submit(process, inputs=text_input, outputs=text_out)
235
+ text_btn.click(process, inputs=text_input, outputs=text_out)
236
+ upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
237
+ ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
238
+ ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
239
+ sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
240
+ # demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
241
+ demo.launch(share=False)
demo_assets.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sum_examples = [
2
+ [['examples/note%d.txt'%i for i in range(1,n)]]
3
+ for n in range(5,1, -1)
4
+ ]
5
+
6
+ text_examples = [
7
+ ["a 72 year old female with chronic indwelling foley. GNR identified in her blood and there is concern for possible resistant pseudomonas. vanco discontinued and ceftriaxone replaced with zosyn. "],
8
+ ["This is a 73 year old man with CMML with recent admission to OSH for emergent splenectomy after splenic rupture who is admitted for hypoxemia and worsening bilateral ground glass opacities. He was treated aggressively on the floor with antibiotics and other etiologies (PE, MI, etc) were appropriately addressed. He was fluid resusitated and continued on his CMML regimen. Despite this, the patient became progressively hypotensive and was transferred to the ICU for further care.\n"
9
+ "In the ICU the patient continued to deteriorate and developed progressive hypotension and acidosis despite aggressive fluid repletion, pressor support, and bicarbonate drip. He received >8L NS, 8amps bicarb, pressor support w/ levophed and vasopressin, and maximum ventilatory support. Despite these measures, his lactate continued to trend upwards and he became progressively more hypotensive on the PEEP settings required to adequately oxygenate him. Furthermore, the patient developed tumor lysis syndrome in the setting of his chemotherapy and became anuric producing only 40cc of urine over 8hr. Renal service was called emergently to consider dialysis but the family elected to change his code status to DNR/DNI and focus care on comfort as a priority, after discussion w/ his oncologist Dr [**First Name (STitle) 1557**] and to defer more aggressive therapy."],
10
+ ["48 year old male with complicated past medical history, multiple problems notably including ESRD s/p renal transplant complicated by collapsing FSGS, recent MRSA line sepsis, here with fevers and hypotension at dialysis, code sepsis."
11
+ "He met sepsis criteria with fever, tachycardia and likely source of infection at site of tunneled dialysis catheter. Also had leukocytosis with L shift. CXR clear, urine not produced for sample. No central line placed [**3-5**] lack of access. Treated with 2 doses linezolid PO; d/w renal team - preferred vanco use, patient switched to vanco by level and d/c on vanco at HD. Underwent stim test; failed, started on hydrocort at stress dose levels (50 q6), d/w renal, felt uneccessary, patient started on prednisone taper back to home dose of 5 mg PO qd. Held HTN meds in setting of sepsis. Received dose of vanco prior to d/c."
12
+ "Dialysis Catheter - noted morning after admission to be clotted; question whether this was related to blood draw. Instilled tPA in catheter overnight; were able to use cath in AM for HD. "
13
+ "ESRD - Started on prograf; monitored levels, d/c on home dose. As per pharm, must continue to monitor levels in context of using itraconazole. Continued patient on bactrim for prophylaxis given tacrolimus use. To go to dialysis. 7 point HCT drop noted during admission; thought elevated HCT hemoconcentration. Hemolysis labs neg, no stool to guaiac. "
14
+ "PTT elevation - noted on admission, resolved in ICU. DIC labs negative. PT/PTT elevation at discharge c/w warfarin/SC heparin use."
15
+ "Hypertension: History of HTN, on lopressor and diltiazem. "
16
+ "Pulmonary Aspergillus: Stable. On itraconazole and followed by pulmonary as an outpatient. Continued in house"
17
+ "Atrial fibrillation: He is normally rate controlled with metoprolol and anticoagulated with coumadin. NSR on EKG here, continued warfarin, held beta blocker . "
18
+ "..."
19
+ "Call your PCP or return to the ED for fevers/chills/shakes, chest pain, shortness of breath, pain at the site of your dialysis catheter, nausea, vomiting, or swelling in your legs/feet. "]
20
+ ]
electra-base.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64dc780def96ec19006340e180a60531dacd8db9e0e4e206ba57c720e79775d3
3
+ size 435711089
examples/note1.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date: 2/12/2024
2
+ History: the patient is a 45 yo F here for nervousness. A few weeks ago she noticed that she was feeling more nervous than usual and that it has been worsening. It is exacerbated by family and work. She feels espeically nervous on Sunday night and Monday morning when she is preparing for the week. She is unable to fall asleep and doesn’t want to eat anything, though she does make herself eat. Nothing helps her nervousness. She otherwise denies significant changes in appetite, weight loss, or overall wellbeing. She denies fevers, chills, nausea, constipation, diarrhea, skin changes, racing heart, shortness of breath, dizziness, headaches or rashes.
3
+ ROS: otherwise negative
4
+ PMH: None; PSH: None
5
+ Meds: Tylenol for occasional HA
6
+ FHX: Father had an MI, died at 65yo
7
+ Allergies: NKDA
8
+ SH: Lives at home with husband, mother, and youngest son. Is an english literature professor at a local college.
9
+ Has 2 drinks/mo, no tobacco or drug use.
10
+ Physical Examination:
11
+ VS: Blood Pressure: 130/85 mm Hg
12
+ Heart Rate: 96/min
13
+ Gen: No acute distress, conversational, thin
14
+ Neck: No thyromegaly, no lymphadeopathy
15
+ Heart: RRR, no murmurs, rubs or gallops. Radial pulses +2 bilaterally
16
+ Lungs: Clear to ascultation bilaterally, no wheezes
17
+ Psych: Well-groomed. Non-pressured speech, linear though process.
examples/note2.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date: 3/18/2024
2
+ History: the patient, a 45-year-old female, returns for a follow-up regarding her nervousness. She reports a slight improvement in her symptoms with reduced intensity of nervousness on Sunday nights and Monday mornings. However, she still experiences difficulty sleeping and occasional lack of appetite. She has tried meditation and deep breathing exercises, which have provided minimal relief. No new symptoms have emerged since the last visit. She continues to deny fevers, chills, nausea, or other systemic symptoms.
3
+ ROS: Negative except as noted.
4
+ PMH: No changes.
5
+ PSH: None.
6
+ Meds: Tylenol for occasional headaches. Started on a trial of low-dose sertraline since the last visit.
7
+ FHX: No changes.
8
+ Allergies: NKDA.
9
+ SH: No changes in social circumstances. Continues to work as an English literature professor.
10
+ Physical Examination:
11
+ VS: Blood Pressure: 128/82 mm Hg, Heart Rate: 92/min
12
+ Gen: Appears more relaxed than the previous visit.
13
+ Neck: No changes.
14
+ Heart: Unchanged.
15
+ Lungs: Clear to auscultation.
16
+ Psych: Appears slightly more at ease, maintains good eye contact, speech and thought process remain coherent.
17
+ Assessment/Plan: Improvement noted with sertraline. Will continue the current dose and re-evaluate in 3 months. Encouraged to continue non-pharmacological interventions like meditation and deep breathing exercises. Consider referral to therapy for additional support.
examples/note3.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date: 6/2/2024
2
+ History: the patient, now 46, reports further improvement in her symptoms of nervousness. She has started seeing a therapist, which she finds helpful. Her sleep has improved, and she no longer experiences significant appetite loss. She has not developed any new symptoms and continues to deny systemic symptoms.
3
+ ROS: Negative except as noted.
4
+ PMH: No changes.
5
+ PSH: None.
6
+ Meds: Continues sertraline. No longer using Tylenol for headaches.
7
+ FHX: No changes.
8
+ Allergies: NKDA.
9
+ SH: Stable home and work life. Activities and responsibilities as an English literature professor are well-managed.
10
+ Physical Examination:
11
+ VS: Blood Pressure: 125/80 mm Hg, Heart Rate: 88/min
12
+ Gen: Appears comfortable and at ease.
13
+ Neck: No changes.
14
+ Heart: Unchanged.
15
+ Lungs: Clear to auscultation.
16
+ Psych: Noticeable improvement in mood and anxiety levels. Reports feeling more in control.
17
+ Assessment/Plan: Significant improvement with sertraline and therapy. Plan to continue current management and follow up in 6 months or as needed. Discuss potential for gradually reducing medication under supervision if improvement sustains.
examples/note4.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date: 12/29/2024
2
+ History: the patient comes in for a scheduled follow-up. She feels much better and has managed to maintain her improvements. She expresses a desire to start tapering off sertraline under medical supervision. No new health concerns have been noted. She remains active at work and home.
3
+ ROS: Entirely negative.
4
+ PMH: No changes.
5
+ PSH: None.
6
+ Meds: Sertraline, with a plan to taper.
7
+ FHX: No changes.
8
+ Allergies: NKDA.
9
+ SH: Stable and positive home and work environment.
10
+ Physical Examination:
11
+ VS: Blood Pressure: 122/78 mm Hg, Heart Rate: 84/min
12
+ Gen: Looks healthy and content.
13
+ Neck: No changes.
14
+ Heart: Unchanged.
15
+ Lungs: Clear.
16
+ Psych: Maintained improvement in mental health. Ready for gradual medication reduction.
17
+ Assessment/Plan: Patient has shown sustained improvement and is interested in tapering off medication. Will initiate a slow tapering process of sertraline and monitor closely for any recurrence of symptoms. Continue therapy and supportive measures. Next follow-up scheduled in 3 months to assess progress.
model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from torch import nn
4
+ from transformers import AutoModel
5
+ from torch.optim import AdamW
6
+ from transformers import get_linear_schedule_with_warmup
7
+ # from torchcrf import CRF
8
+
9
+ class MyModel(nn.Module):
10
+ def __init__(self, args, backbone):
11
+ super().__init__()
12
+ self.args = args
13
+ self.backbone = backbone
14
+ self.cls_id = 0
15
+ hidden_dim = self.backbone.config.hidden_size
16
+ self.classifier = nn.Sequential(
17
+ nn.Dropout(0.1),
18
+ nn.Linear(hidden_dim, args.num_labels)
19
+ )
20
+
21
+ if args.distil_att:
22
+ self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size))
23
+
24
+ def forward(self, x, mask):
25
+ x = x.to(self.backbone.device)
26
+ mask = mask.to(self.backbone.device)
27
+ out = self.backbone(x, attention_mask = mask, output_attentions=True)
28
+ return out, self.classifier(out.last_hidden_state)
29
+
30
+ def decisions(self, x, mask):
31
+ x = x.to(self.backbone.device)
32
+ mask = mask.to(self.backbone.device)
33
+ out = self.backbone(x, attention_mask = mask, output_attentions=False)
34
+ return out, self.classifier(out.last_hidden_state)
35
+
36
+ def phenos(self, x, mask):
37
+ x = x.to(self.backbone.device)
38
+ mask = mask.to(self.backbone.device)
39
+ out = self.backbone(x, attention_mask = mask, output_attentions=True)
40
+ return out, self.classifier(out.pooler_output)
41
+
42
+ def generate(self, x, mask, choice=None):
43
+ outs = []
44
+ if self.args.task == 'seq' or choice == 'seq':
45
+ for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)):
46
+ if i == 0:
47
+ segment = x[:, offset:offset + self.args.max_len-1]
48
+ segment_mask = mask[:, offset:offset + self.args.max_len-1]
49
+ else:
50
+ segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\
51
+ *self.cls_id,
52
+ x[:, offset:offset + self.args.max_len-1]), axis=1)
53
+ segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device),
54
+ mask[:, offset:offset + self.args.max_len-1]), axis=1)
55
+ logits = self.phenos(segment, segment_mask)[1]
56
+ outs.append(logits)
57
+
58
+ return torch.max(torch.stack(outs, 1), 1).values
59
+ elif self.args.task == 'token':
60
+ for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
61
+ segment = x[:, offset:offset + self.args.max_len]
62
+ segment_mask = mask[:, offset:offset + self.args.max_len]
63
+ h = self.decisions(segment, segment_mask)[0].last_hidden_state
64
+ outs.append(h)
65
+ h = torch.cat(outs, 1)
66
+ return self.classifier(h)
67
+
68
+ class CNN(nn.Module):
69
+ def __init__(self, args):
70
+ super().__init__()
71
+ self.emb = nn.Embedding(args.vocab_size, args.emb_size)
72
+ self.model = nn.Sequential(
73
+ nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0],
74
+ padding='same' if args.task == 'token' else 'valid'),
75
+ nn.ReLU(),
76
+ nn.MaxPool1d(1),
77
+ nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1],
78
+ padding='same' if args.task == 'token' else 'valid'),
79
+ nn.ReLU(),
80
+ nn.MaxPool1d(1),
81
+ nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2],
82
+ padding='same' if args.task == 'token' else 'valid'),
83
+ nn.ReLU(),
84
+ nn.MaxPool1d(1),
85
+ )
86
+ if args.task == 'seq':
87
+ out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3
88
+ elif args.task == 'token':
89
+ out_shape = 1
90
+ self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels)
91
+ self.dropout = nn.Dropout()
92
+ self.args = args
93
+ self.device = None
94
+
95
+ def forward(self, x, _):
96
+ x = x.to(self.device)
97
+ bs = x.shape[0]
98
+ x = self.emb(x)
99
+ x = x.transpose(1,2)
100
+ x = self.model(x)
101
+ x = self.dropout(x)
102
+ if self.args.task == 'token':
103
+ x = x.transpose(1,2)
104
+ h = self.classifier(x)
105
+ return x, h
106
+ elif self.args.task == 'seq':
107
+ x = x.reshape(bs, -1)
108
+ x = self.classifier(x)
109
+ return x
110
+
111
+ def generate(self, x, _):
112
+ outs = []
113
+ for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
114
+ segment = x[:, offset:offset + self.args.max_len]
115
+ n = segment.shape[1]
116
+ if n != self.args.max_len:
117
+ segment = torch.nn.functional.pad(segment, (0, self.args.max_len - n))
118
+ if self.args.task == 'seq':
119
+ logits = self(segment, None)
120
+ outs.append(logits)
121
+ elif self.args.task == 'token':
122
+ h = self(segment, None)[0]
123
+ h = h[:,:n]
124
+ outs.append(h)
125
+ if self.args.task == 'seq':
126
+ return torch.max(torch.stack(outs, 1), 1).values
127
+ elif self.args.task == 'token':
128
+ h = torch.cat(outs, 1)
129
+ return self.classifier(h)
130
+
131
+ class LSTM(nn.Module):
132
+ def __init__(self, args):
133
+ super().__init__()
134
+ self.emb = nn.Embedding(args.vocab_size, args.emb_size)
135
+ self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers,
136
+ batch_first=True, bidirectional=True)
137
+ dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size
138
+ self.classifier = nn.Linear(dim, args.num_labels)
139
+ self.dropout = nn.Dropout()
140
+ self.args = args
141
+ self.device = None
142
+
143
+ def forward(self, x, _):
144
+ x = x.to(self.device)
145
+ x = self.emb(x)
146
+ o, (x, _) = self.model(x)
147
+ o_out = self.classifier(o) if self.args.task == 'token' else None
148
+ if self.args.task == 'seq':
149
+ x = torch.cat([h for h in x], 1)
150
+ x = self.dropout(x)
151
+ x = self.classifier(x)
152
+ return (x, o), o_out
153
+
154
+ def generate(self, x, _):
155
+ outs = []
156
+ for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
157
+ segment = x[:, offset:offset + self.args.max_len]
158
+ if self.args.task == 'seq':
159
+ logits = self(segment, None)[0][0]
160
+ outs.append(logits)
161
+ elif self.args.task == 'token':
162
+ h = self(segment, None)[0][1]
163
+ outs.append(h)
164
+ if self.args.task == 'seq':
165
+ return torch.max(torch.stack(outs, 1), 1).values
166
+ elif self.args.task == 'token':
167
+ h = torch.cat(outs, 1)
168
+ return self.classifier(h)
169
+
170
+ def load_model(args, device):
171
+ if args.model == 'lstm':
172
+ model = LSTM(args).to(device)
173
+ model.device = device
174
+ elif args.model == 'cnn':
175
+ model = CNN(args).to(device)
176
+ model.device = device
177
+ else:
178
+ model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device)
179
+ if args.ckpt:
180
+ model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False)
181
+ if args.distil:
182
+ args2 = copy.deepcopy(args)
183
+ args2.task = 'token'
184
+ # args2.num_labels = args.num_decs
185
+ args2.num_labels = args.num_umls_tags
186
+ model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device)
187
+ model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False)
188
+ for p in model_B.parameters():
189
+ p.requires_grad = False
190
+ else:
191
+ model_B = None
192
+ if args.label_encoding == 'multiclass':
193
+ if args.use_crf:
194
+ crit = CRF(args.num_labels, batch_first = True).to(device)
195
+ else:
196
+ crit = nn.CrossEntropyLoss(reduction='none')
197
+ else:
198
+ crit = nn.BCEWithLogitsLoss(
199
+ pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight,
200
+ reduction='none'
201
+ )
202
+ optimizer = AdamW(model.parameters(), lr=args.lr)
203
+ lr_scheduler = get_linear_schedule_with_warmup(optimizer,
204
+ int(0.1*args.total_steps), args.total_steps)
205
+
206
+ return model, crit, optimizer, lr_scheduler, model_B
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ argparse
2
+ numpy
3
+ pandas
4
+ torch==1.13.1
5
+ transformers==4.38.1