Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
•
eabdb1c
1
Parent(s):
12dc496
improve chat box; add a enable_wiki button
Browse files- app.py +64 -38
- caption_anything.py +2 -2
- text_refiner/text_refiner.py +8 -6
app.py
CHANGED
@@ -120,28 +120,44 @@ def update_click_state(click_state, caption, click_mode):
|
|
120 |
raise NotImplementedError
|
121 |
|
122 |
|
123 |
-
def chat_with_points(chat_input, click_state, state, text_refiner):
|
124 |
if text_refiner is None:
|
125 |
response = "Text refiner is not initilzed, please input openai api key."
|
126 |
state = state + [(chat_input, response)]
|
127 |
-
return state, state
|
128 |
|
129 |
points, labels, captions = click_state
|
130 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!
|
|
|
|
|
131 |
# # "The image is of width {width} and height {height}."
|
132 |
-
point_chat_prompt = "
|
133 |
prev_visual_context = ""
|
134 |
-
pos_points = [
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
response = text_refiner.llm(chat_prompt)
|
141 |
state = state + [(chat_input, response)]
|
142 |
-
|
|
|
143 |
|
144 |
-
def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
|
145 |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
146 |
|
147 |
model = build_caption_anything_with_models(
|
@@ -173,11 +189,12 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
|
|
173 |
prompt = get_prompt(coordinate, click_state, click_mode)
|
174 |
print('prompt: ', prompt, 'controls: ', controls)
|
175 |
|
176 |
-
|
177 |
-
|
|
|
178 |
# for k, v in out['generated_captions'].items():
|
179 |
# state = state + [(f'{k}: {v}', None)]
|
180 |
-
state = state + [("
|
181 |
wiki = out['generated_captions'].get('wiki', "")
|
182 |
|
183 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
@@ -191,15 +208,18 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
|
|
191 |
|
192 |
yield state, state, click_state, chat_input, image_input, wiki
|
193 |
if not args.disable_gpt and model.text_refiner:
|
194 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
195 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
196 |
new_cap = refined_caption['caption']
|
|
|
|
|
197 |
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
198 |
yield state, state, click_state, chat_input, refined_image_input, wiki
|
199 |
|
200 |
|
201 |
def upload_callback(image_input, state):
|
202 |
-
state = [] + [('Image size: ' + str(image_input.size)
|
|
|
203 |
click_state = [[], [], []]
|
204 |
res = 1024
|
205 |
width, height = image_input.size
|
@@ -219,7 +239,7 @@ def upload_callback(image_input, state):
|
|
219 |
image_embedding = model.segmenter.image_embedding
|
220 |
original_size = model.segmenter.predictor.original_size
|
221 |
input_size = model.segmenter.predictor.input_size
|
222 |
-
return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
223 |
|
224 |
with gr.Blocks(
|
225 |
css='''
|
@@ -229,6 +249,7 @@ with gr.Blocks(
|
|
229 |
) as iface:
|
230 |
state = gr.State([])
|
231 |
click_state = gr.State([[],[],[]])
|
|
|
232 |
origin_image = gr.State(None)
|
233 |
image_embedding = gr.State(None)
|
234 |
text_refiner = gr.State(None)
|
@@ -260,14 +281,13 @@ with gr.Blocks(
|
|
260 |
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
261 |
with gr.Column(visible=False) as modules_need_gpt:
|
262 |
with gr.Row(scale=1.0):
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
)
|
271 |
with gr.Row(scale=1.0):
|
272 |
factuality = gr.Radio(
|
273 |
choices=["Factual", "Imagination"],
|
@@ -281,8 +301,13 @@ with gr.Blocks(
|
|
281 |
value=10,
|
282 |
step=1,
|
283 |
interactive=True,
|
284 |
-
label="Length",
|
285 |
-
)
|
|
|
|
|
|
|
|
|
|
|
286 |
with gr.Column(visible=True) as modules_not_need_gpt3:
|
287 |
gr.Examples(
|
288 |
examples=examples,
|
@@ -303,7 +328,7 @@ with gr.Blocks(
|
|
303 |
with gr.Column(visible=False) as modules_not_need_gpt2:
|
304 |
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
|
305 |
with gr.Column(visible=False) as modules_need_gpt3:
|
306 |
-
chat_input = gr.Textbox(
|
307 |
with gr.Row():
|
308 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
309 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
@@ -320,30 +345,30 @@ with gr.Blocks(
|
|
320 |
show_progress=False
|
321 |
)
|
322 |
clear_button_image.click(
|
323 |
-
lambda: (None, [], [], [[], [], []], "", ""),
|
324 |
[],
|
325 |
-
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
326 |
queue=False,
|
327 |
show_progress=False
|
328 |
)
|
329 |
clear_button_text.click(
|
330 |
-
lambda: ([], [], [[], [], []]),
|
331 |
[],
|
332 |
-
[chatbot, state, click_state],
|
333 |
queue=False,
|
334 |
show_progress=False
|
335 |
)
|
336 |
image_input.clear(
|
337 |
-
lambda: (None, [], [], [[], [], []], "", ""),
|
338 |
[],
|
339 |
-
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
340 |
queue=False,
|
341 |
show_progress=False
|
342 |
)
|
343 |
|
344 |
-
image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
345 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
|
346 |
-
example_image.change(upload_callback,[example_image, state], [
|
347 |
|
348 |
# select coordinate
|
349 |
image_input.select(inference_seg_cap,
|
@@ -351,6 +376,7 @@ with gr.Blocks(
|
|
351 |
origin_image,
|
352 |
point_prompt,
|
353 |
click_mode,
|
|
|
354 |
language,
|
355 |
sentiment,
|
356 |
factuality,
|
|
|
120 |
raise NotImplementedError
|
121 |
|
122 |
|
123 |
+
def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
|
124 |
if text_refiner is None:
|
125 |
response = "Text refiner is not initilzed, please input openai api key."
|
126 |
state = state + [(chat_input, response)]
|
127 |
+
return state, state, chat_state
|
128 |
|
129 |
points, labels, captions = click_state
|
130 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
131 |
+
suffix = '\nHuman: {chat_input}\nAI: '
|
132 |
+
qa_template = '\nHuman: {q}\nAI: {a}'
|
133 |
# # "The image is of width {width} and height {height}."
|
134 |
+
point_chat_prompt = "I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps} \n Now, let's chat!"
|
135 |
prev_visual_context = ""
|
136 |
+
pos_points = []
|
137 |
+
pos_captions = []
|
138 |
+
for i in range(len(points)):
|
139 |
+
if labels[i] == 1:
|
140 |
+
pos_points.append(f"({points[i][0]}, {points[i][0]})")
|
141 |
+
pos_captions.append(captions[i])
|
142 |
+
prev_visual_context = prev_visual_context + '\n' + 'Points: ' +', '.join(pos_points) + '. Description: ' + pos_captions[-1]
|
143 |
+
|
144 |
+
context_length_thres = 500
|
145 |
+
prev_history = ""
|
146 |
+
for i in range(len(chat_state)):
|
147 |
+
q, a = chat_state[i]
|
148 |
+
if len(prev_history) < context_length_thres:
|
149 |
+
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
150 |
+
else:
|
151 |
+
break
|
152 |
+
|
153 |
+
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
|
154 |
+
print('\nchat_prompt: ', chat_prompt)
|
155 |
response = text_refiner.llm(chat_prompt)
|
156 |
state = state + [(chat_input, response)]
|
157 |
+
chat_state = chat_state + [(chat_input, response)]
|
158 |
+
return state, state, chat_state
|
159 |
|
160 |
+
def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
161 |
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
162 |
|
163 |
model = build_caption_anything_with_models(
|
|
|
189 |
prompt = get_prompt(coordinate, click_state, click_mode)
|
190 |
print('prompt: ', prompt, 'controls: ', controls)
|
191 |
|
192 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
193 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
194 |
+
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
195 |
# for k, v in out['generated_captions'].items():
|
196 |
# state = state + [(f'{k}: {v}', None)]
|
197 |
+
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
198 |
wiki = out['generated_captions'].get('wiki', "")
|
199 |
|
200 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
|
|
208 |
|
209 |
yield state, state, click_state, chat_input, image_input, wiki
|
210 |
if not args.disable_gpt and model.text_refiner:
|
211 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
|
212 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
213 |
new_cap = refined_caption['caption']
|
214 |
+
wiki = refined_caption['wiki']
|
215 |
+
state = state + [(None, f"caption: {new_cap}")]
|
216 |
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
217 |
yield state, state, click_state, chat_input, refined_image_input, wiki
|
218 |
|
219 |
|
220 |
def upload_callback(image_input, state):
|
221 |
+
state = [] + [(None, 'Image size: ' + str(image_input.size))]
|
222 |
+
chat_state = []
|
223 |
click_state = [[], [], []]
|
224 |
res = 1024
|
225 |
width, height = image_input.size
|
|
|
239 |
image_embedding = model.segmenter.image_embedding
|
240 |
original_size = model.segmenter.predictor.original_size
|
241 |
input_size = model.segmenter.predictor.input_size
|
242 |
+
return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
243 |
|
244 |
with gr.Blocks(
|
245 |
css='''
|
|
|
249 |
) as iface:
|
250 |
state = gr.State([])
|
251 |
click_state = gr.State([[],[],[]])
|
252 |
+
chat_state = gr.State([])
|
253 |
origin_image = gr.State(None)
|
254 |
image_embedding = gr.State(None)
|
255 |
text_refiner = gr.State(None)
|
|
|
281 |
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
282 |
with gr.Column(visible=False) as modules_need_gpt:
|
283 |
with gr.Row(scale=1.0):
|
284 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
285 |
+
sentiment = gr.Radio(
|
286 |
+
choices=["Positive", "Natural", "Negative"],
|
287 |
+
value="Natural",
|
288 |
+
label="Sentiment",
|
289 |
+
interactive=True,
|
290 |
+
)
|
|
|
291 |
with gr.Row(scale=1.0):
|
292 |
factuality = gr.Radio(
|
293 |
choices=["Factual", "Imagination"],
|
|
|
301 |
value=10,
|
302 |
step=1,
|
303 |
interactive=True,
|
304 |
+
label="Generated Caption Length",
|
305 |
+
)
|
306 |
+
enable_wiki = gr.Radio(
|
307 |
+
choices=["Yes", "No"],
|
308 |
+
value="No",
|
309 |
+
label="Enable Wiki",
|
310 |
+
interactive=True)
|
311 |
with gr.Column(visible=True) as modules_not_need_gpt3:
|
312 |
gr.Examples(
|
313 |
examples=examples,
|
|
|
328 |
with gr.Column(visible=False) as modules_not_need_gpt2:
|
329 |
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
|
330 |
with gr.Column(visible=False) as modules_need_gpt3:
|
331 |
+
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(container=False)
|
332 |
with gr.Row():
|
333 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
334 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
|
|
345 |
show_progress=False
|
346 |
)
|
347 |
clear_button_image.click(
|
348 |
+
lambda: (None, [], [], [], [[], [], []], "", ""),
|
349 |
[],
|
350 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
|
351 |
queue=False,
|
352 |
show_progress=False
|
353 |
)
|
354 |
clear_button_text.click(
|
355 |
+
lambda: ([], [], [[], [], [], []], []),
|
356 |
[],
|
357 |
+
[chatbot, state, click_state, chat_state],
|
358 |
queue=False,
|
359 |
show_progress=False
|
360 |
)
|
361 |
image_input.clear(
|
362 |
+
lambda: (None, [], [], [], [[], [], []], "", ""),
|
363 |
[],
|
364 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
|
365 |
queue=False,
|
366 |
show_progress=False
|
367 |
)
|
368 |
|
369 |
+
image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
370 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner], [chatbot, state, chat_state])
|
371 |
+
example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
372 |
|
373 |
# select coordinate
|
374 |
image_input.select(inference_seg_cap,
|
|
|
376 |
origin_image,
|
377 |
point_prompt,
|
378 |
click_mode,
|
379 |
+
enable_wiki,
|
380 |
language,
|
381 |
sentiment,
|
382 |
factuality,
|
caption_anything.py
CHANGED
@@ -30,7 +30,7 @@ class CaptionAnything():
|
|
30 |
self.text_refiner = None
|
31 |
print('OpenAI GPT is not available')
|
32 |
|
33 |
-
def inference(self, image, prompt, controls, disable_gpt=False):
|
34 |
# segment with prompt
|
35 |
print("CA prompt: ", prompt, "CA controls",controls)
|
36 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
@@ -59,7 +59,7 @@ class CaptionAnything():
|
|
59 |
if self.args.context_captions:
|
60 |
context_captions.append(self.captioner.inference(image))
|
61 |
if not disable_gpt and self.text_refiner is not None:
|
62 |
-
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
63 |
else:
|
64 |
refined_caption = {'raw_caption': caption}
|
65 |
out = {'generated_captions': refined_caption,
|
|
|
30 |
self.text_refiner = None
|
31 |
print('OpenAI GPT is not available')
|
32 |
|
33 |
+
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
|
34 |
# segment with prompt
|
35 |
print("CA prompt: ", prompt, "CA controls",controls)
|
36 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
|
|
59 |
if self.args.context_captions:
|
60 |
context_captions.append(self.captioner.inference(image))
|
61 |
if not disable_gpt and self.text_refiner is not None:
|
62 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
|
63 |
else:
|
64 |
refined_caption = {'raw_caption': caption}
|
65 |
out = {'generated_captions': refined_caption,
|
text_refiner/text_refiner.py
CHANGED
@@ -39,7 +39,7 @@ class TextRefiner:
|
|
39 |
print('prompt: ', input)
|
40 |
return input
|
41 |
|
42 |
-
def inference(self, query: str, controls: dict, context: list=[]):
|
43 |
"""
|
44 |
query: the caption of the region of interest, generated by captioner
|
45 |
controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
|
@@ -58,15 +58,17 @@ class TextRefiner:
|
|
58 |
response = self.llm(input)
|
59 |
response = self.parse(response)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
out = {
|
66 |
'raw_caption': query,
|
67 |
'caption': response,
|
68 |
'wiki': response_wiki
|
69 |
-
}
|
70 |
print(out)
|
71 |
return out
|
72 |
|
|
|
39 |
print('prompt: ', input)
|
40 |
return input
|
41 |
|
42 |
+
def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
|
43 |
"""
|
44 |
query: the caption of the region of interest, generated by captioner
|
45 |
controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
|
|
|
58 |
response = self.llm(input)
|
59 |
response = self.parse(response)
|
60 |
|
61 |
+
response_wiki = ""
|
62 |
+
if enable_wiki:
|
63 |
+
tmp_configs = {"query": query}
|
64 |
+
prompt_wiki = self.wiki_prompts.format(**tmp_configs)
|
65 |
+
response_wiki = self.llm(prompt_wiki)
|
66 |
+
response_wiki = self.parse2(response_wiki)
|
67 |
out = {
|
68 |
'raw_caption': query,
|
69 |
'caption': response,
|
70 |
'wiki': response_wiki
|
71 |
+
}
|
72 |
print(out)
|
73 |
return out
|
74 |
|