Spaces:
Running
on
T4
Running
on
T4
liuyizhang
commited on
Commit
•
e12d135
1
Parent(s):
5ee6e09
update app.py
Browse files- api_client.py +27 -11
- app.py +87 -63
api_client.py
CHANGED
@@ -52,18 +52,34 @@ def base64_to_PILImage(im_b64):
|
|
52 |
pil_img = Image.open(io.BytesIO(im_bytes))
|
53 |
return pil_img
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
image_file = 'dog.png'
|
56 |
-
|
57 |
-
'extend': 20,
|
58 |
-
'img': imgFile_to_base64(image_file),
|
59 |
-
}
|
60 |
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
plt.clf()
|
69 |
|
|
|
52 |
pil_img = Image.open(io.BytesIO(im_bytes))
|
53 |
return pil_img
|
54 |
|
55 |
+
def cleaner_img(image_file, remove_texts, mask_extend=20, disp_debug=True):
|
56 |
+
data = {'remove_texts': remove_texts,
|
57 |
+
'mask_extend': mask_extend,
|
58 |
+
'img': imgFile_to_base64(image_file),
|
59 |
+
}
|
60 |
+
ret = request_post(url, data, timeout=600, headers = None)
|
61 |
+
if ret['code'] == 0:
|
62 |
+
if disp_debug:
|
63 |
+
for img in ret['result']['imgs']:
|
64 |
+
pilImage = base64_to_PILImage(img)
|
65 |
+
plt.imshow(pilImage)
|
66 |
+
plt.show()
|
67 |
+
plt.clf()
|
68 |
+
plt.close('all')
|
69 |
+
img_len = len(ret['result']['imgs'])
|
70 |
+
pilImage = base64_to_PILImage(ret['result']['imgs'][img_len-1])
|
71 |
+
else:
|
72 |
+
pilImage = None
|
73 |
+
return pilImage, ret
|
74 |
+
|
75 |
image_file = 'dog.png'
|
76 |
+
remove_texts = "小狗 . 椅子"
|
|
|
|
|
|
|
77 |
|
78 |
+
mask_extend = 20
|
79 |
+
pil_image, ret = cleaner_img(image_file, remove_texts, mask_extend, disp_debug=False)
|
80 |
|
81 |
+
plt.imshow(pil_image)
|
82 |
+
plt.show()
|
83 |
+
plt.clf()
|
84 |
+
plt.close()
|
|
|
85 |
|
app.py
CHANGED
@@ -3,7 +3,17 @@ import warnings
|
|
3 |
warnings.filterwarnings('ignore')
|
4 |
|
5 |
import subprocess, io, os, sys, time
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
|
9 |
from loguru import logger
|
@@ -35,7 +45,10 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
|
|
35 |
|
36 |
import cv2
|
37 |
import numpy as np
|
38 |
-
import matplotlib
|
|
|
|
|
|
|
39 |
|
40 |
groundingdino_enable = True
|
41 |
sam_enable = True
|
@@ -332,60 +345,63 @@ def load_lama_cleaner_model(device):
|
|
332 |
)
|
333 |
|
334 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
size_limit
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
config.sd_seed
|
|
|
376 |
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
389 |
return image
|
390 |
|
391 |
class Ram_Predictor(RamPredictor):
|
@@ -691,6 +707,8 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
691 |
plt.axis('off')
|
692 |
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
693 |
plt.savefig(image_path, bbox_inches="tight")
|
|
|
|
|
694 |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
695 |
os.remove(image_path)
|
696 |
output_images.append(Image.fromarray(segment_image_result))
|
@@ -757,6 +775,10 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
757 |
|
758 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
759 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
|
|
|
|
|
|
|
|
760 |
# output_images.append(image_inpainting)
|
761 |
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
762 |
|
@@ -975,7 +997,10 @@ class API_Starter:
|
|
975 |
request_data = request.data.decode('utf-8')
|
976 |
data = json.loads(request_data)
|
977 |
result = self.handle_data(data)
|
978 |
-
|
|
|
|
|
|
|
979 |
return jsonify(ret_json)
|
980 |
|
981 |
self.app = app
|
@@ -996,15 +1021,18 @@ class API_Starter:
|
|
996 |
inpaint_mode = "merge",
|
997 |
mask_source_radio = "type what to detect below",
|
998 |
remove_mode = "rectangle", # ["segment", "rectangle"]
|
999 |
-
remove_mask_extend = "
|
1000 |
num_relation = 5,
|
1001 |
kosmos_input = None,
|
1002 |
cleaner_size_limit = -1,
|
1003 |
)
|
1004 |
output_images = results[0]
|
|
|
|
|
1005 |
ret_json_images = []
|
1006 |
file_temp = int(time.time())
|
1007 |
count = 0
|
|
|
1008 |
for image_pil in output_images:
|
1009 |
try:
|
1010 |
img_format = image_pil.format.lower()
|
@@ -1086,16 +1114,12 @@ if __name__ == "__main__":
|
|
1086 |
# print(f'ram_model__{get_model_device(ram_model)}')
|
1087 |
# print(f'kosmos_model__{get_model_device(kosmos_model)}')
|
1088 |
|
1089 |
-
if
|
1090 |
# Provide gradio services
|
1091 |
main_gradio(args)
|
1092 |
else:
|
1093 |
-
|
1094 |
-
|
1095 |
-
main_api(args)
|
1096 |
-
else:
|
1097 |
-
# Provide gradio services
|
1098 |
-
main_gradio(args)
|
1099 |
|
1100 |
|
1101 |
|
|
|
3 |
warnings.filterwarnings('ignore')
|
4 |
|
5 |
import subprocess, io, os, sys, time
|
6 |
+
|
7 |
+
run_gradio = False
|
8 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
9 |
+
run_gradio = True
|
10 |
+
else:
|
11 |
+
run_gradio = False
|
12 |
+
# run_gradio = True
|
13 |
+
|
14 |
+
if run_gradio:
|
15 |
+
os.system("pip install gradio==3.40.1")
|
16 |
+
|
17 |
import gradio as gr
|
18 |
|
19 |
from loguru import logger
|
|
|
45 |
|
46 |
import cv2
|
47 |
import numpy as np
|
48 |
+
import matplotlib
|
49 |
+
matplotlib.use('AGG')
|
50 |
+
plt = matplotlib.pyplot
|
51 |
+
# import matplotlib.pyplot as plt
|
52 |
|
53 |
groundingdino_enable = True
|
54 |
sam_enable = True
|
|
|
345 |
)
|
346 |
|
347 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
348 |
+
try:
|
349 |
+
ori_image = image
|
350 |
+
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
351 |
+
# rotate image
|
352 |
+
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
353 |
+
image = ori_image
|
354 |
+
|
355 |
+
original_shape = ori_image.shape
|
356 |
+
interpolation = cv2.INTER_CUBIC
|
357 |
+
|
358 |
+
size_limit = cleaner_size_limit
|
359 |
+
if size_limit == -1:
|
360 |
+
size_limit = max(image.shape)
|
361 |
+
else:
|
362 |
+
size_limit = int(size_limit)
|
363 |
+
|
364 |
+
config = lama_Config(
|
365 |
+
ldm_steps=25,
|
366 |
+
ldm_sampler='plms',
|
367 |
+
zits_wireframe=True,
|
368 |
+
hd_strategy='Original',
|
369 |
+
hd_strategy_crop_margin=196,
|
370 |
+
hd_strategy_crop_trigger_size=1280,
|
371 |
+
hd_strategy_resize_limit=2048,
|
372 |
+
prompt='',
|
373 |
+
use_croper=False,
|
374 |
+
croper_x=0,
|
375 |
+
croper_y=0,
|
376 |
+
croper_height=512,
|
377 |
+
croper_width=512,
|
378 |
+
sd_mask_blur=5,
|
379 |
+
sd_strength=0.75,
|
380 |
+
sd_steps=50,
|
381 |
+
sd_guidance_scale=7.5,
|
382 |
+
sd_sampler='ddim',
|
383 |
+
sd_seed=42,
|
384 |
+
cv2_flag='INPAINT_NS',
|
385 |
+
cv2_radius=5,
|
386 |
+
)
|
387 |
+
|
388 |
+
if config.sd_seed == -1:
|
389 |
+
config.sd_seed = random.randint(1, 999999999)
|
390 |
|
391 |
+
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
392 |
+
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
393 |
+
# logger.info(f"Resized image shape_1_: {image.shape}")
|
394 |
+
|
395 |
+
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
396 |
+
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
397 |
+
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
398 |
|
399 |
+
res_np_img = lama_cleaner_model(image, mask, config)
|
400 |
+
torch.cuda.empty_cache()
|
401 |
+
|
402 |
+
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
403 |
+
except Exception as e:
|
404 |
+
image = None
|
405 |
return image
|
406 |
|
407 |
class Ram_Predictor(RamPredictor):
|
|
|
707 |
plt.axis('off')
|
708 |
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
709 |
plt.savefig(image_path, bbox_inches="tight")
|
710 |
+
plt.clf()
|
711 |
+
plt.close('all')
|
712 |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
713 |
os.remove(image_path)
|
714 |
output_images.append(Image.fromarray(segment_image_result))
|
|
|
775 |
|
776 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
777 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
778 |
+
if image_inpainting is None:
|
779 |
+
logger.info(f'run_anything_task_failed_')
|
780 |
+
return None, None, None, None, None, None, None
|
781 |
+
|
782 |
# output_images.append(image_inpainting)
|
783 |
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
784 |
|
|
|
997 |
request_data = request.data.decode('utf-8')
|
998 |
data = json.loads(request_data)
|
999 |
result = self.handle_data(data)
|
1000 |
+
if result is None:
|
1001 |
+
ret_json = {'code': -2, 'reason':'handle error'}
|
1002 |
+
else:
|
1003 |
+
ret_json = {'code': 0, 'result':result}
|
1004 |
return jsonify(ret_json)
|
1005 |
|
1006 |
self.app = app
|
|
|
1021 |
inpaint_mode = "merge",
|
1022 |
mask_source_radio = "type what to detect below",
|
1023 |
remove_mode = "rectangle", # ["segment", "rectangle"]
|
1024 |
+
remove_mask_extend = f"{data['mask_extend']}",
|
1025 |
num_relation = 5,
|
1026 |
kosmos_input = None,
|
1027 |
cleaner_size_limit = -1,
|
1028 |
)
|
1029 |
output_images = results[0]
|
1030 |
+
if output_images is None:
|
1031 |
+
return None
|
1032 |
ret_json_images = []
|
1033 |
file_temp = int(time.time())
|
1034 |
count = 0
|
1035 |
+
output_images = output_images[-1:]
|
1036 |
for image_pil in output_images:
|
1037 |
try:
|
1038 |
img_format = image_pil.format.lower()
|
|
|
1114 |
# print(f'ram_model__{get_model_device(ram_model)}')
|
1115 |
# print(f'kosmos_model__{get_model_device(kosmos_model)}')
|
1116 |
|
1117 |
+
if run_gradio:
|
1118 |
# Provide gradio services
|
1119 |
main_gradio(args)
|
1120 |
else:
|
1121 |
+
# Provide API services
|
1122 |
+
main_api(args)
|
|
|
|
|
|
|
|
|
1123 |
|
1124 |
|
1125 |
|