Spaces:
Runtime error
Runtime error
from torch import nn | |
from transformers import CanineModel, CanineForTokenClassification, CaninePreTrainedModel, CanineTokenizer | |
from transformers.modeling_outputs import TokenClassifierOutput | |
import gradio as gr | |
arabic_to_hebrew = { | |
# regular letters | |
"ا": "א", "أ": "א", "إ": "א", "ء": "א", "ئ": "א", "ؤ": "א", | |
"آ": "אא", "ى": "א", "ب": "ב", "ت": "ת", "ث": "ת'", "ج": "ג'", | |
"ح": "ח", "خ": "ח'", "د": "ד", "ذ": "ד'", "ر": "ר", "ز": "ז", | |
"س": "ס", "ش": "ש", "ص": "צ", "ض": "צ'", "ط": "ט", "ظ": "ט'", | |
"ع": "ע", "غ": "ע'", "ف": "פ", "ق": "ק", "ك": "כ", "ل": "ל", | |
"م": "מ", "ن": "נ", "ه": "ה", "و": "ו", "ي": "י", "ة": "ה", | |
# special characters | |
"،": ",", "َ": "ַ", "ُ": "ֻ", "ِ": "ִ", | |
} | |
final_letters = { | |
"ن": "ן", "م": "ם", "ص": "ץ", "ض": "ץ'", "ف": "ף", | |
} | |
def to_taatik(arabic): | |
taatik = [] | |
for index, letter in enumerate(arabic): | |
if ( | |
(index == len(arabic) - 1 or arabic[index + 1] in {" ", ".", "،"}) and | |
letter in final_letters | |
): | |
taatik.append(final_letters[letter]) | |
elif letter not in arabic_to_hebrew: | |
taatik.append(letter) | |
else: | |
taatik.append(arabic_to_hebrew[letter]) | |
return taatik | |
class TaatikModel(CaninePreTrainedModel): | |
# based on CaninePreTrainedModel | |
# slightly modified for multilabel classification | |
def __init__(self, config, num_labels=7): | |
# Note: one label for each nikud type, plus one for the deletion flag | |
super().__init__(config) | |
config.num_labels = num_labels | |
self.num_labels = config.num_labels | |
self.canine = CanineModel(config) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
# Initialize weights and apply final processing | |
self.post_init() | |
self.criterion = nn.BCEWithLogitsLoss() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
): | |
outputs = self.canine( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
logits = self.classifier(sequence_output) | |
loss = None | |
if labels is not None: | |
# print(logits) | |
# print("-----------") | |
# print(labels) | |
loss = self.criterion(logits, labels) | |
return TokenClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
# tokenizer = CanineTokenizer.from_pretrained("google/canine-c") | |
# model = TashkeelModel.from_pretrained("google/canine-c") | |
tokenizer = CanineTokenizer.from_pretrained("google/canine-s") | |
# model = TaatikModel.from_pretrained("google/canine-s") | |
# model = TaatikModel.from_pretrained("./checkpoint-19034/") | |
model = TaatikModel.from_pretrained("guymorlan/Arabic2Taatik") | |
def convert_nikkud_to_harakat(nikkud): | |
labels = [] | |
if "SHADDA" in nikkud: | |
labels.append("SHADDA") | |
if "TSERE" in nikkud: | |
labels.append("KASRA") | |
if "HOLAM" in nikkud: | |
labels.append("DAMMA") | |
if "PATACH" in nikkud: | |
labels.append("FATHA") | |
if "SHVA" in nikkud: | |
labels.append("SUKUN") | |
if "KUBUTZ" in nikkud: | |
labels.append("DAMMA") | |
if "HIRIQ" in nikkud: | |
labels.append("KASRA") | |
return labels | |
def convert_binary_to_labels(binary_labels): | |
labels = [] | |
if binary_labels[0] == 1: | |
labels.append("SHADDA") | |
if binary_labels[1] == 1: | |
labels.append("TSERE") | |
if binary_labels[2] == 1: | |
labels.append("HOLAM") | |
if binary_labels[3] == 1: | |
labels.append("PATACH") | |
if binary_labels[4] == 1: | |
labels.append("SHVA") | |
if binary_labels[5] == 1: | |
labels.append("KUBUTZ") | |
if binary_labels[6] == 1: | |
labels.append("HIRIQ") | |
return labels | |
def convert_label_names_to_chars(label): | |
if label == "SHADDA": | |
return "ّ" | |
if label == "TSERE": | |
return "ֵ" | |
if label == "HOLAM": | |
return "ֹ" | |
if label == "PATACH": | |
return "ַ" | |
if label == "SHVA": | |
return "ְ" | |
if label == "KUBUTZ": | |
return "ֻ" | |
if label == "HIRIQ": | |
return "ִ" | |
# for these, return arabic harakat | |
if label == "DAMMA": | |
return "ُ" | |
if label == "KASRA": | |
return "ِ" | |
if label == "FATHA": | |
return "َ" | |
if label == "SUKUN": | |
return "ْ" | |
return "" | |
def predict(input, prefix = "P "): | |
print(input) | |
input_tok = tokenizer(prefix+input, return_tensors="pt") | |
print(input_tok) | |
outputs = model(**input_tok) | |
print(outputs) | |
labels = outputs.logits.sigmoid().round().int() | |
labels = labels.tolist()[0][3:-1] | |
print(labels) | |
labels_hebrew = [convert_binary_to_labels(x) for x in labels] | |
labels_arabic = [convert_nikkud_to_harakat(x) for x in labels_hebrew] | |
print(f"labels_hebrew: {labels_hebrew}") | |
print(f"labels_arabic: {labels_arabic}") | |
hebrew = [[x] for x in to_taatik(input)] | |
print(hebrew) | |
arabic = [[x] for x in input] | |
print(arabic) | |
print(f"len hebrew: {len(hebrew)}") | |
print(f"len arabic: {len(arabic)}") | |
print(f"len labels_hebrew: {len(labels_hebrew)}") | |
print(f"len labels_arabic: {len(labels_arabic)}") | |
print(f"labels: {labels}") | |
print(f"labels_hebrew: {labels_hebrew}") | |
print(f"labels_arabic: {labels_arabic}") | |
for i in range(len(hebrew)): | |
hebrew[i].extend([convert_label_names_to_chars(x) for x in labels_hebrew[i]]) | |
arabic[i].extend([convert_label_names_to_chars(x) for x in labels_arabic[i]]) | |
hebrew = ["".join(x) for x in hebrew] | |
arabic = ["".join(x) for x in arabic] | |
# loop over hebrew, if there is a ' in the second position move it to last position | |
for i in range(len(hebrew)): | |
if len(hebrew[i]) > 1 and hebrew[i][1] == "'": | |
hebrew[i] = hebrew[i][0] + hebrew[i][2:] + hebrew[i][1] | |
hebrew = "".join(hebrew) | |
arabic = "".join(arabic) | |
return f"<p dir='rtl' style='font-size: 1.5em; font-family: Arial Unicode MS;'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em; font-family: Noto;'>{arabic}</p>" | |
font = "Arial Unicode MS, Tahoma, sans-serif" | |
return f"<p dir='rtl' style='font-size: 1.5em; font-family: {font};'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em; font-family: {font};'>{arabic}</p>" | |
return f"<p dir='rtl' style='font-size: 1.5em; font-family: Heebo;'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em; font-family: Heebo;'>{arabic}</p>" | |
# return f"<p dir='rtl' style='font-size: 1.5em'>{hebrew}</p><p dir='rtl' style='font-size: 1.5em'>{arabic}</p>" | |
font_url = "<link href='https://fonts.googleapis.com/css2?family=Heebo&display=swap' rel='stylesheet'>" | |
with gr.Blocks(theme=gr.themes.Soft(), title="Ammiya Diacritizer") as demo: | |
gr.HTML("<h2><span style='color: #2563eb'>Colloquial Arabic</span></h2> Diacritizer and Hebrew Transliterator" + font_url) | |
with gr.Row(): | |
with gr.Column(): | |
input = gr.Textbox(label="Input", placeholder="Enter Arabic text", lines=1) | |
gr.Examples(["بديش اروح معك"], input) | |
btn = gr.Button(label="Analyze") | |
with gr.Column(): | |
with gr.Box(): | |
html = gr.HTML() | |
btn.click(predict, inputs=[input], outputs=[html]) | |
input.submit(predict, inputs = [input], outputs=[html]) | |
demo.load() | |
demo.launch() | |