#!/usr/bin/env python import os import re import string os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" from simpletransformers.ner import NERModel class BERTmodel: def __init__(self, normalization="full", wrds_per_pred=256): self.normalization = normalization self.wrds_per_pred = wrds_per_pred self.overlap_wrds = 32 self.valid_labels = ["O", "F", "C", "Q"] self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""} self.model = NERModel( "bert", "/code/models/urdu", use_cuda=False, labels=self.valid_labels, args={"silent": True, "max_seq_length": 512}, ) self.patterns = { "partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+", "full": string.punctuation + "،؛؟۔٪ء‘’", } def punctuation_removal(self, text: str) -> str: if self.normalization == "partial": return re.sub(self.patterns[self.normalization], "", text).strip() else: return "".join(ch for ch in text if ch not in self.patterns[self.normalization]) def punctuate(self, text: str): text = self.punctuation_removal(text) splits = self.split_on_tokens(text) full_preds_lst = [self.predict(i["text"]) for i in splits] preds_lst = [i[0][0] for i in full_preds_lst] combined_preds = self.combine_results(text, preds_lst) punct_text = self.punctuate_texts(combined_preds) return punct_text def predict(self, input_slice): return self.model.predict([input_slice]) def split_on_tokens(self, text): wrds = text.replace("\n", " ").split() response = [] lst_chunk_idx = 0 i = 0 while True: wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred] wrds_ovlp = wrds[ (i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds ] wrds_split = wrds_len + wrds_ovlp if not wrds_split: break response_obj = { "text": " ".join(wrds_split), "start_idx": lst_chunk_idx, "end_idx": lst_chunk_idx + len(" ".join(wrds_len)), } response.append(response_obj) lst_chunk_idx += response_obj["end_idx"] + 1 i += 1 return response def combine_results(self, full_text: str, text_slices): split_full_text = full_text.replace("\n", " ").split(" ") split_full_text = [i for i in split_full_text if i] split_full_text_len = len(split_full_text) output_text = [] index = 0 if len(text_slices[-1]) <= 3 and len(text_slices) > 1: text_slices = text_slices[:-1] for slice in text_slices: slice_wrds = len(slice) for ix, wrd in enumerate(slice): if index == split_full_text_len: break if ( split_full_text[index] == str(list(wrd.keys())[0]) and ix <= slice_wrds - 3 and text_slices[-1] != slice ): index += 1 pred_item_tuple = list(wrd.items())[0] output_text.append(pred_item_tuple) elif ( split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice ): index += 1 pred_item_tuple = list(wrd.items())[0] output_text.append(pred_item_tuple) assert [i[0] for i in output_text] == split_full_text return output_text def punctuate_texts(self, full_pred: list): punct_resp = [] for punct_wrd, label in full_pred: punct_wrd += self.label_to_punct[label] if punct_wrd.endswith("‘‘"): punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘" punct_resp.append(punct_wrd) punct_resp = " ".join(punct_resp) if punct_resp[-1].isalnum(): punct_resp += "۔" return punct_resp class Urdu: def __init__(self): self.model = BERTmodel() def punctuate(self, data: str): return self.model.punctuate(data)