Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from huggingface_hub import hf_hub_download | |
from pathlib import Path | |
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast | |
import json | |
import torch | |
model = GPT2LMHeadModel.from_pretrained('gpt2') | |
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
# temp_folder = 'temp' | |
# os.makedirs(temp_folder, exist_ok=True) | |
logit = {} | |
json_file = 'index.json' | |
with open(json_file, 'r') as file: | |
data = json.load(file) | |
for key, value in data.items(): | |
text_description = value['text_description'] | |
inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=32, truncation=True) | |
outputs = model(**inputs, labels=inputs["input_ids"]) | |
logits = outputs.logits | |
logit[key] = logits | |
# torch.save(logits, os.path.join(temp_folder, f"{key}.pt")) | |
def search_index(query): | |
inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=32, truncation=True) | |
outputs = model(**inputs, labels=inputs["input_ids"]) | |
max_similarity = float('-inf') | |
max_similarity_uuid = None | |
# for file in os.listdir(temp_folder): | |
# uuid = file.split('.')[0] | |
# logits = torch.load(os.path.join(temp_folder, file)) | |
for uuid, logits in logit.items(): | |
similarity = (outputs.logits * logits).sum() | |
if similarity > max_similarity: | |
max_similarity = similarity | |
max_similarity_uuid = uuid | |
gr.Info(f"Found the most similar video with UUID: {max_similarity_uuid}. \n Downloading the video...") | |
return max_similarity_uuid | |
def download_video(uuid): | |
dataset_name = "quchenyuan/360x_dataset_LR" | |
dataset_path = "360_dataset/binocular/" | |
video_filename = f"{uuid}.mp4" | |
storage_dir = Path("videos") | |
storage_dir.mkdir(exist_ok=True) | |
# storage_limit = 40 * 1024 * 1024 * 1024 | |
# current_storage = sum(f.stat().st_size for f in storage_dir.glob('*') if f.is_file()) | |
# if current_storage + os.path.getsize(video_filename) > storage_limit: | |
# oldest_file = min(storage_dir.glob('*'), key=os.path.getmtime) | |
# oldest_file.unlink() | |
downloaded_file_path = hf_hub_download(dataset_name, dataset_path + video_filename) | |
return str(storage_dir / video_filename) | |
def search_and_show_video(query): | |
uuid = search_index(query) | |
video_path = download_video(uuid) | |
return video_path | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
with gr.Row(): | |
gr.HTML("<h1><i>360+x</i> : A Panoptic Multi-modal Scene Understanding Dataset</h1>") | |
with gr.Row(): | |
gr.HTML("<p><a href='https://x360dataset.github.io/'>Official Website</a> <a href='https://arxiv.org/abs/2404.00989'>Paper</a></p>") | |
with gr.Row(): | |
gr.HTML("<h2>Search for a video by entering a query below:</h2>") | |
with gr.Row(): | |
search_input = gr.Textbox(label="Query", placeholder="Enter a query to search for a video.") | |
with gr.Row(): | |
with gr.Column(): | |
video_output_1 = gr.Video() | |
with gr.Column(): | |
video_output_2 = gr.Video() | |
with gr.Column(): | |
video_output_3 = gr.Video() | |
with gr.Row(): | |
submit_button = gr.Button(value="Search") | |
submit_button.click(search_and_show_video, search_input, | |
outputs=[video_output_1, video_output_2, video_output_3]) | |
demo.launch() | |