from collections import defaultdict from typing import List, Tuple, Dict import torch from torch import nn from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader import random class InstructBase(nn.Module): def __init__(self, config): super().__init__() self.max_width = config.max_width self.base_config = config def get_dict(self, spans, classes_to_id): dict_tag = defaultdict(int) for span in spans: if span[2] in classes_to_id: dict_tag[(span[0], span[1])] = classes_to_id[span[2]] return dict_tag def preprocess_spans(self, tokens, ner, classes_to_id): max_len = self.base_config.max_len if len(tokens) > max_len: length = max_len tokens = tokens[:max_len] else: length = len(tokens) spans_idx = [] for i in range(length): spans_idx.extend([(i, i + j) for j in range(self.max_width)]) dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) # 0 for null labels span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) spans_idx = torch.LongTensor(spans_idx) # mask for valid spans valid_span_mask = spans_idx[:, 1] > length - 1 # mask invalid positions span_label = span_label.masked_fill(valid_span_mask, -1) return { 'tokens': tokens, 'span_idx': spans_idx, 'span_label': span_label, 'seq_length': length, 'entities': ner, } def collate_fn(self, batch_list, entity_types=None): # batch_list: list of dict containing tokens, ner if entity_types is None: negs = self.get_negatives(batch_list, 100) class_to_ids = [] id_to_classes = [] for b in batch_list: # negs = b["negative"] random.shuffle(negs) # negs = negs[:sampled_neg] max_neg_type_ratio = int(self.base_config.max_neg_type_ratio) if max_neg_type_ratio == 0: # no negatives neg_type_ratio = 0 else: neg_type_ratio = random.randint(0, max_neg_type_ratio) if neg_type_ratio == 0: # no negatives negs_i = [] else: negs_i = negs[:len(b['ner']) * neg_type_ratio] # this is the list of all possible entity types (positive and negative) types = list(set([el[-1] for el in b['ner']] + negs_i)) # shuffle (every epoch) random.shuffle(types) if len(types) != 0: # prob of higher number shoul # random drop if self.base_config.random_drop: num_ents = random.randint(1, len(types)) types = types[:num_ents] # maximum number of entities types types = types[:int(self.base_config.max_types)] # supervised training if "label" in b: types = sorted(b["label"]) class_to_id = {k: v for v, k in enumerate(types, start=1)} id_to_class = {k: v for v, k in class_to_id.items()} class_to_ids.append(class_to_id) id_to_classes.append(id_to_class) batch = [ self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list) ] else: class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)} id_to_classes = {k: v for v, k in class_to_ids.items()} batch = [ self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list ] span_idx = pad_sequence( [b['span_idx'] for b in batch], batch_first=True, padding_value=0 ) span_label = pad_sequence( [el['span_label'] for el in batch], batch_first=True, padding_value=-1 ) return { 'seq_length': torch.LongTensor([el['seq_length'] for el in batch]), 'span_idx': span_idx, 'tokens': [el['tokens'] for el in batch], 'span_mask': span_label != -1, 'span_label': span_label, 'entities': [el['entities'] for el in batch], 'classes_to_id': class_to_ids, 'id_to_classes': id_to_classes, } @staticmethod def get_negatives(batch_list, sampled_neg=5): ent_types = [] for b in batch_list: types = set([el[-1] for el in b['ner']]) ent_types.extend(list(types)) ent_types = list(set(ent_types)) # sample negatives random.shuffle(ent_types) return ent_types[:sampled_neg] def create_dataloader(self, data, entity_types=None, **kwargs): return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)