mturan commited on
Commit
48b5e1d
1 Parent(s): 8c82d61

Add application file

Browse files
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential gcc \
7
+ && apt-get clean \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ RUN pip install --no-cache-dir Cython
11
+
12
+ COPY ./requirements.txt /code/
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY ./multi_lingual.py ./urdu_punkt.py ./main.py /code/
16
+ COPY ./models/ /code/models/
17
+
18
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+ from fastapi import FastAPI
5
+ from urdu_punkt import Urdu
6
+ from multi_lingual import MultiLingual
7
+ from langdetect import detect, DetectorFactory
8
+
9
+ CUSTOM_PATH = "/punctuate"
10
+ DetectorFactory.seed = 42
11
+
12
+ app = FastAPI()
13
+ nemo_model = Urdu()
14
+ multi_model = MultiLingual()
15
+
16
+
17
+ def punctuate(text: str) -> str:
18
+ if detect(text) == "ur":
19
+ return nemo_model.punctuate(text)
20
+ else:
21
+ return multi_model.punctuate(text)
22
+
23
+
24
+ title = "SELMA H2020 — Multilingual Punctuation & Casing Prediction"
25
+ description = "Supported languages are: Amharic, Bengali, German, English, Spanish, French, Hindi, Italian, Latvian, Pashto, Portuguese, Russian, Tamil and Urdu."
26
+ article = "<p style='text-align: center'><a href='https://selma-project.eu' target='_blank'>SELMA-H2020</a></p>"
27
+
28
+ text_input = gr.Textbox(label="Enter some text")
29
+ result_output = gr.Textbox(label="Result")
30
+
31
+ io = gr.Interface(
32
+ fn=punctuate,
33
+ title=title,
34
+ description=description,
35
+ article=article,
36
+ theme=gr.themes.Soft(),
37
+ inputs=text_input,
38
+ outputs=result_output,
39
+ allow_flagging="never",
40
+ css="footer {visibility: hidden}",
41
+ )
42
+
43
+ app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH)
models/multilingual/config.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # am: Amharic
2
+ # bn: Bengali
3
+ # de: German
4
+ # en: English
5
+ # es: Spanish
6
+ # fr: French
7
+ # hi: Hindi
8
+ # it: Italian
9
+ # lv: Latvian
10
+ # ps: Pashto
11
+ # pt: Portuguese
12
+ # ru: Russian
13
+ # ta: Tamil
14
+
15
+ languages: ["am", "bn", "de", "en", "es", "fr", "hi", "it", "lv", "ps", "pt", "ru", "ta"]
16
+
17
+ max_length: 256
18
+
19
+ # just for Spanish
20
+ pre_labels: [
21
+ "<NULL>",
22
+ "¿",
23
+ ]
24
+
25
+ post_labels: [
26
+ "<NULL>",
27
+ "<ACRONYM>",
28
+ ".",
29
+ ",",
30
+ "?",
31
+ "?",
32
+ ",",
33
+ "。",
34
+ "、",
35
+ "・",
36
+ "।",
37
+ "؟",
38
+ "،",
39
+ ";",
40
+ "።",
41
+ "፣",
42
+ "፧",
43
+ ]
models/multilingual/nemo_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c43ca686dabc237c3b06be834b9423c07580fef7e2b1a6c09976f7d60caa5d89
3
+ size 1112481438
models/multilingual/xlm_roberta_encoding.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f944d0be93b275f62e1913fd409f378ddbba108e57fe4a9cb47e8c047f6bef1
3
+ size 5069059
models/urdu/config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bert-base-multilingual-cased",
3
+ "architectures": [
4
+ "BertForTokenClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "O",
14
+ "1": "F",
15
+ "2": "C",
16
+ "3": "Q"
17
+ },
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 3072,
20
+ "label2id": {
21
+ "C": 2,
22
+ "F": 1,
23
+ "O": 0,
24
+ "Q": 3
25
+ },
26
+ "layer_norm_eps": 1e-12,
27
+ "max_position_embeddings": 512,
28
+ "model_type": "bert",
29
+ "num_attention_heads": 12,
30
+ "num_hidden_layers": 12,
31
+ "pad_token_id": 0,
32
+ "pooler_fc_size": 768,
33
+ "pooler_num_attention_heads": 12,
34
+ "pooler_num_fc_layers": 3,
35
+ "pooler_size_per_head": 128,
36
+ "pooler_type": "first_token_transform",
37
+ "position_embedding_type": "absolute",
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.33.1",
40
+ "type_vocab_size": 2,
41
+ "use_cache": true,
42
+ "vocab_size": 119547
43
+ }
models/urdu/model_args.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"adafactor_beta1": null, "adafactor_clip_threshold": 1.0, "adafactor_decay_rate": -0.8, "adafactor_eps": [1e-30, 0.001], "adafactor_relative_step": true, "adafactor_scale_parameter": true, "adafactor_warmup_init": true, "adam_betas": [0.9, 0.999], "adam_epsilon": 1e-08, "best_model_dir": "./titanen_outputs/best_model/", "cache_dir": "./titanen_cache/", "config": {}, "cosine_schedule_num_cycles": 0.5, "custom_layer_parameters": [], "custom_parameter_groups": [], "dataloader_num_workers": 0, "do_lower_case": false, "dynamic_quantize": false, "early_stopping_consider_epochs": false, "early_stopping_delta": 0, "early_stopping_metric": "eval_loss", "early_stopping_metric_minimize": true, "early_stopping_patience": 3, "encoding": null, "eval_batch_size": 8, "evaluate_during_training": true, "evaluate_during_training_silent": true, "evaluate_during_training_steps": 2000, "evaluate_during_training_verbose": true, "evaluate_each_epoch": true, "fp16": true, "gradient_accumulation_steps": 1, "learning_rate": 4e-05, "local_rank": -1, "logging_steps": 50, "loss_type": null, "loss_args": {}, "manual_seed": 42, "max_grad_norm": 1.0, "max_seq_length": 512, "model_name": "bert-base-multilingual-cased", "model_type": "bert", "multiprocessing_chunksize": -1, "n_gpu": 1, "no_cache": false, "no_save": false, "not_saved_args": [], "num_train_epochs": 3, "optimizer": "AdamW", "output_dir": "./titanen_outputs/", "overwrite_output_dir": true, "polynomial_decay_schedule_lr_end": 1e-07, "polynomial_decay_schedule_power": 1.0, "process_count": 18, "quantized_model": false, "reprocess_input_data": false, "save_best_model": true, "save_eval_checkpoints": false, "save_model_every_epoch": true, "save_optimizer_and_scheduler": true, "save_steps": -1, "scheduler": "linear_schedule_with_warmup", "silent": false, "skip_special_tokens": true, "tensorboard_dir": null, "thread_count": null, "tokenizer_name": null, "tokenizer_type": null, "train_batch_size": 8, "train_custom_parameters_only": false, "use_cached_eval_features": false, "use_early_stopping": false, "use_hf_datasets": false, "use_multiprocessing": true, "use_multiprocessing_for_evaluation": true, "wandb_kwargs": {"name": "bert-base-multilingual-titanen", "entity": "tugtekin", "notes": "Training punctuation prediction using BERT.", "tags": ["urdu", "bert", "punctuation"]}, "wandb_project": "urdu-punctuation", "warmup_ratio": 0.06, "warmup_steps": 30073, "weight_decay": 0.0, "model_class": "NERModel", "classification_report": false, "labels_list": ["O", "F", "C", "Q"], "lazy_loading": true, "lazy_loading_start_line": 0, "onnx": false, "special_tokens_list": []}
models/urdu/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:026c4bec45da24b0b7aac90ce21be9daa184910ac543f31df32ae34e9a9ce73b
3
+ size 1418293317
models/urdu/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39fdcf468bb421bcfc6f9028c220032130f9e4d5f685b6c293d4c7484d41ff29
3
+ size 709131433
models/urdu/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:091e952a51388e2a6e71cfa983884d27b6b07a02767ade8eb751f3348703458b
3
+ size 627
models/urdu/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
models/urdu/tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "cls_token": "[CLS]",
4
+ "do_basic_tokenize": true,
5
+ "do_lower_case": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "never_split": null,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
models/urdu/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf79804213bb807517d0519a6bcc18e290c8770eda02d4538f4a90507cfce545
3
+ size 3259
models/urdu/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
multi_lingual.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import torch
5
+ import string
6
+ import onnxruntime as ort
7
+ from dataclasses import dataclass
8
+ from omegaconf import OmegaConf
9
+ from typing import List, Optional, Union, Dict
10
+ from sentencepiece import SentencePieceProcessor
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from typing import Iterator, List, Iterable, Tuple
13
+
14
+ ACRONYM_TOKEN = "<ACRONYM>"
15
+ torch.set_grad_enabled(False)
16
+ torch.backends.cudnn.enabled = False
17
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
18
+
19
+
20
+ @dataclass
21
+ class PunctCapConfigONNX:
22
+ spe_filename: str = "xlm_roberta_encoding.model"
23
+ model_filename: str = "nemo_model.onnx"
24
+ config_filename: str = "config.yaml"
25
+ directory: Optional[str] = None
26
+
27
+
28
+ class PunctCapModelONNX:
29
+ def __init__(self, cfg: PunctCapConfigONNX):
30
+ self._spe_path = os.path.join(cfg.directory, cfg.spe_filename)
31
+ onnx_path = os.path.join(cfg.directory, cfg.model_filename)
32
+ config_path = os.path.join(cfg.directory, cfg.config_filename)
33
+
34
+ self._tokenizer: SentencePieceProcessor = SentencePieceProcessor(self._spe_path)
35
+ self._ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path)
36
+ self._config = OmegaConf.load(config_path)
37
+ self._max_len = self._config.max_length
38
+ self._pre_labels: List[str] = self._config.pre_labels
39
+ self._post_labels: List[str] = self._config.post_labels
40
+ self._languages: List[str] = self._config.languages
41
+ self._null_token = self._config.get("null_token", "<NULL>")
42
+
43
+ def _setup_dataloader(self, texts: List[str], batch_size_tokens: int, overlap: int) -> DataLoader:
44
+ dataset: TextInferenceDataset = TextInferenceDataset(
45
+ texts=texts,
46
+ batch_size_tokens=batch_size_tokens,
47
+ overlap=overlap,
48
+ max_length=self._max_len,
49
+ spe_model_path=self._spe_path,
50
+ )
51
+ return DataLoader(
52
+ dataset=dataset,
53
+ collate_fn=dataset.collate_fn,
54
+ batch_sampler=dataset.sampler,
55
+ )
56
+
57
+ def punctuation_removal(self, texts: List[str]) -> List[str]:
58
+ punkt = string.punctuation + """`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…–ـ""" + """!?。。"""
59
+ punkt = punkt.replace("-", "")
60
+ punkt = punkt.replace("'", "")
61
+ punkt += "„“"
62
+ return [text.translate(str.maketrans("", "", punkt)).lower().strip() for text in texts]
63
+
64
+ def infer(
65
+ self,
66
+ texts: List[str],
67
+ apply_sbd: bool = False,
68
+ batch_size_tokens: int = 4096,
69
+ overlap: int = 16,
70
+ ) -> Union[List[str], List[List[str]]]:
71
+ texts = self.punctuation_removal(texts)
72
+
73
+ collectors: List[PunctCapCollector] = [
74
+ PunctCapCollector(sp_model=self._tokenizer, apply_sbd=apply_sbd, overlap=overlap)
75
+ for _ in range(len(texts))
76
+ ]
77
+ dataloader: DataLoader = self._setup_dataloader(texts=texts, batch_size_tokens=batch_size_tokens, overlap=overlap)
78
+ for batch in dataloader:
79
+ input_ids, batch_indices, input_indices, lengths = batch
80
+ pre_preds, post_preds, cap_preds, seg_preds = self._ort_session.run(None, {"input_ids": input_ids.numpy()})
81
+ batch_size = input_ids.shape[0]
82
+ for i in range(batch_size):
83
+ length = lengths[i].item()
84
+ batch_idx = batch_indices[i].item()
85
+ input_idx = input_indices[i].item()
86
+ segment_ids = input_ids[i, 1 : length - 1].tolist()
87
+ segment_pre_preds = pre_preds[i, 1 : length - 1].tolist()
88
+ segment_post_preds = post_preds[i, 1 : length - 1].tolist()
89
+ segment_cap_preds = cap_preds[i, 1 : length - 1].tolist()
90
+ segment_sbd_preds = seg_preds[i, 1 : length - 1].tolist()
91
+ pre_tokens = [self._pre_labels[i] for i in segment_pre_preds]
92
+ post_tokens = [self._post_labels[i] for i in segment_post_preds]
93
+ pre_tokens = [x if x != self._null_token else None for x in pre_tokens]
94
+ post_tokens = [x if x != self._null_token else None for x in post_tokens]
95
+ collectors[batch_idx].collect(
96
+ ids=segment_ids,
97
+ pre_preds=pre_tokens,
98
+ post_preds=post_tokens,
99
+ cap_preds=segment_cap_preds,
100
+ sbd_preds=segment_sbd_preds,
101
+ idx=input_idx,
102
+ )
103
+ outputs: Union[List[str], List[List[str]]] = [x.produce() for x in collectors]
104
+ return outputs
105
+
106
+
107
+ @dataclass
108
+ class TokenizedSegment:
109
+ input_ids: List[int]
110
+ batch_idx: int
111
+ input_idx: int
112
+
113
+ def __len__(self) -> int:
114
+ return len(self.input_ids)
115
+
116
+
117
+ class TokenBatchSampler(Iterable):
118
+ def __init__(self, segments: List[TokenizedSegment], batch_size_tokens: int):
119
+ self._batches = self._make_batches(segments, batch_size_tokens)
120
+
121
+ def _make_batches(self, segments: List[TokenizedSegment], batch_size_tokens: int) -> List[List[int]]:
122
+ segments_with_index = [(segment, i) for i, segment in enumerate(segments)]
123
+ segments_with_index.sort(key=lambda x: len(x[0]), reverse=True)
124
+
125
+ batches, current_batch_elements, current_max_len = [], [], 0
126
+
127
+ for segment, idx in segments_with_index:
128
+ potential_max_len = max(current_max_len, len(segment))
129
+
130
+ if potential_max_len * (len(current_batch_elements) + 1) > batch_size_tokens:
131
+ batches.append(current_batch_elements)
132
+ current_batch_elements, current_max_len = [], 0
133
+
134
+ current_batch_elements.append(idx)
135
+ current_max_len = potential_max_len
136
+
137
+ if current_batch_elements:
138
+ batches.append(current_batch_elements)
139
+
140
+ return batches
141
+
142
+ def __iter__(self) -> Iterator:
143
+ yield from self._batches
144
+
145
+ def __len__(self) -> int:
146
+ return len(self._batches)
147
+
148
+
149
+ class TextInferenceDataset(Dataset):
150
+ def __init__(
151
+ self,
152
+ texts: List[str],
153
+ spe_model_path: str,
154
+ batch_size_tokens: int = 4096,
155
+ max_length: int = 512,
156
+ overlap: int = 32,
157
+ ):
158
+ self._spe_model = SentencePieceProcessor(spe_model_path)
159
+ self._segments = self._tokenize_inputs(texts, max_length, overlap)
160
+ self._sampler = TokenBatchSampler(self._segments, batch_size_tokens)
161
+
162
+ @property
163
+ def sampler(self) -> Iterable:
164
+ return self._sampler
165
+
166
+ def _tokenize_inputs(self, texts: List[str], max_len: int, overlap: int) -> List[TokenizedSegment]:
167
+ max_len -= 2
168
+ segments = []
169
+
170
+ for batch_idx, text in enumerate(texts):
171
+ ids, start, input_idx = self._spe_model.EncodeAsIds(text), 0, 0
172
+
173
+ while start < len(ids):
174
+ adjusted_start = start - overlap if input_idx else 0
175
+ segments.append(
176
+ TokenizedSegment(
177
+ ids[adjusted_start : adjusted_start + max_len],
178
+ batch_idx,
179
+ input_idx,
180
+ )
181
+ )
182
+ start += max_len - overlap
183
+ input_idx += 1
184
+
185
+ return segments
186
+
187
+ def __len__(self) -> int:
188
+ return len(self._segments)
189
+
190
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]:
191
+ segment = self._segments[idx]
192
+ input_ids = torch.Tensor([self._spe_model.bos_id(), *segment.input_ids, self._spe_model.eos_id()])
193
+ return input_ids, segment.batch_idx, segment.input_idx
194
+
195
+ def collate_fn(self, batch: List[Tuple[torch.Tensor, int, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
196
+ input_ids = [x[0] for x in batch]
197
+ lengths = torch.tensor([x.shape[0] for x in input_ids])
198
+ max_len = lengths.max().item()
199
+
200
+ batched_ids = torch.full((len(input_ids), max_len), self._spe_model.pad_id())
201
+ for idx, ids in enumerate(input_ids):
202
+ batched_ids[idx, : lengths[idx]] = ids
203
+
204
+ return (
205
+ batched_ids,
206
+ torch.tensor([x[1] for x in batch]),
207
+ torch.tensor([x[2] for x in batch]),
208
+ lengths,
209
+ )
210
+
211
+
212
+ @dataclass
213
+ class PCSegment:
214
+ ids: List[int]
215
+ pre_preds: List[Optional[str]]
216
+ post_preds: List[Optional[str]]
217
+ cap_preds: List[List[int]]
218
+ sbd_preds: List[int]
219
+
220
+ def __len__(self):
221
+ return len(self.ids)
222
+
223
+
224
+ class PunctCapCollector:
225
+ def __init__(self, apply_sbd: bool, overlap: int, sp_model: SentencePieceProcessor):
226
+ self._segments: Dict[int, PCSegment] = {}
227
+ self._apply_sbd = apply_sbd
228
+ self._overlap = overlap
229
+ self._sp_model = sp_model
230
+
231
+ def collect(
232
+ self,
233
+ ids: List[int],
234
+ pre_preds: List[Optional[str]],
235
+ post_preds: List[Optional[str]],
236
+ sbd_preds: List[int],
237
+ cap_preds: List[List[int]],
238
+ idx: int,
239
+ ):
240
+ self._segments[idx] = PCSegment(
241
+ ids=ids,
242
+ pre_preds=pre_preds,
243
+ post_preds=post_preds,
244
+ sbd_preds=sbd_preds,
245
+ cap_preds=cap_preds,
246
+ )
247
+
248
+ def produce(self) -> Union[List[str], str]:
249
+ ids: List[int] = []
250
+ pre_preds: List[Optional[str]] = []
251
+ post_preds: List[Optional[str]] = []
252
+ cap_preds: List[List[int]] = []
253
+ sbd_preds: List[int] = []
254
+
255
+ for i in range(len(self._segments)):
256
+ segment = self._segments[i]
257
+ start = 0
258
+ stop = len(segment)
259
+ if i > 0:
260
+ start += self._overlap // 2
261
+ if i < len(self._segments) - 1:
262
+ stop -= self._overlap // 2
263
+
264
+ ids.extend(segment.ids[start:stop])
265
+ pre_preds.extend(segment.pre_preds[start:stop])
266
+ post_preds.extend(segment.post_preds[start:stop])
267
+ sbd_preds.extend(segment.sbd_preds[start:stop])
268
+ cap_preds.extend(segment.cap_preds[start:stop])
269
+
270
+ input_tokens = [self._sp_model.IdToPiece(x) for x in ids]
271
+ output_texts: List[str] = []
272
+ current_chars: List[str] = []
273
+
274
+ for token_idx, token in enumerate(input_tokens):
275
+ if token.startswith("▁") and current_chars:
276
+ current_chars.append(" ")
277
+ char_start = 1 if token.startswith("▁") else 0
278
+
279
+ for token_char_idx, char in enumerate(token[char_start:], start=char_start):
280
+ if token_char_idx == char_start and pre_preds[token_idx] is not None:
281
+ current_chars.append(pre_preds[token_idx])
282
+ if cap_preds[token_idx][token_char_idx]:
283
+ char = char.upper()
284
+ current_chars.append(char)
285
+
286
+ label = post_preds[token_idx]
287
+ if label == ACRONYM_TOKEN:
288
+ current_chars.append(".")
289
+ elif token_char_idx == len(token) - 1 and post_preds[token_idx] is not None:
290
+ current_chars.append(post_preds[token_idx])
291
+ if self._apply_sbd and token_char_idx == len(token) - 1 and sbd_preds[token_idx]:
292
+ output_texts.append("".join(current_chars))
293
+ current_chars = []
294
+
295
+ if current_chars:
296
+ output_texts.append("".join(current_chars))
297
+ if not self._apply_sbd:
298
+ if len(output_texts) > 1:
299
+ raise ValueError(f"Not applying SBD but got more than one result: {output_texts}")
300
+ return output_texts[0]
301
+ return output_texts
302
+
303
+
304
+ class MultiLingual:
305
+ def __init__(self):
306
+ cfg = PunctCapConfigONNX(directory="/code/models/multilingual")
307
+ self._punctuator = PunctCapModelONNX(cfg)
308
+
309
+ def punctuate(self, data: str) -> str:
310
+ return self._punctuator.infer([data])[0]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ fastapi==0.103.1
3
+ gradio==3.44.3
4
+ langdetect==1.0.9
5
+ onnxruntime==1.15.1
6
+ omegaconf==2.3.0
7
+ pandas==2.1.0
8
+ six==1.16.0
9
+ simpletransformers==0.64.3
10
+ tensorflow-datasets==4.9.3
11
+ torch==1.13.1+cpu
urdu_punkt.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import string
6
+
7
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
8
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
9
+
10
+ from simpletransformers.ner import NERModel
11
+
12
+
13
+ class BERTmodel:
14
+ def __init__(self, normalization="full", wrds_per_pred=256):
15
+ self.normalization = normalization
16
+ self.wrds_per_pred = wrds_per_pred
17
+ self.overlap_wrds = 32
18
+ self.valid_labels = ["O", "F", "C", "Q"]
19
+ self.label_to_punct = {"F": "۔", "C": "،", "Q": "؟", "O": ""}
20
+ self.model = NERModel(
21
+ "bert",
22
+ "/code/models/urdu",
23
+ use_cuda=False,
24
+ labels=self.valid_labels,
25
+ args={"silent": True, "max_seq_length": 512},
26
+ )
27
+ self.patterns = {
28
+ "partial": r"[ً-٠ٰ۟-ۤۧ-۪ۨ-ۭ،۔؟]+",
29
+ "full": string.punctuation + "،؛؟۔٪ء‘’",
30
+ }
31
+
32
+ def punctuation_removal(self, text: str) -> str:
33
+ if self.normalization == "partial":
34
+ return re.sub(self.patterns[self.normalization], "", text).strip()
35
+ else:
36
+ return "".join(ch for ch in text if ch not in self.patterns[self.normalization])
37
+
38
+ def punctuate(self, text: str):
39
+ text = self.punctuation_removal(text)
40
+ splits = self.split_on_tokens(text)
41
+ full_preds_lst = [self.predict(i["text"]) for i in splits]
42
+ preds_lst = [i[0][0] for i in full_preds_lst]
43
+ combined_preds = self.combine_results(text, preds_lst)
44
+ punct_text = self.punctuate_texts(combined_preds)
45
+ return punct_text
46
+
47
+ def predict(self, input_slice):
48
+ return self.model.predict([input_slice])
49
+
50
+ def split_on_tokens(self, text):
51
+ wrds = text.replace("\n", " ").split()
52
+ response = []
53
+ lst_chunk_idx = 0
54
+ i = 0
55
+
56
+ while True:
57
+ wrds_len = wrds[i * self.wrds_per_pred : (i + 1) * self.wrds_per_pred]
58
+ wrds_ovlp = wrds[
59
+ (i + 1) * self.wrds_per_pred : (i + 1) * self.wrds_per_pred + self.overlap_wrds
60
+ ]
61
+ wrds_split = wrds_len + wrds_ovlp
62
+
63
+ if not wrds_split:
64
+ break
65
+
66
+ response_obj = {
67
+ "text": " ".join(wrds_split),
68
+ "start_idx": lst_chunk_idx,
69
+ "end_idx": lst_chunk_idx + len(" ".join(wrds_len)),
70
+ }
71
+
72
+ response.append(response_obj)
73
+ lst_chunk_idx += response_obj["end_idx"] + 1
74
+ i += 1
75
+
76
+ return response
77
+
78
+ def combine_results(self, full_text: str, text_slices):
79
+ split_full_text = full_text.replace("\n", " ").split(" ")
80
+ split_full_text = [i for i in split_full_text if i]
81
+ split_full_text_len = len(split_full_text)
82
+ output_text = []
83
+ index = 0
84
+
85
+ if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
86
+ text_slices = text_slices[:-1]
87
+
88
+ for slice in text_slices:
89
+ slice_wrds = len(slice)
90
+ for ix, wrd in enumerate(slice):
91
+ if index == split_full_text_len:
92
+ break
93
+
94
+ if (
95
+ split_full_text[index] == str(list(wrd.keys())[0])
96
+ and ix <= slice_wrds - 3
97
+ and text_slices[-1] != slice
98
+ ):
99
+ index += 1
100
+ pred_item_tuple = list(wrd.items())[0]
101
+ output_text.append(pred_item_tuple)
102
+ elif (
103
+ split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == slice
104
+ ):
105
+ index += 1
106
+ pred_item_tuple = list(wrd.items())[0]
107
+ output_text.append(pred_item_tuple)
108
+
109
+ assert [i[0] for i in output_text] == split_full_text
110
+ return output_text
111
+
112
+ def punctuate_texts(self, full_pred: list):
113
+ punct_resp = []
114
+ for punct_wrd, label in full_pred:
115
+ punct_wrd += self.label_to_punct[label]
116
+ if punct_wrd.endswith("‘‘"):
117
+ punct_wrd = punct_wrd[:-2] + self.label_to_punct[label] + "‘‘"
118
+ punct_resp.append(punct_wrd)
119
+
120
+ punct_resp = " ".join(punct_resp)
121
+ if punct_resp[-1].isalnum():
122
+ punct_resp += "۔"
123
+
124
+ return punct_resp
125
+
126
+
127
+ class Urdu:
128
+ def __init__(self):
129
+ self.model = BERTmodel()
130
+
131
+ def punctuate(self, data: str):
132
+ return self.model.punctuate(data)