import copy import torch from torch import nn from transformers import AutoModel from torch.optim import AdamW from transformers import get_linear_schedule_with_warmup # from torchcrf import CRF class MyModel(nn.Module): def __init__(self, args, backbone): super().__init__() self.args = args self.backbone = backbone self.cls_id = 0 hidden_dim = self.backbone.config.hidden_size self.classifier = nn.Sequential( nn.Dropout(0.1), nn.Linear(hidden_dim, args.num_labels) ) if args.distil_att: self.distil_att = nn.Parameter(torch.ones(self.backbone.config.hidden_size)) def forward(self, x, mask): x = x.to(self.backbone.device) mask = mask.to(self.backbone.device) out = self.backbone(x, attention_mask = mask, output_attentions=True) return out, self.classifier(out.last_hidden_state) def decisions(self, x, mask): x = x.to(self.backbone.device) mask = mask.to(self.backbone.device) out = self.backbone(x, attention_mask = mask, output_attentions=False) return out, self.classifier(out.last_hidden_state) def phenos(self, x, mask): x = x.to(self.backbone.device) mask = mask.to(self.backbone.device) out = self.backbone(x, attention_mask = mask, output_attentions=True) return out, self.classifier(out.pooler_output) def generate(self, x, mask, choice=None): outs = [] if self.args.task == 'seq' or choice == 'seq': for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)): if i == 0: segment = x[:, offset:offset + self.args.max_len-1] segment_mask = mask[:, offset:offset + self.args.max_len-1] else: segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\ *self.cls_id, x[:, offset:offset + self.args.max_len-1]), axis=1) segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device), mask[:, offset:offset + self.args.max_len-1]), axis=1) logits = self.phenos(segment, segment_mask)[1] outs.append(logits) return torch.max(torch.stack(outs, 1), 1).values elif self.args.task == 'token': for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): segment = x[:, offset:offset + self.args.max_len] segment_mask = mask[:, offset:offset + self.args.max_len] h = self.decisions(segment, segment_mask)[0].last_hidden_state outs.append(h) h = torch.cat(outs, 1) return self.classifier(h) class CNN(nn.Module): def __init__(self, args): super().__init__() self.emb = nn.Embedding(args.vocab_size, args.emb_size) self.model = nn.Sequential( nn.Conv1d(args.emb_size, args.hidden_size, args.kernels[0], padding='same' if args.task == 'token' else 'valid'), nn.ReLU(), nn.MaxPool1d(1), nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[1], padding='same' if args.task == 'token' else 'valid'), nn.ReLU(), nn.MaxPool1d(1), nn.Conv1d(args.hidden_size, args.hidden_size, args.kernels[2], padding='same' if args.task == 'token' else 'valid'), nn.ReLU(), nn.MaxPool1d(1), ) if args.task == 'seq': out_shape = 512 - args.kernels[0] - args.kernels[1] - args.kernels[2] + 3 elif args.task == 'token': out_shape = 1 self.classifier = nn.Linear(args.hidden_size*out_shape, args.num_labels) self.dropout = nn.Dropout() self.args = args self.device = None def forward(self, x, _): x = x.to(self.device) bs = x.shape[0] x = self.emb(x) x = x.transpose(1,2) x = self.model(x) x = self.dropout(x) if self.args.task == 'token': x = x.transpose(1,2) h = self.classifier(x) return x, h elif self.args.task == 'seq': x = x.reshape(bs, -1) x = self.classifier(x) return x def generate(self, x, _): outs = [] for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): segment = x[:, offset:offset + self.args.max_len] n = segment.shape[1] if n != self.args.max_len: segment = torch.nn.functional.pad(segment, (0, self.args.max_len - n)) if self.args.task == 'seq': logits = self(segment, None) outs.append(logits) elif self.args.task == 'token': h = self(segment, None)[0] h = h[:,:n] outs.append(h) if self.args.task == 'seq': return torch.max(torch.stack(outs, 1), 1).values elif self.args.task == 'token': h = torch.cat(outs, 1) return self.classifier(h) class LSTM(nn.Module): def __init__(self, args): super().__init__() self.emb = nn.Embedding(args.vocab_size, args.emb_size) self.model = nn.LSTM(args.emb_size, args.hidden_size, num_layers=args.num_layers, batch_first=True, bidirectional=True) dim = 2*args.num_layers*args.hidden_size if args.task == 'seq' else 2*args.hidden_size self.classifier = nn.Linear(dim, args.num_labels) self.dropout = nn.Dropout() self.args = args self.device = None def forward(self, x, _): x = x.to(self.device) x = self.emb(x) o, (x, _) = self.model(x) o_out = self.classifier(o) if self.args.task == 'token' else None if self.args.task == 'seq': x = torch.cat([h for h in x], 1) x = self.dropout(x) x = self.classifier(x) return (x, o), o_out def generate(self, x, _): outs = [] for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): segment = x[:, offset:offset + self.args.max_len] if self.args.task == 'seq': logits = self(segment, None)[0][0] outs.append(logits) elif self.args.task == 'token': h = self(segment, None)[0][1] outs.append(h) if self.args.task == 'seq': return torch.max(torch.stack(outs, 1), 1).values elif self.args.task == 'token': h = torch.cat(outs, 1) return self.classifier(h) def load_model(args, device): if args.model == 'lstm': model = LSTM(args).to(device) model.device = device elif args.model == 'cnn': model = CNN(args).to(device) model.device = device else: model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device) if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=False) if args.distil: args2 = copy.deepcopy(args) args2.task = 'token' # args2.num_labels = args.num_decs args2.num_labels = args.num_umls_tags model_B = MyModel(args2, AutoModel.from_pretrained(args.model_name)).to(device) model_B.load_state_dict(torch.load(args.distil_ckpt, map_location=device), strict=False) for p in model_B.parameters(): p.requires_grad = False else: model_B = None if args.label_encoding == 'multiclass': if args.use_crf: crit = CRF(args.num_labels, batch_first = True).to(device) else: crit = nn.CrossEntropyLoss(reduction='none') else: crit = nn.BCEWithLogitsLoss( pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight, reduction='none' ) optimizer = AdamW(model.parameters(), lr=args.lr) lr_scheduler = get_linear_schedule_with_warmup(optimizer, int(0.1*args.total_steps), args.total_steps) return model, crit, optimizer, lr_scheduler, model_B