punctcap / urdu_punkt.py
mturan's picture
Add application file
48b5e1d
raw
history blame
4.42 kB
#!/usr/bin/env python
import os
import re
import string
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
from simpletransformers.ner import NERModel
class BERTmodel:
def __init__(self, normalization="full", wrds_per_pred=256):
self.normalization = normalization
self.wrds_per_pred = wrds_per_pred
self.overlap_wrds = 32
self.valid_labels = ["O", "F", "C", "Q"]
self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""}
self.model = NERModel(
"bert",
"/code/models/urdu",
use_cuda=False,
labels=self.valid_labels,
args={"silent": True, "max_seq_length": 512},
)
self.patterns = {
"partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+",
"full": string.punctuation + "،؛؟۔٪ء‘’",
}
def punctuation_removal(self, text: str) -> str:
if self.normalization == "partial":
return re.sub(self.patterns[self.normalization], "", text).strip()
else:
return "".join(ch for ch in text if ch not in self.patterns[self.normalization])
def punctuate(self, text: str):
text = self.punctuation_removal(text)
splits = self.split_on_tokens(text)
full_preds_lst = [self.predict(i["text"]) for i in splits]
preds_lst = [i[0][0] for i in full_preds_lst]
combined_preds = self.combine_results(text, preds_lst)
punct_text = self.punctuate_texts(combined_preds)
return punct_text
def predict(self, input_slice):
return self.model.predict([input_slice])
def split_on_tokens(self, text):
wrds = text.replace("\n", " ").split()
response = []
lst_chunk_idx = 0
i = 0
while True:
wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred]
wrds_ovlp = wrds[
(i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds
]
wrds_split = wrds_len + wrds_ovlp
if not wrds_split:
break
response_obj = {
"text": " ".join(wrds_split),
"start_idx": lst_chunk_idx,
"end_idx": lst_chunk_idx + len(" ".join(wrds_len)),
}
response.append(response_obj)
lst_chunk_idx += response_obj["end_idx"] + 1
i += 1
return response
def combine_results(self, full_text: str, text_slices):
split_full_text = full_text.replace("\n", " ").split(" ")
split_full_text = [i for i in split_full_text if i]
split_full_text_len = len(split_full_text)
output_text = []
index = 0
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
text_slices = text_slices[:-1]
for slice in text_slices:
slice_wrds = len(slice)
for ix, wrd in enumerate(slice):
if index == split_full_text_len:
break
if (
split_full_text[index] == str(list(wrd.keys())[0])
and ix <= slice_wrds - 3
and text_slices[-1] != slice
):
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
elif (
split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice
):
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
assert [i[0] for i in output_text] == split_full_text
return output_text
def punctuate_texts(self, full_pred: list):
punct_resp = []
for punct_wrd, label in full_pred:
punct_wrd += self.label_to_punct[label]
if punct_wrd.endswith("‘‘"):
punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘"
punct_resp.append(punct_wrd)
punct_resp = " ".join(punct_resp)
if punct_resp[-1].isalnum():
punct_resp += "۔"
return punct_resp
class Urdu:
def __init__(self):
self.model = BERTmodel()
def punctuate(self, data: str):
return self.model.punctuate(data)