sitammeur commited on
Commit
0e71e76
1 Parent(s): 568a984

Upload initial file structure

Browse files
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing the requirements
2
+ import gradio as gr
3
+ from src.task import ocr_task
4
+
5
+
6
+ # Image input for the interface
7
+ image = gr.Image(type="pil", label="Image")
8
+
9
+ # Output for the interface (image and text)
10
+ ocr_text_output = gr.Textbox(label="OCR Text")
11
+ ocr_image_output = gr.Image(type="pil", label="Output Image")
12
+
13
+ # Examples for the interface (image paths)
14
+ examples = [
15
+ ["images/ocr_image_1jpg"],
16
+ ["images/ocr_image_2.jpg"],
17
+ ["images/ocr_image_3.jpg"],
18
+ ]
19
+
20
+ # Title, description, and article for the interface
21
+ title = "OCR Text Extraction and Visualization"
22
+ description = "Gradio Demo for the Florence-2-large Vision Language Model. This application performs Optical Character Recognition (OCR) on images and provides both extracted text and visualized bounding boxes around detected text regions. To use it, simply upload your image and click 'Submit'. The application will return the detected text and an image with bounding boxes drawn around the detected text regions. This is useful for various OCR-related tasks including document digitization, text extraction, and visual verification of detected text regions."
23
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.06242' target='_blank'>Florence-2: Advancing a Unified Representation for a Variety of Vision Tasks</a> | <a href='https://huggingface.co/microsoft/Florence-2-large-ft' target='_blank'>Model Page</a></p>"
24
+
25
+
26
+ # Launch the interface
27
+ interface = gr.Interface(
28
+ fn=ocr_task,
29
+ inputs=[image],
30
+ outputs=[ocr_image_output, ocr_text_output],
31
+ examples=examples,
32
+ title=title,
33
+ description=description,
34
+ article=article,
35
+ theme="soft",
36
+ allow_flagging="never",
37
+ )
38
+ interface.launch(debug=False)
images/ocr_image_1.png ADDED
images/ocr_image_2.png ADDED
images/ocr_image_3.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==2.0.0
2
+ Pillow==10.4.0
3
+ gradio==4.38.1
4
+ transformers==4.42.4
5
+ timm==1.0.7
src/__init__.py ADDED
File without changes
src/model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import spaces
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
+
5
+
6
+ # Load model and processor from Hugging Face
7
+ model_id = "microsoft/Florence-2-large-ft"
8
+ model = (
9
+ AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
10
+ )
11
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
12
+
13
+
14
+ @spaces.GPU(duration=120)
15
+ def run_example(task_prompt, image, text_input=None):
16
+ """
17
+ Runs an example using the given task prompt and image.
18
+
19
+ Args:
20
+ task_prompt (str): The task prompt for the example.
21
+ image (PIL.Image.Image): The image to be processed.
22
+ text_input (str, optional): Additional text input to be appended to the task prompt. Defaults to None.
23
+
24
+ Returns:
25
+ str: The parsed answer generated by the model.
26
+ """
27
+
28
+ # If there is no text input, use the task prompt as the prompt
29
+ if text_input is None:
30
+ prompt = task_prompt
31
+ else:
32
+ prompt = task_prompt + text_input
33
+
34
+ # Process the image and text input
35
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
36
+
37
+ # Generate the answer using the model
38
+ generated_ids = model.generate(
39
+ input_ids=inputs["input_ids"].cuda(),
40
+ pixel_values=inputs["pixel_values"].cuda(),
41
+ max_new_tokens=1024,
42
+ early_stopping=False,
43
+ do_sample=False,
44
+ num_beams=3,
45
+ )
46
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
47
+ parsed_answer = processor.post_process_generation(
48
+ generated_text, task=task_prompt, image_size=(image.width, image.height)
49
+ )
50
+
51
+ # Return the parsed answer
52
+ return parsed_answer
src/task.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from src.utils import run_example, clean_text, draw_ocr_bboxes
3
+
4
+
5
+ def ocr_task(image):
6
+ """
7
+ Perform OCR (Optical Character Recognition) on the given image.
8
+
9
+ Args:
10
+ image (PIL.Image.Image): The input image to perform OCR on.
11
+
12
+ Returns:
13
+ tuple: A tuple containing the output image with OCR bounding boxes drawn and the cleaned OCR text.
14
+
15
+ """
16
+
17
+ # Task prompts
18
+ ocr_prompt = "<OCR>"
19
+ ocr_with_region_prompt = "<OCR_WITH_REGION>"
20
+
21
+ # Get OCR text
22
+ ocr_results = run_example(ocr_prompt, image)
23
+ cleaned_text = clean_text(ocr_results["<OCR>"])
24
+
25
+ # Get OCR with region
26
+ ocr_with_region_results = run_example(ocr_with_region_prompt, image)
27
+ output_image = copy.deepcopy(image)
28
+ output_image = draw_ocr_bboxes(
29
+ output_image, ocr_with_region_results["<OCR_WITH_REGION>"]
30
+ )
31
+
32
+ # Return the output image and cleaned OCR text
33
+ return output_image, cleaned_text
src/utils.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import ImageDraw
2
+ import numpy as np
3
+ import re
4
+
5
+
6
+ # Use a color map for bounding boxes
7
+ colormap = [
8
+ "#0000FF",
9
+ "#FFA500",
10
+ "#008000",
11
+ "#800080",
12
+ "#A52A2A",
13
+ "#FFC0CB",
14
+ "#808080",
15
+ "#808000",
16
+ "#00FFFF",
17
+ "#FF0000",
18
+ "#00FF00",
19
+ "#4B0082",
20
+ "#4B0082",
21
+ "#EE82EE",
22
+ "#00FFFF",
23
+ "#FF00FF",
24
+ "#FF7F50",
25
+ "#FFD700",
26
+ "#87CEEB",
27
+ ]
28
+
29
+
30
+ # Text cleaning function
31
+ def clean_text(text):
32
+ """
33
+ Cleans the given text by removing unwanted tokens, extra spaces,
34
+ and ensures proper spacing between words and after periods.
35
+
36
+ Args:
37
+ text (str): The input text to be cleaned.
38
+
39
+ Returns:
40
+ str: The cleaned and properly formatted text.
41
+ """
42
+
43
+ # Remove unwanted tokens
44
+ text = text.replace("<pad>", "").replace("</s>", "").strip()
45
+
46
+ # Split the text into lines and clean each line
47
+ lines = text.split("\n")
48
+ cleaned_lines = [line.strip() for line in lines if line.strip()]
49
+
50
+ # Join the cleaned lines into a single string with a space between each line
51
+ cleaned_text = " ".join(cleaned_lines)
52
+
53
+ # Ensure proper spacing between words and after periods using regex
54
+ cleaned_text = re.sub(
55
+ r"\s+", " ", cleaned_text
56
+ ) # Replace multiple spaces with a single space
57
+ cleaned_text = re.sub(
58
+ r"(?<=[.])(?=[^\s])", r" ", cleaned_text
59
+ ) # Add space after a period if not followed by a space
60
+
61
+ # Return the cleaned text
62
+ return cleaned_text
63
+
64
+
65
+ # Convert hex color to RGBA with the given alpha
66
+ def hex_to_rgba(hex_color, alpha):
67
+ """
68
+ Convert a hexadecimal color code to RGBA format.
69
+
70
+ Args:
71
+ hex_color (str): The hexadecimal color code (e.g., "#FF0000").
72
+ alpha (int): The alpha value for the RGBA color (0-255).
73
+
74
+ Returns:
75
+ tuple: A tuple representing the RGBA color values (red, green, blue, alpha).
76
+ """
77
+ hex_color = hex_color.lstrip("#")
78
+ r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
79
+ return (r, g, b, alpha)
80
+
81
+
82
+ # Draw OCR bounding boxes with enhanced visual elements
83
+ def draw_ocr_bboxes(image, prediction):
84
+ """
85
+ Draw bounding boxes with enhanced visual elements on the given image based on the OCR prediction.
86
+
87
+ Args:
88
+ image (PIL.Image.Image): The input image on which the bounding boxes will be drawn.
89
+ prediction (dict): The OCR prediction containing 'quad_boxes' and 'labels'.
90
+
91
+ Returns:
92
+ PIL.Image.Image: The image with the bounding boxes drawn.
93
+ """
94
+
95
+ # Create a drawing object for the image with RGBA mode
96
+ draw = ImageDraw.Draw(image, "RGBA")
97
+
98
+ # Extract bounding boxes and labels from the prediction
99
+ bboxes, labels = prediction["quad_boxes"], prediction["labels"]
100
+
101
+ for i, (box, label) in enumerate(zip(bboxes, labels)):
102
+ # Select color for the bounding box and label
103
+ color = colormap[i % len(colormap)]
104
+ new_box = (np.array(box)).tolist()
105
+
106
+ # Define the outline width and corner radius for the bounding box
107
+ box_outline_width = 3
108
+ corner_radius = 10
109
+
110
+ # Draw rounded corners for the bounding box
111
+ for j in range(4):
112
+ start_x, start_y = new_box[j * 2], new_box[j * 2 + 1]
113
+ end_x, end_y = new_box[(j * 2 + 2) % 8], new_box[(j * 2 + 3) % 8]
114
+
115
+ # Draw the arcs for the rounded corners
116
+ draw.arc(
117
+ [
118
+ (start_x - corner_radius, start_y - corner_radius),
119
+ (start_x + corner_radius, start_y + corner_radius),
120
+ ],
121
+ 90 + j * 90,
122
+ 180 + j * 90,
123
+ fill=color,
124
+ width=box_outline_width,
125
+ )
126
+ draw.arc(
127
+ [
128
+ (end_x - corner_radius, end_y - corner_radius),
129
+ (end_x + corner_radius, end_y + corner_radius),
130
+ ],
131
+ j * 90,
132
+ 90 + j * 90,
133
+ fill=color,
134
+ width=box_outline_width,
135
+ )
136
+
137
+ # Draw the lines connecting the arcs
138
+ if j in [0, 1, 2]:
139
+ draw.line(
140
+ [
141
+ (start_x + corner_radius if j != 1 else start_x, start_y),
142
+ (end_x - corner_radius if j != 1 else end_x, end_y),
143
+ ],
144
+ fill=color,
145
+ width=box_outline_width,
146
+ )
147
+ else:
148
+ draw.line(
149
+ [
150
+ (start_x, start_y + corner_radius),
151
+ (end_x, end_y - corner_radius),
152
+ ],
153
+ fill=color,
154
+ width=box_outline_width,
155
+ )
156
+
157
+ # Calculate the position for the text label
158
+ text_x, text_y = min(new_box[0::2]), min(new_box[1::2]) - 20
159
+ text_w, text_h = draw.textsize(label)
160
+ rgba_color = hex_to_rgba(color, 200) # Semi-transparent background for text
161
+
162
+ # Draw the background rectangle for the text
163
+ draw.rectangle(
164
+ [text_x, text_y, text_x + text_w + 10, text_y + text_h + 10],
165
+ fill=rgba_color,
166
+ )
167
+
168
+ # Draw the text label
169
+ draw.text((text_x + 5, text_y + 5), label, fill=(0, 0, 0, 255))
170
+
171
+ # Return the image with the OCR boxes drawn
172
+ return image