import streamlit as st from annotated_text import annotated_text import torch from torch import nn from torch.utils.data import DataLoader from .nugget_model_utils import CustomRobertaWithPOS as NuggetModel from .nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth from .utils import get_idxs_from_text, event_nugget_list import spacy from transformers import AutoTokenizer from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset import os os.environ["TOKENIZERS_PARALLELISM"] = "true" def find_dep_depth(token): depth = 0 current_token = token while current_token.head != current_token: depth += 1 current_token = current_token.head return min(depth, 16) nlp = spacy.load('en_core_web_sm') pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"] ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"] dep_spacy_tag_list = list(nlp.get_pipe("parser").labels) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model_checkpoint = "ehsanaghaei/SecureBERT" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True) # model_nugget = NuggetModel(num_classes = 11) # model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/nugget_model_state_dict.pth", map_location=device)) # model_nugget.eval() """ Function: create_dataloader(text_input) Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels. Inputs: - text_input: The input text to be processed. Output: - dataloader: A DataLoader for the tokenized and batched text data. - tokenized_dataset_ner: The tokenized dataset used for training. """ def create_dataloader(text_input): doc = nlp(text_input) content_as_words_emdash = [tok.text for tok in doc] content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash] content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash) data = [] words = [] pos_spacy = [tok.pos_ for tok in doc] ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc] dep_spacy = [tok.dep_ for tok in doc] depth_spacy = [find_dep_depth(tok) for tok in doc] for content_dict in content_idx_dict: start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"] words.append(content_dict["word"]) content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"]) if content_token_len > tokenizer.model_max_length: no_split = (content_token_len // tokenizer.model_max_length) + 2 split_len = (len(words) // no_split) + 1 last_id = 0 threshold = split_len for id, token in enumerate(words): if token == "." and id > threshold: data.append( { "tokens" : words[last_id : id + 1], "pos_spacy" : pos_spacy[last_id : id + 1], "ner_spacy" : ner_spacy[last_id : id + 1], "dep_spacy" : dep_spacy[last_id : id + 1], "depth_spacy" : depth_spacy[last_id : id + 1], } ) last_id = id + 1 threshold += split_len data.append({"tokens" : words[last_id : ], "pos_spacy" : pos_spacy[last_id : ], "ner_spacy" : ner_spacy[last_id : ], "dep_spacy" : dep_spacy[last_id : ], "depth_spacy" : depth_spacy[last_id : ]}) else: data.append( { "tokens" : words, "pos_spacy" : pos_spacy, "ner_spacy" : ner_spacy, "dep_spacy" : dep_spacy, "depth_spacy" : depth_spacy } ) ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None), 'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None), 'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None), 'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None) }) dataset = Dataset.from_list(data, features=ner_features) tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False) tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch") tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens") batch_size = 4 # Number of input texts dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size) # TODO : context_idx_dict should be used to index the words return dataloader, tokenized_dataset_ner """ Function: predict(dataloader) Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels. Inputs: - dataloader: A DataLoader containing input data for prediction. Output: - predicted_label: A tensor containing the predicted labels for the input data. """ def predict(model_nugget, dataloader): predicted_label = [] for batch in dataloader: with torch.no_grad(): logits = model_nugget(**batch) batch_predicted_label = logits.argmax(-1) predicted_label.append(batch_predicted_label) return torch.cat(predicted_label, dim=-1) """ Function: show_annotations(text_input) Description: This function displays annotated event nuggets in the provided input text using the Streamlit library. Inputs: - text_input: The input text containing event nuggets to be annotated and displayed. Output: - An interactive display of annotated event nuggets within the input text. """ def show_annotations(text_input): st.title("Event Nuggets") dataloader, tokenized_dataset_ner = create_dataloader(text_input) predicted_label = predict(dataloader) for idx, labels in enumerate(predicted_label): token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]] tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True) tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens] text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask]) idxs = get_idxs_from_text(text, tokens) labels = labels[token_mask] annotated_text_list = [] last_label = "" cumulative_tokens = "" last_id = 0 for idx, label in zip(idxs, labels): to_label = event_nugget_list[label] label_short = to_label.split("-")[1] if "-" in to_label else to_label if last_label == label_short: cumulative_tokens += text[last_id : idx["end_idx"]] last_id = idx["end_idx"] else: if last_label != "": if last_label == "O": annotated_text_list.append(cumulative_tokens) else: annotated_text_list.append((cumulative_tokens, last_label)) last_label = label_short cumulative_tokens = idx["word"] last_id = idx["end_idx"] if last_label == "O": annotated_text_list.append(cumulative_tokens) else: annotated_text_list.append((cumulative_tokens, last_label)) annotated_text(annotated_text_list) """ Function: get_event_nuggets(text_input) Description: This function extracts predicted event nuggets (event entities) from the provided input text. Inputs: - text_input: The input text containing event nuggets to be extracted. Output: - predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets, subtype, and text content. """ def get_event_nuggets(model_nugget, text_input): dataloader, tokenized_dataset_ner = create_dataloader(text_input) predicted_label = predict(model_nugget, dataloader) predicted_event_nuggets = [] text_length = 0 for idx, labels in enumerate(predicted_label): token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]] tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True) tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens] text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask]) idxs = get_idxs_from_text(text_input[text_length : ], tokens) labels = labels[token_mask] start_idx = 0 end_idx = 0 last_label = "" for idx, label in zip(idxs, labels): to_label = event_nugget_list[label] label_short = to_label.split("-")[1] if "-" in to_label else to_label if label_short == last_label: end_idx = idx["end_idx"] else: if text_input[start_idx : end_idx] != "" and last_label != "O": predicted_event_nuggets.append( { "startOffset" : text_length + start_idx, "endOffset" : text_length + end_idx, "subtype" : last_label, "text" : text_input[text_length + start_idx : text_length + end_idx] } ) start_idx = idx["start_idx"] end_idx = idx["start_idx"] + len(idx["word"]) last_label = label_short text_length += idx["end_idx"] return predicted_event_nuggets