Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,427 Bytes
4d6899e b36fd77 7efa162 bcc0c7f b36fd77 bcc0c7f 7efa162 bcc0c7f 7efa162 bcc0c7f 6b59652 6d41dc3 bcc0c7f 7264ba7 bcc0c7f 6d41dc3 bcc0c7f 6d41dc3 bcc0c7f 7efa162 6f55e41 d567ef9 7efa162 bcc0c7f 944e71d 7efa162 bcc0c7f 031ad86 bcc0c7f 7efa162 bcc0c7f 7efa162 e899574 d567ef9 6f55e41 7efa162 944e71d 7efa162 d567ef9 7efa162 d567ef9 7efa162 d567ef9 6f55e41 bc40199 e899574 7efa162 d567ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import spaces
import gradio as gr
from shakkala import Shakkala
from pathlib import Path
import torch
from eo_pl import TashkeelModel as TashkeelModelEO
from ed_pl import TashkeelModel as TashkeelModelED
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
# Initialize the Shakkala model
sh = Shakkala(version=3)
model, graph = sh.get_model()
def infer_shakkala(input_text):
input_int = sh.prepare_input(input_text)
logits = model.predict(input_int)[0]
predicted_harakat = sh.logits_to_text(logits)
final_output = sh.get_final_text(input_text, predicted_harakat)
print(final_output)
return final_output
# Initialize the CaTT model and tokenizer
tokenizer = TashkeelTokenizer()
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt'
device = 'cpu'
max_seq_len = 1024
print('Creating Model...')
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device))
eo_model.eval().to(device)
ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device))
ed_model.eval().to(device)
@spaces.GPU()
def infer_catt(input_text, choose_model):
input_text = remove_non_arabic(input_text)
batch_size = 16
verbose = True
if choose_model == 'Encoder-Only':
eo_model.to("cuda")
output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
else:
ed_model.to("cuda")
output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
return output_text[0]
examples = ["السلام عليكم ورحمة الله وبركاته", "العلم نور", "الحمد لله"]
with gr.Blocks(title="Arabic Tashkeel") as demo:
gr.HTML("<center><h1>Arabic Tashkeel</h1></center>")
gr.HTML(
"<center><p>Compare different methods for adding tashkeel to Arabic text.</p></center>"
)
with gr.Tab(label="CATT"):
gr.HTML("<center><h2>CATT: Character-based Arabic Tashkeel Transformer</h2></center>")
gr.HTML("<center><a href='https://github.com/abjadai/catt'>GitHub</a> - <a href='https://arxiv.org/abs/2407.03236'>Arxiv Paper</a></center>")
with gr.Row():
with gr.Column():
text_input1 = gr.Textbox(label="Input Text", rtl=True, text_align="right")
choose_model = gr.Radio(
label="Choose Model",
value="Encoder-Decoder",
choices=["Encoder-Only", "Encoder-Decoder"],
)
with gr.Row():
clear_button1 = gr.Button(value="Clear", variant="secondary")
submit_button1 = gr.Button(value="Add Tashkeel", variant="primary")
with gr.Column():
text_output1 = gr.Textbox(label="Output Text", rtl=True, text_align="right")
gr.Examples(examples, text_input1, cache_examples=False)
submit_button1.click(infer_catt, inputs=[text_input1, choose_model], outputs=text_output1)
clear_button1.click(lambda: ("", ""), outputs=[text_input1, text_output1])
with gr.Tab(label="Shakkala"):
gr.HTML("<center><h2>Shakkala: Arabic Diacritization</h2></center>")
gr.HTML("<center><a href='https://github.com/Barqawiz/Shakkala'>GitHub</a> - <a href='https://pypi.org/project/shakkala/'>PyPi Package</a></center>")
with gr.Row():
with gr.Column():
text_input2 = gr.Textbox(label="Input Text", rtl=True, text_align="right")
with gr.Row():
clear_button2 = gr.Button(value="Clear", variant="secondary")
submit_button2 = gr.Button(value="Apply Tashkeel", variant="primary")
with gr.Column():
text_output2 = gr.Textbox(
lines=1, label="Output Text", rtl=True, text_align="right"
)
submit_button2.click(infer_shakkala, inputs=text_input2, outputs=text_output2)
clear_button2.click(lambda: ("", ""), outputs=[text_input2, text_output2])
gr.Examples(examples, text_input2, cache_examples=False)
if __name__ == '__main__':
demo.queue().launch()
|