Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import os | |
import re | |
import string | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
from simpletransformers.ner import NERModel | |
class BERTmodel: | |
def __init__(self, normalization="full", wrds_per_pred=256): | |
self.normalization = normalization | |
self.wrds_per_pred = wrds_per_pred | |
self.overlap_wrds = 32 | |
self.valid_labels = ["O", "F", "C", "Q"] | |
self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""} | |
self.model = NERModel( | |
"bert", | |
"/code/models/urdu", | |
use_cuda=False, | |
labels=self.valid_labels, | |
args={"silent": True, "max_seq_length": 512}, | |
) | |
self.patterns = { | |
"partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+", | |
"full": string.punctuation + "،؛؟۔٪ء‘’", | |
} | |
def punctuation_removal(self, text: str) -> str: | |
if self.normalization == "partial": | |
return re.sub(self.patterns[self.normalization], "", text).strip() | |
else: | |
return "".join(ch for ch in text if ch not in self.patterns[self.normalization]) | |
def punctuate(self, text: str): | |
text = self.punctuation_removal(text) | |
splits = self.split_on_tokens(text) | |
full_preds_lst = [self.predict(i["text"]) for i in splits] | |
preds_lst = [i[0][0] for i in full_preds_lst] | |
combined_preds = self.combine_results(text, preds_lst) | |
punct_text = self.punctuate_texts(combined_preds) | |
return punct_text | |
def predict(self, input_slice): | |
return self.model.predict([input_slice]) | |
def split_on_tokens(self, text): | |
wrds = text.replace("\n", " ").split() | |
response = [] | |
lst_chunk_idx = 0 | |
i = 0 | |
while True: | |
wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred] | |
wrds_ovlp = wrds[ | |
(i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds | |
] | |
wrds_split = wrds_len + wrds_ovlp | |
if not wrds_split: | |
break | |
response_obj = { | |
"text": " ".join(wrds_split), | |
"start_idx": lst_chunk_idx, | |
"end_idx": lst_chunk_idx + len(" ".join(wrds_len)), | |
} | |
response.append(response_obj) | |
lst_chunk_idx += response_obj["end_idx"] + 1 | |
i += 1 | |
return response | |
def combine_results(self, full_text: str, text_slices): | |
split_full_text = full_text.replace("\n", " ").split(" ") | |
split_full_text = [i for i in split_full_text if i] | |
split_full_text_len = len(split_full_text) | |
output_text = [] | |
index = 0 | |
if len(text_slices[-1]) <= 3 and len(text_slices) > 1: | |
text_slices = text_slices[:-1] | |
for slice in text_slices: | |
slice_wrds = len(slice) | |
for ix, wrd in enumerate(slice): | |
if index == split_full_text_len: | |
break | |
if ( | |
split_full_text[index] == str(list(wrd.keys())[0]) | |
and ix <= slice_wrds - 3 | |
and text_slices[-1] != slice | |
): | |
index += 1 | |
pred_item_tuple = list(wrd.items())[0] | |
output_text.append(pred_item_tuple) | |
elif ( | |
split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice | |
): | |
index += 1 | |
pred_item_tuple = list(wrd.items())[0] | |
output_text.append(pred_item_tuple) | |
assert [i[0] for i in output_text] == split_full_text | |
return output_text | |
def punctuate_texts(self, full_pred: list): | |
punct_resp = [] | |
for punct_wrd, label in full_pred: | |
punct_wrd += self.label_to_punct[label] | |
if punct_wrd.endswith("‘‘"): | |
punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘" | |
punct_resp.append(punct_wrd) | |
punct_resp = " ".join(punct_resp) | |
if punct_resp[-1].isalnum(): | |
punct_resp += "۔" | |
return punct_resp | |
class Urdu: | |
def __init__(self): | |
self.model = BERTmodel() | |
def punctuate(self, data: str): | |
return self.model.punctuate(data) | |