|
|
|
import math |
|
import re |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, List, Union, Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from transformers import ErnieModel, ErniePreTrainedModel, PretrainedConfig, PreTrainedTokenizerFast |
|
from transformers.utils import ModelOutput |
|
|
|
|
|
@dataclass |
|
class UIEModelOutput(ModelOutput): |
|
""" |
|
Output class for outputs of UIE. |
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1),`, *optional*, returned when `labels` is provided): |
|
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. |
|
start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): |
|
Span-start scores (after Sigmoid). |
|
end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): |
|
Span-end scores (after Sigmoid). |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding |
|
layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
Attentions weights after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
loss: Optional[torch.FloatTensor] = None |
|
start_prob: torch.FloatTensor = None |
|
end_prob: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class UIE(ErniePreTrainedModel): |
|
""" |
|
UIE model based on Bert model. |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
Parameters: |
|
config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig): |
|
super(UIE, self).__init__(config) |
|
self.encoder = ErnieModel(config) |
|
self.config = config |
|
hidden_size = self.config.hidden_size |
|
|
|
self.linear_start = nn.Linear(hidden_size, 1) |
|
self.linear_end = nn.Linear(hidden_size, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
start_positions: Optional[torch.Tensor] = None, |
|
end_positions: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None |
|
): |
|
""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `({0})`): |
|
Indices of input sequence tokens in the vocabulary. |
|
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
[What are attention masks?](../glossary#attention-mask) |
|
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
|
1]`: |
|
- 0 corresponds to a *sentence A* token, |
|
- 1 corresponds to a *sentence B* token. |
|
[What are token type IDs?](../glossary#token-type-ids) |
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.max_position_embeddings - 1]`. |
|
[What are position IDs?](../glossary#position-ids) |
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: |
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the start of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for position (index) of the end of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
|
are not taken into account for computing the loss. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
outputs = self.encoder( |
|
input_ids=input_ids, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict |
|
) |
|
sequence_output = outputs[0] |
|
|
|
start_logits = self.linear_start(sequence_output) |
|
start_logits = torch.squeeze(start_logits, -1) |
|
start_prob = self.sigmoid(start_logits) |
|
end_logits = self.linear_end(sequence_output) |
|
end_logits = torch.squeeze(end_logits, -1) |
|
end_prob = self.sigmoid(end_logits) |
|
|
|
total_loss = None |
|
if start_positions is not None and end_positions is not None: |
|
loss_fct = nn.BCELoss() |
|
start_loss = loss_fct(start_prob, start_positions) |
|
end_loss = loss_fct(end_prob, end_positions) |
|
total_loss = (start_loss + end_loss) / 2.0 |
|
|
|
if not return_dict: |
|
output = (start_prob, end_prob) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
return UIEModelOutput( |
|
loss=total_loss, |
|
start_prob=start_prob, |
|
end_prob=end_prob, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def predict(self, schema: Union[Dict, List[str], str], input_texts: Union[List[str], str], |
|
tokenizer: PreTrainedTokenizerFast, max_length: int = 512, batch_size: int = 32, |
|
position_prob: int = 0.5, progress_hook=None) -> List[Dict]: |
|
""" |
|
|
|
Args: |
|
schema (Union[Dict, List[str], str]): 抽取目标 |
|
input_texts (input_texts: Union[List[str], str]): 待抽取文本 |
|
tokenizer (PreTrainedTokenizerFast): |
|
max_length (int): |
|
batch_size (int): |
|
position_prob (float): |
|
progress_hook: |
|
|
|
Returns: |
|
result (List[Dict]): |
|
""" |
|
|
|
predictor = UIEPredictor(self, tokenizer=tokenizer, schema=schema, max_length=max_length, |
|
position_prob=position_prob, batch_size=batch_size, hook=progress_hook) |
|
input_texts = [input_texts] if isinstance(input_texts, str) else input_texts |
|
return predictor.predict(input_texts) |
|
|
|
|
|
class UIEPredictor(object): |
|
def __init__(self, model, tokenizer, schema, max_length=512, position_prob=0.5, batch_size=32, hook=None): |
|
self.model = model |
|
self._tokenizer = tokenizer |
|
|
|
self._position_prob = position_prob |
|
self.max_length = max_length |
|
self._batch_size = batch_size |
|
self._multilingual = getattr(self.model.config, 'multilingual', False) |
|
self._schema_tree = self.set_schema(schema) |
|
self._hook = hook |
|
|
|
def set_schema(self, schema): |
|
if isinstance(schema, dict) or isinstance(schema, str): |
|
schema = [schema] |
|
return self._build_tree(schema) |
|
|
|
@classmethod |
|
def _build_tree(cls, schema, name="root"): |
|
""" |
|
Build the schema tree. |
|
""" |
|
schema_tree = SchemaTree(name) |
|
for s in schema: |
|
if isinstance(s, str): |
|
schema_tree.add_child(SchemaTree(s)) |
|
elif isinstance(s, dict): |
|
for k, v in s.items(): |
|
if isinstance(v, str): |
|
child = [v] |
|
elif isinstance(v, list): |
|
child = v |
|
else: |
|
raise TypeError( |
|
"Invalid schema, value for each key:value pairs should be list or string" |
|
"but {} received".format(type(v)) |
|
) |
|
schema_tree.add_child(cls._build_tree(child, name=k)) |
|
else: |
|
raise TypeError("Invalid schema, element should be string or dict, " "but {} received".format(type(s))) |
|
return schema_tree |
|
|
|
def _single_stage_predict(self, inputs): |
|
input_texts = [] |
|
prompts = [] |
|
for i in range(len(inputs)): |
|
input_texts.append(inputs[i]["text"]) |
|
prompts.append(inputs[i]["prompt"]) |
|
|
|
max_predict_len = self.max_length - len(max(prompts)) - 3 |
|
short_input_texts, self.input_mapping = Utils.auto_splitter(input_texts, max_predict_len, split_sentence=False) |
|
|
|
short_texts_prompts = [] |
|
for k, v in self.input_mapping.items(): |
|
short_texts_prompts.extend([prompts[k] for _ in range(len(v))]) |
|
short_inputs = [ |
|
{"text": short_input_texts[i], "prompt": short_texts_prompts[i]} for i in range(len(short_input_texts)) |
|
] |
|
|
|
prompts = [] |
|
texts = [] |
|
for s in short_inputs: |
|
prompts.append(s["prompt"]) |
|
texts.append(s["text"]) |
|
|
|
if self._multilingual: |
|
padding_type = "max_length" |
|
else: |
|
padding_type = "longest" |
|
|
|
encoded_inputs = self._tokenizer( |
|
text=prompts, |
|
text_pair=texts, |
|
stride=2, |
|
truncation=True, |
|
max_length=self.max_length, |
|
padding=padding_type, |
|
add_special_tokens=True, |
|
return_offsets_mapping=True, |
|
return_tensors="np") |
|
|
|
offset_maps = encoded_inputs["offset_mapping"] |
|
start_probs = [] |
|
end_probs = [] |
|
for idx in range(0, len(texts), self._batch_size): |
|
l, r = idx, idx + self._batch_size |
|
|
|
input_ids = encoded_inputs["input_ids"][l:r] |
|
token_type_ids = encoded_inputs["token_type_ids"][l:r] |
|
attention_mask = encoded_inputs["attention_mask"][l:r] |
|
|
|
if self._multilingual: |
|
input_ids = np.array( |
|
input_ids, dtype="int64") |
|
attention_mask = np.array( |
|
attention_mask, dtype="int64") |
|
position_ids = (np.cumsum(np.ones_like(input_ids), axis=1) |
|
- np.ones_like(input_ids)) * attention_mask |
|
input_dict = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"position_ids": position_ids |
|
} |
|
else: |
|
input_dict = { |
|
"input_ids": np.array( |
|
input_ids, dtype="int64"), |
|
"token_type_ids": np.array( |
|
token_type_ids, dtype="int64"), |
|
"attention_mask": np.array( |
|
attention_mask, dtype="int64") |
|
} |
|
|
|
start_prob, end_prob = self._infer(input_dict) |
|
start_prob = start_prob.tolist() |
|
end_prob = end_prob.tolist() |
|
start_probs.extend(start_prob) |
|
end_probs.extend(end_prob) |
|
if self._hook is not None: |
|
self._hook.update(1) |
|
start_ids_list = Utils.get_bool_ids_greater_than(start_probs, limit=self._position_prob, return_prob=True) |
|
end_ids_list = Utils.get_bool_ids_greater_than(end_probs, limit=self._position_prob, return_prob=True) |
|
sentence_ids = [] |
|
probs = [] |
|
for start_ids, end_ids, offset_map in zip(start_ids_list, end_ids_list, offset_maps.tolist()): |
|
span_list = Utils.get_span(start_ids, end_ids, with_prob=True) |
|
sentence_id, prob = Utils.get_id_and_prob(span_list, offset_map) |
|
sentence_ids.append(sentence_id) |
|
probs.append(prob) |
|
results = Utils.convert_ids_to_results(short_inputs, sentence_ids, probs) |
|
results = Utils.auto_joiner(results, short_input_texts, self.input_mapping) |
|
return results |
|
|
|
def _multi_stage_predict(self, data): |
|
""" |
|
Traversal the schema tree and do multi-stage prediction. |
|
Args: |
|
data (list): a list of strings |
|
Returns: |
|
list: a list of predictions, where the list's length |
|
equals to the length of `data` |
|
""" |
|
results = [{} for _ in range(len(data))] |
|
|
|
if len(data) < 1 or self._schema_tree is None: |
|
return results |
|
|
|
_pre_node_total = len(data) // self._batch_size + (1 if len(data) % self._batch_size else 0) |
|
_finish_node = 0 |
|
if self._hook is not None: |
|
self._hook.reset(total=self._schema_tree.shape * _pre_node_total) |
|
|
|
|
|
schema_list = self._schema_tree.children[:] |
|
while len(schema_list) > 0: |
|
node = schema_list.pop(0) |
|
examples = [] |
|
input_map = {} |
|
cnt = 0 |
|
idx = 0 |
|
if not node.prefix: |
|
for one_data in data: |
|
examples.append({"text": one_data, "prompt": Utils.dbc2sbc(node.name)}) |
|
input_map[cnt] = [idx] |
|
idx += 1 |
|
cnt += 1 |
|
else: |
|
for pre, one_data in zip(node.prefix, data): |
|
if len(pre) == 0: |
|
input_map[cnt] = [] |
|
else: |
|
for p in pre: |
|
examples.append({"text": one_data, "prompt": Utils.dbc2sbc(p + node.name)}) |
|
input_map[cnt] = [i + idx for i in range(len(pre))] |
|
idx += len(pre) |
|
cnt += 1 |
|
if len(examples) == 0: |
|
result_list = [] |
|
else: |
|
result_list = self._single_stage_predict(examples) |
|
|
|
if not node.parent_relations: |
|
relations = [[] for _ in range(len(data))] |
|
for k, v in input_map.items(): |
|
for idx in v: |
|
if len(result_list[idx]) == 0: |
|
continue |
|
if node.name not in results[k].keys(): |
|
results[k][node.name] = result_list[idx] |
|
else: |
|
results[k][node.name].extend(result_list[idx]) |
|
if node.name in results[k].keys(): |
|
relations[k].extend(results[k][node.name]) |
|
else: |
|
relations = node.parent_relations |
|
for k, v in input_map.items(): |
|
for i in range(len(v)): |
|
if len(result_list[v[i]]) == 0: |
|
continue |
|
if "relations" not in relations[k][i].keys(): |
|
relations[k][i]["relations"] = {node.name: result_list[v[i]]} |
|
elif node.name not in relations[k][i]["relations"].keys(): |
|
relations[k][i]["relations"][node.name] = result_list[v[i]] |
|
else: |
|
relations[k][i]["relations"][node.name].extend(result_list[v[i]]) |
|
new_relations = [[] for _ in range(len(data))] |
|
for i in range(len(relations)): |
|
for j in range(len(relations[i])): |
|
if "relations" in relations[i][j].keys() and node.name in relations[i][j]["relations"].keys(): |
|
for k in range(len(relations[i][j]["relations"][node.name])): |
|
new_relations[i].append(relations[i][j]["relations"][node.name][k]) |
|
relations = new_relations |
|
|
|
prefix = [[] for _ in range(len(data))] |
|
for k, v in input_map.items(): |
|
for idx in v: |
|
for i in range(len(result_list[idx])): |
|
prefix[k].append(result_list[idx][i]["text"] + "的") |
|
for child in node.children: |
|
child.prefix = prefix |
|
child.parent_relations = relations |
|
schema_list.append(child) |
|
_finish_node += 1 |
|
if self._hook is not None: |
|
self._hook.n = _finish_node * _pre_node_total |
|
if self._hook is not None: |
|
self._hook.close() |
|
return results |
|
|
|
def _infer(self, input_dict): |
|
for input_name, input_value in input_dict.items(): |
|
input_dict[input_name] = torch.LongTensor(input_value).to(self.model.device) |
|
outputs = self.model(**input_dict) |
|
return outputs.start_prob.detach().cpu().numpy(), outputs.end_prob.detach().cpu().numpy() |
|
|
|
def predict(self, input_data): |
|
results = self._multi_stage_predict(data=input_data) |
|
return results |
|
|
|
|
|
class SchemaTree(object): |
|
""" |
|
Implementataion of SchemaTree |
|
""" |
|
|
|
def __init__(self, name="root", children=None): |
|
self.name = name |
|
self.children = [] |
|
self.prefix = None |
|
self.parent_relations = None |
|
if children is not None: |
|
for child in children: |
|
self.add_child(child) |
|
self._total_nodes = 0 |
|
|
|
@property |
|
def shape(self): |
|
return len(self.children) + sum([child.shape for child in self.children]) |
|
|
|
def __repr__(self): |
|
return self.name |
|
|
|
def add_child(self, node): |
|
assert isinstance(node, SchemaTree), "The children of a node should be an instacne of SchemaTree." |
|
self._total_nodes += 1 |
|
self.children.append(node) |
|
|
|
|
|
class Utils: |
|
|
|
@classmethod |
|
def dbc2sbc(cls, s): |
|
rs = "" |
|
for char in s: |
|
code = ord(char) |
|
if code == 0x3000: |
|
code = 0x0020 |
|
else: |
|
code -= 0xFEE0 |
|
if not (0x0021 <= code <= 0x7E): |
|
rs += char |
|
continue |
|
rs += chr(code) |
|
return rs |
|
|
|
@classmethod |
|
def cut_chinese_sent(cls, para): |
|
""" |
|
Cut the Chinese sentences more precisely, reference to |
|
"https://blog.csdn.net/blmoistawinde/article/details/82379256". |
|
""" |
|
para = re.sub(r'([。!??])([^”’])', r"\1\n\2", para) |
|
para = re.sub(r'(\.{6})([^”’])', r"\1\n\2", para) |
|
para = re.sub(r'(…{2})([^”’])', r"\1\n\2", para) |
|
para = re.sub(r'([。!??][”’])([^,。!??])', r'\1\n\2', para) |
|
para = para.rstrip() |
|
return para.split("\n") |
|
|
|
@classmethod |
|
def get_bool_ids_greater_than(cls, probs, limit=0.5, return_prob=False): |
|
""" |
|
Get idx of the last dimension in probability arrays, which is greater than a limitation. |
|
|
|
Args: |
|
probs (List[List[float]]): The input probability arrays. |
|
limit (float): The limitation for probability. |
|
return_prob (bool): Whether to return the probability |
|
Returns: |
|
List[List[int]]: The index of the last dimension meet the conditions. |
|
""" |
|
probs = np.array(probs) |
|
dim_len = len(probs.shape) |
|
if dim_len > 1: |
|
result = [] |
|
for p in probs: |
|
result.append(cls.get_bool_ids_greater_than(p, limit, return_prob)) |
|
return result |
|
else: |
|
result = [] |
|
for i, p in enumerate(probs): |
|
if p > limit: |
|
if return_prob: |
|
result.append((i, p)) |
|
else: |
|
result.append(i) |
|
return result |
|
|
|
@classmethod |
|
def get_span(cls, start_ids, end_ids, with_prob=False): |
|
""" |
|
Get span set from position start and end list. |
|
|
|
Args: |
|
start_ids (List[int]/List[tuple]): The start index list. |
|
end_ids (List[int]/List[tuple]): The end index list. |
|
with_prob (bool): If True, each element for start_ids and end_ids is a tuple as like: (index, probability). |
|
Returns: |
|
set: The span set without overlapping, every id can only be used once . |
|
""" |
|
if with_prob: |
|
start_ids = sorted(start_ids, key=lambda x: x[0]) |
|
end_ids = sorted(end_ids, key=lambda x: x[0]) |
|
else: |
|
start_ids = sorted(start_ids) |
|
end_ids = sorted(end_ids) |
|
|
|
start_pointer = 0 |
|
end_pointer = 0 |
|
len_start = len(start_ids) |
|
len_end = len(end_ids) |
|
couple_dict = {} |
|
while start_pointer < len_start and end_pointer < len_end: |
|
if with_prob: |
|
start_id = start_ids[start_pointer][0] |
|
end_id = end_ids[end_pointer][0] |
|
else: |
|
start_id = start_ids[start_pointer] |
|
end_id = end_ids[end_pointer] |
|
|
|
if start_id == end_id: |
|
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] |
|
start_pointer += 1 |
|
end_pointer += 1 |
|
continue |
|
if start_id < end_id: |
|
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] |
|
start_pointer += 1 |
|
continue |
|
if start_id > end_id: |
|
end_pointer += 1 |
|
continue |
|
result = [(couple_dict[end], end) for end in couple_dict] |
|
result = set(result) |
|
return result |
|
|
|
@classmethod |
|
def get_id_and_prob(cls, span_set, offset_mapping: np.array): |
|
""" |
|
Return text id and probability of predicted spans |
|
|
|
Args: |
|
span_set (set): set of predicted spans. |
|
offset_mapping (numpy.array): list of pair preserving the |
|
index of start and end char in original text pair (prompt + text) for each token. |
|
Returns: |
|
sentence_id (list[tuple]): index of start and end char in original text. |
|
prob (list[float]): probabilities of predicted spans. |
|
""" |
|
prompt_end_token_id = offset_mapping[1:].index([0, 0]) |
|
bias = offset_mapping[prompt_end_token_id][1] + 1 |
|
for index in range(1, prompt_end_token_id + 1): |
|
offset_mapping[index][0] -= bias |
|
offset_mapping[index][1] -= bias |
|
|
|
sentence_id = [] |
|
prob = [] |
|
for start, end in span_set: |
|
prob.append(start[1] * end[1]) |
|
start_id = offset_mapping[start[0]][0] |
|
end_id = offset_mapping[end[0]][1] |
|
sentence_id.append((start_id, end_id)) |
|
return sentence_id, prob |
|
|
|
@classmethod |
|
def auto_splitter(cls, input_texts, max_text_len, split_sentence=False): |
|
""" |
|
Split the raw texts automatically for model inference. |
|
Args: |
|
input_texts (List[str]): input raw texts. |
|
max_text_len (int): cutting length. |
|
split_sentence (bool): If True, sentence-level split will be performed. |
|
return: |
|
short_input_texts (List[str]): the short input texts for model inference. |
|
input_mapping (dict): mapping between raw text and short input texts. |
|
""" |
|
input_mapping = {} |
|
short_input_texts = [] |
|
cnt_org = 0 |
|
cnt_short = 0 |
|
for text in input_texts: |
|
if not split_sentence: |
|
sens = [text] |
|
else: |
|
sens = Utils.cut_chinese_sent(text) |
|
for sen in sens: |
|
lens = len(sen) |
|
if lens <= max_text_len: |
|
short_input_texts.append(sen) |
|
if cnt_org not in input_mapping.keys(): |
|
input_mapping[cnt_org] = [cnt_short] |
|
else: |
|
input_mapping[cnt_org].append(cnt_short) |
|
cnt_short += 1 |
|
else: |
|
temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)] |
|
short_input_texts.extend(temp_text_list) |
|
short_idx = cnt_short |
|
cnt_short += math.ceil(lens / max_text_len) |
|
temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)] |
|
if cnt_org not in input_mapping.keys(): |
|
input_mapping[cnt_org] = temp_text_id |
|
else: |
|
input_mapping[cnt_org].extend(temp_text_id) |
|
cnt_org += 1 |
|
return short_input_texts, input_mapping |
|
|
|
@classmethod |
|
def convert_ids_to_results(cls, examples, sentence_ids, probs): |
|
""" |
|
Convert ids to raw text in a single stage. |
|
""" |
|
results = [] |
|
for example, sentence_id, prob in zip(examples, sentence_ids, probs): |
|
if len(sentence_id) == 0: |
|
results.append([]) |
|
continue |
|
result_list = [] |
|
text = example["text"] |
|
prompt = example["prompt"] |
|
for i in range(len(sentence_id)): |
|
start, end = sentence_id[i] |
|
if start < 0 and end >= 0: |
|
continue |
|
if end < 0: |
|
start += len(prompt) + 1 |
|
end += len(prompt) + 1 |
|
result = {"text": prompt[start:end], "probability": prob[i]} |
|
result_list.append(result) |
|
else: |
|
result = {"text": text[start:end], "start": start, "end": end, "probability": prob[i]} |
|
result_list.append(result) |
|
results.append(result_list) |
|
return results |
|
|
|
@classmethod |
|
def auto_joiner(cls, short_results, short_inputs, input_mapping): |
|
concat_results = [] |
|
is_cls_task = False |
|
for short_result in short_results: |
|
if not short_result: |
|
continue |
|
elif "start" not in short_result[0].keys() and "end" not in short_result[0].keys(): |
|
is_cls_task = True |
|
break |
|
else: |
|
break |
|
for k, vs in input_mapping.items(): |
|
if is_cls_task: |
|
cls_options = {} |
|
for v in vs: |
|
if len(short_results[v]) == 0: |
|
continue |
|
if short_results[v][0]["text"] not in cls_options.keys(): |
|
cls_options[short_results[v][0]["text"]] = [1, short_results[v][0]["probability"]] |
|
else: |
|
cls_options[short_results[v][0]["text"]][0] += 1 |
|
cls_options[short_results[v][0]["text"]][1] += short_results[v][0]["probability"] |
|
if len(cls_options) != 0: |
|
cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1]) |
|
concat_results.append([{"text": cls_res, "probability": cls_info[1] / cls_info[0]}]) |
|
else: |
|
concat_results.append([]) |
|
else: |
|
offset = 0 |
|
single_results = [] |
|
for v in vs: |
|
if v == 0: |
|
single_results = short_results[v] |
|
offset += len(short_inputs[v]) |
|
else: |
|
for i in range(len(short_results[v])): |
|
if "start" not in short_results[v][i] or "end" not in short_results[v][i]: |
|
continue |
|
short_results[v][i]["start"] += offset |
|
short_results[v][i]["end"] += offset |
|
offset += len(short_inputs[v]) |
|
single_results.extend(short_results[v]) |
|
concat_results.append(single_results) |
|
return concat_results |
|
|