Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, AutoImageProcessor, BertTokenizerFast | |
import requests | |
from PIL import Image | |
urls = ['https://huggingface.co/spaces/sivan22/TrOCR-handwritten-hebrew/resolve/main/article_1_page_10line_1.png', 'https://huggingface.co/spaces/sivan22/TrOCR-handwritten-hebrew/resolve/main/article_1_page_10line_10.png', | |
'https://huggingface.co/spaces/sivan22/TrOCR-handwritten-hebrew/resolve/main/article_1_page_10line_11.png'] | |
for idx, url in enumerate(urls): | |
image = Image.open(requests.get(url, stream=True).raw) | |
image.save(f"image_{idx}.png") | |
from transformers import BertTokenizer, BasicTokenizer | |
from transformers.tokenization_utils import _is_punctuation | |
class OurBasicTokenizer(BasicTokenizer): | |
def _run_split_on_punc(self, text, never_split=None): | |
"""Splits punctuation on a piece of text.""" | |
if text in self.never_split or (never_split and text in never_split): | |
return [text] | |
chars = list(text) | |
i = 0 | |
start_new_word = True | |
output = [] | |
while i < len(chars): | |
char = chars[i] | |
if _is_punctuation(char) and char != "'" and not (char == '"' and i + 1 < len(chars) and not _is_punctuation(chars[i + 1])): | |
output.append([char]) | |
start_new_word = True | |
else: | |
if start_new_word: | |
output.append([]) | |
start_new_word = False | |
output[-1].append(char) | |
i += 1 | |
return ["".join(x) for x in output] | |
def RabbinicTokenizer(tok): | |
tok.basic_tokenizer = OurBasicTokenizer(tok.basic_tokenizer.do_lower_case, tok.basic_tokenizer.never_split) | |
return tok | |
image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") | |
tokenizer = RabbinicTokenizer(BertTokenizer.from_pretrained("sivan22/BEREL")) | |
model = VisionEncoderDecoderModel.from_pretrained("sivan22/ABBA-HTR") | |
def process_image(image): | |
# prepare image | |
pixel_values = image_processor(image, return_tensors="pt").pixel_values | |
# generate (no beam search) | |
generated_ids = model.generate(pixel_values) | |
# decode | |
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_text | |
title = "讛讚讙诪讛: 驻注谞讜讞 讻转讘 讬讚 讘讗诪爪注讜转 讘讬谞讛 诪诇讗讻讜转讬转" | |
description = "注诇 讘住讬住 讟讻谞讜诇讜讙讬讬转 trOCR" | |
article = "<p style='text-align: center'>sivan22</p>" | |
examples =[["article_1_page_10line_1.png"], ["article_1_page_10line_10.png"], ["article_1_page_10line_11.png"]] | |
#css = """.output_image, .input_image {height: 600px !important}""" | |
iface = gr.Interface(fn=process_image, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Textbox(), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples) | |
iface.launch(debug=True) |