UniVTG / app.py
KevinQHLin's picture
Update app.py
f827fdd
import os
import pdb
import time
import torch
import gradio as gr
import numpy as np
import argparse
import subprocess
from run_on_video import clip, vid2clip, txt2clip
parser = argparse.ArgumentParser(description='')
parser.add_argument('--save_dir', type=str, default='./tmp')
parser.add_argument('--resume', type=str, default='./results/omni/model_best.ckpt')
parser.add_argument("--gpu_id", type=int, default=0)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
#################################
model_version = "ViT-B/32"
output_feat_size = 512
clip_len = 2
overwrite = True
num_decoding_thread = 4
half_precision = False
clip_model, _ = clip.load(model_version, device=args.gpu_id, jit=False)
import logging
import torch.backends.cudnn as cudnn
from main.config import TestOptions, setup_model
from utils.basic_utils import l2_normalize_np_array
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO)
def load_model():
logger.info("Setup config, data and model...")
opt = TestOptions().parse(args)
# pdb.set_trace()
cudnn.benchmark = True
cudnn.deterministic = False
if opt.lr_warmup > 0:
total_steps = opt.n_epoch
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
opt.lr_warmup = [warmup_steps, total_steps]
model, criterion, _, _ = setup_model(opt)
return model
vtg_model = load_model()
def convert_to_hms(seconds):
return time.strftime('%H:%M:%S', time.gmtime(seconds))
def load_data(save_dir):
vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
vid = torch.from_numpy(l2_normalize_np_array(vid))
txt = torch.from_numpy(l2_normalize_np_array(txt))
clip_len = 2
ctx_l = vid.shape[0]
timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
if True:
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
tef_ed = tef_st + 1.0 / ctx_l
tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
src_vid = vid.unsqueeze(0).cuda()
src_txt = txt.unsqueeze(0).cuda()
src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
def forward(model, save_dir, query):
src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
src_vid = src_vid.cuda(args.gpu_id)
src_txt = src_txt.cuda(args.gpu_id)
src_vid_mask = src_vid_mask.cuda(args.gpu_id)
src_txt_mask = src_txt_mask.cuda(args.gpu_id)
model.eval()
with torch.no_grad():
output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
# prepare the model prediction
pred_logits = output['pred_logits'][0].cpu()
pred_spans = output['pred_spans'][0].cpu()
pred_saliency = output['saliency_scores'].cpu()
# prepare the model prediction
pred_windows = (pred_spans + timestamp) * ctx_l * clip_len
pred_confidence = pred_logits
# grounding
top1_window = pred_windows[torch.argmax(pred_confidence)].tolist()
top5_values, top5_indices = torch.topk(pred_confidence.flatten(), k=5)
top5_windows = pred_windows[top5_indices].tolist()
# print(f"The video duration is {convert_to_hms(src_vid.shape[1]*clip_len)}.")
q_response = f"For query: {query}"
mr_res = " - ".join([convert_to_hms(int(i)) for i in top1_window])
mr_response = f"The Top-1 interval is: {mr_res}"
hl_res = convert_to_hms(torch.argmax(pred_saliency) * clip_len)
hl_response = f"The Top-1 highlight is: {hl_res}"
return '\n'.join([q_response, mr_response, hl_response])
def extract_vid(vid_path, state):
history = state['messages']
vid_features = vid2clip(clip_model, vid_path, args.save_dir)
history.append({"role": "user", "content": "Finish extracting video features."})
history.append({"role": "system", "content": "Please Enter the text query."})
chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history),2)]
return '', chat_messages, state
def extract_txt(txt):
txt_features = txt2clip(clip_model, txt, args.save_dir)
return
def download_video(url, save_dir='./examples', size=768):
save_path = f'{save_dir}/{url}.mp4'
cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}'
if not os.path.exists(save_path):
try:
subprocess.call(cmd, shell=True)
except:
return None
return save_path
def get_empty_state():
return {"total_tokens": 0, "messages": []}
def submit_message(prompt, state):
history = state['messages']
if not prompt:
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
prompt_msg = { "role": "user", "content": prompt }
try:
history.append(prompt_msg)
# answer = vlogger.chat2video(prompt)
# answer = prompt
extract_txt(prompt)
answer = forward(vtg_model, args.save_dir, prompt)
history.append({"role": "system", "content": answer})
except Exception as e:
history.append(prompt_msg)
history.append({
"role": "system",
"content": f"Error: {e}"
})
chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
return '', chat_messages, state
def clear_conversation():
return gr.update(value=None, visible=True), gr.update(value=None, interactive=True), None, gr.update(value=None, visible=True), get_empty_state()
def subvid_fn(vid):
save_path = download_video(vid)
return gr.update(value=save_path)
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#video_inp {min-height: 100px}
#chatbox {min-height: 100px;}
#header {text-align: center;}
#hint {font-size: 1.0em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css) as demo:
state = gr.State(get_empty_state())
with gr.Column(elem_id="col-container"):
gr.Markdown("""## πŸ€–οΈ UniVTG: Towards Unified Video-Language Temporal Grounding
Given a video and text query, return relevant window and highlight.
https://github.com/showlab/UniVTG/""",
elem_id="header")
with gr.Row():
with gr.Column():
video_inp = gr.Video(label="video_input")
gr.Markdown("πŸ‘‹ **Step1**: Select a video in Examples (bottom) or input youtube video_id in this textbox, *e.g.* *G7zJK6lcbyU* for https://www.youtube.com/watch?v=G7zJK6lcbyU", elem_id="hint")
with gr.Row():
video_id = gr.Textbox(value="", placeholder="Youtube video url", show_label=False)
vidsub_btn = gr.Button("(Optional) Submit Youtube id")
with gr.Column():
vid_ext = gr.Button("Step2: Extract video feature, may takes a while")
# vlog_outp = gr.Textbox(label="Document output", lines=40)
total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
chatbot = gr.Chatbot(elem_id="chatbox")
input_message = gr.Textbox(show_label=False, placeholder="Enter text query and press enter", visible=True).style(container=False)
btn_submit = gr.Button("Step3: Enter your text query")
btn_clear_conversation = gr.Button("πŸ”ƒ Clear")
examples = gr.Examples(
examples=[
["./examples/youtube.mp4"],
["./examples/charades.mp4"],
["./examples/ego4d.mp4"],
],
inputs=[video_inp],
)
gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/anzorq/chatgpt-demo?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br></center>''')
btn_submit.click(submit_message, [input_message, state], [input_message, chatbot])
input_message.submit(submit_message, [input_message, state], [input_message, chatbot])
# btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, vlog_outp, state])
btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, state])
vid_ext.click(extract_vid, [video_inp, state], [input_message, chatbot])
vidsub_btn.click(subvid_fn, [video_id], [video_inp])
demo.load(queur=False)
demo.queue(concurrency_count=10)
# demo.launch(height='800px', server_port=2253, debug=True, share=True)
demo.launch(height='800px')