import random import re from collections import defaultdict from typing import Iterable, Iterator, List, MutableSet, Optional, Tuple, TypeVar, Union import torch import torch.nn.functional as F from rex.data.collate_fn import GeneralCollateFn from rex.data.transforms.base import CachedTransformBase, CachedTransformOneBase from rex.metrics import calc_p_r_f1_from_tp_fp_fn from rex.utils.io import load_json from rex.utils.iteration import windowed_queue_iter from rex.utils.logging import logger from transformers import AutoTokenizer from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast from transformers.models.deberta_v2.tokenization_deberta_v2_fast import ( DebertaV2TokenizerFast, ) from transformers.tokenization_utils_base import BatchEncoding from src.utils import ( decode_nnw_nsw_thw_mat, decode_nnw_thw_mat, encode_nnw_nsw_thw_mat, encode_nnw_thw_mat, ) Filled = TypeVar("Filled") class PaddingMixin: max_seq_len: int def pad_seq(self, batch_seqs: Iterable[Filled], fill: Filled) -> Iterable[Filled]: max_len = max(len(seq) for seq in batch_seqs) assert max_len <= self.max_seq_len for i in range(len(batch_seqs)): batch_seqs[i] = batch_seqs[i] + [fill] * (max_len - len(batch_seqs[i])) return batch_seqs def pad_mat( self, mats: List[torch.Tensor], fill: Union[int, float] ) -> List[torch.Tensor]: max_len = max(mat.shape[0] for mat in mats) assert max_len <= self.max_seq_len for i in range(len(mats)): num_add = max_len - mats[i].shape[0] mats[i] = F.pad( mats[i], (0, 0, 0, num_add, 0, num_add), mode="constant", value=fill ) return mats class PointerTransformMixin: tokenizer: BertTokenizerFast max_seq_len: int space_token: str = "[unused1]" def build_ins( self, query_tokens: list[str], context_tokens: list[str], answer_indexes: list[list[int]], add_context_tokens: list[str] = None, ) -> Tuple: # -2: cls and sep reserved_seq_len = self.max_seq_len - 3 - len(query_tokens) # reserve at least 20 tokens if reserved_seq_len < 20: raise ValueError( f"Query {query_tokens} too long: {len(query_tokens)} " f"while max seq len is {self.max_seq_len}" ) input_tokens = [self.tokenizer.cls_token] input_tokens += query_tokens input_tokens += [self.tokenizer.sep_token] offset = len(input_tokens) input_tokens += context_tokens[:reserved_seq_len] available_token_range = range( offset, offset + len(context_tokens[:reserved_seq_len]) ) input_tokens += [self.tokenizer.sep_token] add_context_len = 0 max_add_context_len = self.max_seq_len - len(input_tokens) - 1 add_context_flag = False if add_context_tokens and len(add_context_tokens) > 0: add_context_flag = True add_context_len = len(add_context_tokens[:max_add_context_len]) input_tokens += add_context_tokens[:max_add_context_len] input_tokens += [self.tokenizer.sep_token] new_tokens = [] for t in input_tokens: if len(t.strip()) > 0: new_tokens.append(t) else: new_tokens.append(self.space_token) input_tokens = new_tokens input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) mask = [1] mask += [2] * len(query_tokens) mask += [3] mask += [4] * len(context_tokens[:reserved_seq_len]) mask += [5] if add_context_flag: mask += [6] * add_context_len mask += [7] assert len(mask) == len(input_ids) <= self.max_seq_len available_spans = [tuple(i + offset for i in index) for index in answer_indexes] available_spans = list( filter( lambda index: all(i in available_token_range for i in index), available_spans, ) ) token_len = len(input_ids) pad_len = self.max_seq_len - token_len input_tokens += pad_len * [self.tokenizer.pad_token] input_ids += pad_len * [self.tokenizer.pad_token_id] mask += pad_len * [0] return input_tokens, input_ids, mask, offset, available_spans def update_labels(self, data: dict) -> dict: bs = len(data["input_ids"]) seq_len = self.max_seq_len labels = torch.zeros((bs, 2, seq_len, seq_len)) for i, batch_spans in enumerate(data["available_spans"]): # offset = data["offset"][i] # pad_len = data["mask"].count(0) # token_len = seq_len - pad_len for span in batch_spans: if len(span) == 1: labels[i, :, span[0], span[0]] = 1 else: for s, e in windowed_queue_iter(span, 2, 1, drop_last=True): labels[i, 0, s, e] = 1 labels[i, 1, span[-1], span[0]] = 1 # labels[i, :, 0:offset, :] = -100 # labels[i, :, :, 0:offset] = -100 # labels[i, :, :, token_len:] = -100 # labels[i, :, token_len:, :] = -100 data["labels"] = labels return data def update_consecutive_span_labels(self, data: dict) -> dict: bs = len(data["input_ids"]) seq_len = self.max_seq_len labels = torch.zeros((bs, 1, seq_len, seq_len)) for i, batch_spans in enumerate(data["available_spans"]): for span in batch_spans: assert span == tuple(sorted(set(span))) if len(span) == 1: labels[i, 0, span[0], span[0]] = 1 else: labels[i, 0, span[0], span[-1]] = 1 data["labels"] = labels return data class CachedPointerTaggingTransform(CachedTransformBase, PointerTransformMixin): def __init__( self, max_seq_len: int, plm_dir: str, ent_type2query_filepath: str, mode: str = "w2", negative_sample_prob: float = 1.0, ) -> None: super().__init__() self.max_seq_len: int = max_seq_len self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) self.ent_type2query: dict = load_json(ent_type2query_filepath) self.negative_sample_prob = negative_sample_prob self.collate_fn: GeneralCollateFn = GeneralCollateFn( { "input_ids": torch.long, "mask": torch.long, "labels": torch.long, }, guessing=False, missing_key_as_null=True, ) if mode == "w2": self.collate_fn.update_before_tensorify = self.update_labels elif mode == "cons": self.collate_fn.update_before_tensorify = ( self.update_consecutive_span_labels ) else: raise ValueError(f"Mode: {mode} not recognizable") def transform( self, transform_loader: Iterator, dataset_name: str = None, **kwargs, ) -> Iterable: final_data = [] # tp = fp = fn = 0 for data in transform_loader: ent_type2ents = defaultdict(set) for ent in data["ents"]: ent_type2ents[ent["type"]].add(tuple(ent["index"])) for ent_type in self.ent_type2query: gold_ents = ent_type2ents[ent_type] if ( len(gold_ents) < 1 and dataset_name == "train" and random.random() > self.negative_sample_prob ): # skip negative samples continue # res = self.build_ins(ent_type, data["tokens"], gold_ents) query = self.ent_type2query[ent_type] query_tokens = self.tokenizer.tokenize(query) try: res = self.build_ins(query_tokens, data["tokens"], gold_ents) except (ValueError, AssertionError): continue input_tokens, input_ids, mask, offset, available_spans = res ins = { "id": data.get("id", str(len(final_data))), "ent_type": ent_type, "gold_ents": gold_ents, "raw_tokens": data["tokens"], "input_tokens": input_tokens, "input_ids": input_ids, "mask": mask, "offset": offset, "available_spans": available_spans, # labels are dynamically padded in collate fn "labels": None, # "labels": labels.tolist(), } final_data.append(ins) # # upper bound analysis # pred_spans = set(decode_nnw_thw_mat(labels.unsqueeze(0))[0]) # g_ents = set(available_spans) # tp += len(g_ents & pred_spans) # fp += len(pred_spans - g_ents) # fn += len(g_ents - pred_spans) # # upper bound results # measures = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) # logger.info(f"Upper Bound: {measures}") return final_data def predict_transform(self, texts: List[str]): dataset = [] for text_id, text in enumerate(texts): data_id = f"Prediction#{text_id}" tokens = self.tokenizer.tokenize(text) dataset.append( { "id": data_id, "tokens": tokens, "ents": [], } ) final_data = self(dataset, disable_pbar=True) return final_data class CachedPointerMRCTransform(CachedTransformBase, PointerTransformMixin): def __init__( self, max_seq_len: int, plm_dir: str, mode: str = "w2", ) -> None: super().__init__() self.max_seq_len: int = max_seq_len self.tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(plm_dir) self.collate_fn: GeneralCollateFn = GeneralCollateFn( { "input_ids": torch.long, "mask": torch.long, "labels": torch.long, }, guessing=False, missing_key_as_null=True, ) if mode == "w2": self.collate_fn.update_before_tensorify = self.update_labels elif mode == "cons": self.collate_fn.update_before_tensorify = ( self.update_consecutive_span_labels ) else: raise ValueError(f"Mode: {mode} not recognizable") def transform( self, transform_loader: Iterator, dataset_name: str = None, **kwargs, ) -> Iterable: final_data = [] for data in transform_loader: try: res = self.build_ins( data["query_tokens"], data["context_tokens"], data["answer_index"], data.get("background_tokens"), ) except (ValueError, AssertionError): continue input_tokens, input_ids, mask, offset, available_spans = res ins = { "id": data.get("id", str(len(final_data))), "gold_spans": sorted(set(tuple(x) for x in data["answer_index"])), "raw_tokens": data["context_tokens"], "input_tokens": input_tokens, "input_ids": input_ids, "mask": mask, "offset": offset, "available_spans": available_spans, "labels": None, } final_data.append(ins) return final_data def predict_transform(self, data: list[dict]): """ Args: data: a list of dict with query, context, and background strings """ dataset = [] for idx, ins in enumerate(data): idx = f"Prediction#{idx}" dataset.append( { "id": idx, "query_tokens": list(ins["query"]), "context_tokens": list(ins["context"]), "background_tokens": list(ins.get("background")), "answer_index": [], } ) final_data = self(dataset, disable_pbar=True, num_samples=0) return final_data class CachedLabelPointerTransform(CachedTransformOneBase): """Transform for label-token linking for skip consecutive spans""" def __init__( self, max_seq_len: int, plm_dir: str, mode: str = "w2", label_span: str = "tag", include_instructions: bool = True, **kwargs, ) -> None: super().__init__() self.max_seq_len: int = max_seq_len self.mode = mode self.label_span = label_span self.include_instructions = include_instructions self.tokenizer: DebertaV2TokenizerFast = DebertaV2TokenizerFast.from_pretrained( plm_dir ) self.lc_token = "[LC]" self.lm_token = "[LM]" self.lr_token = "[LR]" self.i_token = "[I]" self.tl_token = "[TL]" self.tp_token = "[TP]" self.b_token = "[B]" num_added = self.tokenizer.add_tokens( [ self.lc_token, self.lm_token, self.lr_token, self.i_token, self.tl_token, self.tp_token, self.b_token, ] ) assert num_added == 7 self.collate_fn: GeneralCollateFn = GeneralCollateFn( { "input_ids": torch.long, "mask": torch.long, "labels": torch.long, "spans": None, }, guessing=False, missing_key_as_null=True, # only for pre-training discard_missing=False, ) self.collate_fn.update_before_tensorify = self.skip_consecutive_span_labels def transform(self, instance: dict, **kwargs): # input tokens = [self.tokenizer.cls_token] mask = [1] label_map = {"lc": {}, "lm": {}, "lr": {}} # (2, 3): {"type": "lc", "task": "cls/ent/rel/event/hyper_rel/discontinuous_ent", "string": ""} span_to_label = {} def _update_seq( label: str, label_type: str, task: str = "", label_mask: int = 4, content_mask: int = 5, ): if label not in label_map[label_type]: label_token_map = { "lc": self.lc_token, "lm": self.lm_token, "lr": self.lr_token, } label_tag_start_idx = len(tokens) tokens.append(label_token_map[label_type]) mask.append(label_mask) label_tag_end_idx = len(tokens) - 1 # exact end position label_tokens = self.tokenizer(label, add_special_tokens=False).tokens() label_content_start_idx = len(tokens) tokens.extend(label_tokens) mask.extend([content_mask] * len(label_tokens)) label_content_end_idx = len(tokens) - 1 # exact end position if self.label_span == "tag": start_idx = label_tag_start_idx end_idx = label_tag_end_idx elif self.label_span == "content": start_idx = label_content_start_idx end_idx = label_content_end_idx else: raise ValueError(f"label_span={self.label_span} is not supported") if end_idx == start_idx: label_map[label_type][label] = (start_idx,) else: label_map[label_type][label] = (start_idx, end_idx) span_to_label[label_map[label_type][label]] = { "type": label_type, "task": task, "string": label, } return label_map[label_type][label] if self.include_instructions: instruction = instance.get("instruction") if not instruction: logger.warning( "include_instructions=True, while the instruction is empty!" ) else: instruction = "" if instruction: tokens.append(self.i_token) mask.append(2) instruction_tokens = self.tokenizer( instruction, add_special_tokens=False ).tokens() tokens.extend(instruction_tokens) mask.extend([3] * len(instruction_tokens)) types = instance["schema"].get("cls") if types: for t in types: _update_seq(t, "lc", task="cls") mention_types = instance["schema"].get("ent") if mention_types: for mt in mention_types: _update_seq(mt, "lm", task="ent") discon_ent_types = instance["schema"].get("discontinuous_ent") if discon_ent_types: for mt in discon_ent_types: _update_seq(mt, "lm", task="discontinuous_ent") rel_types = instance["schema"].get("rel") if rel_types: for rt in rel_types: _update_seq(rt, "lr", task="rel") hyper_rel_schema = instance["schema"].get("hyper_rel") if hyper_rel_schema: for rel, qualifiers in hyper_rel_schema.items(): _update_seq(rel, "lr", task="hyper_rel") for qualifier in qualifiers: _update_seq(qualifier, "lr", task="hyper_rel") event_schema = instance["schema"].get("event") if event_schema: for event_type, roles in event_schema.items(): _update_seq(event_type, "lm", task="event") for role in roles: _update_seq(role, "lr", task="event") text = instance.get("text") if text: text_tokenized = self.tokenizer( text, return_offsets_mapping=True, add_special_tokens=False ) if any(val for val in label_map.values()): text_label_token = self.tl_token else: text_label_token = self.tp_token tokens.append(text_label_token) mask.append(6) remain_token_len = self.max_seq_len - 1 - len(tokens) if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": return None text_off = len(tokens) text_tokens = text_tokenized.tokens()[:remain_token_len] tokens.extend(text_tokens) mask.extend([7] * len(text_tokens)) else: text_tokenized = None bg = instance.get("bg") if bg: bg_tokenized = self.tokenizer( bg, return_offsets_mapping=True, add_special_tokens=False ) tokens.append(self.b_token) mask.append(8) remain_token_len = self.max_seq_len - 1 - len(tokens) if remain_token_len < 5 and kwargs.get("dataset_name", "train") == "train": return None bg_tokens = bg_tokenized.tokens()[:remain_token_len] tokens.extend(bg_tokens) mask.extend([9] * len(bg_tokens)) else: bg_tokenized = None tokens.append(self.tokenizer.sep_token) mask.append(10) # labels # spans: [[(ent_type start, ent_type end + 1), (ent s, ent e + 1)]] spans = [] # one span may have many parts if "cls" in instance["ans"]: for t in instance["ans"]["cls"]: part = label_map["lc"][t] spans.append([part]) if "ent" in instance["ans"]: for ent in instance["ans"]["ent"]: label_part = label_map["lm"][ent["type"]] position_seq = self.char_to_token_span( ent["span"], text_tokenized, text_off ) spans.append([label_part, position_seq]) if "discontinuous_ent" in instance["ans"]: for ent in instance["ans"]["discontinuous_ent"]: label_part = label_map["lm"][ent["type"]] ent_span = [label_part] for part in ent["span"]: position_seq = self.char_to_token_span( part, text_tokenized, text_off ) ent_span.append(position_seq) spans.append(ent_span) if "rel" in instance["ans"]: for rel in instance["ans"]["rel"]: label_part = label_map["lr"][rel["relation"]] head_position_seq = self.char_to_token_span( rel["head"]["span"], text_tokenized, text_off ) tail_position_seq = self.char_to_token_span( rel["tail"]["span"], text_tokenized, text_off ) spans.append([label_part, head_position_seq, tail_position_seq]) if "hyper_rel" in instance["ans"]: for rel in instance["ans"]["hyper_rel"]: label_part = label_map["lr"][rel["relation"]] head_position_seq = self.char_to_token_span( rel["head"]["span"], text_tokenized, text_off ) tail_position_seq = self.char_to_token_span( rel["tail"]["span"], text_tokenized, text_off ) # rel_span = [label_part, head_position_seq, tail_position_seq] for q in rel["qualifiers"]: q_label_part = label_map["lr"][q["label"]] q_position_seq = self.char_to_token_span( q["span"], text_tokenized, text_off ) spans.append( [ label_part, head_position_seq, tail_position_seq, q_label_part, q_position_seq, ] ) if "event" in instance["ans"]: for event in instance["ans"]["event"]: event_type_label_part = label_map["lm"][event["event_type"]] trigger_position_seq = self.char_to_token_span( event["trigger"]["span"], text_tokenized, text_off ) trigger_part = [event_type_label_part, trigger_position_seq] spans.append(trigger_part) for arg in event["args"]: role_label_part = label_map["lr"][arg["role"]] arg_position_seq = self.char_to_token_span( arg["span"], text_tokenized, text_off ) arg_part = [role_label_part, trigger_position_seq, arg_position_seq] spans.append(arg_part) if "span" in instance["ans"]: # Extractive-QA or Extractive-MRC tasks for span in instance["ans"]["span"]: span_position_seq = self.char_to_token_span( span["span"], text_tokenized, text_off ) spans.append([span_position_seq]) if self.mode == "w2": new_spans = [] for parts in spans: new_parts = [] for part in parts: new_parts.append(tuple(range(part[0], part[-1] + 1))) new_spans.append(new_parts) spans = new_spans elif self.mode == "span": spans = spans else: raise ValueError(f"mode={self.mode} is not supported") ins = { "raw": instance, "tokens": tokens, "input_ids": self.tokenizer.convert_tokens_to_ids(tokens), "mask": mask, "spans": spans, "label_map": label_map, "span_to_label": span_to_label, "labels": None, # labels are calculated dynamically in collate_fn } return ins def char_to_token_span( self, span: list[int], tokenized: BatchEncoding, offset: int = 0 ) -> list[int]: token_s = tokenized.char_to_token(span[0]) token_e = tokenized.char_to_token(span[1] - 1) if token_e == token_s: position_seq = (offset + token_s,) else: position_seq = (offset + token_s, offset + token_e) return position_seq def skip_consecutive_span_labels(self, data: dict) -> dict: bs = len(data["input_ids"]) max_seq_len = max(len(input_ids) for input_ids in data["input_ids"]) batch_seq_len = min(self.max_seq_len, max_seq_len) for i in range(bs): data["input_ids"][i] = data["input_ids"][i][:batch_seq_len] data["mask"][i] = data["mask"][i][:batch_seq_len] assert len(data["input_ids"][i]) == len(data["mask"][i]) pad_len = batch_seq_len - len(data["mask"][i]) data["input_ids"][i] = ( data["input_ids"][i] + [self.tokenizer.pad_token_id] * pad_len ) data["mask"][i] = data["mask"][i] + [0] * pad_len data["labels"][i] = encode_nnw_nsw_thw_mat(data["spans"][i], batch_seq_len) # # for debugging only # pred_spans = decode_nnw_nsw_thw_mat(data["labels"][i].unsqueeze(0))[0] # sorted_gold = sorted(set(tuple(x) for x in data["spans"][i])) # sorted_pred = sorted(set(tuple(x) for x in pred_spans)) # if sorted_gold != sorted_pred: # breakpoint() # # for pre-training only # del data["spans"] return data