quchenyuan's picture
Update app.py
7fdf6e4 verified
raw
history blame contribute delete
No virus
3.6 kB
import gradio as gr
import os
from huggingface_hub import hf_hub_download
from pathlib import Path
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
import json
import torch
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# temp_folder = 'temp'
# os.makedirs(temp_folder, exist_ok=True)
logit = {}
json_file = 'index.json'
with open(json_file, 'r') as file:
data = json.load(file)
for key, value in data.items():
text_description = value['text_description']
inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
outputs = model(**inputs, labels=inputs["input_ids"])
logits = outputs.logits
logit[key] = logits
# torch.save(logits, os.path.join(temp_folder, f"{key}.pt"))
def search_index(query):
inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
outputs = model(**inputs, labels=inputs["input_ids"])
max_similarity = float('-inf')
max_similarity_uuid = None
# for file in os.listdir(temp_folder):
# uuid = file.split('.')[0]
# logits = torch.load(os.path.join(temp_folder, file))
for uuid, logits in logit.items():
similarity = (outputs.logits * logits).sum()
if similarity > max_similarity:
max_similarity = similarity
max_similarity_uuid = uuid
gr.Info(f"Found the most similar video with UUID: {max_similarity_uuid}. \n Downloading the video...")
return max_similarity_uuid
def download_video(uuid):
dataset_name = "quchenyuan/360x_dataset_LR"
dataset_path = "360_dataset/binocular/"
video_filename = f"{uuid}.mp4"
storage_dir = Path("videos")
storage_dir.mkdir(exist_ok=True)
# storage_limit = 40 * 1024 * 1024 * 1024
# current_storage = sum(f.stat().st_size for f in storage_dir.glob('*') if f.is_file())
# if current_storage + os.path.getsize(video_filename) > storage_limit:
# oldest_file = min(storage_dir.glob('*'), key=os.path.getmtime)
# oldest_file.unlink()
downloaded_file_path = hf_hub_download(dataset_name, dataset_path + video_filename)
return str(storage_dir / video_filename)
def search_and_show_video(query):
uuid = search_index(query)
video_path = download_video(uuid)
return video_path
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.HTML("<h1><i>360+x</i> : A Panoptic Multi-modal Scene Understanding Dataset</h1>")
with gr.Row():
gr.HTML("<p><a href='https://x360dataset.github.io/'>Official Website</a> <a href='https://arxiv.org/abs/2404.00989'>Paper</a></p>")
with gr.Row():
gr.HTML("<h2>Search for a video by entering a query below:</h2>")
with gr.Row():
search_input = gr.Textbox(label="Query", placeholder="Enter a query to search for a video.")
with gr.Row():
with gr.Column():
video_output_1 = gr.Video()
with gr.Column():
video_output_2 = gr.Video()
with gr.Column():
video_output_3 = gr.Video()
with gr.Row():
submit_button = gr.Button(value="Search")
submit_button.click(search_and_show_video, search_input,
outputs=[video_output_1, video_output_2, video_output_3])
demo.launch()