File size: 3,927 Bytes
4eb7414
 
 
 
 
 
 
 
 
c997d19
4eb7414
c997d19
 
 
 
 
4eb7414
 
 
 
 
 
 
ab5180d
 
 
1525ca6
4eb7414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7062e70
ab5180d
 
 
 
 
4eb7414
 
 
 
 
 
 
 
 
 
 
ab5180d
 
 
 
4eb7414
 
 
 
 
 
a130a5d
4eb7414
 
 
3b17ea1
 
4eb7414
 
 
 
ab5180d
 
a5eec45
ab5180d
a5eec45
99770ac
 
 
ab5180d
 
 
 
4eb7414
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import numpy as np
import torch
from PIL import Image
from gradio_image_prompter import ImagePrompter
from transformers import AutoProcessor, UdopForConditionalGeneration
import easyocr
from PIL import Image
import spaces
from typing import Optional, List, TypedDict, Union, Literal


class PromptValue(TypedDict):
    image: Optional[Union[Image.Image, str]]
    points: Optional[List[List[float]]]
    
processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@spaces.GPU
def udop_box_inference(image, text_prompt, box_coordinates):
    if box_coordinates != []:
        box_coordinates = [box_coordinates[0], box_coordinates[1], box_coordinates[3], box_coordinates[4]]
  
    extracted_image = extract_box(image, box_coordinates)
    extracted_image.save("cropped_image.png")

    reader = easyocr.Reader(['en'])
    result = reader.readtext('cropped_image.png')
    texts = []
    bboxs = []
    for (bbox, text, prob) in result:
      texts.append(text)
      bboxs.append([bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]])

    height = image.size[1]
    width = image.size[0]
    image = image.convert("RGB")
    norm_boxes = []
    for box in bboxs:
      norm_boxes.append(normalize_bbox(box, width, height))

    encoding = processor(image, text_prompt, texts, boxes=norm_boxes, return_tensors="pt")
    predicted_ids = model.generate(**encoding)
    return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]


def normalize_bbox(bbox, width, height):
    return [
        int(1000 * (bbox[0] / width)),
        int(1000 * (bbox[1] / height)),
        int(1000 * (bbox[2] / width)),
        int(1000 * (bbox[3] / height)),
    ]


def extract_box(image, coordinates):
  if type(image) == str:
    image = Image.open(image)
  if coordinates==[]:
    return image
  else:
    x, y, x2, y2 = coordinates
    cropped_image = image.crop((x, y, x2, y2))
    return cropped_image



def infer_box(prompts, text_prompts):
    # background (original image) layers[0] ( point prompt) composite (total image)
    image = prompts["image"]
    if image is None:
      gr.Error("Please upload an image and draw a box before submitting")
    try:
      points = prompts["points"][0]
    except:
      points = []
    return udop_box_inference(image, text_prompts, points)


with gr.Blocks(title="UDOP") as demo:
  gr.Markdown("# UDOP")
  gr.Markdown("UDOP is a cutting-edge foundation model for a document understanding and generation.")
  gr.Markdown("Try UDOP in this demo. Simply upload a document, draw a box on part of the image you'd like UDOP to work and enter a prompt. If you don't draw a box, the model will take into account the whole image. You can try one of the examples to see how it works.")

  with gr.Row():
      with gr.Column():
          im = ImagePrompter(type="pil", label="Input Document")
          text_prompt = gr.Textbox(label = "Text Prompt with Task Prefix")
          btn = gr.Button("Submit")
      with gr.Column():
        output = gr.Textbox(label="UDOP Output")

  with gr.Row():   
    gr.Examples(
          examples = [[PromptValue(image = "./dummy_pdf.png", 
                            points = [[87.0, 908.0, 2.0, 456.0, 972.0, 3.0]]), "Question answering. What is the objective?"],
                      [PromptValue(image = "./docvqa_example (3).png", 
                            points = [[]]), "Question answering. How much is the total?"],
                     [PromptValue(image = "./docvqa_example (3).png", 
                            points = [[]]), "Document Classification."]],
          inputs=[im, text_prompt],
          outputs=output,
          fn=infer_box,
          )
  btn.click(infer_box, inputs=[im,text_prompt], outputs=[output])

demo.launch(debug=True)