DRGCoder / model.py
danielhajialigol's picture
pushing
f73dc21
raw
history blame
No virus
4.58 kB
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from torch.nn import Linear, Module
from typing import Dict, List
from collections import Counter, defaultdict
from itertools import chain
import torch
class MimicTransformer(Module):
def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512):
"""
:param args:
"""
super().__init__()
self.tokenizer_name = self.find_tokenizer(tokenizer_name)
self.num_labels = num_labels
self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels)
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config)
self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config)
if 'longformer' in self.tokenizer_name:
self.cutoff = self.model.config.max_position_embeddings
else:
self.cutoff = cutoff
self.linear = Linear(in_features=self.cutoff, out_features=1)
def parse_icds(self, instances: List[Dict]):
token_list = defaultdict(set)
token_freq_list = []
for instance in instances:
icds = list(chain(*instance['icd']))
icd_dict_list = list({icd['start']: icd for icd in icds}.values())
for icd_dict in icd_dict_list:
icd_ent = icd_dict['text']
icd_tokenized = self.tokenizer(icd_ent, add_special_tokens=False)['input_ids']
icd_dict['tokens'] = icd_tokenized
icd_dict['labels'] = []
for i,token in enumerate(icd_tokenized):
if i != 0:
label = "I-ATTN"
else:
label = "B-ATTN"
icd_dict['labels'].append(label)
token_list[token].add(label)
token_freq_list.append(str(token) + ": " + label)
token_tag_freqs = Counter(token_freq_list)
for token in token_list:
if len(token_list[token]) == 2:
inside_count = token_tag_freqs[str(token) + ": I-ATTN"]
begin_count = token_tag_freqs[str(token) + ": B-ATTN"]
if begin_count > inside_count:
token_list[token].remove('I-ATTN')
else:
token_list[token].remove('B-ATTN')
return token_list
def collate_mimic(
self, instances: List[Dict], device='cuda'
):
tokenized = [
self.tokenizer.encode(
' '.join(instance['description']), max_length=self.cutoff, truncation=True, padding='max_length'
) for instance in instances
]
entries = [instance['entry'] for instance in instances]
labels = torch.tensor([x['drg'] for x in instances], dtype=torch.long).to(device).unsqueeze(1)
inputs = torch.tensor(tokenized, dtype=torch.long).to(device)
icds = self.parse_icds(instances)
xai_labels = torch.zeros(size=inputs.shape, dtype=torch.float32).to(device)
for i,row in enumerate(inputs):
for j,ele in enumerate(row):
if ele.item() in icds:
xai_labels[i][j] = 1
return {
'text': inputs,
'drg': labels,
'entry': entries,
'icds': icds,
'xai': xai_labels
}
def forward(self, input_ids, attention_mask=None, drg_labels=None):
if drg_labels:
cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
else:
cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
# last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens)
last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
return (cls_results, xai_logits)
def find_tokenizer(self, tokenizer_name):
"""
:param args:
:return:
"""
if tokenizer_name == 'clinical_longformer':
return 'yikuan8/Clinical-Longformer'
if tokenizer_name == 'clinical':
return 'emilyalsentzer/Bio_ClinicalBERT'
else:
# standard transformer
return 'bert-based-uncased'