gaviego commited on
Commit
c7fd94d
1 Parent(s): 4541bfd
Files changed (2) hide show
  1. app.py +96 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from PIL import Image, ImageDraw
4
+ from unittest.mock import patch
5
+ import gradio as gr
6
+ import ast
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
11
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
12
+ if not str(filename).endswith("/modeling_florence2.py"):
13
+ return get_imports(filename)
14
+ imports = get_imports(filename)
15
+ imports.remove("flash_attn")
16
+ return imports
17
+
18
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
19
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
20
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
21
+
22
+ def draw_boxes(image, quad_boxes):
23
+ draw = ImageDraw.Draw(image)
24
+ for box in quad_boxes:
25
+ draw.polygon(box, outline="red", width=2)
26
+ return image
27
+
28
+ def run_example(image, task, additional_text=""):
29
+ if image is None:
30
+ return "Please upload an image.", None
31
+
32
+ prompt = f"<{task}>"
33
+ if task == "CAPTION_TO_PHRASE_GROUNDING" and additional_text:
34
+ inputs = processor(text=prompt, images=image, return_tensors="pt", text_input=additional_text)
35
+ else:
36
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
37
+
38
+ generated_ids = model.generate(
39
+ input_ids=inputs["input_ids"],
40
+ pixel_values=inputs["pixel_values"],
41
+ max_new_tokens=1024,
42
+ num_beams=3,
43
+ )
44
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
45
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
46
+
47
+ result_text = str(parsed_answer)
48
+ result_image = image.copy()
49
+
50
+ if task == "OCR_WITH_REGION":
51
+ try:
52
+ result_dict = ast.literal_eval(result_text)
53
+ quad_boxes = result_dict['<OCR_WITH_REGION>']['quad_boxes']
54
+ result_image = draw_boxes(result_image, quad_boxes)
55
+ except:
56
+ print("Failed to draw bounding boxes.")
57
+
58
+ return result_text, result_image
59
+
60
+ def update_additional_text_visibility(task):
61
+ return gr.update(visible=(task == "CAPTION_TO_PHRASE_GROUNDING"))
62
+
63
+ # Define the Gradio interface
64
+ with gr.Blocks() as iface:
65
+ gr.Markdown("# Florence-2 Image Analysis")
66
+ with gr.Row():
67
+ image_input = gr.Image(type="pil", label="Upload an image")
68
+ with gr.Column():
69
+ task_dropdown = gr.Dropdown(
70
+ choices=[
71
+ "CAPTION", "DETAILED_CAPTION", "MORE_DETAILED_CAPTION",
72
+ "CAPTION_TO_PHRASE_GROUNDING", "OD", "DENSE_REGION_CAPTION",
73
+ "REGION_PROPOSAL", "OCR", "OCR_WITH_REGION"
74
+ ],
75
+ label="Select Task",
76
+ value="CAPTION"
77
+ )
78
+ additional_text = gr.Textbox(
79
+ label="Additional Text (for Caption to Phrase Grounding)",
80
+ placeholder="Enter caption here",
81
+ visible=False
82
+ )
83
+ submit_button = gr.Button("Analyze Image")
84
+ with gr.Row():
85
+ text_output = gr.Textbox(label="Result")
86
+ image_output = gr.Image(label="Processed Image")
87
+
88
+ task_dropdown.change(fn=update_additional_text_visibility, inputs=task_dropdown, outputs=additional_text)
89
+ submit_button.click(
90
+ fn=run_example,
91
+ inputs=[image_input, task_dropdown, additional_text],
92
+ outputs=[text_output, image_output]
93
+ )
94
+
95
+ # Launch the interface
96
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.40.0
2
+ Pillow==9.1.0
3
+ Pillow==8.0.0
4
+ Pillow==10.4.0
5
+ Requests==2.32.3
6
+ transformers==4.24.0