import argparse import json from pathlib import Path import re from typing import Dict, Optional, Union import torch import torch.nn.functional as F from modules.layers import LstmSeq2SeqEncoder from modules.base import InstructBase from modules.evaluator import Evaluator, greedy_search from modules.span_rep import SpanRepLayer from modules.token_rep import TokenRepLayer from torch import nn from torch.nn.utils.rnn import pad_sequence from huggingface_hub import PyTorchModelHubMixin, hf_hub_download from huggingface_hub.utils import HfHubHTTPError class GLiNER(InstructBase, PyTorchModelHubMixin): def __init__(self, config): super().__init__(config) self.config = config # [ENT] token self.entity_token = "<>" self.sep_token = "<>" # usually a pretrained bidirectional transformer, returns first subtoken representation self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune, subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size, add_tokens=[self.entity_token, self.sep_token]) # hierarchical representation of tokens self.rnn = LstmSeq2SeqEncoder( input_size=config.hidden_size, hidden_size=config.hidden_size // 2, num_layers=1, bidirectional=True, ) # span representation self.span_rep_layer = SpanRepLayer( span_mode=config.span_mode, hidden_size=config.hidden_size, max_width=config.max_width, dropout=config.dropout, ) # prompt representation (FFN) self.prompt_rep_layer = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size * 4), nn.Dropout(config.dropout), nn.ReLU(), nn.Linear(config.hidden_size * 4, config.hidden_size) ) def compute_score_train(self, x): span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1) new_length = x['seq_length'].clone() new_tokens = [] all_len_prompt = [] num_classes_all = [] # add prompt to the tokens for i in range(len(x['tokens'])): all_types_i = list(x['classes_to_id'][i].keys()) # multiple entity types in all_types. Prompt is appended at the start of tokens entity_prompt = [] num_classes_all.append(len(all_types_i)) # add enity types to prompt for entity_type in all_types_i: entity_prompt.append(self.entity_token) # [ENT] token entity_prompt.append(entity_type) # entity type entity_prompt.append(self.sep_token) # [SEP] token # prompt format: # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP] # add prompt to the tokens tokens_p = entity_prompt + x['tokens'][i] # input format: # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n # update length of the sequence (add prompt length to the original length) new_length[i] = new_length[i] + len(entity_prompt) # update tokens new_tokens.append(tokens_p) # store prompt length all_len_prompt.append(len(entity_prompt)) # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise) max_num_classes = max(num_classes_all) entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to( x['span_mask'].device) entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to( x['span_mask'].device) # [batch_size, max_num_classes] # compute all token representations bert_output = self.token_rep_layer(new_tokens, new_length) word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt) mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt) # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP]) word_rep = [] # word representation (after [SEP]) mask = [] # mask (after [SEP]) entity_type_rep = [] # entity type representation (before [SEP]) for i in range(len(x['tokens'])): prompt_entity_length = all_len_prompt[i] # length of prompt for this example # get word representation (after [SEP]) word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) # get mask (after [SEP]) mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) # get entity type representation (before [SEP]) entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP] entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one entity_type_rep.append(entity_rep) # padding for word_rep, mask and entity_type_rep word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size] mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len] entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size] # compute span representation word_rep = self.rnn(word_rep, mask) span_rep = self.span_rep_layer(word_rep, span_idx) # compute final entity type representation (FFN) entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size) num_classes = entity_type_rep.shape[1] # number of entity types # similarity score scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) return scores, num_classes, entity_type_mask def forward(self, x): # compute span representation scores, num_classes, entity_type_mask = self.compute_score_train(x) batch_size = scores.shape[0] # loss for filtering classifier logits_label = scores.view(-1, num_classes) labels = x["span_label"].view(-1) # (batch_size * num_spans) mask_label = labels != -1 # (batch_size * num_spans) labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0 # one-hot encoding labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device) labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1 labels_one_hot = labels_one_hot[:, 1:] # Remove the first column # Shape of labels_one_hot: (batch_size * num_spans, num_classes) # compute loss (without reduction) all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, reduction='none') # mask loss using entity_type_mask (B, C) masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1) all_losses = masked_loss.view(-1, num_classes) # expand mask_label to all_losses mask_label = mask_label.unsqueeze(-1).expand_as(all_losses) # put lower loss for in label_one_hot (2 for positive, 1 for negative) weight_c = labels_one_hot + 1 # apply mask all_losses = all_losses * mask_label.float() * weight_c return all_losses.sum() def compute_score_eval(self, x, device): # check if classes_to_id is dict assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict" span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device) all_types = list(x['classes_to_id'].keys()) # multiple entity types in all_types. Prompt is appended at the start of tokens entity_prompt = [] # add enity types to prompt for entity_type in all_types: entity_prompt.append(self.entity_token) entity_prompt.append(entity_type) entity_prompt.append(self.sep_token) prompt_entity_length = len(entity_prompt) # add prompt tokens_p = [entity_prompt + tokens for tokens in x['tokens']] seq_length_p = x['seq_length'] + prompt_entity_length out = self.token_rep_layer(tokens_p, seq_length_p) word_rep_w_prompt = out["embeddings"] mask_w_prompt = out["mask"] # remove prompt word_rep = word_rep_w_prompt[:, prompt_entity_length:, :] mask = mask_w_prompt[:, prompt_entity_length:] # get_entity_type_rep entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :] # extract [ENT] tokens (which are at even positions in entity_type_rep) entity_type_rep = entity_type_rep[:, 0::2, :] entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size) word_rep = self.rnn(word_rep, mask) span_rep = self.span_rep_layer(word_rep, span_idx) local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) return local_scores @torch.no_grad() def predict(self, x, flat_ner=False, threshold=0.5): self.eval() local_scores = self.compute_score_eval(x, device=next(self.parameters()).device) spans = [] for i, _ in enumerate(x["tokens"]): local_i = local_scores[i] wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)] span_i = [] for s, k, c in zip(*wh_i): if s + k < len(x["tokens"][i]): span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c])) span_i = greedy_search(span_i, flat_ner) spans.append(span_i) return spans def predict_entities(self, text, labels, flat_ner=True, threshold=0.5): tokens = [] start_token_idx_to_text_idx = [] end_token_idx_to_text_idx = [] for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text): tokens.append(match.group()) start_token_idx_to_text_idx.append(match.start()) end_token_idx_to_text_idx.append(match.end()) input_x = {"tokenized_text": tokens, "ner": None} x = self.collate_fn([input_x], labels) output = self.predict(x, flat_ner=flat_ner, threshold=threshold) entities = [] for start_token_idx, end_token_idx, ent_type in output[0]: start_text_idx = start_token_idx_to_text_idx[start_token_idx] end_text_idx = end_token_idx_to_text_idx[end_token_idx] entities.append({ "start": start_token_idx_to_text_idx[start_token_idx], "end": end_token_idx_to_text_idx[end_token_idx], "text": text[start_text_idx:end_text_idx], "label": ent_type, }) return entities def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None): self.eval() data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False) device = next(self.parameters()).device all_preds = [] all_trues = [] for x in data_loader: for k, v in x.items(): if isinstance(v, torch.Tensor): x[k] = v.to(device) batch_predictions = self.predict(x, flat_ner, threshold) all_preds.extend(batch_predictions) all_trues.extend(x["entities"]) evaluator = Evaluator(all_trues, all_preds) out, f1 = evaluator.evaluate() return out, f1 @classmethod def _from_pretrained( cls, *, model_id: str, revision: Optional[str], cache_dir: Optional[Union[str, Path]], force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", strict: bool = False, **model_kwargs, ): # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data filenames = ["gliner_base.pt", "gliner_multi.pt"] for filename in filenames: model_file = Path(model_id) / filename if not model_file.exists(): try: model_file = hf_hub_download( repo_id=model_id, filename=filename, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) except HfHubHTTPError: continue dict_load = torch.load(model_file, map_location=torch.device(map_location)) config = dict_load["config"] state_dict = dict_load["model_weights"] config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base" model = cls(config) model.load_state_dict(state_dict, strict=strict, assign=True) # Required to update flair's internals as well: model.to(map_location) return model # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json" from train import load_config_as_namespace model_file = Path(model_id) / "pytorch_model.bin" if not model_file.exists(): model_file = hf_hub_download( repo_id=model_id, filename="pytorch_model.bin", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) config_file = Path(model_id) / "gliner_config.json" if not config_file.exists(): config_file = hf_hub_download( repo_id=model_id, filename="gliner_config.json", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) config = load_config_as_namespace(config_file) model = cls(config) state_dict = torch.load(model_file, map_location=torch.device(map_location)) model.load_state_dict(state_dict, strict=strict, assign=True) model.to(map_location) return model def save_pretrained( self, save_directory: Union[str, Path], *, config: Optional[Union[dict, "DataclassInstance"]] = None, repo_id: Optional[str] = None, push_to_hub: bool = False, **push_to_hub_kwargs, ) -> Optional[str]: """ Save weights in local directory. Args: save_directory (`str` or `Path`): Path to directory in which the model weights and configuration will be saved. config (`dict` or `DataclassInstance`, *optional*): Model configuration specified as a key/value dictionary or a dataclass instance. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Huggingface Hub after saving it. repo_id (`str`, *optional*): ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if not provided. kwargs: Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. """ save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) # save model weights/files torch.save(self.state_dict(), save_directory / "pytorch_model.bin") # save config (if provided) if config is None: config = self.config if config is not None: if isinstance(config, argparse.Namespace): config = vars(config) (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2)) # push to the Hub if required if push_to_hub: kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input if config is not None: # kwarg for `push_to_hub` kwargs["config"] = config if repo_id is None: repo_id = save_directory.name # Defaults to `save_directory` name return self.push_to_hub(repo_id=repo_id, **kwargs) return None def to(self, device): super().to(device) import flair flair.device = device return self