yizhangliu commited on
Commit
607dee4
β€’
1 Parent(s): c861f8d

update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -672,7 +672,7 @@ def load_kolors_inpainting(inpaint_prompt, input_image, mask_image):
672
  return None
673
 
674
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
675
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
676
  text_prompt = getTextTrans(text_prompt, source='zh', target='en')
677
  inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
678
 
@@ -716,7 +716,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
716
  return [], gr.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
717
 
718
  file_temp = int(time.time())
719
- logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
720
 
721
  output_images = []
722
 
@@ -739,6 +739,9 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
739
  size = image_pil.size
740
  H, W = size[1], size[0]
741
 
 
 
 
742
  # run grounding dino model
743
  if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw:
744
  pass
@@ -772,7 +775,11 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
772
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
773
  if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
774
  image = np.array(input_img)
775
- if sam_predictor:
 
 
 
 
776
  sam_predictor.set_image(image)
777
 
778
  for i in range(boxes_filt.size(0)):
@@ -780,7 +787,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
780
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
781
  boxes_filt[i][2:] += boxes_filt[i][:2]
782
 
783
- if sam_predictor:
784
  boxes_filt = boxes_filt.to(sam_device)
785
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
786
 
@@ -1019,7 +1026,8 @@ def main_gradio(args):
1019
  remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
1020
  with gr.Column(scale=1):
1021
  remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
1022
-
 
1023
  with gr.Column():
1024
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
1025
  ) #.style(preview=True, columns=[5], object_fit="scale-down", height="auto")
@@ -1055,7 +1063,7 @@ def main_gradio(args):
1055
 
1056
  run_button.click(fn=run_anything_task, inputs=[
1057
  input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
1058
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
1059
  outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
1060
 
1061
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
 
672
  return None
673
 
674
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
675
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, remove_use_segment, num_relation, kosmos_input, cleaner_size_limit=1080):
676
  text_prompt = getTextTrans(text_prompt, source='zh', target='en')
677
  inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
678
 
 
716
  return [], gr.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
717
 
718
  file_temp = int(time.time())
719
+ logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}/{remove_use_segment}_[{text_prompt}]/[{inpaint_prompt}]___1_')
720
 
721
  output_images = []
722
 
 
739
  size = image_pil.size
740
  H, W = size[1], size[0]
741
 
742
+ if remove_use_segment and task_type == 'remove':
743
+ remove_mode = 'rectangle'
744
+
745
  # run grounding dino model
746
  if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw:
747
  pass
 
775
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
776
  if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
777
  image = np.array(input_img)
778
+ use_sam_predictor = True
779
+ if remove_use_segment and task_type == 'remove':
780
+ use_sam_predictor = False
781
+
782
+ if sam_predictor and use_sam_predictor:
783
  sam_predictor.set_image(image)
784
 
785
  for i in range(boxes_filt.size(0)):
 
787
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
788
  boxes_filt[i][2:] += boxes_filt[i][:2]
789
 
790
+ if sam_predictor and use_sam_predictor:
791
  boxes_filt = boxes_filt.to(sam_device)
792
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
793
 
 
1026
  remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
1027
  with gr.Column(scale=1):
1028
  remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
1029
+ with gr.Column(scale=1):
1030
+ remove_use_segment = gr.Checkbox(value=True, elem_id='remove_use_segment', label="use segment for removing?", info="")
1031
  with gr.Column():
1032
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
1033
  ) #.style(preview=True, columns=[5], object_fit="scale-down", height="auto")
 
1063
 
1064
  run_button.click(fn=run_anything_task, inputs=[
1065
  input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
1066
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, remove_use_segment, num_relation, kosmos_input],
1067
  outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
1068
 
1069
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],