Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import os | |
import torch | |
import string | |
import onnxruntime as ort | |
from dataclasses import dataclass | |
from omegaconf import OmegaConf | |
from typing import List, Optional, Union, Dict | |
from sentencepiece import SentencePieceProcessor | |
from torch.utils.data import Dataset, DataLoader | |
from typing import Iterator, List, Iterable, Tuple | |
ACRONYM_TOKEN = "<ACRONYM>" | |
torch.set_grad_enabled(False) | |
torch.backends.cudnn.enabled = False | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
class PunctCapConfigONNX: | |
spe_filename: str = "xlm_roberta_encoding.model" | |
model_filename: str = "nemo_model.onnx" | |
config_filename: str = "config.yaml" | |
directory: Optional[str] = None | |
class PunctCapModelONNX: | |
def __init__(self, cfg: PunctCapConfigONNX): | |
self._spe_path = os.path.join(cfg.directory, cfg.spe_filename) | |
onnx_path = os.path.join(cfg.directory, cfg.model_filename) | |
config_path = os.path.join(cfg.directory, cfg.config_filename) | |
self._tokenizer: SentencePieceProcessor = SentencePieceProcessor(self._spe_path) | |
self._ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path) | |
self._config = OmegaConf.load(config_path) | |
self._max_len = self._config.max_length | |
self._pre_labels: List[str] = self._config.pre_labels | |
self._post_labels: List[str] = self._config.post_labels | |
self._languages: List[str] = self._config.languages | |
self._null_token = self._config.get("null_token", "<NULL>") | |
def _setup_dataloader(self, texts: List[str], batch_size_tokens: int, overlap: int) -> DataLoader: | |
dataset: TextInferenceDataset = TextInferenceDataset( | |
texts=texts, | |
batch_size_tokens=batch_size_tokens, | |
overlap=overlap, | |
max_length=self._max_len, | |
spe_model_path=self._spe_path, | |
) | |
return DataLoader( | |
dataset=dataset, | |
collate_fn=dataset.collate_fn, | |
batch_sampler=dataset.sampler, | |
) | |
def punctuation_removal(self, texts: List[str]) -> List[str]: | |
punkt = string.punctuation + """`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…–ـ""" + """!?。。""" | |
punkt = punkt.replace("-", "") | |
punkt = punkt.replace("'", "") | |
punkt += "„“" | |
return [text.translate(str.maketrans("", "", punkt)).lower().strip() for text in texts] | |
def infer( | |
self, | |
texts: List[str], | |
apply_sbd: bool = False, | |
batch_size_tokens: int = 4096, | |
overlap: int = 16, | |
) -> Union[List[str], List[List[str]]]: | |
texts = self.punctuation_removal(texts) | |
collectors: List[PunctCapCollector] = [ | |
PunctCapCollector(sp_model=self._tokenizer, apply_sbd=apply_sbd, overlap=overlap) | |
for _ in range(len(texts)) | |
] | |
dataloader: DataLoader = self._setup_dataloader(texts=texts, batch_size_tokens=batch_size_tokens, overlap=overlap) | |
for batch in dataloader: | |
input_ids, batch_indices, input_indices, lengths = batch | |
pre_preds, post_preds, cap_preds, seg_preds = self._ort_session.run(None, {"input_ids": input_ids.numpy()}) | |
batch_size = input_ids.shape[0] | |
for i in range(batch_size): | |
length = lengths[i].item() | |
batch_idx = batch_indices[i].item() | |
input_idx = input_indices[i].item() | |
segment_ids = input_ids[i, 1 : length - 1].tolist() | |
segment_pre_preds = pre_preds[i, 1 : length - 1].tolist() | |
segment_post_preds = post_preds[i, 1 : length - 1].tolist() | |
segment_cap_preds = cap_preds[i, 1 : length - 1].tolist() | |
segment_sbd_preds = seg_preds[i, 1 : length - 1].tolist() | |
pre_tokens = [self._pre_labels[i] for i in segment_pre_preds] | |
post_tokens = [self._post_labels[i] for i in segment_post_preds] | |
pre_tokens = [x if x != self._null_token else None for x in pre_tokens] | |
post_tokens = [x if x != self._null_token else None for x in post_tokens] | |
collectors[batch_idx].collect( | |
ids=segment_ids, | |
pre_preds=pre_tokens, | |
post_preds=post_tokens, | |
cap_preds=segment_cap_preds, | |
sbd_preds=segment_sbd_preds, | |
idx=input_idx, | |
) | |
outputs: Union[List[str], List[List[str]]] = [x.produce() for x in collectors] | |
return outputs | |
class TokenizedSegment: | |
input_ids: List[int] | |
batch_idx: int | |
input_idx: int | |
def __len__(self) -> int: | |
return len(self.input_ids) | |
class TokenBatchSampler(Iterable): | |
def __init__(self, segments: List[TokenizedSegment], batch_size_tokens: int): | |
self._batches = self._make_batches(segments, batch_size_tokens) | |
def _make_batches(self, segments: List[TokenizedSegment], batch_size_tokens: int) -> List[List[int]]: | |
segments_with_index = [(segment, i) for i, segment in enumerate(segments)] | |
segments_with_index.sort(key=lambda x: len(x[0]), reverse=True) | |
batches, current_batch_elements, current_max_len = [], [], 0 | |
for segment, idx in segments_with_index: | |
potential_max_len = max(current_max_len, len(segment)) | |
if potential_max_len * (len(current_batch_elements) + 1) > batch_size_tokens: | |
batches.append(current_batch_elements) | |
current_batch_elements, current_max_len = [], 0 | |
current_batch_elements.append(idx) | |
current_max_len = potential_max_len | |
if current_batch_elements: | |
batches.append(current_batch_elements) | |
return batches | |
def __iter__(self) -> Iterator: | |
yield from self._batches | |
def __len__(self) -> int: | |
return len(self._batches) | |
class TextInferenceDataset(Dataset): | |
def __init__( | |
self, | |
texts: List[str], | |
spe_model_path: str, | |
batch_size_tokens: int = 4096, | |
max_length: int = 512, | |
overlap: int = 32, | |
): | |
self._spe_model = SentencePieceProcessor(spe_model_path) | |
self._segments = self._tokenize_inputs(texts, max_length, overlap) | |
self._sampler = TokenBatchSampler(self._segments, batch_size_tokens) | |
def sampler(self) -> Iterable: | |
return self._sampler | |
def _tokenize_inputs(self, texts: List[str], max_len: int, overlap: int) -> List[TokenizedSegment]: | |
max_len -= 2 | |
segments = [] | |
for batch_idx, text in enumerate(texts): | |
ids, start, input_idx = self._spe_model.EncodeAsIds(text), 0, 0 | |
while start < len(ids): | |
adjusted_start = start - overlap if input_idx else 0 | |
segments.append( | |
TokenizedSegment( | |
ids[adjusted_start : adjusted_start + max_len], | |
batch_idx, | |
input_idx, | |
) | |
) | |
start += max_len - overlap | |
input_idx += 1 | |
return segments | |
def __len__(self) -> int: | |
return len(self._segments) | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]: | |
segment = self._segments[idx] | |
input_ids = torch.Tensor([self._spe_model.bos_id(), *segment.input_ids, self._spe_model.eos_id()]) | |
return input_ids, segment.batch_idx, segment.input_idx | |
def collate_fn(self, batch: List[Tuple[torch.Tensor, int, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
input_ids = [x[0] for x in batch] | |
lengths = torch.tensor([x.shape[0] for x in input_ids]) | |
max_len = lengths.max().item() | |
batched_ids = torch.full((len(input_ids), max_len), self._spe_model.pad_id()) | |
for idx, ids in enumerate(input_ids): | |
batched_ids[idx, : lengths[idx]] = ids | |
return ( | |
batched_ids, | |
torch.tensor([x[1] for x in batch]), | |
torch.tensor([x[2] for x in batch]), | |
lengths, | |
) | |
class PCSegment: | |
ids: List[int] | |
pre_preds: List[Optional[str]] | |
post_preds: List[Optional[str]] | |
cap_preds: List[List[int]] | |
sbd_preds: List[int] | |
def __len__(self): | |
return len(self.ids) | |
class PunctCapCollector: | |
def __init__(self, apply_sbd: bool, overlap: int, sp_model: SentencePieceProcessor): | |
self._segments: Dict[int, PCSegment] = {} | |
self._apply_sbd = apply_sbd | |
self._overlap = overlap | |
self._sp_model = sp_model | |
def collect( | |
self, | |
ids: List[int], | |
pre_preds: List[Optional[str]], | |
post_preds: List[Optional[str]], | |
sbd_preds: List[int], | |
cap_preds: List[List[int]], | |
idx: int, | |
): | |
self._segments[idx] = PCSegment( | |
ids=ids, | |
pre_preds=pre_preds, | |
post_preds=post_preds, | |
sbd_preds=sbd_preds, | |
cap_preds=cap_preds, | |
) | |
def produce(self) -> Union[List[str], str]: | |
ids: List[int] = [] | |
pre_preds: List[Optional[str]] = [] | |
post_preds: List[Optional[str]] = [] | |
cap_preds: List[List[int]] = [] | |
sbd_preds: List[int] = [] | |
for i in range(len(self._segments)): | |
segment = self._segments[i] | |
start = 0 | |
stop = len(segment) | |
if i > 0: | |
start += self._overlap // 2 | |
if i < len(self._segments) - 1: | |
stop -= self._overlap // 2 | |
ids.extend(segment.ids[start:stop]) | |
pre_preds.extend(segment.pre_preds[start:stop]) | |
post_preds.extend(segment.post_preds[start:stop]) | |
sbd_preds.extend(segment.sbd_preds[start:stop]) | |
cap_preds.extend(segment.cap_preds[start:stop]) | |
input_tokens = [self._sp_model.IdToPiece(x) for x in ids] | |
output_texts: List[str] = [] | |
current_chars: List[str] = [] | |
for token_idx, token in enumerate(input_tokens): | |
if token.startswith("▁") and current_chars: | |
current_chars.append(" ") | |
char_start = 1 if token.startswith("▁") else 0 | |
for token_char_idx, char in enumerate(token[char_start:], start=char_start): | |
if token_char_idx == char_start and pre_preds[token_idx] is not None: | |
current_chars.append(pre_preds[token_idx]) | |
if cap_preds[token_idx][token_char_idx]: | |
char = char.upper() | |
current_chars.append(char) | |
label = post_preds[token_idx] | |
if label == ACRONYM_TOKEN: | |
current_chars.append(".") | |
elif token_char_idx == len(token) - 1 and post_preds[token_idx] is not None: | |
current_chars.append(post_preds[token_idx]) | |
if self._apply_sbd and token_char_idx == len(token) - 1 and sbd_preds[token_idx]: | |
output_texts.append("".join(current_chars)) | |
current_chars = [] | |
if current_chars: | |
output_texts.append("".join(current_chars)) | |
if not self._apply_sbd: | |
if len(output_texts) > 1: | |
raise ValueError(f"Not applying SBD but got more than one result: {output_texts}") | |
return output_texts[0] | |
return output_texts | |
class MultiLingual: | |
def __init__(self): | |
cfg = PunctCapConfigONNX(directory="/code/models/multilingual") | |
self._punctuator = PunctCapModelONNX(cfg) | |
def punctuate(self, data: str) -> str: | |
return self._punctuator.infer([data])[0] | |