Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
•
10240e0
1
Parent(s):
36b57c3
update lastest
Browse files- .gitignore +135 -0
- README.md +45 -11
- app.py +28 -23
- app_huggingface.py +268 -0
- app_old.py +5 -5
- caption_anything.py +114 -0
- captioner/base_captioner.py +3 -2
- env.sh +1 -1
- image_editing_utils.py +2 -2
- requirements.txt +1 -0
- tools.py +7 -1
.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result/
|
2 |
+
model_cache/
|
3 |
+
*.pth
|
4 |
+
teng_grad_start.sh
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
result/
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
pip-wheel-metadata/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
*.py,cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
.python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
101 |
+
__pypackages__/
|
102 |
+
|
103 |
+
# Celery stuff
|
104 |
+
celerybeat-schedule
|
105 |
+
celerybeat.pid
|
106 |
+
|
107 |
+
# SageMath parsed files
|
108 |
+
*.sage.py
|
109 |
+
|
110 |
+
# Environments
|
111 |
+
.env
|
112 |
+
.venv
|
113 |
+
env/
|
114 |
+
venv/
|
115 |
+
ENV/
|
116 |
+
env.bak/
|
117 |
+
venv.bak/
|
118 |
+
|
119 |
+
# Spyder project settings
|
120 |
+
.spyderproject
|
121 |
+
.spyproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
|
126 |
+
# mkdocs documentation
|
127 |
+
/site
|
128 |
+
|
129 |
+
# mypy
|
130 |
+
.mypy_cache/
|
131 |
+
.dmypy.json
|
132 |
+
dmypy.json
|
133 |
+
|
134 |
+
# Pyre type checker
|
135 |
+
.pyre/
|
README.md
CHANGED
@@ -1,13 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
title: Caption Anything
|
3 |
-
emoji: 📚
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.24.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Caption-Anything
|
2 |
+
<!-- ![](./Image/title.svg) -->
|
3 |
+
**Caption-Anything** is a versatile image processing tool that combines the capabilities of [Segment Anything](https://github.com/facebookresearch/segment-anything), Visual Captioning, and [ChatGPT](https://openai.com/blog/chatgpt). Our solution generates descriptive captions for any object within an image, offering a range of language styles to accommodate diverse user preferences. **Caption-Anything** supports visual controls (mouse click) and language controls (length, sentiment, factuality, and language).
|
4 |
+
* visual controls and language controls for text generation
|
5 |
+
* Chat about selected object for detailed understanding
|
6 |
+
* Interactive demo
|
7 |
+
![](./Image/UI.png)
|
8 |
+
|
9 |
+
<!-- <a src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" href="https://huggingface.co/spaces/wybertwang/Caption-Anything">
|
10 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" alt="Open in Spaces">
|
11 |
+
</a> -->
|
12 |
+
|
13 |
+
<!-- <a src="https://colab.research.google.com/assets/colab-badge.svg" href="">
|
14 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
|
15 |
+
</a> -->
|
16 |
+
|
17 |
+
### Demo
|
18 |
+
Explore the interactive demo of Caption-Anything, which showcases its powerful capabilities in generating captions for various objects within an image. The demo allows users to control visual aspects by clicking on objects, as well as to adjust textual properties such as length, sentiment, factuality, and language.
|
19 |
+
![](./Image/demo1.png)
|
20 |
+
|
21 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
![](./Image/demo2.png)
|
24 |
+
|
25 |
+
### Getting Started
|
26 |
+
|
27 |
+
|
28 |
+
* Clone the repository:
|
29 |
+
```bash
|
30 |
+
git clone https://github.com/ttengwang/caption-anything.git
|
31 |
+
```
|
32 |
+
* Install dependencies:
|
33 |
+
```bash
|
34 |
+
cd caption-anything
|
35 |
+
pip install -r requirements.txt
|
36 |
+
```
|
37 |
+
* Download the [SAM checkpoints](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and place it to `./segmenter/sam_vit_h_4b8939.pth.`
|
38 |
+
|
39 |
+
* Run the Caption-Anything gradio demo.
|
40 |
+
```bash
|
41 |
+
# Configure the necessary ChatGPT APIs
|
42 |
+
export OPENAI_API_KEY={Your_Private_Openai_Key}
|
43 |
+
python app.py --regular_box --captioner blip2 --port 6086
|
44 |
+
```
|
45 |
+
|
46 |
+
## Acknowledgement
|
47 |
+
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), BLIP/BLIP-2, [ChatGPT](https://openai.com/blog/chatgpt). Thanks for the authors for their efforts.
|
app.py
CHANGED
@@ -2,12 +2,12 @@ from io import BytesIO
|
|
2 |
import string
|
3 |
import gradio as gr
|
4 |
import requests
|
5 |
-
from
|
6 |
import torch
|
7 |
import json
|
8 |
import sys
|
9 |
import argparse
|
10 |
-
from
|
11 |
import numpy as np
|
12 |
import PIL.ImageDraw as ImageDraw
|
13 |
from image_editing_utils import create_bubble_frame
|
@@ -47,6 +47,9 @@ examples = [
|
|
47 |
]
|
48 |
|
49 |
args = parse_augment()
|
|
|
|
|
|
|
50 |
# args.device = 'cuda:5'
|
51 |
# args.disable_gpt = False
|
52 |
# args.enable_reduce_tokens = True
|
@@ -81,9 +84,9 @@ def chat_with_points(chat_input, click_state, state):
|
|
81 |
return state, state
|
82 |
|
83 |
points, labels, captions = click_state
|
84 |
-
point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
85 |
-
# "The image is of width {width} and height {height}."
|
86 |
-
|
87 |
prev_visual_context = ""
|
88 |
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
89 |
if len(captions):
|
@@ -114,9 +117,10 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
|
|
114 |
|
115 |
out = model.inference(image_input, prompt, controls)
|
116 |
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
117 |
-
for k, v in out['generated_captions'].items():
|
118 |
-
|
119 |
-
|
|
|
120 |
click_state[2].append(out['generated_captions']['raw_caption'])
|
121 |
|
122 |
text = out['generated_captions']['raw_caption']
|
@@ -127,12 +131,13 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
|
|
127 |
origin_image_input = image_input
|
128 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
129 |
|
130 |
-
yield state, state, click_state, chat_input, image_input
|
131 |
if not args.disable_gpt and hasattr(model, "text_refiner"):
|
132 |
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
133 |
-
new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
|
|
134 |
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
135 |
-
yield state, state, click_state, chat_input, refined_image_input
|
136 |
|
137 |
|
138 |
def upload_callback(image_input, state):
|
@@ -195,28 +200,29 @@ with gr.Blocks(
|
|
195 |
with gr.Column(scale=0.5):
|
196 |
openai_api_key = gr.Textbox(
|
197 |
placeholder="Input your openAI API key and press Enter",
|
198 |
-
show_label=
|
199 |
label = "OpenAI API Key",
|
200 |
lines=1,
|
201 |
type="password"
|
202 |
)
|
203 |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
204 |
-
|
|
|
205 |
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
206 |
with gr.Row():
|
207 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
208 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
209 |
clear_button_clike.click(
|
210 |
-
lambda x: ([[], [], []], x),
|
211 |
[origin_image],
|
212 |
-
[click_state, image_input],
|
213 |
queue=False,
|
214 |
show_progress=False
|
215 |
)
|
216 |
clear_button_image.click(
|
217 |
-
lambda: (None, [], [], [[], [], []]),
|
218 |
[],
|
219 |
-
[image_input, chatbot, state, click_state],
|
220 |
queue=False,
|
221 |
show_progress=False
|
222 |
)
|
@@ -228,9 +234,9 @@ with gr.Blocks(
|
|
228 |
show_progress=False
|
229 |
)
|
230 |
image_input.clear(
|
231 |
-
lambda: (None, [], [], [[], [], []]),
|
232 |
[],
|
233 |
-
[image_input, chatbot, state, click_state],
|
234 |
queue=False,
|
235 |
show_progress=False
|
236 |
)
|
@@ -255,9 +261,8 @@ with gr.Blocks(
|
|
255 |
state,
|
256 |
click_state
|
257 |
],
|
258 |
-
|
259 |
-
|
260 |
-
show_progress=False, queue=True)
|
261 |
|
262 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
263 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
2 |
import string
|
3 |
import gradio as gr
|
4 |
import requests
|
5 |
+
from caption_anything import CaptionAnything
|
6 |
import torch
|
7 |
import json
|
8 |
import sys
|
9 |
import argparse
|
10 |
+
from caption_anything import parse_augment
|
11 |
import numpy as np
|
12 |
import PIL.ImageDraw as ImageDraw
|
13 |
from image_editing_utils import create_bubble_frame
|
|
|
47 |
]
|
48 |
|
49 |
args = parse_augment()
|
50 |
+
args.captioner = 'blip2'
|
51 |
+
args.seg_crop_mode = 'wo_bg'
|
52 |
+
args.regular_box = True
|
53 |
# args.device = 'cuda:5'
|
54 |
# args.disable_gpt = False
|
55 |
# args.enable_reduce_tokens = True
|
|
|
84 |
return state, state
|
85 |
|
86 |
points, labels, captions = click_state
|
87 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
88 |
+
# # "The image is of width {width} and height {height}."
|
89 |
+
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
90 |
prev_visual_context = ""
|
91 |
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
92 |
if len(captions):
|
|
|
117 |
|
118 |
out = model.inference(image_input, prompt, controls)
|
119 |
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
120 |
+
# for k, v in out['generated_captions'].items():
|
121 |
+
# state = state + [(f'{k}: {v}', None)]
|
122 |
+
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
123 |
+
wiki = out['generated_captions'].get('wiki', "")
|
124 |
click_state[2].append(out['generated_captions']['raw_caption'])
|
125 |
|
126 |
text = out['generated_captions']['raw_caption']
|
|
|
131 |
origin_image_input = image_input
|
132 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
133 |
|
134 |
+
yield state, state, click_state, chat_input, image_input, wiki
|
135 |
if not args.disable_gpt and hasattr(model, "text_refiner"):
|
136 |
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
137 |
+
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
138 |
+
new_cap = refined_caption['caption']
|
139 |
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
140 |
+
yield state, state, click_state, chat_input, refined_image_input, wiki
|
141 |
|
142 |
|
143 |
def upload_callback(image_input, state):
|
|
|
200 |
with gr.Column(scale=0.5):
|
201 |
openai_api_key = gr.Textbox(
|
202 |
placeholder="Input your openAI API key and press Enter",
|
203 |
+
show_label=False,
|
204 |
label = "OpenAI API Key",
|
205 |
lines=1,
|
206 |
type="password"
|
207 |
)
|
208 |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
209 |
+
wiki_output = gr.Textbox(lines=6, label="Wiki")
|
210 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
|
211 |
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
212 |
with gr.Row():
|
213 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
214 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
215 |
clear_button_clike.click(
|
216 |
+
lambda x: ([[], [], []], x, ""),
|
217 |
[origin_image],
|
218 |
+
[click_state, image_input, wiki_output],
|
219 |
queue=False,
|
220 |
show_progress=False
|
221 |
)
|
222 |
clear_button_image.click(
|
223 |
+
lambda: (None, [], [], [[], [], []], ""),
|
224 |
[],
|
225 |
+
[image_input, chatbot, state, click_state, wiki_output],
|
226 |
queue=False,
|
227 |
show_progress=False
|
228 |
)
|
|
|
234 |
show_progress=False
|
235 |
)
|
236 |
image_input.clear(
|
237 |
+
lambda: (None, [], [], [[], [], []], ""),
|
238 |
[],
|
239 |
+
[image_input, chatbot, state, click_state, wiki_output],
|
240 |
queue=False,
|
241 |
show_progress=False
|
242 |
)
|
|
|
261 |
state,
|
262 |
click_state
|
263 |
],
|
264 |
+
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
265 |
+
show_progress=False, queue=True)
|
|
|
266 |
|
267 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
268 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
app_huggingface.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import string
|
3 |
+
import gradio as gr
|
4 |
+
import requests
|
5 |
+
from caption_anything import CaptionAnything
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
from caption_anything import parse_augment
|
11 |
+
import numpy as np
|
12 |
+
import PIL.ImageDraw as ImageDraw
|
13 |
+
from image_editing_utils import create_bubble_frame
|
14 |
+
import copy
|
15 |
+
from tools import mask_painter
|
16 |
+
from PIL import Image
|
17 |
+
import os
|
18 |
+
|
19 |
+
def download_checkpoint(url, folder, filename):
|
20 |
+
os.makedirs(folder, exist_ok=True)
|
21 |
+
filepath = os.path.join(folder, filename)
|
22 |
+
|
23 |
+
if not os.path.exists(filepath):
|
24 |
+
response = requests.get(url, stream=True)
|
25 |
+
with open(filepath, "wb") as f:
|
26 |
+
for chunk in response.iter_content(chunk_size=8192):
|
27 |
+
if chunk:
|
28 |
+
f.write(chunk)
|
29 |
+
|
30 |
+
return filepath
|
31 |
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
32 |
+
folder = "segmenter"
|
33 |
+
filename = "sam_vit_h_4b8939.pth"
|
34 |
+
|
35 |
+
download_checkpoint(checkpoint_url, folder, filename)
|
36 |
+
|
37 |
+
|
38 |
+
title = """<h1 align="center">Caption-Anything</h1>"""
|
39 |
+
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
|
40 |
+
"""
|
41 |
+
|
42 |
+
examples = [
|
43 |
+
["test_img/img2.jpg"],
|
44 |
+
["test_img/img5.jpg"],
|
45 |
+
["test_img/img12.jpg"],
|
46 |
+
["test_img/img14.jpg"],
|
47 |
+
]
|
48 |
+
|
49 |
+
args = parse_augment()
|
50 |
+
args.captioner = 'blip2'
|
51 |
+
args.seg_crop_mode = 'wo_bg'
|
52 |
+
args.regular_box = True
|
53 |
+
# args.device = 'cuda:5'
|
54 |
+
# args.disable_gpt = False
|
55 |
+
# args.enable_reduce_tokens = True
|
56 |
+
# args.port=20322
|
57 |
+
model = CaptionAnything(args)
|
58 |
+
|
59 |
+
def init_openai_api_key(api_key):
|
60 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
61 |
+
model.init_refiner()
|
62 |
+
|
63 |
+
|
64 |
+
def get_prompt(chat_input, click_state):
|
65 |
+
points = click_state[0]
|
66 |
+
labels = click_state[1]
|
67 |
+
inputs = json.loads(chat_input)
|
68 |
+
for input in inputs:
|
69 |
+
points.append(input[:2])
|
70 |
+
labels.append(input[2])
|
71 |
+
|
72 |
+
prompt = {
|
73 |
+
"prompt_type":["click"],
|
74 |
+
"input_point":points,
|
75 |
+
"input_label":labels,
|
76 |
+
"multimask_output":"True",
|
77 |
+
}
|
78 |
+
return prompt
|
79 |
+
|
80 |
+
def chat_with_points(chat_input, click_state, state):
|
81 |
+
if not hasattr(model, "text_refiner"):
|
82 |
+
response = "Text refiner is not initilzed, please input openai api key."
|
83 |
+
state = state + [(chat_input, response)]
|
84 |
+
return state, state
|
85 |
+
|
86 |
+
points, labels, captions = click_state
|
87 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
88 |
+
# # "The image is of width {width} and height {height}."
|
89 |
+
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
90 |
+
prev_visual_context = ""
|
91 |
+
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
92 |
+
if len(captions):
|
93 |
+
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
94 |
+
else:
|
95 |
+
prev_visual_context = 'no point exists.'
|
96 |
+
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
97 |
+
response = model.text_refiner.llm(chat_prompt)
|
98 |
+
state = state + [(chat_input, response)]
|
99 |
+
return state, state
|
100 |
+
|
101 |
+
def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
|
102 |
+
|
103 |
+
if point_prompt == 'Positive':
|
104 |
+
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
105 |
+
else:
|
106 |
+
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
107 |
+
|
108 |
+
controls = {'length': length,
|
109 |
+
'sentiment': sentiment,
|
110 |
+
'factuality': factuality,
|
111 |
+
'language': language}
|
112 |
+
|
113 |
+
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
114 |
+
# chat_input = click_coordinate
|
115 |
+
prompt = get_prompt(coordinate, click_state)
|
116 |
+
print('prompt: ', prompt, 'controls: ', controls)
|
117 |
+
|
118 |
+
out = model.inference(image_input, prompt, controls)
|
119 |
+
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
120 |
+
# for k, v in out['generated_captions'].items():
|
121 |
+
# state = state + [(f'{k}: {v}', None)]
|
122 |
+
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
123 |
+
wiki = out['generated_captions'].get('wiki', "")
|
124 |
+
click_state[2].append(out['generated_captions']['raw_caption'])
|
125 |
+
|
126 |
+
text = out['generated_captions']['raw_caption']
|
127 |
+
# draw = ImageDraw.Draw(image_input)
|
128 |
+
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
129 |
+
input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
|
130 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
131 |
+
origin_image_input = image_input
|
132 |
+
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
133 |
+
|
134 |
+
yield state, state, click_state, chat_input, image_input, wiki
|
135 |
+
if not args.disable_gpt and hasattr(model, "text_refiner"):
|
136 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
137 |
+
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
138 |
+
new_cap = refined_caption['caption']
|
139 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
140 |
+
yield state, state, click_state, chat_input, refined_image_input, wiki
|
141 |
+
|
142 |
+
|
143 |
+
def upload_callback(image_input, state):
|
144 |
+
state = [] + [('Image size: ' + str(image_input.size), None)]
|
145 |
+
click_state = [[], [], []]
|
146 |
+
model.segmenter.image = None
|
147 |
+
model.segmenter.image_embedding = None
|
148 |
+
model.segmenter.set_image(image_input)
|
149 |
+
return state, image_input, click_state
|
150 |
+
|
151 |
+
with gr.Blocks(
|
152 |
+
css='''
|
153 |
+
#image_upload{min-height:400px}
|
154 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
|
155 |
+
'''
|
156 |
+
) as iface:
|
157 |
+
state = gr.State([])
|
158 |
+
click_state = gr.State([[],[],[]])
|
159 |
+
origin_image = gr.State(None)
|
160 |
+
|
161 |
+
gr.Markdown(title)
|
162 |
+
gr.Markdown(description)
|
163 |
+
|
164 |
+
with gr.Row():
|
165 |
+
with gr.Column(scale=1.0):
|
166 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
167 |
+
with gr.Row(scale=1.0):
|
168 |
+
point_prompt = gr.Radio(
|
169 |
+
choices=["Positive", "Negative"],
|
170 |
+
value="Positive",
|
171 |
+
label="Point Prompt",
|
172 |
+
interactive=True)
|
173 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
174 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
175 |
+
with gr.Row(scale=1.0):
|
176 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
177 |
+
|
178 |
+
sentiment = gr.Radio(
|
179 |
+
choices=["Positive", "Natural", "Negative"],
|
180 |
+
value="Natural",
|
181 |
+
label="Sentiment",
|
182 |
+
interactive=True,
|
183 |
+
)
|
184 |
+
with gr.Row(scale=1.0):
|
185 |
+
factuality = gr.Radio(
|
186 |
+
choices=["Factual", "Imagination"],
|
187 |
+
value="Factual",
|
188 |
+
label="Factuality",
|
189 |
+
interactive=True,
|
190 |
+
)
|
191 |
+
length = gr.Slider(
|
192 |
+
minimum=10,
|
193 |
+
maximum=80,
|
194 |
+
value=10,
|
195 |
+
step=1,
|
196 |
+
interactive=True,
|
197 |
+
label="Length",
|
198 |
+
)
|
199 |
+
|
200 |
+
with gr.Column(scale=0.5):
|
201 |
+
openai_api_key = gr.Textbox(
|
202 |
+
placeholder="Input your openAI API key and press Enter",
|
203 |
+
show_label=False,
|
204 |
+
label = "OpenAI API Key",
|
205 |
+
lines=1,
|
206 |
+
type="password"
|
207 |
+
)
|
208 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
209 |
+
wiki_output = gr.Textbox(lines=6, label="Wiki")
|
210 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
|
211 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
212 |
+
with gr.Row():
|
213 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
214 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
215 |
+
clear_button_clike.click(
|
216 |
+
lambda x: ([[], [], []], x, ""),
|
217 |
+
[origin_image],
|
218 |
+
[click_state, image_input, wiki_output],
|
219 |
+
queue=False,
|
220 |
+
show_progress=False
|
221 |
+
)
|
222 |
+
clear_button_image.click(
|
223 |
+
lambda: (None, [], [], [[], [], []], ""),
|
224 |
+
[],
|
225 |
+
[image_input, chatbot, state, click_state, wiki_output],
|
226 |
+
queue=False,
|
227 |
+
show_progress=False
|
228 |
+
)
|
229 |
+
clear_button_text.click(
|
230 |
+
lambda: ([], [], [[], [], []]),
|
231 |
+
[],
|
232 |
+
[chatbot, state, click_state],
|
233 |
+
queue=False,
|
234 |
+
show_progress=False
|
235 |
+
)
|
236 |
+
image_input.clear(
|
237 |
+
lambda: (None, [], [], [[], [], []], ""),
|
238 |
+
[],
|
239 |
+
[image_input, chatbot, state, click_state, wiki_output],
|
240 |
+
queue=False,
|
241 |
+
show_progress=False
|
242 |
+
)
|
243 |
+
|
244 |
+
examples = gr.Examples(
|
245 |
+
examples=examples,
|
246 |
+
inputs=[image_input],
|
247 |
+
)
|
248 |
+
|
249 |
+
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
|
250 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
251 |
+
|
252 |
+
# select coordinate
|
253 |
+
image_input.select(inference_seg_cap,
|
254 |
+
inputs=[
|
255 |
+
origin_image,
|
256 |
+
point_prompt,
|
257 |
+
language,
|
258 |
+
sentiment,
|
259 |
+
factuality,
|
260 |
+
length,
|
261 |
+
state,
|
262 |
+
click_state
|
263 |
+
],
|
264 |
+
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
265 |
+
show_progress=False, queue=True)
|
266 |
+
|
267 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
268 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
app_old.py
CHANGED
@@ -2,12 +2,12 @@ from io import BytesIO
|
|
2 |
import string
|
3 |
import gradio as gr
|
4 |
import requests
|
5 |
-
from
|
6 |
import torch
|
7 |
import json
|
8 |
import sys
|
9 |
import argparse
|
10 |
-
from
|
11 |
import os
|
12 |
|
13 |
# download sam checkpoint if not downloaded
|
@@ -83,12 +83,12 @@ def get_select_coords(image_input, point_prompt, language, sentiment, factuality
|
|
83 |
else:
|
84 |
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
85 |
return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
|
86 |
-
|
87 |
def chat_with_points(chat_input, click_state, state):
|
88 |
points, labels, captions = click_state
|
89 |
-
point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
|
90 |
# "The image is of width {width} and height {height}."
|
91 |
-
|
92 |
prev_visual_context = ""
|
93 |
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
94 |
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
|
|
2 |
import string
|
3 |
import gradio as gr
|
4 |
import requests
|
5 |
+
from caption_anything import CaptionAnything
|
6 |
import torch
|
7 |
import json
|
8 |
import sys
|
9 |
import argparse
|
10 |
+
from caption_anything import parse_augment
|
11 |
import os
|
12 |
|
13 |
# download sam checkpoint if not downloaded
|
|
|
83 |
else:
|
84 |
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
85 |
return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
|
86 |
+
|
87 |
def chat_with_points(chat_input, click_state, state):
|
88 |
points, labels, captions = click_state
|
89 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
|
90 |
# "The image is of width {width} and height {height}."
|
91 |
+
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
92 |
prev_visual_context = ""
|
93 |
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
94 |
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
caption_anything.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from captioner import build_captioner, BaseCaptioner
|
2 |
+
from segmenter import build_segmenter
|
3 |
+
from text_refiner import build_text_refiner
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import pdb
|
7 |
+
import time
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
class CaptionAnything():
|
11 |
+
def __init__(self, args):
|
12 |
+
self.args = args
|
13 |
+
self.captioner = build_captioner(args.captioner, args.device, args)
|
14 |
+
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
15 |
+
if not args.disable_gpt:
|
16 |
+
self.init_refiner()
|
17 |
+
|
18 |
+
|
19 |
+
def init_refiner(self):
|
20 |
+
if os.environ.get('OPENAI_API_KEY', None):
|
21 |
+
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
|
22 |
+
|
23 |
+
def inference(self, image, prompt, controls, disable_gpt=False):
|
24 |
+
# segment with prompt
|
25 |
+
print("CA prompt: ", prompt, "CA controls",controls)
|
26 |
+
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
27 |
+
mask_save_path = f'result/mask_{time.time()}.png'
|
28 |
+
if not os.path.exists(os.path.dirname(mask_save_path)):
|
29 |
+
os.makedirs(os.path.dirname(mask_save_path))
|
30 |
+
new_p = Image.fromarray(seg_mask.astype('int') * 255.)
|
31 |
+
if new_p.mode != 'RGB':
|
32 |
+
new_p = new_p.convert('RGB')
|
33 |
+
new_p.save(mask_save_path)
|
34 |
+
print('seg_mask path: ', mask_save_path)
|
35 |
+
print("seg_mask.shape: ", seg_mask.shape)
|
36 |
+
# captioning with mask
|
37 |
+
if self.args.enable_reduce_tokens:
|
38 |
+
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
|
39 |
+
else:
|
40 |
+
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
|
41 |
+
# refining with TextRefiner
|
42 |
+
context_captions = []
|
43 |
+
if self.args.context_captions:
|
44 |
+
context_captions.append(self.captioner.inference(image))
|
45 |
+
if not disable_gpt and hasattr(self, "text_refiner"):
|
46 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
47 |
+
else:
|
48 |
+
refined_caption = {'raw_caption': caption}
|
49 |
+
out = {'generated_captions': refined_caption,
|
50 |
+
'crop_save_path': crop_save_path,
|
51 |
+
'mask_save_path': mask_save_path,
|
52 |
+
'context_captions': context_captions}
|
53 |
+
return out
|
54 |
+
|
55 |
+
def parse_augment():
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument('--captioner', type=str, default="blip")
|
58 |
+
parser.add_argument('--segmenter', type=str, default="base")
|
59 |
+
parser.add_argument('--text_refiner', type=str, default="base")
|
60 |
+
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
61 |
+
parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
62 |
+
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
63 |
+
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
|
64 |
+
parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
|
65 |
+
parser.add_argument('--device', type=str, default="cuda:0")
|
66 |
+
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
67 |
+
parser.add_argument('--debug', action="store_true")
|
68 |
+
parser.add_argument('--gradio_share', action="store_true")
|
69 |
+
parser.add_argument('--disable_gpt', action="store_true")
|
70 |
+
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
71 |
+
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
if args.debug:
|
75 |
+
print(args)
|
76 |
+
return args
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
args = parse_augment()
|
80 |
+
# image_path = 'test_img/img3.jpg'
|
81 |
+
image_path = 'test_img/img13.jpg'
|
82 |
+
prompts = [
|
83 |
+
{
|
84 |
+
"prompt_type":["click"],
|
85 |
+
"input_point":[[500, 300], [1000, 500]],
|
86 |
+
"input_label":[1, 0],
|
87 |
+
"multimask_output":"True",
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"prompt_type":["click"],
|
91 |
+
"input_point":[[900, 800]],
|
92 |
+
"input_label":[1],
|
93 |
+
"multimask_output":"True",
|
94 |
+
}
|
95 |
+
]
|
96 |
+
controls = {
|
97 |
+
"length": "30",
|
98 |
+
"sentiment": "positive",
|
99 |
+
# "imagination": "True",
|
100 |
+
"imagination": "False",
|
101 |
+
"language": "English",
|
102 |
+
}
|
103 |
+
|
104 |
+
model = CaptionAnything(args)
|
105 |
+
for prompt in prompts:
|
106 |
+
print('*'*30)
|
107 |
+
print('Image path: ', image_path)
|
108 |
+
image = Image.open(image_path)
|
109 |
+
print(image)
|
110 |
+
print('Visual controls (SAM prompt):\n', prompt)
|
111 |
+
print('Language controls:\n', controls)
|
112 |
+
out = model.inference(image_path, prompt, controls)
|
113 |
+
|
114 |
+
|
captioner/base_captioner.py
CHANGED
@@ -146,7 +146,8 @@ class BaseCaptioner:
|
|
146 |
seg_mask = np.array(seg_mask) > 0
|
147 |
|
148 |
if crop_mode=="wo_bg":
|
149 |
-
image = np.array(image) * seg_mask[:,:,np.newaxis]
|
|
|
150 |
else:
|
151 |
image = np.array(image)
|
152 |
|
@@ -168,7 +169,7 @@ class BaseCaptioner:
|
|
168 |
seg_mask = np.array(seg_mask) > 0
|
169 |
|
170 |
if crop_mode=="wo_bg":
|
171 |
-
image = np.array(image) * seg_mask[:,:,np.newaxis]
|
172 |
else:
|
173 |
image = np.array(image)
|
174 |
|
|
|
146 |
seg_mask = np.array(seg_mask) > 0
|
147 |
|
148 |
if crop_mode=="wo_bg":
|
149 |
+
image = np.array(image) * seg_mask[:,:,np.newaxis] + (1 - seg_mask[:,:,np.newaxis]) * 255
|
150 |
+
image = np.uint8(image)
|
151 |
else:
|
152 |
image = np.array(image)
|
153 |
|
|
|
169 |
seg_mask = np.array(seg_mask) > 0
|
170 |
|
171 |
if crop_mode=="wo_bg":
|
172 |
+
image = np.array(image) * seg_mask[:,:,np.newaxis] + (1- seg_mask[:,:,np.newaxis]) * 255
|
173 |
else:
|
174 |
image = np.array(image)
|
175 |
|
env.sh
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
conda create -n caption_anything python=3.8 -y
|
2 |
source activate caption_anything
|
3 |
-
pip install -r
|
4 |
cd segmenter
|
5 |
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
6 |
|
|
|
1 |
conda create -n caption_anything python=3.8 -y
|
2 |
source activate caption_anything
|
3 |
+
pip install -r requirements.txt
|
4 |
cd segmenter
|
5 |
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
6 |
|
image_editing_utils.py
CHANGED
@@ -17,7 +17,7 @@ def wrap_text(text, font, max_width):
|
|
17 |
lines.append(current_line)
|
18 |
return lines
|
19 |
|
20 |
-
def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.
|
21 |
# Load the image
|
22 |
if type(image) == np.ndarray:
|
23 |
image = Image.fromarray(image)
|
@@ -27,7 +27,7 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
|
|
27 |
|
28 |
# Calculate max_text_width and font_size based on image dimensions and total number of characters
|
29 |
total_chars = len(text)
|
30 |
-
max_text_width = int(0.
|
31 |
font_size = int(height * font_size_ratio)
|
32 |
|
33 |
# Load the font
|
|
|
17 |
lines.append(current_line)
|
18 |
return lines
|
19 |
|
20 |
+
def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.025):
|
21 |
# Load the image
|
22 |
if type(image) == np.ndarray:
|
23 |
image = Image.fromarray(image)
|
|
|
27 |
|
28 |
# Calculate max_text_width and font_size based on image dimensions and total number of characters
|
29 |
total_chars = len(text)
|
30 |
+
max_text_width = int(0.4 * width)
|
31 |
font_size = int(height * font_size_ratio)
|
32 |
|
33 |
# Load the font
|
requirements.txt
CHANGED
@@ -16,3 +16,4 @@ matplotlib
|
|
16 |
onnxruntime
|
17 |
onnx
|
18 |
https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
|
|
|
|
16 |
onnxruntime
|
17 |
onnx
|
18 |
https://gradio-builds.s3.amazonaws.com/3e68e5e882a6790ac5b457bd33f4edf9b695af90/gradio-3.24.1-py3-none-any.whl
|
19 |
+
accelerate
|
tools.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
|
|
4 |
|
5 |
|
6 |
def colormap(rgb=True):
|
@@ -145,6 +146,11 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
145 |
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
146 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
# 0: background, 1: foreground
|
149 |
input_mask[input_mask>0] = 255
|
150 |
|
@@ -157,7 +163,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
157 |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
158 |
contour_mask = cv2.dilate(contour_mask, kernel)
|
159 |
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
160 |
-
|
161 |
return painted_image
|
162 |
|
163 |
|
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
+
import copy
|
5 |
|
6 |
|
7 |
def colormap(rgb=True):
|
|
|
146 |
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
147 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
148 |
|
149 |
+
width, height = input_image.shape[0], input_image.shape[1]
|
150 |
+
res = 1024
|
151 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
152 |
+
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
153 |
+
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
154 |
# 0: background, 1: foreground
|
155 |
input_mask[input_mask>0] = 255
|
156 |
|
|
|
163 |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
164 |
contour_mask = cv2.dilate(contour_mask, kernel)
|
165 |
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
166 |
+
painted_image = cv2.resize(painted_image, (height, width))
|
167 |
return painted_image
|
168 |
|
169 |
|