File size: 6,223 Bytes
10240e0
 
 
 
 
 
 
 
13c1c2e
 
10240e0
 
13c1c2e
10240e0
13c1c2e
 
 
5c74464
3421695
 
 
 
 
10240e0
5c74464
 
 
 
 
 
13c1c2e
5c74464
eabdb1c
10240e0
 
 
13c1c2e
 
 
 
 
 
10240e0
 
 
13c1c2e
 
 
 
10240e0
 
 
 
5c74464
10240e0
5c74464
10240e0
 
 
 
5c74464
eabdb1c
10240e0
 
 
 
 
13c1c2e
10240e0
 
 
 
 
5c74464
10240e0
 
 
5c74464
10240e0
 
5c74464
10240e0
 
 
 
 
 
 
13c1c2e
10240e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c74464
10240e0
 
 
 
 
 
 
 
 
13c1c2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from captioner import build_captioner, BaseCaptioner
from segmenter import build_segmenter
from text_refiner import build_text_refiner
import os
import argparse
import pdb
import time
from PIL import Image
import cv2
import numpy as np

class CaptionAnything():
    def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
        self.args = args
        self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
        self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
        
        self.text_refiner = None
        if not args.disable_gpt:
            if text_refiner is not None:
                self.text_refiner = text_refiner
            else:
                self.init_refiner(api_key)

    def init_refiner(self, api_key):
        try:
            self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
            self.text_refiner.llm('hi') # test
        except:
            self.text_refiner = None
            print('OpenAI GPT is not available')
    
    def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
        #  segment with prompt
        print("CA prompt: ", prompt, "CA controls",controls)
        seg_mask = self.segmenter.inference(image, prompt)[0, ...]
        if self.args.enable_morphologyex:
            seg_mask = 255 * seg_mask.astype(np.uint8)
            seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
            seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
            seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
            seg_mask = seg_mask[:,:,0] > 0
        mask_save_path = f'result/mask_{time.time()}.png'
        if not os.path.exists(os.path.dirname(mask_save_path)):
            os.makedirs(os.path.dirname(mask_save_path))
        seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
        if seg_mask_img.mode != 'RGB':
            seg_mask_img = seg_mask_img.convert('RGB')
        seg_mask_img.save(mask_save_path)
        print('seg_mask path: ', mask_save_path)
        print("seg_mask.shape: ", seg_mask.shape)
        #  captioning with mask
        if self.args.enable_reduce_tokens:
            caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
        else:
            caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
        #  refining with TextRefiner
        context_captions = []
        if self.args.context_captions:
            context_captions.append(self.captioner.inference(image))
        if not disable_gpt and self.text_refiner is not None:
            refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
        else:
            refined_caption = {'raw_caption': caption}                
        out = {'generated_captions': refined_caption,
            'crop_save_path': crop_save_path,
            'mask_save_path': mask_save_path,
            'mask': seg_mask_img,
            'context_captions': context_captions}
        return out
    
def parse_augment():
    parser = argparse.ArgumentParser()
    parser.add_argument('--captioner', type=str, default="blip2")
    parser.add_argument('--segmenter', type=str, default="base")
    parser.add_argument('--text_refiner', type=str, default="base")
    parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
    parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
    parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
    parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
    parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")  
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--gradio_share', action="store_true")
    parser.add_argument('--disable_gpt', action="store_true")
    parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
    parser.add_argument('--disable_reuse_features', action="store_true", default=False)
    parser.add_argument('--enable_morphologyex', action="store_true", default=False)
    args = parser.parse_args()

    if args.debug:
        print(args)
    return args

if __name__ == "__main__":
    args = parse_augment()
    # image_path = 'test_img/img3.jpg'
    image_path = 'test_img/img13.jpg'
    prompts = [
        {
            "prompt_type":["click"],
            "input_point":[[500, 300], [1000, 500]],
            "input_label":[1, 0],
            "multimask_output":"True",
        },
        {
            "prompt_type":["click"],
            "input_point":[[900, 800]],
            "input_label":[1],
            "multimask_output":"True",
        }
    ]
    controls = {
            "length": "30",
            "sentiment": "positive",
            # "imagination": "True",
            "imagination": "False",
            "language": "English",
        }
    
    model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
    for prompt in prompts:
        print('*'*30)
        print('Image path: ', image_path)
        image = Image.open(image_path)
        print(image)
        print('Visual controls (SAM prompt):\n', prompt)
        print('Language controls:\n', controls)
        out = model.inference(image_path, prompt, controls)