zjr commited on
Commit
2461d7d
1 Parent(s): 388219c
Files changed (2) hide show
  1. app.py +5 -3
  2. image_editing_utils.py +54 -8
app.py CHANGED
@@ -59,7 +59,7 @@ ckpt_url_map = {
59
  'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
60
  'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
61
  }
62
- os.makedirs('result', exist_ok=True)
63
  args = parse_augment()
64
 
65
  checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
@@ -202,6 +202,8 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
202
  # chat_input = click_coordinate
203
  prompt = get_prompt(coordinate, click_state, click_mode)
204
  print('prompt: ', prompt, 'controls: ', controls)
 
 
205
 
206
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
207
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
@@ -218,7 +220,7 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
218
  input_mask = np.array(out['mask'].convert('P'))
219
  image_input = mask_painter(np.array(image_input), input_mask)
220
  origin_image_input = image_input
221
- image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
222
 
223
  yield state, state, click_state, chat_input, image_input, wiki
224
  if not args.disable_gpt and model.text_refiner:
@@ -227,7 +229,7 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
227
  new_cap = refined_caption['caption']
228
  wiki = refined_caption['wiki']
229
  state = state + [(None, f"caption: {new_cap}")]
230
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
231
  yield state, state, click_state, chat_input, refined_image_input, wiki
232
 
233
 
 
59
  'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
60
  'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
61
  }
62
+
63
  args = parse_augment()
64
 
65
  checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
 
202
  # chat_input = click_coordinate
203
  prompt = get_prompt(coordinate, click_state, click_mode)
204
  print('prompt: ', prompt, 'controls: ', controls)
205
+ input_points = prompt['input_point']
206
+ input_labels = prompt['input_label']
207
 
208
  enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
209
  out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
 
220
  input_mask = np.array(out['mask'].convert('P'))
221
  image_input = mask_painter(np.array(image_input), input_mask)
222
  origin_image_input = image_input
223
+ image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]), input_mask, input_points=input_points, input_labels=input_labels)
224
 
225
  yield state, state, click_state, chat_input, image_input, wiki
226
  if not args.disable_gpt and model.text_refiner:
 
229
  new_cap = refined_caption['caption']
230
  wiki = refined_caption['wiki']
231
  state = state + [(None, f"caption: {new_cap}")]
232
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]), input_mask, input_points=input_points, input_labels=input_labels)
233
  yield state, state, click_state, chat_input, refined_image_input, wiki
234
 
235
 
image_editing_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  from PIL import Image, ImageDraw, ImageFont
2
  import copy
3
- import numpy as np
 
4
 
5
  def wrap_text(text, font, max_width):
6
  lines = []
@@ -17,7 +18,7 @@ def wrap_text(text, font, max_width):
17
  lines.append(current_line)
18
  return lines
19
 
20
- def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.025):
21
  # Load the image
22
  if type(image) == np.ndarray:
23
  image = Image.fromarray(image)
@@ -29,6 +30,7 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
29
  total_chars = len(text)
30
  max_text_width = int(0.4 * width)
31
  font_size = int(height * font_size_ratio)
 
32
 
33
  # Load the font
34
  font = ImageFont.truetype(font_path, font_size)
@@ -45,20 +47,33 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
45
  bubble_height = text_height + 2 * padding
46
 
47
  # Create a new image for the bubble frame
48
- bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 255, 255, 0))
49
 
50
  # Draw the bubble frame on the new image
51
  draw = ImageDraw.Draw(bubble)
52
  # draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
53
-
 
54
  # Draw the wrapped text line by line
55
  y_text = padding
56
  for line in lines:
57
- draw.text((padding, y_text), line, font=font, fill=(255, 255, 255, 255))
58
  y_text += font.getsize(line)[1]
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Calculate the bubble frame position
61
- x, y = point
62
  if x + bubble_width > width:
63
  x = width - bubble_width
64
  if y + bubble_height > height:
@@ -66,4 +81,35 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
66
 
67
  # Paste the bubble frame onto the image
68
  image.paste(bubble, (x, y), bubble)
69
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from PIL import Image, ImageDraw, ImageFont
2
  import copy
3
+ import numpy as np
4
+ import cv2
5
 
6
  def wrap_text(text, font, max_width):
7
  lines = []
 
18
  lines.append(current_line)
19
  return lines
20
 
21
+ def create_bubble_frame(image, text, point, segmask, input_points, input_labels, font_path='times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
22
  # Load the image
23
  if type(image) == np.ndarray:
24
  image = Image.fromarray(image)
 
30
  total_chars = len(text)
31
  max_text_width = int(0.4 * width)
32
  font_size = int(height * font_size_ratio)
33
+ point_size = max(int(height * point_size_ratio), 1)
34
 
35
  # Load the font
36
  font = ImageFont.truetype(font_path, font_size)
 
47
  bubble_height = text_height + 2 * padding
48
 
49
  # Create a new image for the bubble frame
50
+ bubble = Image.new('RGBA', (bubble_width, bubble_height), (255,248, 220, 0))
51
 
52
  # Draw the bubble frame on the new image
53
  draw = ImageDraw.Draw(bubble)
54
  # draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
55
+ draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
56
+ fill=(255,248, 220, 120), outline=None, width=2)
57
  # Draw the wrapped text line by line
58
  y_text = padding
59
  for line in lines:
60
+ draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
61
  y_text += font.getsize(line)[1]
62
+
63
+ # Determine the point by the min area rect of mask
64
+ try:
65
+ ret, thresh = cv2.threshold(segmask, 127, 255, 0)
66
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
67
+ largest_contour = max(contours, key=cv2.contourArea)
68
+ min_area_rect = cv2.minAreaRect(largest_contour)
69
+ box = cv2.boxPoints(min_area_rect)
70
+ sorted_points = box[np.argsort(box[:, 0])]
71
+ right_most_points = sorted_points[-2:]
72
+ right_down_most_point = right_most_points[np.argsort(right_most_points[:, 1])][-1]
73
+ x, y = int(right_down_most_point[0]), int(right_down_most_point[1])
74
+ except:
75
+ x, y = point
76
  # Calculate the bubble frame position
 
77
  if x + bubble_width > width:
78
  x = width - bubble_width
79
  if y + bubble_height > height:
 
81
 
82
  # Paste the bubble frame onto the image
83
  image.paste(bubble, (x, y), bubble)
84
+ draw = ImageDraw.Draw(image)
85
+ colors = [(0, 191, 255, 255), (255, 106, 106, 255)]
86
+ for p, label in zip(input_points, input_labels):
87
+ point_x, point_y = p[0], p[1]
88
+ left = point_x - point_size
89
+ top = point_y - point_size
90
+ right = point_x + point_size
91
+ bottom = point_y + point_size
92
+ draw.ellipse((left, top, right, bottom), fill=colors[label])
93
+ return image
94
+
95
+
96
+ def draw_rounded_rectangle(draw, xy, corner_radius, fill=None, outline=None, width=1):
97
+ x1, y1, x2, y2 = xy
98
+
99
+ draw.rectangle(
100
+ (x1, y1 + corner_radius, x2, y2 - corner_radius),
101
+ fill=fill,
102
+ outline=outline,
103
+ width=width
104
+ )
105
+ draw.rectangle(
106
+ (x1 + corner_radius, y1, x2 - corner_radius, y2),
107
+ fill=fill,
108
+ outline=outline,
109
+ width=width
110
+ )
111
+
112
+ draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline, width=width)
113
+ draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline, width=width)
114
+ draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline, width=width)
115
+ draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline, width=width)