Spaces:
Running
Running
mohdelgaar
commited on
Commit
•
c47c7dc
0
Parent(s):
Initial commit
Browse files- .gitattributes +34 -0
- README.md +12 -0
- app.py +81 -0
- data.py +487 -0
- demo.py +241 -0
- demo_assets.py +20 -0
- electra-base.pt +3 -0
- examples/note1.txt +17 -0
- examples/note2.txt +17 -0
- examples/note3.txt +17 -0
- examples/note4.txt +17 -0
- model.py +206 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Clinical Decisions
|
3 |
+
emoji: ⚕️
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.40.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from data import load_tokenizer
|
4 |
+
from model import load_model
|
5 |
+
from demo import run_gradio
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
9 |
+
parser.add_argument('--aim_repo', default='/data/mohamed/')
|
10 |
+
parser.add_argument('--ckpt', default='electra-base.pt')
|
11 |
+
parser.add_argument('--aim_exp', default='mimic-decisions-1215')
|
12 |
+
parser.add_argument('--label_encoding', default='multiclass')
|
13 |
+
parser.add_argument('--multiclass', action='store_true')
|
14 |
+
parser.add_argument('--debug', action='store_true')
|
15 |
+
parser.add_argument('--save_losses', action='store_true')
|
16 |
+
parser.add_argument('--task', default='token', choices=['seq', 'token'])
|
17 |
+
parser.add_argument('--max_len', type=int, default=512)
|
18 |
+
parser.add_argument('--num_layers', type=int, default=3)
|
19 |
+
parser.add_argument('--kernels', nargs=3, type=int, default=[1,2,3])
|
20 |
+
parser.add_argument('--model', default='roberta-base',)
|
21 |
+
parser.add_argument('--model_name', default='google/electra-base-discriminator',)
|
22 |
+
parser.add_argument('--gpu', default='0')
|
23 |
+
parser.add_argument('--grad_accumulation', default=2, type=int)
|
24 |
+
parser.add_argument('--pheno_id', type=int)
|
25 |
+
parser.add_argument('--unseen_pheno', type=int)
|
26 |
+
parser.add_argument('--text_subset')
|
27 |
+
parser.add_argument('--pheno_n', type=int, default=500)
|
28 |
+
parser.add_argument('--hidden_size', type=int, default=100)
|
29 |
+
parser.add_argument('--emb_size', type=int, default=400)
|
30 |
+
parser.add_argument('--total_steps', type=int, default=5000)
|
31 |
+
parser.add_argument('--train_log', type=int, default=500)
|
32 |
+
parser.add_argument('--val_log', type=int, default=1000)
|
33 |
+
parser.add_argument('--seed', default = '0')
|
34 |
+
parser.add_argument('--num_phenos', type=int, default=10)
|
35 |
+
parser.add_argument('--num_decs', type=int, default=9)
|
36 |
+
parser.add_argument('--num_umls_tags', type=int, default=33)
|
37 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
38 |
+
parser.add_argument('--pos_weight', type=float, default=1.25)
|
39 |
+
parser.add_argument('--alpha_distil', type=float, default=1)
|
40 |
+
parser.add_argument('--distil', action='store_true')
|
41 |
+
parser.add_argument('--distil_att', action='store_true')
|
42 |
+
parser.add_argument('--distil_ckpt')
|
43 |
+
parser.add_argument('--use_umls', action='store_true')
|
44 |
+
parser.add_argument('--include_nolabel', action='store_true')
|
45 |
+
parser.add_argument('--truncate_train', action='store_true')
|
46 |
+
parser.add_argument('--truncate_eval', action='store_true')
|
47 |
+
parser.add_argument('--load_ckpt', action='store_true')
|
48 |
+
parser.add_argument('--gradio', action='store_true')
|
49 |
+
parser.add_argument('--optuna', action='store_true')
|
50 |
+
parser.add_argument('--mimic_data', action='store_true')
|
51 |
+
parser.add_argument('--eval_only', action='store_true')
|
52 |
+
parser.add_argument('--lr', type=float, default=4e-5)
|
53 |
+
parser.add_argument('--resample', default='')
|
54 |
+
parser.add_argument('--verbose', type=bool, default=True)
|
55 |
+
parser.add_argument('--use_crf', type=bool)
|
56 |
+
parser.add_argument('--print_spans', action='store_true')
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
if args.task == 'seq' and args.pheno_id is not None:
|
60 |
+
args.num_labels = 1
|
61 |
+
elif args.task == 'seq':
|
62 |
+
args.num_labels = args.num_phenos
|
63 |
+
elif args.task == 'token':
|
64 |
+
if args.use_umls:
|
65 |
+
args.num_labels = args.num_umls_tags
|
66 |
+
else:
|
67 |
+
args.num_labels = args.num_decs
|
68 |
+
if args.label_encoding == 'multiclass':
|
69 |
+
args.num_labels = args.num_labels * 2 + 1
|
70 |
+
elif args.label_encoding == 'bo':
|
71 |
+
args.num_labels *= 2
|
72 |
+
elif args.label_encoding == 'boe':
|
73 |
+
args.num_labels *= 3
|
74 |
+
|
75 |
+
|
76 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
77 |
+
tokenizer = load_tokenizer(args.model_name)
|
78 |
+
model = load_model(args, device)[0]
|
79 |
+
model.eval()
|
80 |
+
torch.set_grad_enabled(False)
|
81 |
+
run_gradio(model, tokenizer)
|
data.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
from glob import glob
|
9 |
+
from collections.abc import Iterable
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
|
13 |
+
pheno_map = {'alcohol.abuse': 0,
|
14 |
+
'advanced.lung.disease': 1,
|
15 |
+
'advanced.heart.disease': 2,
|
16 |
+
'chronic.pain.fibromyalgia': 3,
|
17 |
+
'other.substance.abuse': 4,
|
18 |
+
'psychiatric.disorders': 5,
|
19 |
+
'obesity': 6,
|
20 |
+
'depression': 7,
|
21 |
+
'advanced.cancer': 8,
|
22 |
+
'chronic.neurological.dystrophies': 9,
|
23 |
+
'none': -1}
|
24 |
+
rev_pheno_map = {v: k for k,v in pheno_map.items()}
|
25 |
+
valid_cats = range(0,9)
|
26 |
+
|
27 |
+
umls_cats = ['T114', 'T029', 'T073', 'T058', 'T191', 'T200', 'T048', 'T019', 'T046', 'T023', 'T041', 'T059', 'T184', 'T034', 'T116', 'T039', 'T127', 'T201', 'T129', 'T067', 'T109', 'T197', 'T131', 'T130', 'T126', 'T061', 'T203', 'T047', 'T037', 'T074', 'T031', 'T195', 'T168']
|
28 |
+
umls_map = {s: i for i,s in enumerate(umls_cats)}
|
29 |
+
|
30 |
+
def gen_splits(args, phenos):
|
31 |
+
np.random.seed(0)
|
32 |
+
if args.task == 'token':
|
33 |
+
files = glob(os.path.join(args.data_dir, 'mimic_decisions/data/**/*'))
|
34 |
+
if args.use_umls:
|
35 |
+
files = ["/".join(x.split('/')[-1:]) for x in files]
|
36 |
+
else:
|
37 |
+
files = ["/".join(x.split('/')[-2:]) for x in files]
|
38 |
+
subjects = np.unique([os.path.basename(x).split('_')[0] for x in files])
|
39 |
+
elif phenos is not None:
|
40 |
+
subjects = phenos['subject_id'].unique()
|
41 |
+
else:
|
42 |
+
raise ValueError
|
43 |
+
|
44 |
+
phenos['phenotype_label'] = phenos['phenotype_label'].apply(lambda x: x.lower())
|
45 |
+
|
46 |
+
n = len(subjects)
|
47 |
+
train_count = int(0.8*n)
|
48 |
+
val_count = int(0.9*n) - int(0.8*n)
|
49 |
+
test_count = n - int(0.9*n)
|
50 |
+
|
51 |
+
train, val, test = [], [], []
|
52 |
+
np.random.shuffle(subjects)
|
53 |
+
subjects = list(subjects)
|
54 |
+
pheno_list = set(pheno_map.keys())
|
55 |
+
if args.unseen_pheno is not None:
|
56 |
+
test_phenos = {rev_pheno_map[args.unseen_pheno]}
|
57 |
+
unseen_pheno = rev_pheno_map[args.unseen_pheno]
|
58 |
+
train_phenos = pheno_list - test_phenos
|
59 |
+
else:
|
60 |
+
test_phenos = pheno_list
|
61 |
+
train_phenos = pheno_list
|
62 |
+
unseen_pheno = 'null'
|
63 |
+
while len(subjects) > 0:
|
64 |
+
if len(pheno_list) > 0:
|
65 |
+
for pheno in pheno_list:
|
66 |
+
if len(train) < train_count and pheno in train_phenos:
|
67 |
+
el = None
|
68 |
+
for i, subj in enumerate(subjects):
|
69 |
+
row = phenos[phenos.subject_id == subj]
|
70 |
+
if row['phenotype_label'].apply(lambda x: pheno in x and not unseen_pheno in x).any():
|
71 |
+
el = subjects.pop(i)
|
72 |
+
break
|
73 |
+
if el is not None:
|
74 |
+
train.append(el)
|
75 |
+
elif el is None:
|
76 |
+
pheno_list.remove(pheno)
|
77 |
+
break
|
78 |
+
if len(val) < val_count and (not args.pheno_id or len(val) <= (0.5*val_count)):
|
79 |
+
el = None
|
80 |
+
for i, subj in enumerate(subjects):
|
81 |
+
row = phenos[phenos.subject_id == subj]
|
82 |
+
if row['phenotype_label'].apply(lambda x: pheno in x).any():
|
83 |
+
el = subjects.pop(i)
|
84 |
+
break
|
85 |
+
if el is not None:
|
86 |
+
val.append(el)
|
87 |
+
elif el is None:
|
88 |
+
pheno_list.remove(pheno)
|
89 |
+
break
|
90 |
+
if len(test) < test_count or (args.unseen_pheno is not None and pheno in test_phenos):
|
91 |
+
el = None
|
92 |
+
for i, subj in enumerate(subjects):
|
93 |
+
row = phenos[phenos.subject_id == subj]
|
94 |
+
if row['phenotype_label'].apply(lambda x: pheno in x).any():
|
95 |
+
el = subjects.pop(i)
|
96 |
+
break
|
97 |
+
if el is not None:
|
98 |
+
test.append(el)
|
99 |
+
elif el is None:
|
100 |
+
pheno_list.remove(pheno)
|
101 |
+
break
|
102 |
+
else:
|
103 |
+
if len(train) < train_count:
|
104 |
+
el = subjects.pop()
|
105 |
+
if el is not None:
|
106 |
+
train.append(el)
|
107 |
+
if len(val) < val_count:
|
108 |
+
el = subjects.pop()
|
109 |
+
if el is not None:
|
110 |
+
val.append(el)
|
111 |
+
if len(test) < test_count:
|
112 |
+
el = subjects.pop()
|
113 |
+
if el is not None:
|
114 |
+
test.append(el)
|
115 |
+
|
116 |
+
if args.task == 'token':
|
117 |
+
train = [x for x in files if os.path.basename(x).split('_')[0] in train]
|
118 |
+
val = [x for x in files if os.path.basename(x).split('_')[0] in val]
|
119 |
+
test = [x for x in files if os.path.basename(x).split('_')[0] in test]
|
120 |
+
elif phenos is not None:
|
121 |
+
train = phenos[phenos.subject_id.isin(train)]
|
122 |
+
val = phenos[phenos.subject_id.isin(val)]
|
123 |
+
test = phenos[phenos.subject_id.isin(test)]
|
124 |
+
return train, val, test
|
125 |
+
|
126 |
+
class MyDataset(Dataset):
|
127 |
+
def __init__(self, args, tokenizer, data_source, phenos, train = False):
|
128 |
+
super().__init__()
|
129 |
+
self.tokenizer = tokenizer
|
130 |
+
self.data = []
|
131 |
+
self.train = train
|
132 |
+
self.pheno_ids = defaultdict(list)
|
133 |
+
self.dec_ids = {k: [] for k in pheno_map.keys()}
|
134 |
+
|
135 |
+
if args.task == 'seq':
|
136 |
+
for i, row in data_source.iterrows():
|
137 |
+
sample = self.load_phenos(args, row, i)
|
138 |
+
self.data.append(sample)
|
139 |
+
else:
|
140 |
+
for i, fn in enumerate(data_source):
|
141 |
+
sample = self.load_decisions(args, fn, i, phenos)
|
142 |
+
self.data.append(sample)
|
143 |
+
|
144 |
+
def load_phenos(self, args, row, idx):
|
145 |
+
txt_candidates = glob(os.path.join(args.data_dir,
|
146 |
+
f'mimic_decisions/raw_text/{row["subject_id"]}_{row["hadm_id"]}*.txt'))
|
147 |
+
text = open(txt_candidates[0]).read()
|
148 |
+
if args.pheno_n == 500:
|
149 |
+
|
150 |
+
file_dir = glob(os.path.join(args.data_dir,
|
151 |
+
f'mimic_decisions/data/*/{row["subject_id"]}_{row["hadm_id"]}*.json'))[0]
|
152 |
+
with open(file_dir) as f:
|
153 |
+
data = json.load(f, strict=False)
|
154 |
+
annots = data[0]['annotations']
|
155 |
+
|
156 |
+
if args.text_subset:
|
157 |
+
unlabeled_text = np.ones(len(text), dtype=bool)
|
158 |
+
labeled_text = np.zeros(len(text), dtype=bool)
|
159 |
+
for annot in annots:
|
160 |
+
cat = parse_cat(annot['category'])
|
161 |
+
start, end = map(int, (annot['start_offset'], annot['end_offset']))
|
162 |
+
if cat is not None:
|
163 |
+
unlabeled_text[start:end] = 0
|
164 |
+
if cat in args.text_subset:
|
165 |
+
labeled_text[start:end] = 1
|
166 |
+
|
167 |
+
combined_text = unlabeled_text | labeled_text if args.include_nolabel else labeled_text
|
168 |
+
text = "".join([c for i,c in enumerate(text) if combined_text[i]])
|
169 |
+
|
170 |
+
encoding = self.tokenizer.encode_plus(text,
|
171 |
+
truncation=args.truncate_train if self.train else args.truncate_eval)
|
172 |
+
|
173 |
+
ids = np.zeros((args.num_decs, len(encoding['input_ids'])))
|
174 |
+
for annot in annots:
|
175 |
+
start = int(annot['start_offset'])
|
176 |
+
|
177 |
+
enc_start = encoding.char_to_token(start)
|
178 |
+
i = 1
|
179 |
+
while enc_start is None:
|
180 |
+
enc_start = encoding.char_to_token(start+i)
|
181 |
+
i += 1
|
182 |
+
|
183 |
+
end = int(annot['end_offset'])
|
184 |
+
enc_end = encoding.char_to_token(end)
|
185 |
+
j = 1
|
186 |
+
while enc_end is None:
|
187 |
+
enc_end = encoding.char_to_token(end-j)
|
188 |
+
j += 1
|
189 |
+
|
190 |
+
if enc_start is None or enc_end is None:
|
191 |
+
raise ValueError
|
192 |
+
|
193 |
+
cat = parse_cat(annot['category'])
|
194 |
+
if not cat or cat not in valid_cats:
|
195 |
+
continue
|
196 |
+
ids[cat-1, enc_start:enc_end] = 1
|
197 |
+
else:
|
198 |
+
encoding = self.tokenizer.encode_plus(text,
|
199 |
+
truncation=args.truncate_train if self.train else args.truncate_eval)
|
200 |
+
ids = None
|
201 |
+
|
202 |
+
labels = np.zeros(args.num_phenos)
|
203 |
+
|
204 |
+
if args.pheno_n in (500, 800):
|
205 |
+
sample_phenos = row['phenotype_label']
|
206 |
+
if sample_phenos != 'none':
|
207 |
+
for pheno in sample_phenos.split(','):
|
208 |
+
labels[pheno_map[pheno.lower()]] = 1
|
209 |
+
|
210 |
+
elif args.pheno_n == 1500:
|
211 |
+
for k,v in pheno_map.items():
|
212 |
+
if row[k] == 1:
|
213 |
+
labels[v] = 1
|
214 |
+
|
215 |
+
if args.pheno_id is not None:
|
216 |
+
if args.pheno_id == -1:
|
217 |
+
labels = [0.0 if any(labels) else 1.0]
|
218 |
+
else:
|
219 |
+
labels = [labels[args.pheno_id]]
|
220 |
+
|
221 |
+
return encoding['input_ids'], labels, ids
|
222 |
+
|
223 |
+
def load_decisions(self, args, fn, idx, phenos):
|
224 |
+
basename = os.path.basename(fn).split("-")[0]
|
225 |
+
if args.use_umls:
|
226 |
+
file_dir = os.path.join(args.data_dir, 'mimic_decisions/umls', basename)
|
227 |
+
else:
|
228 |
+
file_dir = os.path.join(args.data_dir, 'mimic_decisions/data', fn)
|
229 |
+
|
230 |
+
pheno_id = "_".join(basename.split("_")[:3]) + '.txt'
|
231 |
+
txt_candidates = glob(os.path.join(args.data_dir,
|
232 |
+
f'mimic_decisions/raw_text/{basename}*.txt'))
|
233 |
+
text = open(txt_candidates[0]).read()
|
234 |
+
encoding = self.tokenizer.encode_plus(text,
|
235 |
+
max_length=args.max_len,
|
236 |
+
truncation=args.truncate_train if self.train else args.truncate_eval,
|
237 |
+
padding = 'max_length',
|
238 |
+
)
|
239 |
+
if pheno_id in phenos.index:
|
240 |
+
sample_phenos = phenos.loc[pheno_id]['phenotype_label']
|
241 |
+
for pheno in sample_phenos.split(','):
|
242 |
+
self.pheno_ids[pheno].append(idx)
|
243 |
+
|
244 |
+
|
245 |
+
with open(file_dir) as f:
|
246 |
+
data = json.load(f, strict=False)
|
247 |
+
if args.use_umls:
|
248 |
+
annots = data
|
249 |
+
else:
|
250 |
+
annots = data[0]['annotations']
|
251 |
+
|
252 |
+
if args.label_encoding == 'multiclass':
|
253 |
+
labels = np.full(len(encoding['input_ids']), args.num_labels-1, dtype=int)
|
254 |
+
else:
|
255 |
+
labels = np.zeros((len(encoding['input_ids']), args.num_labels))
|
256 |
+
for annot in annots:
|
257 |
+
start = int(annot['start_offset'])
|
258 |
+
|
259 |
+
enc_start = encoding.char_to_token(start)
|
260 |
+
i = 1
|
261 |
+
while enc_start is None and i < 10:
|
262 |
+
enc_start = encoding.char_to_token(start+i)
|
263 |
+
i += 1
|
264 |
+
if i == 10:
|
265 |
+
break
|
266 |
+
|
267 |
+
end = int(annot['end_offset'])
|
268 |
+
enc_end = encoding.char_to_token(end)
|
269 |
+
j = 1
|
270 |
+
while enc_end is None and j < 10:
|
271 |
+
enc_end = encoding.char_to_token(end-j)
|
272 |
+
j += 1
|
273 |
+
if j == 10:
|
274 |
+
enc_end = len(encoding.input_ids)
|
275 |
+
|
276 |
+
if enc_start is None or enc_end is None:
|
277 |
+
raise ValueError
|
278 |
+
|
279 |
+
if args.label_encoding == 'multiclass' and any([x in [2*y for y in range(args.num_labels//2)] for x in labels[enc_start:enc_end]]):
|
280 |
+
continue
|
281 |
+
|
282 |
+
if args.use_umls:
|
283 |
+
cat = umls_map.get(annot['category'], None)
|
284 |
+
else:
|
285 |
+
cat = parse_cat(annot['category'])
|
286 |
+
if cat:
|
287 |
+
cat -= 1
|
288 |
+
if cat is None or (not args.use_umls and cat not in valid_cats):
|
289 |
+
continue
|
290 |
+
if args.label_encoding == 'multiclass':
|
291 |
+
cat1 = cat * 2
|
292 |
+
cat2 = cat * 2 + 1
|
293 |
+
labels[enc_start] = cat1
|
294 |
+
labels[enc_start+1:enc_end] = cat2
|
295 |
+
elif args.label_encoding == 'bo':
|
296 |
+
cat1 = cat * 2
|
297 |
+
cat2 = cat * 2 + 1
|
298 |
+
labels[enc_start, cat1] = 1
|
299 |
+
labels[enc_start+1:enc_end, cat2] = 1
|
300 |
+
elif args.label_encoding == 'boe':
|
301 |
+
cat1 = cat * 3
|
302 |
+
cat2 = cat * 3 + 1
|
303 |
+
cat3 = cat * 3 + 2
|
304 |
+
labels[enc_start, cat1] = 1
|
305 |
+
labels[enc_start+1:enc_end-1, cat2] = 1
|
306 |
+
labels[enc_end-1, cat3] = 1
|
307 |
+
else:
|
308 |
+
labels[enc_start:enc_end, cat] = 1
|
309 |
+
|
310 |
+
return {'input_ids': encoding['input_ids'], 'labels': labels, 't2c': encoding.token_to_chars}
|
311 |
+
|
312 |
+
|
313 |
+
def __getitem__(self, idx):
|
314 |
+
return self.data[idx]
|
315 |
+
|
316 |
+
def __len__(self):
|
317 |
+
return len(self.data)
|
318 |
+
|
319 |
+
def parse_cat(cat):
|
320 |
+
for i,c in enumerate(cat):
|
321 |
+
if c.isnumeric():
|
322 |
+
if cat[i+1].isnumeric():
|
323 |
+
return int(cat[i:i+2])
|
324 |
+
return int(c)
|
325 |
+
return None
|
326 |
+
|
327 |
+
|
328 |
+
def load_phenos(args):
|
329 |
+
if args.pheno_n == 500:
|
330 |
+
phenos = pd.read_csv(os.path.join(args.data_dir,
|
331 |
+
'mimic_decisions/phenos500'),
|
332 |
+
sep='\t').rename(lambda x: x.strip(), axis=1)
|
333 |
+
phenos['raw_text'] = phenos['raw_text'].apply(lambda x: os.path.basename(x))
|
334 |
+
phenos[['SUBJECT_ID', 'HADM_ID', 'ROW_ID']] = \
|
335 |
+
[os.path.splitext(x)[0].split('_')[:3] for x in phenos['raw_text']]
|
336 |
+
phenos = phenos[phenos['phenotype_label'] != '?']
|
337 |
+
elif args.pheno_n == 800:
|
338 |
+
phenos = pd.read_csv(os.path.join(args.data_dir, 'mimic_decisions/phenos800.csv'))
|
339 |
+
phenos.rename({'Ham_ID': 'HADM_ID'}, inplace=True, axis=1)
|
340 |
+
phenos = phenos[phenos.phenotype_label != '?']
|
341 |
+
elif args.pheno_n == 1500:
|
342 |
+
phenos = pd.read_csv(os.path.join(args.data_dir, 'mimic_decisions/phenos1500.csv'))
|
343 |
+
phenos.rename({'Hospital.Admission.ID': 'HADM_ID',
|
344 |
+
'subject.id': 'SUBJECT_ID'}, inplace=True, axis=1)
|
345 |
+
phenos = phenos[phenos.Unsure != 1]
|
346 |
+
phenos['psychiatric.disorders'] = phenos['Dementia']\
|
347 |
+
| phenos['Developmental.Delay.Retardation']\
|
348 |
+
| phenos['Schizophrenia.and.other.Psychiatric.Disorders']
|
349 |
+
else:
|
350 |
+
raise ValueError
|
351 |
+
phenos.rename(lambda k: k.lower(), inplace=True, axis = 1)
|
352 |
+
return phenos
|
353 |
+
|
354 |
+
def downsample(dataset):
|
355 |
+
data = dataset.data
|
356 |
+
class0 = [x for x in data if x[1][0] == 0]
|
357 |
+
class1 = [x for x in data if x[1][0] == 1]
|
358 |
+
|
359 |
+
if len(class0) > len(class1):
|
360 |
+
class0 = resample(class0, replace=False, n_samples=len(class1), random_state=0)
|
361 |
+
else:
|
362 |
+
class1 = resample(class1, replace=False, n_samples=len(class0), random_state=0)
|
363 |
+
dataset.data = class0 + class1
|
364 |
+
|
365 |
+
def upsample(dataset):
|
366 |
+
data = dataset.data
|
367 |
+
class0 = [x for x in data if x[1][0] == 0]
|
368 |
+
class1 = [x for x in data if x[1][0] == 1]
|
369 |
+
|
370 |
+
if len(class0) > len(class1):
|
371 |
+
class1 = resample(class1, replace=True, n_samples=len(class0), random_state=0)
|
372 |
+
else:
|
373 |
+
class0 = resample(class0, replace=True, n_samples=len(class1), random_state=0)
|
374 |
+
dataset.data = class0 + class1
|
375 |
+
|
376 |
+
def load_tokenizer(name):
|
377 |
+
return AutoTokenizer.from_pretrained(name)
|
378 |
+
|
379 |
+
def load_data(args):
|
380 |
+
from sklearn.utils import resample
|
381 |
+
def collate_segment(batch):
|
382 |
+
xs = []
|
383 |
+
ys = []
|
384 |
+
t2cs = []
|
385 |
+
has_ids = 'ids' in batch[0]
|
386 |
+
if has_ids:
|
387 |
+
idss = []
|
388 |
+
else:
|
389 |
+
ids = None
|
390 |
+
masks = []
|
391 |
+
for i in range(len(batch)):
|
392 |
+
x = batch[i]['input_ids']
|
393 |
+
y = batch[i]['labels']
|
394 |
+
if has_ids:
|
395 |
+
ids = batch[i]['ids']
|
396 |
+
n = len(x)
|
397 |
+
if n > args.max_len:
|
398 |
+
start = np.random.randint(0, n - args.max_len + 1)
|
399 |
+
x = x[start:start + args.max_len]
|
400 |
+
if args.task == 'token':
|
401 |
+
y = y[start:start + args.max_len]
|
402 |
+
if has_ids:
|
403 |
+
new_ids = []
|
404 |
+
ids = [x[start:start + args.max_len] for x in ids]
|
405 |
+
for subids in ids:
|
406 |
+
subids = [idx for idx, x in enumerate(subids) if x]
|
407 |
+
new_ids.append(subids)
|
408 |
+
all_ids = set([y for x in new_ids for y in x])
|
409 |
+
nones = set(range(args.max_len)) - all_ids
|
410 |
+
new_ids.append(list(nones))
|
411 |
+
mask = [1] * args.max_len
|
412 |
+
elif n < args.max_len:
|
413 |
+
x = np.pad(x, (0, args.max_len - n))
|
414 |
+
if args.task == 'token':
|
415 |
+
y = np.pad(y, ((0, args.max_len - n), (0, 0)))
|
416 |
+
mask = [1] * n + [0] * (args.max_len - n)
|
417 |
+
else:
|
418 |
+
mask = [1] * n
|
419 |
+
xs.append(x)
|
420 |
+
ys.append(y)
|
421 |
+
t2cs.append(batch[i]['t2c'])
|
422 |
+
if has_ids:
|
423 |
+
idss.append(new_ids)
|
424 |
+
masks.append(mask)
|
425 |
+
|
426 |
+
xs = torch.tensor(xs)
|
427 |
+
ys = torch.tensor(ys)
|
428 |
+
masks = torch.tensor(masks)
|
429 |
+
return {'input_ids': xs, 'labels': ys, 'ids': ids, 'mask': masks, 't2c': t2cs}
|
430 |
+
|
431 |
+
def collate_full(batch):
|
432 |
+
lens = [len(x['input_ids']) for x in batch]
|
433 |
+
max_len = max(args.max_len, max(lens))
|
434 |
+
for i in range(len(batch)):
|
435 |
+
batch[i]['input_ids'] = np.pad(batch[i]['input_ids'], (0, max_len - lens[i]))
|
436 |
+
if args.task == 'token':
|
437 |
+
if args.label_encoding == 'multiclass':
|
438 |
+
batch[i]['labels'] = np.pad(batch[i]['labels'], (0, max_len - lens[i]), constant_values=-100)
|
439 |
+
else:
|
440 |
+
batch[i]['labels'] = np.pad(batch[i]['labels'], ((0, max_len - lens[i]), (0, 0)))
|
441 |
+
mask = [1] * lens[i] + [0] * (max_len - lens[i])
|
442 |
+
batch[i]['mask'] = mask
|
443 |
+
|
444 |
+
batch = {k: torch.tensor(np.array([sample[k] for sample in batch])) if isinstance(batch[0][k], Iterable) else
|
445 |
+
[sample[k] for sample in batch]
|
446 |
+
for k in batch[0].keys()}
|
447 |
+
return batch
|
448 |
+
|
449 |
+
tokenizer = load_tokenizer(args.model_name)
|
450 |
+
args.vocab_size = tokenizer.vocab_size
|
451 |
+
args.max_length = min(tokenizer.model_max_length, 512)
|
452 |
+
|
453 |
+
if args.mimic_data:
|
454 |
+
from datasets import Dataset
|
455 |
+
df = pd.read_csv('/data/mohamed/data/mimiciii/NOTEEVENTS.csv.gz',
|
456 |
+
usecols=['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'TEXT'])
|
457 |
+
data = Dataset.from_pandas(df)
|
458 |
+
return data, tokenizer
|
459 |
+
else:
|
460 |
+
phenos = load_phenos(args)
|
461 |
+
train_files, val_files, test_files = gen_splits(args, phenos)
|
462 |
+
phenos.set_index('raw_text', inplace=True)
|
463 |
+
train_dataset = MyDataset(args, tokenizer, train_files, phenos, train=True)
|
464 |
+
|
465 |
+
if args.resample == 'down':
|
466 |
+
downsample(train_dataset)
|
467 |
+
elif args.resample == 'up':
|
468 |
+
upsample(train_dataset)
|
469 |
+
val_dataset = MyDataset(args, tokenizer, val_files, phenos)
|
470 |
+
test_dataset = MyDataset(args, tokenizer, test_files, phenos)
|
471 |
+
print('Train dataset:', len(train_dataset))
|
472 |
+
print('Val dataset:', len(val_dataset))
|
473 |
+
print('Test dataset:', len(test_dataset))
|
474 |
+
train_ns = DataLoader(train_dataset, 1, False,
|
475 |
+
collate_fn=collate_full,
|
476 |
+
)
|
477 |
+
train_dataloader = DataLoader(train_dataset, args.batch_size, True,
|
478 |
+
collate_fn=collate_segment,
|
479 |
+
)
|
480 |
+
val_dataloader = DataLoader(val_dataset, 1, False, collate_fn=collate_full)
|
481 |
+
test_dataloader = DataLoader(test_dataset, 1, False, collate_fn=collate_full)
|
482 |
+
|
483 |
+
train_files = [os.path.basename(x).split('-')[0] for x in train_files]
|
484 |
+
val_files = [os.path.basename(x).split('-')[0] for x in val_files]
|
485 |
+
test_files = [os.path.basename(x).split('-')[0] for x in test_files]
|
486 |
+
|
487 |
+
return train_dataloader, val_dataloader, test_dataloader, train_ns, [train_files, val_files, test_files]
|
demo.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from datetime import datetime
|
4 |
+
from dateutil import parser
|
5 |
+
from demo_assets import *
|
6 |
+
import re
|
7 |
+
|
8 |
+
categories = ['Contact related', 'Gathering additional information', 'Defining problem',
|
9 |
+
'Treatment goal', 'Drug related', 'Therapeutic procedure related', 'Evaluating test result',
|
10 |
+
'Deferment', 'Advice and precaution', 'Legal and insurance related']
|
11 |
+
unicode_symbols = [
|
12 |
+
"\U0001F91D", # Handshake
|
13 |
+
"\U0001F50D", # Magnifying glass
|
14 |
+
"\U0001F9E9", # Puzzle piece
|
15 |
+
"\U0001F3AF", # Target
|
16 |
+
"\U0001F48A", # Pill
|
17 |
+
"\U00002702", # Surgical scissors
|
18 |
+
"\U0001F9EA", # Test tube
|
19 |
+
"\U000023F0", # Alarm clock
|
20 |
+
"\U000026A0", # Warning sign
|
21 |
+
"\U0001F4C4" # Document
|
22 |
+
]
|
23 |
+
|
24 |
+
OTHERS_ID = 18
|
25 |
+
def postprocess_labels(text, logits, t2c):
|
26 |
+
tags = [None for _ in text]
|
27 |
+
labels = logits.argmax(-1)
|
28 |
+
for i,cat in enumerate(labels):
|
29 |
+
if cat != OTHERS_ID:
|
30 |
+
char_ids = t2c(i)
|
31 |
+
if char_ids is None:
|
32 |
+
continue
|
33 |
+
for idx in range(char_ids.start, char_ids.end):
|
34 |
+
if tags[idx] is None and idx < len(text):
|
35 |
+
tags[idx] = categories[cat // 2]
|
36 |
+
for i in range(len(text)-1):
|
37 |
+
if text[i] == ' ' and (text[i+1] == ' ' or tags[i-1] == tags[i+1]):
|
38 |
+
tags[i] = tags[i-1]
|
39 |
+
return tags
|
40 |
+
|
41 |
+
def indicators_to_spans(labels, t2c = None):
|
42 |
+
def add_span(c, start, end):
|
43 |
+
if t2c(start) is None or t2c(end) is None:
|
44 |
+
start, end = -1, -1
|
45 |
+
else:
|
46 |
+
start = t2c(start).start
|
47 |
+
end = t2c(end).end
|
48 |
+
span = (c, start, end)
|
49 |
+
spans.add(span)
|
50 |
+
|
51 |
+
spans = set()
|
52 |
+
num_tokens = len(labels)
|
53 |
+
num_classes = OTHERS_ID // 2
|
54 |
+
start = None
|
55 |
+
cls = None
|
56 |
+
for t in range(num_tokens):
|
57 |
+
if start and labels[t] == cls + 1:
|
58 |
+
continue
|
59 |
+
elif start:
|
60 |
+
add_span(cls // 2, start, t - 1)
|
61 |
+
start = None
|
62 |
+
# if not start and labels[t] in [2*x for x in range(num_classes)]:
|
63 |
+
if not start and labels[t] != OTHERS_ID:
|
64 |
+
start = t
|
65 |
+
cls = int(labels[t]) // 2 * 2
|
66 |
+
return spans
|
67 |
+
|
68 |
+
def extract_date(text):
|
69 |
+
pattern = r'(?<=Date: )\s*(\[\*\*.*?\*\*\]|\d{1,4}[-/]\d{1,2}[-/]\d{1,4})'
|
70 |
+
match = re.search(pattern, text).group(1)
|
71 |
+
start, end = None, None
|
72 |
+
for i, c in enumerate(match):
|
73 |
+
if start is None and c.isnumeric():
|
74 |
+
start = i
|
75 |
+
elif c.isnumeric():
|
76 |
+
end = i + 1
|
77 |
+
match = match[start:end]
|
78 |
+
return match
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
def run_gradio(model, tokenizer):
|
83 |
+
def predict(text):
|
84 |
+
encoding = tokenizer.encode_plus(text)
|
85 |
+
x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
|
86 |
+
mask = torch.ones_like(x)
|
87 |
+
output = model.generate(x, mask)[0]
|
88 |
+
return output, encoding.token_to_chars
|
89 |
+
|
90 |
+
def process(text):
|
91 |
+
if text is not None:
|
92 |
+
output, t2c = predict(text)
|
93 |
+
tags = postprocess_labels(text, output, t2c)
|
94 |
+
with open('log.csv', 'a') as f:
|
95 |
+
f.write(f'{datetime.now()},{text}\n')
|
96 |
+
return list(zip(text, tags))
|
97 |
+
else:
|
98 |
+
return text
|
99 |
+
|
100 |
+
def process_sum(*inputs):
|
101 |
+
global sum_c
|
102 |
+
dates = {}
|
103 |
+
for i in range(sum_c):
|
104 |
+
text = inputs[i]
|
105 |
+
output, t2c = predict(text)
|
106 |
+
spans = indicators_to_spans(output.argmax(-1), t2c)
|
107 |
+
date = extract_date(text)
|
108 |
+
present_decs = set(cat for cat, _, _ in spans)
|
109 |
+
decs = {k: [] for k in sorted(present_decs)}
|
110 |
+
for c, s, e in spans:
|
111 |
+
decs[c].append(text[s:e])
|
112 |
+
dates[date] = decs
|
113 |
+
|
114 |
+
out = ""
|
115 |
+
for date in sorted(dates.keys(), key = lambda x: parser.parse(x)):
|
116 |
+
out += f'## **[{date}]**\n\n'
|
117 |
+
decs = dates[date]
|
118 |
+
for c in decs:
|
119 |
+
out += f'### {unicode_symbols[c]} ***{categories[c]}***\n\n'
|
120 |
+
for dec in decs[c]:
|
121 |
+
out += f'{dec}\n\n'
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
global sum_c
|
126 |
+
sum_c = 1
|
127 |
+
SUM_INPUTS = 20
|
128 |
+
def update_inputs(inputs):
|
129 |
+
outputs = []
|
130 |
+
if inputs is None:
|
131 |
+
c = 0
|
132 |
+
else:
|
133 |
+
inputs = [open(f.name).read() for f in inputs]
|
134 |
+
for i, text in enumerate(inputs):
|
135 |
+
outputs.append(gr.update(value=text, visible=True))
|
136 |
+
c = len(inputs)
|
137 |
+
|
138 |
+
n = SUM_INPUTS
|
139 |
+
for i in range(n - c):
|
140 |
+
outputs.append(gr.update(value='', visible=False))
|
141 |
+
global sum_c; sum_c = c
|
142 |
+
return outputs
|
143 |
+
|
144 |
+
def add_ex(*inputs):
|
145 |
+
global sum_c
|
146 |
+
new_idx = sum_c
|
147 |
+
if new_idx < SUM_INPUTS:
|
148 |
+
out = inputs[:new_idx] + (gr.update(visible=True),) + inputs[new_idx+1:]
|
149 |
+
sum_c += 1
|
150 |
+
else:
|
151 |
+
out = inputs
|
152 |
+
return out
|
153 |
+
|
154 |
+
def sub_ex(*inputs):
|
155 |
+
global sum_c
|
156 |
+
new_idx = sum_c - 1
|
157 |
+
if new_idx > 0:
|
158 |
+
out = inputs[:new_idx] + (gr.update(visible=False),) + inputs[new_idx+1:]
|
159 |
+
sum_c -= 1
|
160 |
+
else:
|
161 |
+
out = inputs
|
162 |
+
return out
|
163 |
+
|
164 |
+
|
165 |
+
device = model.backbone.device
|
166 |
+
# colors = ['aqua', 'blue', 'fuchsia', 'teal', 'green', 'olive', 'lime', 'silver', 'purple', 'red',
|
167 |
+
# 'yellow', 'navy', 'gray', 'white', 'maroon', 'black']
|
168 |
+
colors = ['#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#d9d9d9', '#bc80bd']
|
169 |
+
|
170 |
+
color_map = {cat: colors[i] for i,cat in enumerate(categories)}
|
171 |
+
|
172 |
+
det_desc = ['Admit, discharge, follow-up, referral',
|
173 |
+
'Ordering test, consulting colleague, seeking external information',
|
174 |
+
'Diagnostic conclusion, evaluation of health state, etiological inference, prognostic judgment',
|
175 |
+
'Quantitative or qualitative',
|
176 |
+
'Start, stop, alter, maintain, refrain',
|
177 |
+
'Start, stop, alter, maintain, refrain',
|
178 |
+
'Positive, negative, ambiguous test results',
|
179 |
+
'Transfer responsibility, wait and see, change subject',
|
180 |
+
'Advice or precaution',
|
181 |
+
'Sick leave, drug refund, insurance, disability']
|
182 |
+
|
183 |
+
desc = '### Zones (categories)\n'
|
184 |
+
desc += '| | |\n| --- | --- |\n'
|
185 |
+
for i,cat in enumerate(categories):
|
186 |
+
desc += f'| {unicode_symbols[i]} **{cat}** | {det_desc[i]}|\n'
|
187 |
+
|
188 |
+
#colors
|
189 |
+
#markdown labels
|
190 |
+
#legend and desc
|
191 |
+
#css font-size
|
192 |
+
css = '.category-legend {border:1px dashed black;}'\
|
193 |
+
'.text-sm {font-size: 1.5rem; line-height: 200%;}'\
|
194 |
+
'.gr-sample-textbox {width: 1000px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}'\
|
195 |
+
'.text-limit label textarea {height: 150px !important; overflow: scroll; }'\
|
196 |
+
'.text-gray-500 {color: #111827; font-weight: 600; font-size: 1.25em; margin-top: 1.6em; margin-bottom: 0.6em;'\
|
197 |
+
'line-height: 1.6;}'\
|
198 |
+
'#sum-out {border: 2px solid #007bff; padding: 20px; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);'
|
199 |
+
title='Clinical Decision Zoning'
|
200 |
+
with gr.Blocks(title=title, css=css) as demo:
|
201 |
+
gr.Markdown(f'# {title}')
|
202 |
+
with gr.Tab("Label a Clinical Note"):
|
203 |
+
with gr.Row():
|
204 |
+
with gr.Column():
|
205 |
+
gr.Markdown("## Enter a Discharge Summary or Clinical Note"),
|
206 |
+
text_input = gr.Textbox(
|
207 |
+
# value=examples[0],
|
208 |
+
label="",
|
209 |
+
placeholder="Enter text here...")
|
210 |
+
text_btn = gr.Button('Run')
|
211 |
+
with gr.Column():
|
212 |
+
gr.Markdown("## Labeled Summary or Note"),
|
213 |
+
text_out = gr.Highlight(label="", combine_adjacent=True, show_legend=False, color_map=color_map)
|
214 |
+
gr.Examples(text_examples, inputs=text_input)
|
215 |
+
with gr.Tab("Summarize Patient History"):
|
216 |
+
with gr.Row():
|
217 |
+
with gr.Column():
|
218 |
+
sum_inputs = [gr.Text(label='Clinical Note 1', elem_classes='text-limit')]
|
219 |
+
sum_inputs.extend([gr.Text(label='Clinical Note %d'%i, visible=False, elem_classes='text-limit')
|
220 |
+
for i in range(2, SUM_INPUTS + 1)])
|
221 |
+
sum_btn = gr.Button('Run')
|
222 |
+
with gr.Row():
|
223 |
+
ex_add = gr.Button("+")
|
224 |
+
ex_sub = gr.Button("-")
|
225 |
+
upload = gr.File(label='Upload clinical notes', file_type='text', file_count='multiple')
|
226 |
+
gr.Examples(sum_examples, inputs=upload,
|
227 |
+
fn = update_inputs, outputs=sum_inputs, run_on_click=True)
|
228 |
+
with gr.Column():
|
229 |
+
gr.Markdown("## Summarized Clinical Decision History")
|
230 |
+
sum_out = gr.Markdown(elem_id='sum-out')
|
231 |
+
gr.Markdown(desc)
|
232 |
+
|
233 |
+
# Functions
|
234 |
+
text_input.submit(process, inputs=text_input, outputs=text_out)
|
235 |
+
text_btn.click(process, inputs=text_input, outputs=text_out)
|
236 |
+
upload.change(update_inputs, inputs=upload, outputs=sum_inputs)
|
237 |
+
ex_add.click(add_ex, inputs=sum_inputs, outputs=sum_inputs)
|
238 |
+
ex_sub.click(sub_ex, inputs=sum_inputs, outputs=sum_inputs)
|
239 |
+
sum_btn.click(process_sum, inputs=sum_inputs, outputs=sum_out)
|
240 |
+
# demo = gr.TabbedInterface([text_demo, sum_demo], ["Label a Clinical Note", "Summarize Patient History"])
|
241 |
+
demo.launch(share=False)
|
demo_assets.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sum_examples = [
|
2 |
+
[['examples/note%d.txt'%i for i in range(1,n)]]
|
3 |
+
for n in range(5,1, -1)
|
4 |
+
]
|
5 |
+
|
6 |
+
text_examples = [
|
7 |
+
["a 72 year old female with chronic indwelling foley. GNR identified in her blood and there is concern for possible resistant pseudomonas. vanco discontinued and ceftriaxone replaced with zosyn. "],
|
8 |
+
["This is a 73 year old man with CMML with recent admission to OSH for emergent splenectomy after splenic rupture who is admitted for hypoxemia and worsening bilateral ground glass opacities. He was treated aggressively on the floor with antibiotics and other etiologies (PE, MI, etc) were appropriately addressed. He was fluid resusitated and continued on his CMML regimen. Despite this, the patient became progressively hypotensive and was transferred to the ICU for further care.\n"
|
9 |
+
"In the ICU the patient continued to deteriorate and developed progressive hypotension and acidosis despite aggressive fluid repletion, pressor support, and bicarbonate drip. He received >8L NS, 8amps bicarb, pressor support w/ levophed and vasopressin, and maximum ventilatory support. Despite these measures, his lactate continued to trend upwards and he became progressively more hypotensive on the PEEP settings required to adequately oxygenate him. Furthermore, the patient developed tumor lysis syndrome in the setting of his chemotherapy and became anuric producing only 40cc of urine over 8hr. Renal service was called emergently to consider dialysis but the family elected to change his code status to DNR/DNI and focus care on comfort as a priority, after discussion w/ his oncologist Dr [**First Name (STitle) 1557**] and to defer more aggressive therapy."],
|
10 |
+
["48 year old male with complicated past medical history, multiple problems notably including ESRD s/p renal transplant complicated by collapsing FSGS, recent MRSA line sepsis, here with fevers and hypotension at dialysis, code sepsis."
|
11 |
+
"He met sepsis criteria with fever, tachycardia and likely source of infection at site of tunneled dialysis catheter. Also had leukocytosis with L shift. CXR clear, urine not produced for sample. No central line placed [**3-5**] lack of access. Treated with 2 doses linezolid PO; d/w renal team - preferred vanco use, patient switched to vanco by level and d/c on vanco at HD. Underwent stim test; failed, started on hydrocort at stress dose levels (50 q6), d/w renal, felt uneccessary, patient started on prednisone taper back to home dose of 5 mg PO qd. Held HTN meds in setting of sepsis. Received dose of vanco prior to d/c."
|
12 |
+
"Dialysis Catheter - noted morning after admission to be clotted; question whether this was related to blood draw. Instilled tPA in catheter overnight; were able to use cath in AM for HD. "
|
13 |
+
"ESRD - Started on prograf; monitored levels, d/c on home dose. As per pharm, must continue to monitor levels in context of using itraconazole. Continued patient on bactrim for prophylaxis given tacrolimus use. To go to dialysis. 7 point HCT drop noted during admission; thought elevated HCT hemoconcentration. Hemolysis labs neg, no stool to guaiac. "
|
14 |
+
"PTT elevation - noted on admission, resolved in ICU. DIC labs negative. PT/PTT elevation at discharge c/w warfarin/SC heparin use."
|
15 |
+
"Hypertension: History of HTN, on lopressor and diltiazem. "
|
16 |
+
"Pulmonary Aspergillus: Stable. On itraconazole and followed by pulmonary as an outpatient. Continued in house"
|
17 |
+
"Atrial fibrillation: He is normally rate controlled with metoprolol and anticoagulated with coumadin. NSR on EKG here, continued warfarin, held beta blocker . "
|
18 |
+
"..."
|
19 |
+
"Call your PCP or return to the ED for fevers/chills/shakes, chest pain, shortness of breath, pain at the site of your dialysis catheter, nausea, vomiting, or swelling in your legs/feet. "]
|
20 |
+
]
|
electra-base.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64dc780def96ec19006340e180a60531dacd8db9e0e4e206ba57c720e79775d3
|
3 |
+
size 435711089
|
examples/note1.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Date: 2/12/2024
|
2 |
+
History: the patient is a 45 yo F here for nervousness. A few weeks ago she noticed that she was feeling more nervous than usual and that it has been worsening. It is exacerbated by family and work. She feels espeically nervous on Sunday night and Monday morning when she is preparing for the week. She is unable to fall asleep and doesn’t want to eat anything, though she does make herself eat. Nothing helps her nervousness. She otherwise denies significant changes in appetite, weight loss, or overall wellbeing. She denies fevers, chills, nausea, constipation, diarrhea, skin changes, racing heart, shortness of breath, dizziness, headaches or rashes.
|
3 |
+
ROS: otherwise negative
|
4 |
+
PMH: None; PSH: None
|
5 |
+
Meds: Tylenol for occasional HA
|
6 |
+
FHX: Father had an MI, died at 65yo
|
7 |
+
Allergies: NKDA
|
8 |
+
SH: Lives at home with husband, mother, and youngest son. Is an english literature professor at a local college.
|
9 |
+
Has 2 drinks/mo, no tobacco or drug use.
|
10 |
+
Physical Examination:
|
11 |
+
VS: Blood Pressure: 130/85 mm Hg
|
12 |
+
Heart Rate: 96/min
|
13 |
+
Gen: No acute distress, conversational, thin
|
14 |
+
Neck: No thyromegaly, no lymphadeopathy
|
15 |
+
Heart: RRR, no murmurs, rubs or gallops. Radial pulses +2 bilaterally
|
16 |
+
Lungs: Clear to ascultation bilaterally, no wheezes
|
17 |
+
Psych: Well-groomed. Non-pressured speech, linear though process.
|
examples/note2.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Date: 3/18/2024
|
2 |
+
History: the patient, a 45-year-old female, returns for a follow-up regarding her nervousness. She reports a slight improvement in her symptoms with reduced intensity of nervousness on Sunday nights and Monday mornings. However, she still experiences difficulty sleeping and occasional lack of appetite. She has tried meditation and deep breathing exercises, which have provided minimal relief. No new symptoms have emerged since the last visit. She continues to deny fevers, chills, nausea, or other systemic symptoms.
|
3 |
+
ROS: Negative except as noted.
|
4 |
+
PMH: No changes.
|
5 |
+
PSH: None.
|
6 |
+
Meds: Tylenol for occasional headaches. Started on a trial of low-dose sertraline since the last visit.
|
7 |
+
FHX: No changes.
|
8 |
+
Allergies: NKDA.
|
9 |
+
SH: No changes in social circumstances. Continues to work as an English literature professor.
|
10 |
+
Physical Examination:
|
11 |
+
VS: Blood Pressure: 128/82 mm Hg, Heart Rate: 92/min
|
12 |
+
Gen: Appears more relaxed than the previous visit.
|
13 |
+
Neck: No changes.
|
14 |
+
Heart: Unchanged.
|
15 |
+
Lungs: Clear to auscultation.
|
16 |
+
Psych: Appears slightly more at ease, maintains good eye contact, speech and thought process remain coherent.
|
17 |
+
Assessment/Plan: Improvement noted with sertraline. Will continue the current dose and re-evaluate in 3 months. Encouraged to continue non-pharmacological interventions like meditation and deep breathing exercises. Consider referral to therapy for additional support.
|
examples/note3.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Date: 6/2/2024
|
2 |
+
History: the patient, now 46, reports further improvement in her symptoms of nervousness. She has started seeing a therapist, which she finds helpful. Her sleep has improved, and she no longer experiences significant appetite loss. She has not developed any new symptoms and continues to deny systemic symptoms.
|
3 |
+
ROS: Negative except as noted.
|
4 |
+
PMH: No changes.
|
5 |
+
PSH: None.
|
6 |
+
Meds: Continues sertraline. No longer using Tylenol for headaches.
|
7 |
+
FHX: No changes.
|
8 |
+
Allergies: NKDA.
|
9 |
+
SH: Stable home and work life. Activities and responsibilities as an English literature professor are well-managed.
|
10 |
+
Physical Examination:
|
11 |
+
VS: Blood Pressure: 125/80 mm Hg, Heart Rate: 88/min
|
12 |
+
Gen: Appears comfortable and at ease.
|
13 |
+
Neck: No changes.
|
14 |
+
Heart: Unchanged.
|
15 |
+
Lungs: Clear to auscultation.
|
16 |
+
Psych: Noticeable improvement in mood and anxiety levels. Reports feeling more in control.
|
17 |
+
Assessment/Plan: Significant improvement with sertraline and therapy. Plan to continue current management and follow up in 6 months or as needed. Discuss potential for gradually reducing medication under supervision if improvement sustains.
|
examples/note4.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Date: 12/29/2024
|
2 |
+
History: the patient comes in for a scheduled follow-up. She feels much better and has managed to maintain her improvements. She expresses a desire to start tapering off sertraline under medical supervision. No new health concerns have been noted. She remains active at work and home.
|
3 |
+
ROS: Entirely negative.
|
4 |
+
PMH: No changes.
|
5 |
+
PSH: None.
|
6 |
+
Meds: Sertraline, with a plan to taper.
|
7 |
+
FHX: No changes.
|
8 |
+
Allergies: NKDA.
|
9 |
+
SH: Stable and positive home and work environment.
|
10 |
+
Physical Examination:
|
11 |
+
VS: Blood Pressure: 122/78 mm Hg, Heart Rate: 84/min
|
12 |
+
Gen: Looks healthy and content.
|
13 |
+
Neck: No changes.
|
14 |
+
Heart: Unchanged.
|
15 |
+
Lungs: Clear.
|
16 |
+
Psych: Maintained improvement in mental health. Ready for gradual medication reduction.
|
17 |
+
Assessment/Plan: Patient has shown sustained improvement and is interested in tapering off medication. Will initiate a slow tapering process of sertraline and monitor closely for any recurrence of symptoms. Continue therapy and supportive measures. Next follow-up scheduled in 3 months to assess progress.
|
model.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from transformers import AutoModel
|
5 |
+
from torch.optim import AdamW
|
6 |
+
from transformers import get_linear_schedule_with_warmup
|
7 |
+
# from torchcrf import CRF
|
8 |
+
|
9 |
+
class MyModel(nn.Module):
|
10 |
+
def __init__(self, args, backbone):
|
11 |
+
super().__init__()
|
12 |
+
self.args = args
|
13 |
+
self.backbone = backbone
|
14 |
+
self.cls_id = 0
|
15 |
+
hidden_dim = self.backbone.config.hidden_size
|
16 |
+
self.classifier = nn.Sequential(
|
17 |
+
nn.Dropout(0.1),
|
18 |
+
nn.Linear(hidden_dim, args.num_labels)
|
19 |
+
)
|
20 |
+
|
21 |
+
if args.distil_att:
|
22 |
+
self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size))
|
23 |
+
|
24 |
+
def forward(self, x, mask):
|
25 |
+
x = x.to(self.backbone.device)
|
26 |
+
mask = mask.to(self.backbone.device)
|
27 |
+
out = self.backbone(x, attention_mask = mask, output_attentions=True)
|
28 |
+
return out, self.classifier(out.last_hidden_state)
|
29 |
+
|
30 |
+
def decisions(self, x, mask):
|
31 |
+
x = x.to(self.backbone.device)
|
32 |
+
mask = mask.to(self.backbone.device)
|
33 |
+
out = self.backbone(x, attention_mask = mask, output_attentions=False)
|
34 |
+
return out, self.classifier(out.last_hidden_state)
|
35 |
+
|
36 |
+
def phenos(self, x, mask):
|
37 |
+
x = x.to(self.backbone.device)
|
38 |
+
mask = mask.to(self.backbone.device)
|
39 |
+
out = self.backbone(x, attention_mask = mask, output_attentions=True)
|
40 |
+
return out, self.classifier(out.pooler_output)
|
41 |
+
|
42 |
+
def generate(self, x, mask, choice=None):
|
43 |
+
outs = []
|
44 |
+
if self.args.task == 'seq' or choice == 'seq':
|
45 |
+
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)):
|
46 |
+
if i == 0:
|
47 |
+
segment = x[:, offset:offset + self.args.max_len-1]
|
48 |
+
segment_mask = mask[:, offset:offset + self.args.max_len-1]
|
49 |
+
else:
|
50 |
+
segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\
|
51 |
+
*self.cls_id,
|
52 |
+
x[:, offset:offset + self.args.max_len-1]), axis=1)
|
53 |
+
segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device),
|
54 |
+
mask[:, offset:offset + self.args.max_len-1]), axis=1)
|
55 |
+
logits = self.phenos(segment, segment_mask)[1]
|
56 |
+
outs.append(logits)
|
57 |
+
|
58 |
+
return torch.max(torch.stack(outs, 1), 1).values
|
59 |
+
elif self.args.task == 'token':
|
60 |
+
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
|
61 |
+
segment = x[:, offset:offset + self.args.max_len]
|
62 |
+
segment_mask = mask[:, offset:offset + self.args.max_len]
|
63 |
+
h = self.decisions(segment, segment_mask)[0].last_hidden_state
|
64 |
+
outs.append(h)
|
65 |
+
h = torch.cat(outs, 1)
|
66 |
+
return self.classifier(h)
|
67 |
+
|
68 |
+
class CNN(nn.Module):
|
69 |
+
def __init__(self, args):
|
70 |
+
super().__init__()
|
71 |
+
self.emb = nn.Embedding(args.vocab_size, args.emb_size)
|
72 |
+
self.model = nn.Sequential(
|
73 |
+
nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0],
|
74 |
+
padding='same' if args.task == 'token' else 'valid'),
|
75 |
+
nn.ReLU(),
|
76 |
+
nn.MaxPool1d(1),
|
77 |
+
nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1],
|
78 |
+
padding='same' if args.task == 'token' else 'valid'),
|
79 |
+
nn.ReLU(),
|
80 |
+
nn.MaxPool1d(1),
|
81 |
+
nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2],
|
82 |
+
padding='same' if args.task == 'token' else 'valid'),
|
83 |
+
nn.ReLU(),
|
84 |
+
nn.MaxPool1d(1),
|
85 |
+
)
|
86 |
+
if args.task == 'seq':
|
87 |
+
out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3
|
88 |
+
elif args.task == 'token':
|
89 |
+
out_shape = 1
|
90 |
+
self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels)
|
91 |
+
self.dropout = nn.Dropout()
|
92 |
+
self.args = args
|
93 |
+
self.device = None
|
94 |
+
|
95 |
+
def forward(self, x, _):
|
96 |
+
x = x.to(self.device)
|
97 |
+
bs = x.shape[0]
|
98 |
+
x = self.emb(x)
|
99 |
+
x = x.transpose(1,2)
|
100 |
+
x = self.model(x)
|
101 |
+
x = self.dropout(x)
|
102 |
+
if self.args.task == 'token':
|
103 |
+
x = x.transpose(1,2)
|
104 |
+
h = self.classifier(x)
|
105 |
+
return x, h
|
106 |
+
elif self.args.task == 'seq':
|
107 |
+
x = x.reshape(bs, -1)
|
108 |
+
x = self.classifier(x)
|
109 |
+
return x
|
110 |
+
|
111 |
+
def generate(self, x, _):
|
112 |
+
outs = []
|
113 |
+
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
|
114 |
+
segment = x[:, offset:offset + self.args.max_len]
|
115 |
+
n = segment.shape[1]
|
116 |
+
if n != self.args.max_len:
|
117 |
+
segment = torch.nn.functional.pad(segment, (0, self.args.max_len - n))
|
118 |
+
if self.args.task == 'seq':
|
119 |
+
logits = self(segment, None)
|
120 |
+
outs.append(logits)
|
121 |
+
elif self.args.task == 'token':
|
122 |
+
h = self(segment, None)[0]
|
123 |
+
h = h[:,:n]
|
124 |
+
outs.append(h)
|
125 |
+
if self.args.task == 'seq':
|
126 |
+
return torch.max(torch.stack(outs, 1), 1).values
|
127 |
+
elif self.args.task == 'token':
|
128 |
+
h = torch.cat(outs, 1)
|
129 |
+
return self.classifier(h)
|
130 |
+
|
131 |
+
class LSTM(nn.Module):
|
132 |
+
def __init__(self, args):
|
133 |
+
super().__init__()
|
134 |
+
self.emb = nn.Embedding(args.vocab_size, args.emb_size)
|
135 |
+
self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers,
|
136 |
+
batch_first=True, bidirectional=True)
|
137 |
+
dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size
|
138 |
+
self.classifier = nn.Linear(dim, args.num_labels)
|
139 |
+
self.dropout = nn.Dropout()
|
140 |
+
self.args = args
|
141 |
+
self.device = None
|
142 |
+
|
143 |
+
def forward(self, x, _):
|
144 |
+
x = x.to(self.device)
|
145 |
+
x = self.emb(x)
|
146 |
+
o, (x, _) = self.model(x)
|
147 |
+
o_out = self.classifier(o) if self.args.task == 'token' else None
|
148 |
+
if self.args.task == 'seq':
|
149 |
+
x = torch.cat([h for h in x], 1)
|
150 |
+
x = self.dropout(x)
|
151 |
+
x = self.classifier(x)
|
152 |
+
return (x, o), o_out
|
153 |
+
|
154 |
+
def generate(self, x, _):
|
155 |
+
outs = []
|
156 |
+
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
|
157 |
+
segment = x[:, offset:offset + self.args.max_len]
|
158 |
+
if self.args.task == 'seq':
|
159 |
+
logits = self(segment, None)[0][0]
|
160 |
+
outs.append(logits)
|
161 |
+
elif self.args.task == 'token':
|
162 |
+
h = self(segment, None)[0][1]
|
163 |
+
outs.append(h)
|
164 |
+
if self.args.task == 'seq':
|
165 |
+
return torch.max(torch.stack(outs, 1), 1).values
|
166 |
+
elif self.args.task == 'token':
|
167 |
+
h = torch.cat(outs, 1)
|
168 |
+
return self.classifier(h)
|
169 |
+
|
170 |
+
def load_model(args, device):
|
171 |
+
if args.model == 'lstm':
|
172 |
+
model = LSTM(args).to(device)
|
173 |
+
model.device = device
|
174 |
+
elif args.model == 'cnn':
|
175 |
+
model = CNN(args).to(device)
|
176 |
+
model.device = device
|
177 |
+
else:
|
178 |
+
model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device)
|
179 |
+
if args.ckpt:
|
180 |
+
model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False)
|
181 |
+
if args.distil:
|
182 |
+
args2 = copy.deepcopy(args)
|
183 |
+
args2.task = 'token'
|
184 |
+
# args2.num_labels = args.num_decs
|
185 |
+
args2.num_labels = args.num_umls_tags
|
186 |
+
model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device)
|
187 |
+
model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False)
|
188 |
+
for p in model_B.parameters():
|
189 |
+
p.requires_grad = False
|
190 |
+
else:
|
191 |
+
model_B = None
|
192 |
+
if args.label_encoding == 'multiclass':
|
193 |
+
if args.use_crf:
|
194 |
+
crit = CRF(args.num_labels, batch_first = True).to(device)
|
195 |
+
else:
|
196 |
+
crit = nn.CrossEntropyLoss(reduction='none')
|
197 |
+
else:
|
198 |
+
crit = nn.BCEWithLogitsLoss(
|
199 |
+
pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight,
|
200 |
+
reduction='none'
|
201 |
+
)
|
202 |
+
optimizer = AdamW(model.parameters(), lr=args.lr)
|
203 |
+
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
204 |
+
int(0.1*args.total_steps), args.total_steps)
|
205 |
+
|
206 |
+
return model, crit, optimizer, lr_scheduler, model_B
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
argparse
|
2 |
+
numpy
|
3 |
+
pandas
|
4 |
+
torch==1.13.1
|
5 |
+
transformers==4.38.1
|