Clinical_Decisions / model.py
mohdelgaar's picture
Initial commit
c47c7dc
raw
history blame contribute delete
No virus
8.37 kB
import copy
import torch
from torch import nn
from transformers import AutoModel
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
# from torchcrf import CRF
class MyModel(nn.Module):
def __init__(self, args, backbone):
super().__init__()
self.args = args
self.backbone = backbone
self.cls_id = 0
hidden_dim = self.backbone.config.hidden_size
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_dim, args.num_labels)
)
if args.distil_att:
self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size))
def forward(self, x, mask):
x = x.to(self.backbone.device)
mask = mask.to(self.backbone.device)
out = self.backbone(x, attention_mask = mask, output_attentions=True)
return out, self.classifier(out.last_hidden_state)
def decisions(self, x, mask):
x = x.to(self.backbone.device)
mask = mask.to(self.backbone.device)
out = self.backbone(x, attention_mask = mask, output_attentions=False)
return out, self.classifier(out.last_hidden_state)
def phenos(self, x, mask):
x = x.to(self.backbone.device)
mask = mask.to(self.backbone.device)
out = self.backbone(x, attention_mask = mask, output_attentions=True)
return out, self.classifier(out.pooler_output)
def generate(self, x, mask, choice=None):
outs = []
if self.args.task == 'seq' or choice == 'seq':
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)):
if i == 0:
segment = x[:, offset:offset + self.args.max_len-1]
segment_mask = mask[:, offset:offset + self.args.max_len-1]
else:
segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\
*self.cls_id,
x[:, offset:offset + self.args.max_len-1]), axis=1)
segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device),
mask[:, offset:offset + self.args.max_len-1]), axis=1)
logits = self.phenos(segment, segment_mask)[1]
outs.append(logits)
return torch.max(torch.stack(outs, 1), 1).values
elif self.args.task == 'token':
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
segment = x[:, offset:offset + self.args.max_len]
segment_mask = mask[:, offset:offset + self.args.max_len]
h = self.decisions(segment, segment_mask)[0].last_hidden_state
outs.append(h)
h = torch.cat(outs, 1)
return self.classifier(h)
class CNN(nn.Module):
def __init__(self, args):
super().__init__()
self.emb = nn.Embedding(args.vocab_size, args.emb_size)
self.model = nn.Sequential(
nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0],
padding='same' if args.task == 'token' else 'valid'),
nn.ReLU(),
nn.MaxPool1d(1),
nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1],
padding='same' if args.task == 'token' else 'valid'),
nn.ReLU(),
nn.MaxPool1d(1),
nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2],
padding='same' if args.task == 'token' else 'valid'),
nn.ReLU(),
nn.MaxPool1d(1),
)
if args.task == 'seq':
out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3
elif args.task == 'token':
out_shape = 1
self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels)
self.dropout = nn.Dropout()
self.args = args
self.device = None
def forward(self, x, _):
x = x.to(self.device)
bs = x.shape[0]
x = self.emb(x)
x = x.transpose(1,2)
x = self.model(x)
x = self.dropout(x)
if self.args.task == 'token':
x = x.transpose(1,2)
h = self.classifier(x)
return x, h
elif self.args.task == 'seq':
x = x.reshape(bs, -1)
x = self.classifier(x)
return x
def generate(self, x, _):
outs = []
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
segment = x[:, offset:offset + self.args.max_len]
n = segment.shape[1]
if n != self.args.max_len:
segment = torch.nn.functional.pad(segment, (0, self.args.max_len - n))
if self.args.task == 'seq':
logits = self(segment, None)
outs.append(logits)
elif self.args.task == 'token':
h = self(segment, None)[0]
h = h[:,:n]
outs.append(h)
if self.args.task == 'seq':
return torch.max(torch.stack(outs, 1), 1).values
elif self.args.task == 'token':
h = torch.cat(outs, 1)
return self.classifier(h)
class LSTM(nn.Module):
def __init__(self, args):
super().__init__()
self.emb = nn.Embedding(args.vocab_size, args.emb_size)
self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers,
batch_first=True, bidirectional=True)
dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size
self.classifier = nn.Linear(dim, args.num_labels)
self.dropout = nn.Dropout()
self.args = args
self.device = None
def forward(self, x, _):
x = x.to(self.device)
x = self.emb(x)
o, (x, _) = self.model(x)
o_out = self.classifier(o) if self.args.task == 'token' else None
if self.args.task == 'seq':
x = torch.cat([h for h in x], 1)
x = self.dropout(x)
x = self.classifier(x)
return (x, o), o_out
def generate(self, x, _):
outs = []
for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)):
segment = x[:, offset:offset + self.args.max_len]
if self.args.task == 'seq':
logits = self(segment, None)[0][0]
outs.append(logits)
elif self.args.task == 'token':
h = self(segment, None)[0][1]
outs.append(h)
if self.args.task == 'seq':
return torch.max(torch.stack(outs, 1), 1).values
elif self.args.task == 'token':
h = torch.cat(outs, 1)
return self.classifier(h)
def load_model(args, device):
if args.model == 'lstm':
model = LSTM(args).to(device)
model.device = device
elif args.model == 'cnn':
model = CNN(args).to(device)
model.device = device
else:
model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device)
if args.ckpt:
model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False)
if args.distil:
args2 = copy.deepcopy(args)
args2.task = 'token'
# args2.num_labels = args.num_decs
args2.num_labels = args.num_umls_tags
model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device)
model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False)
for p in model_B.parameters():
p.requires_grad = False
else:
model_B = None
if args.label_encoding == 'multiclass':
if args.use_crf:
crit = CRF(args.num_labels, batch_first = True).to(device)
else:
crit = nn.CrossEntropyLoss(reduction='none')
else:
crit = nn.BCEWithLogitsLoss(
pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight,
reduction='none'
)
optimizer = AdamW(model.parameters(), lr=args.lr)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
int(0.1*args.total_steps), args.total_steps)
return model, crit, optimizer, lr_scheduler, model_B