TrOCR_EN_ICR / app.py
imflash217's picture
Create new file
d3034c6
raw
history blame
2.95 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
import cv2
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.filterwarnings("ignore")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
def extract_text(image):
# calling the processor is equivalent to calling the feature extractor
pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def hand_written(image_raw):
image_raw = np.array(image_raw)
image = cv2.cvtColor(image_raw,cv2.COLOR_BGR2GRAY)
image = cv2.GaussianBlur(image,(5,5),0)
image = cv2.threshold(image,200,255,cv2.THRESH_BINARY_INV)[1]
kernal = cv2.getStructuringElement(cv2.MORPH_RECT,(10,1))
image = cv2.dilate(image,kernal,iterations=5)
contours,hier = cv2.findContours(image,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
all_box = []
for i in contours:
bbox = cv2.boundingRect(i)
all_box.append(bbox)
# Calculate maximum rectangle height
c = np.array(all_box)
max_height = np.max(c[::, 3])
# Sort the contours by y-value
by_y = sorted(all_box, key=lambda x: x[1]) # y values
line_y = by_y[0][1] # first y
line = 1
by_line = []
# Assign a line number to each contour
for x, y, w, h in by_y:
if y > line_y + max_height:
line_y = y
line += 1
by_line.append((line, x, y, w, h))
# This will now sort automatically by line then by x
contours_sorted = [(x, y, w, h) for line, x, y, w, h in sorted(by_line)]
text = ""
for line in contours_sorted:
x,y,w,h = line
cropped_image = image_raw[y:y+h,x:x+w]
try:
extracted = extract_text(cropped_image)
if not extracted == "0 0" and not extracted == "0 1":
text = "\n".join([text,extracted])
except:
print("skiping")
pass
return text
# load image examples from the IAM database
title = "TrOCR + EN_ICR demo"
description = "TrOCR Handwritten Recognizer"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models</a> | <a href='https://github.com/microsoft/unilm/tree/master/trocr'>Github Repo</a></p>"
examples =[["image_0.png"]]
iface = gr.Interface(fn=hand_written,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Textbox(),
title=title,
description=description,
article=article,
examples=examples)
iface.launch(debug=True,share=True)