Spaces:
Build error
Build error
VLog hf gradio demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +159 -0
- examples/C8lMW0MODFs.log +3 -0
- examples/C8lMW0MODFs.mp4 +3 -0
- examples/XZVHmRvfDHM.log +3 -0
- examples/XZVHmRvfDHM.mp4 +3 -0
- examples/basketball_vlog.log +3 -0
- examples/basketball_vlog.mp4 +3 -0
- examples/buy_watermelon.log +3 -0
- examples/buy_watermelon.mp4 +3 -0
- examples/covid.log +3 -0
- examples/covid.mp4 +3 -0
- examples/huaqiang.log +3 -0
- examples/huaqiang.mp4 +3 -0
- examples/news.log +3 -0
- examples/news.mp4 +3 -0
- examples/outcGtbnMuQ.log +3 -0
- examples/outcGtbnMuQ.mp4 +3 -0
- examples/travel_in_roman.log +3 -0
- examples/travel_in_roman.mp4 +3 -0
- examples/travel_in_roman_full.log +3 -0
- examples/travel_in_roman_full.mp4 +3 -0
- examples/vlog.jpg +0 -0
- models/__init__.py +3 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/blip2_model.cpython-38.pyc +0 -0
- models/__pycache__/clip_model.cpython-38.pyc +0 -0
- models/__pycache__/gpt_model.cpython-38.pyc +0 -0
- models/__pycache__/grit_model.cpython-38.pyc +0 -0
- models/__pycache__/kts_model.cpython-38.pyc +0 -0
- models/__pycache__/vlog.cpython-38.pyc +0 -0
- models/__pycache__/whisper_model.cpython-38.pyc +0 -0
- models/blip2_model.py +47 -0
- models/clip_model.py +54 -0
- models/gpt_model.py +102 -0
- models/grit_model.py +21 -0
- models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc +0 -0
- models/grit_src/configs/Base.yaml +77 -0
- models/grit_src/configs/GRiT_B_DenseCap.yaml +20 -0
- models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml +23 -0
- models/grit_src/configs/GRiT_B_ObjectDet.yaml +20 -0
- models/grit_src/configs/GRiT_H_ObjectDet.yaml +21 -0
- models/grit_src/configs/GRiT_L_ObjectDet.yaml +20 -0
- models/grit_src/grit/__init__.py +7 -0
- models/grit_src/grit/__pycache__/__init__.cpython-38.pyc +0 -0
- models/grit_src/grit/__pycache__/config.cpython-38.pyc +0 -0
- models/grit_src/grit/__pycache__/predictor.cpython-38.pyc +0 -0
- models/grit_src/grit/config.py +50 -0
- models/grit_src/grit/custom_solver.py +88 -0
- models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc +0 -0
- models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc +0 -0
app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import openai
|
4 |
+
import requests
|
5 |
+
import csv
|
6 |
+
import argparse
|
7 |
+
from models.vlog import Vlogger
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--video_path', default='examples/huaqiang.mp4')
|
11 |
+
parser.add_argument('--alpha', default=10, type=int, help='Determine the maximum segment number for KTS algorithm, the larger the value, the fewer segments.')
|
12 |
+
parser.add_argument('--beta', default=1, type=int, help='The smallest time gap between successive clips, in seconds.')
|
13 |
+
parser.add_argument('--data_dir', default='./examples', type=str, help='Directory for saving videos and logs.')
|
14 |
+
parser.add_argument('--tmp_dir', default='./tmp', type=str, help='Directory for saving intermediate files.')
|
15 |
+
|
16 |
+
# * Models settings *
|
17 |
+
parser.add_argument('--openai_api_key', default='xxx', type=str, help='OpenAI API key')
|
18 |
+
parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP Image Caption')
|
19 |
+
parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
|
20 |
+
parser.add_argument('--feature_extractor', default='openai/clip-vit-base-patch32', help='Select the feature extractor model for video segmentation')
|
21 |
+
parser.add_argument('--feature_extractor_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu')
|
22 |
+
parser.add_argument('--image_captioner', choices=['blip', 'blip2'], dest='captioner_base_model', default='blip2', help='blip2 requires 15G GPU memory, blip requires 6G GPU memory')
|
23 |
+
parser.add_argument('--image_captioner_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
|
24 |
+
parser.add_argument('--dense_captioner_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
|
25 |
+
parser.add_argument('--audio_translator', default='large')
|
26 |
+
parser.add_argument('--audio_translator_device', choices=['cuda', 'cpu'], default='cuda')
|
27 |
+
parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo'], default='gpt-3.5-turbo')
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
|
32 |
+
def get_empty_state():
|
33 |
+
return {"total_tokens": 0, "messages": []}
|
34 |
+
|
35 |
+
|
36 |
+
def submit_api_key_fn(api_key, vlogger):
|
37 |
+
try:
|
38 |
+
vlogger.init_llm_with_api_key(api_key)
|
39 |
+
return gr.update(value = "OpenAI key submitted successful 🎉"), True, vlogger
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
return gr.update(value = f"Error {e}"), False, vlogger
|
43 |
+
|
44 |
+
|
45 |
+
def submit_message(prompt, state, vlogger, api_key_submitted, vlog_loaded):
|
46 |
+
if not api_key_submitted:
|
47 |
+
return gr.update(value=''), [("👀", "Please enter your OpenAI API key 😊"),], state, vlogger
|
48 |
+
|
49 |
+
if not vlog_loaded:
|
50 |
+
return gr.update(value=''), [("👀", "Please follow the instruction to select a video and generate the document for chatting 😊"),], state, vlogger
|
51 |
+
|
52 |
+
history = state['messages']
|
53 |
+
|
54 |
+
if not prompt:
|
55 |
+
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state, vlogger
|
56 |
+
|
57 |
+
prompt_msg = { "role": "user", "content": prompt }
|
58 |
+
|
59 |
+
try:
|
60 |
+
history.append(prompt_msg)
|
61 |
+
answer = vlogger.chat2video(prompt)
|
62 |
+
history.append({"role": "system", "content": answer})
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
history.append(prompt_msg)
|
66 |
+
history.append({
|
67 |
+
"role": "system",
|
68 |
+
"content": f"Error: {e}"
|
69 |
+
})
|
70 |
+
|
71 |
+
chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
|
72 |
+
return '', chat_messages, state, vlogger
|
73 |
+
|
74 |
+
def clear_conversation(vlogger):
|
75 |
+
vlogger.clean_history()
|
76 |
+
|
77 |
+
# return input_message, video_inp, chatbot, vlog_outp, state, vlogger, vlog_loaded
|
78 |
+
return gr.update(value=None, visible=True), gr.update(value=None, interactive=False), None, gr.update(value=None, visible=True), get_empty_state(), vlogger, False
|
79 |
+
|
80 |
+
def vlog_fn(vid_path, vlogger, api_key_submitted):
|
81 |
+
if not api_key_submitted:
|
82 |
+
log_text = "====== Please enter your OpenAI API key first 😊 ====="
|
83 |
+
return gr.update(value=log_text, visible=True), False, vlogger
|
84 |
+
|
85 |
+
print(vid_path)
|
86 |
+
if vid_path is None:
|
87 |
+
log_text = "====== Please select an video from examples first 🤔 ====="
|
88 |
+
vloaded_flag = False
|
89 |
+
else:
|
90 |
+
log_list = vlogger.video2log(vid_path)
|
91 |
+
log_text = "\n".join(log_list)
|
92 |
+
vloaded_flag = True
|
93 |
+
return gr.update(value=log_text, visible=True), vloaded_flag, vlogger
|
94 |
+
|
95 |
+
css = """
|
96 |
+
#col-container {max-width: 90%; margin-left: auto; margin-right: auto;}
|
97 |
+
#video_inp {min-height: 300px}
|
98 |
+
#chatbox {min-height: 100px;}
|
99 |
+
#header {text-align: center;
|
100 |
+
#hint {font-size: 0.9em; padding: 0.5em; margin: 0;}
|
101 |
+
.message { font-size: 1.2em; }
|
102 |
+
"""
|
103 |
+
|
104 |
+
with gr.Blocks(css=css) as demo:
|
105 |
+
|
106 |
+
state = gr.State(get_empty_state())
|
107 |
+
vlogger = gr.State(Vlogger(args))
|
108 |
+
vlog_loaded = gr.State(False)
|
109 |
+
api_key_submitted = gr.State(False)
|
110 |
+
|
111 |
+
|
112 |
+
with gr.Column(elem_id="col-container"):
|
113 |
+
gr.Markdown("""## 🎞️ VLog Demo
|
114 |
+
Powered by BLIP2, GRIT, Whisper, ChatGPT and LangChain
|
115 |
+
Github: [https://github.com/showlab/VLog](https://github.com/showlab/VLog)""",
|
116 |
+
elem_id="header")
|
117 |
+
gr.Markdown("*Instruction*: For the current demo, please enter OpenAI api key, select an example video, click the button to generate a document and try chatting over the video 😊", elem_id="hint")
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column(scale=6):
|
120 |
+
video_inp = gr.Video(label="video_input", interactive=False)
|
121 |
+
chatbot = gr.Chatbot(elem_id="chatbox")
|
122 |
+
input_message = gr.Textbox(show_label=False, placeholder="Enter text and press enter", visible=True).style(container=False)
|
123 |
+
btn_submit = gr.Button("Submit")
|
124 |
+
btn_clear_conversation = gr.Button("🔃 Start New Conversation")
|
125 |
+
|
126 |
+
with gr.Column(scale=6):
|
127 |
+
vlog_btn = gr.Button("Generate Video Document")
|
128 |
+
vlog_outp = gr.Textbox(label="Document output", lines=30)
|
129 |
+
|
130 |
+
with gr.Column(scale=1):
|
131 |
+
openai_api_key = gr.Textbox(
|
132 |
+
placeholder="Input OpenAI API key and press Enter",
|
133 |
+
show_label=False,
|
134 |
+
label = "OpenAI API Key",
|
135 |
+
lines=1,
|
136 |
+
type="password"
|
137 |
+
)
|
138 |
+
examples = gr.Examples(
|
139 |
+
examples=[
|
140 |
+
["examples/basketball_vlog.mp4"],
|
141 |
+
["examples/travel_in_roman.mp4"],
|
142 |
+
["examples/C8lMW0MODFs.mp4"],
|
143 |
+
["examples/outcGtbnMuQ.mp4"],
|
144 |
+
["examples/huaqiang.mp4"],
|
145 |
+
],
|
146 |
+
inputs=[video_inp],
|
147 |
+
)
|
148 |
+
|
149 |
+
gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/TencentARC/VLog?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br></center>''')
|
150 |
+
|
151 |
+
btn_submit.click(submit_message, [input_message, state, vlogger, api_key_submitted, vlog_loaded], [input_message, chatbot, state, vlogger])
|
152 |
+
input_message.submit(submit_message, [input_message, state, vlogger, api_key_submitted, vlog_loaded], [input_message, chatbot, state, vlogger])
|
153 |
+
btn_clear_conversation.click(clear_conversation, [vlogger], [input_message, video_inp, chatbot, vlog_outp, state, vlogger, vlog_loaded])
|
154 |
+
vlog_btn.click(vlog_fn, [video_inp, vlogger, api_key_submitted], [vlog_outp, vlog_loaded, vlogger])
|
155 |
+
openai_api_key.submit(submit_api_key_fn, [openai_api_key, vlogger], [vlog_outp, api_key_submitted, vlogger])
|
156 |
+
demo.load(queur=False)
|
157 |
+
|
158 |
+
demo.queue(concurrency_count=10)
|
159 |
+
demo.launch(height='800px', server_port=8749, debug=True, share=True)
|
examples/C8lMW0MODFs.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b044e554f8dc7a790b02aa1ebc391165b84d93cce9579fa6b2fe0418cd4d1122
|
3 |
+
size 9075
|
examples/C8lMW0MODFs.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d094489e459ae952880f4cbd8fdbcc790df1a69ccf9fb4f6c5fca998b6871133
|
3 |
+
size 10537029
|
examples/XZVHmRvfDHM.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e37046ae268e20f3d44df7410954c3cf5ffd73116e6f5e3f9ef73a690f001d51
|
3 |
+
size 7262
|
examples/XZVHmRvfDHM.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c2da0eae7e0b18c04ad4f2b8124a09fbbde407eeedb0a532dbf40701c8c744b5
|
3 |
+
size 1961212
|
examples/basketball_vlog.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b2d18c6d7d7c5061ae41b9cd2b8cc0828d2aee2b02b40b4286fdd26905b0ac0
|
3 |
+
size 23527
|
examples/basketball_vlog.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5d6034c324f3e9de35278783ed68a85081ef74a252c9394e273b339f7d1b6c3
|
3 |
+
size 32376805
|
examples/buy_watermelon.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cd0e7bfca9fba4b71428235d41b446083ffe8d7496ef43249f7438017def067
|
3 |
+
size 3922
|
examples/buy_watermelon.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:926ee7ec1ca4d3e0674a647bf84887bdf077961c3972148ae23fb569c22e0e4e
|
3 |
+
size 6209789
|
examples/covid.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50281df7c21815c662e2f03e461b02dd5b2f8253a3f92bcd1dfca4229d89e3ce
|
3 |
+
size 9782
|
examples/covid.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53c35480ff6ac15f2f8747aa9ba9dc36086d5f4e342ac79eac5e43e5bd248817
|
3 |
+
size 16090827
|
examples/huaqiang.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cd0e7bfca9fba4b71428235d41b446083ffe8d7496ef43249f7438017def067
|
3 |
+
size 3922
|
examples/huaqiang.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:926ee7ec1ca4d3e0674a647bf84887bdf077961c3972148ae23fb569c22e0e4e
|
3 |
+
size 6209789
|
examples/news.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42068b9573daee32bf33d5aa4049f937bfa2cb3c6472d40c42332f8d2173a929
|
3 |
+
size 8968
|
examples/news.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:905e453db16213c962d01371357877b8a168da50508676b81cf474d431d3d2ca
|
3 |
+
size 23599849
|
examples/outcGtbnMuQ.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45a3911acfe78745ed9cfc9502deebef1ab6912dc89566735fcbdf7acda00b44
|
3 |
+
size 63033
|
examples/outcGtbnMuQ.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:47f4ddd4debd3c5955cb7c0a1f5e2ffa9c0d6a171931898ee085c5eab521f33d
|
3 |
+
size 98609326
|
examples/travel_in_roman.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f90d4b4c46322b6f15984b64aaedf35232b2fb21ddac518f1c5784fe25944e3c
|
3 |
+
size 9166
|
examples/travel_in_roman.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b02522bb72215bcb1a69657af9d08cad0141e1b3e30553024609cb0927471e04
|
3 |
+
size 34442658
|
examples/travel_in_roman_full.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a163943c7676168b51adf08e5305dd78fba7c54e6cd00330c06541eb23d0d23
|
3 |
+
size 45295
|
examples/travel_in_roman_full.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1cb54427e21a1ccbba23bfe5314e4ae2d45658d0b4b654f815abf1861c1ca3c
|
3 |
+
size 92642344
|
examples/vlog.jpg
ADDED
models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .kts_src import *
|
2 |
+
from .clip_model import *
|
3 |
+
from .grit_model import *
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (215 Bytes). View file
|
|
models/__pycache__/blip2_model.cpython-38.pyc
ADDED
Binary file (2.02 kB). View file
|
|
models/__pycache__/clip_model.cpython-38.pyc
ADDED
Binary file (1.91 kB). View file
|
|
models/__pycache__/gpt_model.cpython-38.pyc
ADDED
Binary file (3.43 kB). View file
|
|
models/__pycache__/grit_model.cpython-38.pyc
ADDED
Binary file (1.21 kB). View file
|
|
models/__pycache__/kts_model.cpython-38.pyc
ADDED
Binary file (1.34 kB). View file
|
|
models/__pycache__/vlog.cpython-38.pyc
ADDED
Binary file (4.34 kB). View file
|
|
models/__pycache__/whisper_model.cpython-38.pyc
ADDED
Binary file (1.24 kB). View file
|
|
models/blip2_model.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForConditionalGeneration
|
4 |
+
|
5 |
+
class ImageCaptioner:
|
6 |
+
def __init__(self, model_name="blip2-opt", device="cpu"):
|
7 |
+
self.model_name = model_name
|
8 |
+
self.device = device
|
9 |
+
self.processor, self.model = self.initialize_model()
|
10 |
+
|
11 |
+
def initialize_model(self):
|
12 |
+
if self.device == 'cpu':
|
13 |
+
self.data_type = torch.float32
|
14 |
+
else:
|
15 |
+
self.data_type = torch.float16
|
16 |
+
processor, model = None, None
|
17 |
+
if self.model_name == "blip2-opt":
|
18 |
+
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b-coco")
|
19 |
+
model = Blip2ForConditionalGeneration.from_pretrained(
|
20 |
+
"Salesforce/blip2-opt-2.7b-coco", torch_dtype=self.data_type, low_cpu_mem_usage=True)
|
21 |
+
|
22 |
+
elif self.model_name == "blip2-flan-t5":
|
23 |
+
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
24 |
+
model = Blip2ForConditionalGeneration.from_pretrained(
|
25 |
+
"Salesforce/blip2-flan-t5-xl", torch_dtype=self.data_type, low_cpu_mem_usage=True)
|
26 |
+
|
27 |
+
# for gpu with small memory
|
28 |
+
elif self.model_name == "blip":
|
29 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
30 |
+
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
31 |
+
|
32 |
+
else:
|
33 |
+
raise NotImplementedError(f"{self.model_name} not implemented.")
|
34 |
+
model.to(self.device)
|
35 |
+
|
36 |
+
if self.device != 'cpu':
|
37 |
+
model.half()
|
38 |
+
return processor, model
|
39 |
+
|
40 |
+
def image_caption(self, image):
|
41 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
|
42 |
+
generated_ids = self.model.generate(**inputs)
|
43 |
+
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
44 |
+
return generated_text
|
45 |
+
|
46 |
+
def image_caption_debug(self, image_src):
|
47 |
+
return "A dish with salmon, broccoli, and something yellow."
|
models/clip_model.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import pdb
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import CLIPProcessor, CLIPVisionModelWithProjection
|
8 |
+
from transformers import logging
|
9 |
+
logging.set_verbosity_error()
|
10 |
+
|
11 |
+
class FeatureExtractor():
|
12 |
+
def __init__(self, args):
|
13 |
+
self.device = args.feature_extractor_device
|
14 |
+
self.beta = args.beta
|
15 |
+
self.processor = CLIPProcessor.from_pretrained(args.feature_extractor)
|
16 |
+
self.model = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor).to(self.device)
|
17 |
+
self.data_dir = args.data_dir
|
18 |
+
self.tmp_dir = args.tmp_dir
|
19 |
+
|
20 |
+
|
21 |
+
def __call__(self, video_path, video_id):
|
22 |
+
cap = cv2.VideoCapture(video_path)
|
23 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
24 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
25 |
+
video_length = frame_count / fps
|
26 |
+
sample_rate = int(fps) * self.beta
|
27 |
+
|
28 |
+
save_path = os.path.join(self.tmp_dir, video_id + '.npz')
|
29 |
+
if os.path.exists(save_path):
|
30 |
+
data = np.load(save_path)
|
31 |
+
clip_features = data['features']
|
32 |
+
return clip_features, video_length
|
33 |
+
|
34 |
+
clip_features = []
|
35 |
+
print("Extract the clip feature.")
|
36 |
+
while True:
|
37 |
+
ret, frame = cap.read()
|
38 |
+
if not ret:
|
39 |
+
break
|
40 |
+
|
41 |
+
if cap.get(cv2.CAP_PROP_POS_FRAMES) % sample_rate == 0:
|
42 |
+
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
43 |
+
inputs = self.processor(images=image, return_tensors="pt").pixel_values
|
44 |
+
inputs = inputs.to(self.device)
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
feat = self.model(inputs)['image_embeds']
|
48 |
+
clip_features.append(feat.cpu().numpy())
|
49 |
+
print("Finished.")
|
50 |
+
|
51 |
+
clip_features = np.concatenate(clip_features, axis=0)
|
52 |
+
np.savez_compressed(save_path, features=clip_features)
|
53 |
+
|
54 |
+
return clip_features, video_length
|
models/gpt_model.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import pickle
|
4 |
+
from langchain.llms import OpenAI
|
5 |
+
from langchain.vectorstores.faiss import FAISS
|
6 |
+
from langchain.chains import ChatVectorDBChain
|
7 |
+
from langchain.prompts.prompt import PromptTemplate
|
8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
+
from langchain.document_loaders import UnstructuredFileLoader
|
10 |
+
from langchain.embeddings import OpenAIEmbeddings
|
11 |
+
|
12 |
+
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
|
13 |
+
You can assume the discussion is about the video content.
|
14 |
+
Chat History:
|
15 |
+
{chat_history}
|
16 |
+
Follow Up Input: {question}
|
17 |
+
Standalone question:"""
|
18 |
+
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
19 |
+
|
20 |
+
qa_template = """You are an AI assistant designed for answering questions about a video.
|
21 |
+
You are given a document and a question, the document records what people see and hear from this video.
|
22 |
+
Try to connet these information and provide a conversational answer.
|
23 |
+
Question: {question}
|
24 |
+
=========
|
25 |
+
{context}
|
26 |
+
=========
|
27 |
+
"""
|
28 |
+
QA_PROMPT = PromptTemplate(template=qa_template, input_variables=["question", "context"])
|
29 |
+
|
30 |
+
|
31 |
+
class LlmReasoner():
|
32 |
+
def __init__(self, args):
|
33 |
+
self.history = []
|
34 |
+
self.gpt_version = args.gpt_version
|
35 |
+
self.data_dir = args.data_dir
|
36 |
+
self.tmp_dir = args.tmp_dir
|
37 |
+
self.qa_chain = None
|
38 |
+
self.vectorstore = None
|
39 |
+
self.top_k = 3
|
40 |
+
self.llm = OpenAI(temperature=0, model_name=self.gpt_version)
|
41 |
+
|
42 |
+
def exist_vectorstore(self, video_id):
|
43 |
+
pkl_path = os.path.join(self.tmp_dir, f"{video_id}.pkl")
|
44 |
+
log_path = os.path.join(self.data_dir, f"{video_id}.log")
|
45 |
+
if os.path.exists(pkl_path) and os.path.exists(log_path):
|
46 |
+
with open(pkl_path, 'rb') as file:
|
47 |
+
self.vectorstore = pickle.load(file)
|
48 |
+
|
49 |
+
self.qa_chain = ChatVectorDBChain.from_llm(
|
50 |
+
self.llm,
|
51 |
+
self.vectorstore,
|
52 |
+
qa_prompt=QA_PROMPT,
|
53 |
+
condense_question_prompt=CONDENSE_QUESTION_PROMPT,
|
54 |
+
)
|
55 |
+
self.qa_chain.top_k_docs_for_context = self.top_k
|
56 |
+
return True
|
57 |
+
return False
|
58 |
+
|
59 |
+
def create_vectorstore(self, video_id):
|
60 |
+
pkl_path = os.path.join(self.tmp_dir, f"{video_id}.pkl")
|
61 |
+
|
62 |
+
if not os.path.exists(pkl_path):
|
63 |
+
loader = UnstructuredFileLoader(os.path.join(self.data_dir, f"{video_id}.log"))
|
64 |
+
raw_documents = loader.load()
|
65 |
+
|
66 |
+
# Split text
|
67 |
+
text_splitter = RecursiveCharacterTextSplitter()
|
68 |
+
documents = text_splitter.split_documents(raw_documents)
|
69 |
+
|
70 |
+
# Load Data to vectorstore
|
71 |
+
embeddings = OpenAIEmbeddings()
|
72 |
+
vectorstore = FAISS.from_documents(documents, embeddings)
|
73 |
+
|
74 |
+
# Save vectorstore
|
75 |
+
with open(pkl_path, "wb") as f:
|
76 |
+
pickle.dump(vectorstore, f)
|
77 |
+
|
78 |
+
|
79 |
+
with open(pkl_path, 'rb') as file:
|
80 |
+
self.vectorstore = pickle.load(file)
|
81 |
+
|
82 |
+
self.qa_chain = ChatVectorDBChain.from_llm(
|
83 |
+
self.llm,
|
84 |
+
self.vectorstore,
|
85 |
+
qa_prompt=QA_PROMPT,
|
86 |
+
condense_question_prompt=CONDENSE_QUESTION_PROMPT,
|
87 |
+
)
|
88 |
+
self.qa_chain.top_k_docs_for_context = self.top_k
|
89 |
+
|
90 |
+
return
|
91 |
+
|
92 |
+
def __call__(self, question):
|
93 |
+
print(f"Question: {question}")
|
94 |
+
response = self.qa_chain({"question": question, "chat_history": self.history})["answer"]
|
95 |
+
self.history.append((question, response))
|
96 |
+
|
97 |
+
print(f"Assistant: {response}")
|
98 |
+
print("\n")
|
99 |
+
return response
|
100 |
+
|
101 |
+
def clean_history(self):
|
102 |
+
self.history = []
|
models/grit_model.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from models.grit_src.image_dense_captions import image_caption_api
|
3 |
+
|
4 |
+
class DenseCaptioner():
|
5 |
+
def __init__(self, device):
|
6 |
+
self.device = device
|
7 |
+
|
8 |
+
def initialize_model(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def image_dense_caption_debug(self, image_src):
|
12 |
+
dense_caption = """
|
13 |
+
1. the broccoli is green, [0, 0, 333, 325];
|
14 |
+
2. a piece of broccoli, [0, 147, 143, 324];
|
15 |
+
3. silver fork on plate, [4, 547, 252, 612];
|
16 |
+
"""
|
17 |
+
return dense_caption
|
18 |
+
|
19 |
+
def image_dense_caption(self, image_src):
|
20 |
+
dense_caption = image_caption_api(image_src, self.device)
|
21 |
+
return dense_caption
|
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc
ADDED
Binary file (2.33 kB). View file
|
|
models/grit_src/configs/Base.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GRiT"
|
3 |
+
MASK_ON: True
|
4 |
+
PROPOSAL_GENERATOR:
|
5 |
+
NAME: "CenterNet"
|
6 |
+
FPN:
|
7 |
+
IN_FEATURES: ["layer3", "layer4", "layer5"]
|
8 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
9 |
+
PIXEL_STD: [58.395, 57.12, 57.375]
|
10 |
+
ROI_HEADS:
|
11 |
+
NAME: GRiTROIHeadsAndTextDecoder
|
12 |
+
IN_FEATURES: ["p3", "p4", "p5"]
|
13 |
+
IOU_THRESHOLDS: [0.6]
|
14 |
+
NUM_CLASSES: 1
|
15 |
+
SCORE_THRESH_TEST: 0.02
|
16 |
+
NMS_THRESH_TEST: 0.5
|
17 |
+
OBJECT_FEAT_POOLER_RES: 14
|
18 |
+
ROI_BOX_CASCADE_HEAD:
|
19 |
+
IOUS: [0.6, 0.7, 0.8]
|
20 |
+
ROI_BOX_HEAD:
|
21 |
+
NAME: "FastRCNNConvFCHead"
|
22 |
+
NUM_FC: 2
|
23 |
+
POOLER_RESOLUTION: 7
|
24 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
25 |
+
MULT_PROPOSAL_SCORE: True
|
26 |
+
ROI_MASK_HEAD:
|
27 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
28 |
+
NUM_CONV: 4
|
29 |
+
POOLER_RESOLUTION: 14
|
30 |
+
CLS_AGNOSTIC_MASK: True
|
31 |
+
CENTERNET:
|
32 |
+
NUM_CLASSES: 1
|
33 |
+
REG_WEIGHT: 1.
|
34 |
+
NOT_NORM_REG: True
|
35 |
+
ONLY_PROPOSAL: True
|
36 |
+
WITH_AGN_HM: True
|
37 |
+
INFERENCE_TH: 0.0001
|
38 |
+
PRE_NMS_TOPK_TRAIN: 4000
|
39 |
+
POST_NMS_TOPK_TRAIN: 2000
|
40 |
+
PRE_NMS_TOPK_TEST: 1000
|
41 |
+
POST_NMS_TOPK_TEST: 256
|
42 |
+
NMS_TH_TRAIN: 0.9
|
43 |
+
NMS_TH_TEST: 0.9
|
44 |
+
POS_WEIGHT: 0.5
|
45 |
+
NEG_WEIGHT: 0.5
|
46 |
+
IGNORE_HIGH_FP: 0.85
|
47 |
+
DATASETS:
|
48 |
+
TRAIN: ("coco_2017_train",)
|
49 |
+
TEST: ("coco_2017_val",)
|
50 |
+
DATALOADER:
|
51 |
+
SAMPLER_TRAIN: "MultiDatasetSampler"
|
52 |
+
DATASET_RATIO: [1]
|
53 |
+
DATASET_INPUT_SIZE: [1024]
|
54 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0]]
|
55 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
56 |
+
NUM_WORKERS: 8
|
57 |
+
TEST:
|
58 |
+
DETECTIONS_PER_IMAGE: 256
|
59 |
+
SOLVER:
|
60 |
+
LR_SCHEDULER_NAME: "WarmupCosineLR"
|
61 |
+
CHECKPOINT_PERIOD: 10000
|
62 |
+
WARMUP_ITERS: 1000
|
63 |
+
WARMUP_FACTOR: 0.001
|
64 |
+
USE_CUSTOM_SOLVER: True
|
65 |
+
OPTIMIZER: "ADAMW"
|
66 |
+
MAX_ITER: 180000
|
67 |
+
IMS_PER_BATCH: 64
|
68 |
+
BASE_LR: 0.00008
|
69 |
+
VIT_LAYER_DECAY: True
|
70 |
+
CLIP_GRADIENTS:
|
71 |
+
ENABLED: True
|
72 |
+
INPUT:
|
73 |
+
FORMAT: RGB
|
74 |
+
CUSTOM_AUG: EfficientDetResizeCrop
|
75 |
+
TRAIN_SIZE: 640
|
76 |
+
USE_ACT_CHECKPOINT: True
|
77 |
+
VERSION: 2
|
models/grit_src/configs/GRiT_B_DenseCap.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["DenseCap"]
|
4 |
+
TEST_TASK: "DenseCap"
|
5 |
+
MASK_ON: False
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: False
|
8 |
+
BEAM_SIZE: 1
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("vg_train",)
|
17 |
+
TEST: ("vg_test",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 2
|
20 |
+
OUTPUT_DIR: "./output/GRiT_B_DenseCap"
|
models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet", "DenseCap"]
|
4 |
+
TEST_TASK: "DenseCap" # DenseCap or ObjectDet: Choose one for testing
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: False
|
8 |
+
BEAM_SIZE: 1
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train", "vg_train")
|
17 |
+
TEST: ("coco_2017_test-dev",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_RATIO: [1, 1]
|
20 |
+
DATASET_BS: 2
|
21 |
+
DATASET_INPUT_SIZE: [1024, 1024]
|
22 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0], [0.1, 2.0]]
|
23 |
+
OUTPUT_DIR: "./output/GRiT_B_DenseCap_ObjectDet"
|
models/grit_src/configs/GRiT_B_ObjectDet.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train",)
|
17 |
+
TEST: ("coco_2017_val",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 2
|
20 |
+
OUTPUT_DIR: "./output/GRiT_B_ObjectDet"
|
models/grit_src/configs/GRiT_H_ObjectDet.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone_huge
|
12 |
+
VIT_LAYERS: 32
|
13 |
+
SOLVER:
|
14 |
+
MAX_ITER: 135000
|
15 |
+
VIT_LAYER_DECAY_RATE: 0.9
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("GRiT_coco2017_train",)
|
18 |
+
TEST: ("coco_2017_val",)
|
19 |
+
DATALOADER:
|
20 |
+
DATASET_BS: 1
|
21 |
+
OUTPUT_DIR: "./output/GRiT_H_ObjectDet"
|
models/grit_src/configs/GRiT_L_ObjectDet.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone_large
|
12 |
+
VIT_LAYERS: 24
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.8
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train",)
|
17 |
+
TEST: ("coco_2017_val",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 1
|
20 |
+
OUTPUT_DIR: "./output/GRiT_L_ObjectDet"
|
models/grit_src/grit/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling.meta_arch import grit
|
2 |
+
from .modeling.roi_heads import grit_roi_heads
|
3 |
+
from .modeling.backbone import vit
|
4 |
+
|
5 |
+
from .data.datasets import object365
|
6 |
+
from .data.datasets import vg
|
7 |
+
from .data.datasets import grit_coco
|
models/grit_src/grit/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (414 Bytes). View file
|
|
models/grit_src/grit/__pycache__/config.cpython-38.pyc
ADDED
Binary file (1.41 kB). View file
|
|
models/grit_src/grit/__pycache__/predictor.cpython-38.pyc
ADDED
Binary file (2.65 kB). View file
|
|
models/grit_src/grit/config.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def add_grit_config(cfg):
|
5 |
+
_C = cfg
|
6 |
+
|
7 |
+
_C.MODEL.BEAM_SIZE = 1
|
8 |
+
_C.MODEL.TRAIN_TASK = ["ObjectDet", "DenseCap"]
|
9 |
+
_C.MODEL.TEST_TASK = "DenseCap" # This can be varied if the model is jointly trained on multiple tasks
|
10 |
+
|
11 |
+
_C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
|
12 |
+
_C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False
|
13 |
+
|
14 |
+
_C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
|
15 |
+
_C.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES = 14
|
16 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
|
17 |
+
|
18 |
+
# Backbones
|
19 |
+
_C.MODEL.VIT_LAYERS = 12
|
20 |
+
|
21 |
+
# Text Decoder
|
22 |
+
_C.TEXT_DECODER = CN()
|
23 |
+
_C.TEXT_DECODER.VOCAB_SIZE = 30522
|
24 |
+
_C.TEXT_DECODER.HIDDEN_SIZE = 768
|
25 |
+
_C.TEXT_DECODER.NUM_LAYERS = 6
|
26 |
+
_C.TEXT_DECODER.ATTENTION_HEADS = 12
|
27 |
+
_C.TEXT_DECODER.FEEDFORWARD_SIZE = 768 * 4
|
28 |
+
|
29 |
+
# Multi-dataset dataloader
|
30 |
+
_C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
|
31 |
+
_C.DATALOADER.DATASET_BS = 1
|
32 |
+
_C.DATALOADER.DATASET_INPUT_SIZE = [1024, 1024]
|
33 |
+
_C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.1, 2.0)]
|
34 |
+
_C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (640, 800)]
|
35 |
+
_C.DATALOADER.DATASET_MAX_SIZES = [1333, 1333]
|
36 |
+
|
37 |
+
_C.SOLVER.USE_CUSTOM_SOLVER = True
|
38 |
+
_C.SOLVER.OPTIMIZER = 'ADAMW'
|
39 |
+
_C.SOLVER.VIT_LAYER_DECAY = True
|
40 |
+
_C.SOLVER.VIT_LAYER_DECAY_RATE = 0.7
|
41 |
+
|
42 |
+
_C.INPUT.CUSTOM_AUG = 'EfficientDetResizeCrop'
|
43 |
+
_C.INPUT.TRAIN_SIZE = 1024
|
44 |
+
_C.INPUT.TEST_SIZE = 1024
|
45 |
+
_C.INPUT.SCALE_RANGE = (0.1, 2.)
|
46 |
+
# 'default' for fixed short / long edge
|
47 |
+
_C.INPUT.TEST_INPUT_TYPE = 'default'
|
48 |
+
|
49 |
+
_C.FIND_UNUSED_PARAM = True
|
50 |
+
_C.USE_ACT_CHECKPOINT = True
|
models/grit_src/grit/custom_solver.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/custom_solver.py
|
3 |
+
import itertools
|
4 |
+
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from detectron2.config import CfgNode
|
8 |
+
|
9 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
10 |
+
|
11 |
+
|
12 |
+
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
13 |
+
params: List[Dict[str, Any]] = []
|
14 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
15 |
+
optimizer_type = cfg.SOLVER.OPTIMIZER
|
16 |
+
|
17 |
+
for key, value in model.named_parameters(recurse=True):
|
18 |
+
if not value.requires_grad:
|
19 |
+
continue
|
20 |
+
# Avoid duplicating parameters
|
21 |
+
if value in memo:
|
22 |
+
continue
|
23 |
+
memo.add(value)
|
24 |
+
lr = cfg.SOLVER.BASE_LR
|
25 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
26 |
+
|
27 |
+
if cfg.SOLVER.VIT_LAYER_DECAY:
|
28 |
+
lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS)
|
29 |
+
|
30 |
+
param = {"params": [value], "lr": lr}
|
31 |
+
if optimizer_type != 'ADAMW':
|
32 |
+
param['weight_decay'] = weight_decay
|
33 |
+
params += [param]
|
34 |
+
|
35 |
+
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
|
36 |
+
# detectron2 doesn't have full model gradient clipping now
|
37 |
+
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
|
38 |
+
enable = (
|
39 |
+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
|
40 |
+
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
|
41 |
+
and clip_norm_val > 0.0
|
42 |
+
)
|
43 |
+
|
44 |
+
class FullModelGradientClippingOptimizer(optim):
|
45 |
+
def step(self, closure=None):
|
46 |
+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
|
47 |
+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
|
48 |
+
super().step(closure=closure)
|
49 |
+
|
50 |
+
return FullModelGradientClippingOptimizer if enable else optim
|
51 |
+
|
52 |
+
|
53 |
+
if optimizer_type == 'SGD':
|
54 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
|
55 |
+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
|
56 |
+
nesterov=cfg.SOLVER.NESTEROV
|
57 |
+
)
|
58 |
+
elif optimizer_type == 'ADAMW':
|
59 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
|
60 |
+
params, cfg.SOLVER.BASE_LR,
|
61 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
raise NotImplementedError(f"no optimizer type {optimizer_type}")
|
65 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
|
66 |
+
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
|
67 |
+
return optimizer
|
68 |
+
|
69 |
+
|
70 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
71 |
+
"""
|
72 |
+
Calculate lr decay rate for different ViT blocks.
|
73 |
+
Args:
|
74 |
+
name (string): parameter name.
|
75 |
+
lr_decay_rate (float): base lr decay rate.
|
76 |
+
num_layers (int): number of ViT blocks.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
lr decay rate for the given parameter.
|
80 |
+
"""
|
81 |
+
layer_id = num_layers + 1
|
82 |
+
if name.startswith("backbone"):
|
83 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
84 |
+
layer_id = 0
|
85 |
+
elif ".blocks." in name and ".residual." not in name:
|
86 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
87 |
+
|
88 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc
ADDED
Binary file (1.22 kB). View file
|
|
models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc
ADDED
Binary file (5.69 kB). View file
|
|