Spaces:
Runtime error
Runtime error
File size: 6,223 Bytes
10240e0 13c1c2e 10240e0 13c1c2e 10240e0 13c1c2e 5c74464 3421695 10240e0 5c74464 13c1c2e 5c74464 eabdb1c 10240e0 13c1c2e 10240e0 13c1c2e 10240e0 5c74464 10240e0 5c74464 10240e0 5c74464 eabdb1c 10240e0 13c1c2e 10240e0 5c74464 10240e0 5c74464 10240e0 5c74464 10240e0 13c1c2e 10240e0 5c74464 10240e0 13c1c2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from captioner import build_captioner, BaseCaptioner
from segmenter import build_segmenter
from text_refiner import build_text_refiner
import os
import argparse
import pdb
import time
from PIL import Image
import cv2
import numpy as np
class CaptionAnything():
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
self.args = args
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
self.text_refiner = None
if not args.disable_gpt:
if text_refiner is not None:
self.text_refiner = text_refiner
else:
self.init_refiner(api_key)
def init_refiner(self, api_key):
try:
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
self.text_refiner.llm('hi') # test
except:
self.text_refiner = None
print('OpenAI GPT is not available')
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
# segment with prompt
print("CA prompt: ", prompt, "CA controls",controls)
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
if self.args.enable_morphologyex:
seg_mask = 255 * seg_mask.astype(np.uint8)
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
seg_mask = seg_mask[:,:,0] > 0
mask_save_path = f'result/mask_{time.time()}.png'
if not os.path.exists(os.path.dirname(mask_save_path)):
os.makedirs(os.path.dirname(mask_save_path))
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
if seg_mask_img.mode != 'RGB':
seg_mask_img = seg_mask_img.convert('RGB')
seg_mask_img.save(mask_save_path)
print('seg_mask path: ', mask_save_path)
print("seg_mask.shape: ", seg_mask.shape)
# captioning with mask
if self.args.enable_reduce_tokens:
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
else:
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
# refining with TextRefiner
context_captions = []
if self.args.context_captions:
context_captions.append(self.captioner.inference(image))
if not disable_gpt and self.text_refiner is not None:
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
else:
refined_caption = {'raw_caption': caption}
out = {'generated_captions': refined_caption,
'crop_save_path': crop_save_path,
'mask_save_path': mask_save_path,
'mask': seg_mask_img,
'context_captions': context_captions}
return out
def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--captioner', type=str, default="blip2")
parser.add_argument('--segmenter', type=str, default="base")
parser.add_argument('--text_refiner', type=str, default="base")
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
parser.add_argument('--device', type=str, default="cuda:0")
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
parser.add_argument('--debug', action="store_true")
parser.add_argument('--gradio_share', action="store_true")
parser.add_argument('--disable_gpt', action="store_true")
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
args = parser.parse_args()
if args.debug:
print(args)
return args
if __name__ == "__main__":
args = parse_augment()
# image_path = 'test_img/img3.jpg'
image_path = 'test_img/img13.jpg'
prompts = [
{
"prompt_type":["click"],
"input_point":[[500, 300], [1000, 500]],
"input_label":[1, 0],
"multimask_output":"True",
},
{
"prompt_type":["click"],
"input_point":[[900, 800]],
"input_label":[1],
"multimask_output":"True",
}
]
controls = {
"length": "30",
"sentiment": "positive",
# "imagination": "True",
"imagination": "False",
"language": "English",
}
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
for prompt in prompts:
print('*'*30)
print('Image path: ', image_path)
image = Image.open(image_path)
print(image)
print('Visual controls (SAM prompt):\n', prompt)
print('Language controls:\n', controls)
out = model.inference(image_path, prompt, controls)
|