from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification from torch.utils.data import DataLoader from torch.nn import Linear, Module from typing import Dict, List from collections import Counter, defaultdict from itertools import chain import torch class MimicTransformer(Module): def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512): """ :param args: """ super().__init__() self.tokenizer_name = self.find_tokenizer(tokenizer_name) self.num_labels = num_labels self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels) self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config) self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config) if 'longformer' in self.tokenizer_name: self.cutoff = self.model.config.max_position_embeddings else: self.cutoff = cutoff self.linear = Linear(in_features=self.cutoff, out_features=1) def parse_icds(self, instances: List[Dict]): token_list = defaultdict(set) token_freq_list = [] for instance in instances: icds = list(chain(*instance['icd'])) icd_dict_list = list({icd['start']: icd for icd in icds}.values()) for icd_dict in icd_dict_list: icd_ent = icd_dict['text'] icd_tokenized = self.tokenizer(icd_ent, add_special_tokens=False)['input_ids'] icd_dict['tokens'] = icd_tokenized icd_dict['labels'] = [] for i,token in enumerate(icd_tokenized): if i != 0: label = "I-ATTN" else: label = "B-ATTN" icd_dict['labels'].append(label) token_list[token].add(label) token_freq_list.append(str(token) + ": " + label) token_tag_freqs = Counter(token_freq_list) for token in token_list: if len(token_list[token]) == 2: inside_count = token_tag_freqs[str(token) + ": I-ATTN"] begin_count = token_tag_freqs[str(token) + ": B-ATTN"] if begin_count > inside_count: token_list[token].remove('I-ATTN') else: token_list[token].remove('B-ATTN') return token_list def collate_mimic( self, instances: List[Dict], device='cuda' ): tokenized = [ self.tokenizer.encode( ' '.join(instance['description']), max_length=self.cutoff, truncation=True, padding='max_length' ) for instance in instances ] entries = [instance['entry'] for instance in instances] labels = torch.tensor([x['drg'] for x in instances], dtype=torch.long).to(device).unsqueeze(1) inputs = torch.tensor(tokenized, dtype=torch.long).to(device) icds = self.parse_icds(instances) xai_labels = torch.zeros(size=inputs.shape, dtype=torch.float32).to(device) for i,row in enumerate(inputs): for j,ele in enumerate(row): if ele.item() in icds: xai_labels[i][j] = 1 return { 'text': inputs, 'drg': labels, 'entry': entries, 'icds': icds, 'xai': xai_labels } def forward(self, input_ids, attention_mask=None, drg_labels=None): if drg_labels: cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True) else: cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True) # last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens) last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0) last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1) xai_logits = self.linear(last_layer_attn).squeeze(dim=-1) return (cls_results, xai_logits) def find_tokenizer(self, tokenizer_name): """ :param args: :return: """ if tokenizer_name == 'clinical_longformer': return 'yikuan8/Clinical-Longformer' if tokenizer_name == 'clinical': return 'emilyalsentzer/Bio_ClinicalBERT' else: # standard transformer return 'bert-based-uncased'