gaviego commited on
Commit
3569078
1 Parent(s): 57005fe
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -5,14 +5,31 @@ from unittest.mock import patch
5
  import gradio as gr
6
  import ast
7
  from transformers import AutoModelForCausalLM, AutoProcessor
 
8
 
9
- model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
10
- processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
 
 
 
 
 
11
 
12
- def draw_boxes(image, quad_boxes):
 
 
 
 
13
  draw = ImageDraw.Draw(image)
14
- for box in quad_boxes:
15
- draw.polygon(box, outline="red", width=2)
 
 
 
 
 
 
 
16
  return image
17
 
18
  def run_example(image, task, additional_text=""):
@@ -37,13 +54,16 @@ def run_example(image, task, additional_text=""):
37
  result_text = str(parsed_answer)
38
  result_image = image.copy()
39
 
40
- if task == "OCR_WITH_REGION":
41
- try:
42
- result_dict = ast.literal_eval(result_text)
43
- quad_boxes = result_dict['<OCR_WITH_REGION>']['quad_boxes']
44
- result_image = draw_boxes(result_image, quad_boxes)
45
- except:
46
- print("Failed to draw bounding boxes.")
 
 
 
47
 
48
  return result_text, result_image
49
 
 
5
  import gradio as gr
6
  import ast
7
  from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
 
10
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
11
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
12
+ if not str(filename).endswith("/modeling_florence2.py"):
13
+ return get_imports(filename)
14
+ imports = get_imports(filename)
15
+ imports.remove("flash_attn")
16
+ return imports
17
 
18
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
19
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
20
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
21
+
22
+ def draw_boxes(image, boxes, box_type='bbox', labels=None):
23
  draw = ImageDraw.Draw(image)
24
+ for i, box in enumerate(boxes):
25
+ if box_type == 'quad':
26
+ draw.polygon(box, outline="red", width=2)
27
+ elif box_type == 'bbox':
28
+ draw.rectangle(box, outline="red", width=2)
29
+
30
+ if labels and i < len(labels):
31
+ draw.text((box[0], box[1] - 10), labels[i], fill="red")
32
+
33
  return image
34
 
35
  def run_example(image, task, additional_text=""):
 
54
  result_text = str(parsed_answer)
55
  result_image = image.copy()
56
 
57
+ try:
58
+ result_dict = ast.literal_eval(result_text)
59
+ task_key = f"<{task}>"
60
+ if task_key in result_dict:
61
+ if 'quad_boxes' in result_dict[task_key]:
62
+ result_image = draw_boxes(result_image, result_dict[task_key]['quad_boxes'], 'quad')
63
+ elif 'bboxes' in result_dict[task_key]:
64
+ result_image = draw_boxes(result_image, result_dict[task_key]['bboxes'], 'bbox', result_dict[task_key].get('labels'))
65
+ except:
66
+ print(f"Failed to draw bounding boxes for task: {task}")
67
 
68
  return result_text, result_image
69