Arabic2Taatik / app.py
guymorlan's picture
Update app.py
eeedeab verified
from torch import nn
from transformers import CanineModel, CanineForTokenClassification, CaninePreTrainedModel, CanineTokenizer
from transformers.modeling_outputs import TokenClassifierOutput
import gradio as gr
arabic_to_hebrew = {
# regular letters
"ا": "א", "أ": "א", "إ": "א", "ء": "א", "ئ": "א", "ؤ": "א",
"آ": "אא", "ى": "א", "ب": "ב", "ت": "ת", "ث": "ת'", "ج": "ג'",
"ح": "ח", "خ": "ח'", "د": "ד", "ذ": "ד'", "ر": "ר", "ز": "ז",
"س": "ס", "ش": "ש", "ص": "צ", "ض": "צ'", "ط": "ט", "ظ": "ט'",
"ع": "ע", "غ": "ע'", "ف": "פ", "ق": "ק", "ك": "כ", "ل": "ל",
"م": "מ", "ن": "נ", "ه": "ה", "و": "ו", "ي": "י", "ة": "ה",
# special characters
"،": ",", "َ": "ַ", "ُ": "ֻ", "ِ": "ִ",
}
final_letters = {
"ن": "ן", "م": "ם", "ص": "ץ", "ض": "ץ'", "ف": "ף",
}
def to_taatik(arabic):
taatik = []
for index, letter in enumerate(arabic):
if (
(index == len(arabic) - 1 or arabic[index + 1] in {" ", ".", "،"}) and
letter in final_letters
):
taatik.append(final_letters[letter])
elif letter not in arabic_to_hebrew:
taatik.append(letter)
else:
taatik.append(arabic_to_hebrew[letter])
return taatik
class TaatikModel(CaninePreTrainedModel):
# based on CaninePreTrainedModel
# slightly modified for multilabel classification
def __init__(self, config, num_labels=7):
# Note: one label for each nikud type, plus one for the deletion flag
super().__init__(config)
config.num_labels = num_labels
self.num_labels = config.num_labels
self.canine = CanineModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
self.criterion = nn.BCEWithLogitsLoss()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
outputs = self.canine(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# print(logits)
# print("-----------")
# print(labels)
loss = self.criterion(logits, labels)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# tokenizer = CanineTokenizer.from_pretrained("google/canine-c")
# model = TashkeelModel.from_pretrained("google/canine-c")
tokenizer = CanineTokenizer.from_pretrained("google/canine-s")
# model = TaatikModel.from_pretrained("google/canine-s")
# model = TaatikModel.from_pretrained("./checkpoint-19034/")
model = TaatikModel.from_pretrained("guymorlan/Arabic2Taatik")
def convert_nikkud_to_harakat(nikkud):
labels = []
if "SHADDA" in nikkud:
labels.append("SHADDA")
if "TSERE" in nikkud:
labels.append("KASRA")
if "HOLAM" in nikkud:
labels.append("DAMMA")
if "PATACH" in nikkud:
labels.append("FATHA")
if "SHVA" in nikkud:
labels.append("SUKUN")
if "KUBUTZ" in nikkud:
labels.append("DAMMA")
if "HIRIQ" in nikkud:
labels.append("KASRA")
return labels
def convert_binary_to_labels(binary_labels):
labels = []
if binary_labels[0] == 1:
labels.append("SHADDA")
if binary_labels[1] == 1:
labels.append("TSERE")
if binary_labels[2] == 1:
labels.append("HOLAM")
if binary_labels[3] == 1:
labels.append("PATACH")
if binary_labels[4] == 1:
labels.append("SHVA")
if binary_labels[5] == 1:
labels.append("KUBUTZ")
if binary_labels[6] == 1:
labels.append("HIRIQ")
return labels
def convert_label_names_to_chars(label):
if label == "SHADDA":
return "ّ"
if label == "TSERE":
return "ֵ"
if label == "HOLAM":
return "ֹ"
if label == "PATACH":
return "ַ"
if label == "SHVA":
return "ְ"
if label == "KUBUTZ":
return "ֻ"
if label == "HIRIQ":
return "ִ"
# for these, return arabic harakat
if label == "DAMMA":
return "ُ"
if label == "KASRA":
return "ِ"
if label == "FATHA":
return "َ"
if label == "SUKUN":
return "ْ"
return ""
def predict(input, prefix = "P "):
print(input)
input_tok = tokenizer(prefix+input, return_tensors="pt")
print(input_tok)
outputs = model(**input_tok)
print(outputs)
labels = outputs.logits.sigmoid().round().int()
labels = labels.tolist()[0][3:-1]
print(labels)
labels_hebrew = [convert_binary_to_labels(x) for x in labels]
labels_arabic = [convert_nikkud_to_harakat(x) for x in labels_hebrew]
print(f"labels_hebrew: {labels_hebrew}")
print(f"labels_arabic: {labels_arabic}")
hebrew = [[x] for x in to_taatik(input)]
print(hebrew)
arabic = [[x] for x in input]
print(arabic)
print(f"len hebrew: {len(hebrew)}")
print(f"len arabic: {len(arabic)}")
print(f"len labels_hebrew: {len(labels_hebrew)}")
print(f"len labels_arabic: {len(labels_arabic)}")
print(f"labels: {labels}")
print(f"labels_hebrew: {labels_hebrew}")
print(f"labels_arabic: {labels_arabic}")
for i in range(len(hebrew)):
hebrew[i].extend([convert_label_names_to_chars(x) for x in labels_hebrew[i]])
arabic[i].extend([convert_label_names_to_chars(x) for x in labels_arabic[i]])
hebrew = ["".join(x) for x in hebrew]
arabic = ["".join(x) for x in arabic]
# loop over hebrew, if there is a ' in the second position move it to last position
for i in range(len(hebrew)):
if len(hebrew[i]) > 1 and hebrew[i][1] == "'":
hebrew[i] = hebrew[i][0] + hebrew[i][2:] + hebrew[i][1]
hebrew = "".join(hebrew)
arabic = "".join(arabic)
return f"<p dir='rtl' style='font-size: 1.5em; font-family: Arial Unicode MS;'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em; font-family: Noto;'>{arabic}</p>"
font = "Arial Unicode MS, Tahoma, sans-serif"
return f"<p dir='rtl' style='font-size: 1.5em; font-family: {font};'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em; font-family: {font};'>{arabic}</p>"
return f"<p dir='rtl' style='font-size: 1.5em; font-family: Heebo;'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em; font-family: Heebo;'>{arabic}</p>"
# return f"<p dir='rtl' style='font-size: 1.5em'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em'>{arabic}</p>"
font_url = "<link href='https://fonts.googleapis.com/css2?family=Heebo&display=swap' rel='stylesheet'>"
with gr.Blocks(theme=gr.themes.Soft(), title="Ammiya Diacritizer") as demo:
gr.HTML("<h2><span style='color: #2563eb'>Colloquial Arabic</span></h2> Diacritizer and Hebrew Transliterator" + font_url)
with gr.Row():
with gr.Column():
input = gr.Textbox(label="Input", placeholder="Enter Arabic text", lines=1)
gr.Examples(["بديش اروح معك"], input)
btn = gr.Button(label="Analyze")
with gr.Column():
with gr.Box():
html = gr.HTML()
btn.click(predict, inputs=[input], outputs=[html])
input.submit(predict, inputs = [input], outputs=[html])
demo.load()
demo.launch()