gliner_base / GLiNER /model.py
Tom Aarsen
Add cloned GLiNER repository
914502f
import argparse
import json
from pathlib import Path
import re
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
from modules.layers import LstmSeq2SeqEncoder
from modules.base import InstructBase
from modules.evaluator import Evaluator, greedy_search
from modules.span_rep import SpanRepLayer
from modules.token_rep import TokenRepLayer
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
class GLiNER(InstructBase, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__(config)
self.config = config
# [ENT] token
self.entity_token = "<<ENT>>"
self.sep_token = "<<SEP>>"
# usually a pretrained bidirectional transformer, returns first subtoken representation
self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
add_tokens=[self.entity_token, self.sep_token])
# hierarchical representation of tokens
self.rnn = LstmSeq2SeqEncoder(
input_size=config.hidden_size,
hidden_size=config.hidden_size // 2,
num_layers=1,
bidirectional=True,
)
# span representation
self.span_rep_layer = SpanRepLayer(
span_mode=config.span_mode,
hidden_size=config.hidden_size,
max_width=config.max_width,
dropout=config.dropout,
)
# prompt representation (FFN)
self.prompt_rep_layer = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 4),
nn.Dropout(config.dropout),
nn.ReLU(),
nn.Linear(config.hidden_size * 4, config.hidden_size)
)
def compute_score_train(self, x):
span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
new_length = x['seq_length'].clone()
new_tokens = []
all_len_prompt = []
num_classes_all = []
# add prompt to the tokens
for i in range(len(x['tokens'])):
all_types_i = list(x['classes_to_id'][i].keys())
# multiple entity types in all_types. Prompt is appended at the start of tokens
entity_prompt = []
num_classes_all.append(len(all_types_i))
# add enity types to prompt
for entity_type in all_types_i:
entity_prompt.append(self.entity_token) # [ENT] token
entity_prompt.append(entity_type) # entity type
entity_prompt.append(self.sep_token) # [SEP] token
# prompt format:
# [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
# add prompt to the tokens
tokens_p = entity_prompt + x['tokens'][i]
# input format:
# [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
# update length of the sequence (add prompt length to the original length)
new_length[i] = new_length[i] + len(entity_prompt)
# update tokens
new_tokens.append(tokens_p)
# store prompt length
all_len_prompt.append(len(entity_prompt))
# create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
max_num_classes = max(num_classes_all)
entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
x['span_mask'].device)
entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
x['span_mask'].device) # [batch_size, max_num_classes]
# compute all token representations
bert_output = self.token_rep_layer(new_tokens, new_length)
word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
# get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
word_rep = [] # word representation (after [SEP])
mask = [] # mask (after [SEP])
entity_type_rep = [] # entity type representation (before [SEP])
for i in range(len(x['tokens'])):
prompt_entity_length = all_len_prompt[i] # length of prompt for this example
# get word representation (after [SEP])
word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
# get mask (after [SEP])
mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
# get entity type representation (before [SEP])
entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
entity_type_rep.append(entity_rep)
# padding for word_rep, mask and entity_type_rep
word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
# compute span representation
word_rep = self.rnn(word_rep, mask)
span_rep = self.span_rep_layer(word_rep, span_idx)
# compute final entity type representation (FFN)
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
num_classes = entity_type_rep.shape[1] # number of entity types
# similarity score
scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
return scores, num_classes, entity_type_mask
def forward(self, x):
# compute span representation
scores, num_classes, entity_type_mask = self.compute_score_train(x)
batch_size = scores.shape[0]
# loss for filtering classifier
logits_label = scores.view(-1, num_classes)
labels = x["span_label"].view(-1) # (batch_size * num_spans)
mask_label = labels != -1 # (batch_size * num_spans)
labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
# one-hot encoding
labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
# Shape of labels_one_hot: (batch_size * num_spans, num_classes)
# compute loss (without reduction)
all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
reduction='none')
# mask loss using entity_type_mask (B, C)
masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
all_losses = masked_loss.view(-1, num_classes)
# expand mask_label to all_losses
mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
# put lower loss for in label_one_hot (2 for positive, 1 for negative)
weight_c = labels_one_hot + 1
# apply mask
all_losses = all_losses * mask_label.float() * weight_c
return all_losses.sum()
def compute_score_eval(self, x, device):
# check if classes_to_id is dict
assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
all_types = list(x['classes_to_id'].keys())
# multiple entity types in all_types. Prompt is appended at the start of tokens
entity_prompt = []
# add enity types to prompt
for entity_type in all_types:
entity_prompt.append(self.entity_token)
entity_prompt.append(entity_type)
entity_prompt.append(self.sep_token)
prompt_entity_length = len(entity_prompt)
# add prompt
tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
seq_length_p = x['seq_length'] + prompt_entity_length
out = self.token_rep_layer(tokens_p, seq_length_p)
word_rep_w_prompt = out["embeddings"]
mask_w_prompt = out["mask"]
# remove prompt
word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
mask = mask_w_prompt[:, prompt_entity_length:]
# get_entity_type_rep
entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
# extract [ENT] tokens (which are at even positions in entity_type_rep)
entity_type_rep = entity_type_rep[:, 0::2, :]
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
word_rep = self.rnn(word_rep, mask)
span_rep = self.span_rep_layer(word_rep, span_idx)
local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
return local_scores
@torch.no_grad()
def predict(self, x, flat_ner=False, threshold=0.5):
self.eval()
local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
spans = []
for i, _ in enumerate(x["tokens"]):
local_i = local_scores[i]
wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
span_i = []
for s, k, c in zip(*wh_i):
if s + k < len(x["tokens"][i]):
span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
span_i = greedy_search(span_i, flat_ner)
spans.append(span_i)
return spans
def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
tokens = []
start_token_idx_to_text_idx = []
end_token_idx_to_text_idx = []
for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
tokens.append(match.group())
start_token_idx_to_text_idx.append(match.start())
end_token_idx_to_text_idx.append(match.end())
input_x = {"tokenized_text": tokens, "ner": None}
x = self.collate_fn([input_x], labels)
output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
entities = []
for start_token_idx, end_token_idx, ent_type in output[0]:
start_text_idx = start_token_idx_to_text_idx[start_token_idx]
end_text_idx = end_token_idx_to_text_idx[end_token_idx]
entities.append({
"start": start_token_idx_to_text_idx[start_token_idx],
"end": end_token_idx_to_text_idx[end_token_idx],
"text": text[start_text_idx:end_text_idx],
"label": ent_type,
})
return entities
def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
self.eval()
data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
device = next(self.parameters()).device
all_preds = []
all_trues = []
for x in data_loader:
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(device)
batch_predictions = self.predict(x, flat_ner, threshold)
all_preds.extend(batch_predictions)
all_trues.extend(x["entities"])
evaluator = Evaluator(all_trues, all_preds)
out, f1 = evaluator.evaluate()
return out, f1
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
# 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
filenames = ["gliner_base.pt", "gliner_multi.pt"]
for filename in filenames:
model_file = Path(model_id) / filename
if not model_file.exists():
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=filename,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError:
continue
dict_load = torch.load(model_file, map_location=torch.device(map_location))
config = dict_load["config"]
state_dict = dict_load["model_weights"]
config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
model = cls(config)
model.load_state_dict(state_dict, strict=strict, assign=True)
# Required to update flair's internals as well:
model.to(map_location)
return model
# 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
from train import load_config_as_namespace
model_file = Path(model_id) / "pytorch_model.bin"
if not model_file.exists():
model_file = hf_hub_download(
repo_id=model_id,
filename="pytorch_model.bin",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
config_file = Path(model_id) / "gliner_config.json"
if not config_file.exists():
config_file = hf_hub_download(
repo_id=model_id,
filename="gliner_config.json",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
config = load_config_as_namespace(config_file)
model = cls(config)
state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict, assign=True)
model.to(map_location)
return model
def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
**push_to_hub_kwargs,
) -> Optional[str]:
"""
Save weights in local directory.
Args:
save_directory (`str` or `Path`):
Path to directory in which the model weights and configuration will be saved.
config (`dict` or `DataclassInstance`, *optional*):
Model configuration specified as a key/value dictionary or a dataclass instance.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
kwargs:
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
# save model weights/files
torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
# save config (if provided)
if config is None:
config = self.config
if config is not None:
if isinstance(config, argparse.Namespace):
config = vars(config)
(save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
# push to the Hub if required
if push_to_hub:
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
if config is not None: # kwarg for `push_to_hub`
kwargs["config"] = config
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, **kwargs)
return None
def to(self, device):
super().to(device)
import flair
flair.device = device
return self