Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Dict, List, Optional | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
from pytorch_lightning.utilities.types import EVAL_DATALOADERS | |
from t2v_enhanced.model.datasets.video_dataset import Annotations | |
import json | |
class ConcatDataset(torch.utils.data.Dataset): | |
def __init__(self, datasets): | |
self.datasets = datasets | |
self.model_id = datasets["reconstruction_dataset"].model_id | |
def __getitem__(self, idx): | |
sample = {ds: self.datasets[ds].__getitem__( | |
idx) for ds in self.datasets} | |
return sample | |
def __len__(self): | |
return min(len(self.datasets[d]) for d in self.datasets) | |
class CustomPromptsDataset(torch.utils.data.Dataset): | |
def __init__(self, prompt_cfg: Dict[str, str]): | |
super().__init__() | |
if prompt_cfg["type"] == "prompt": | |
self.prompts = [prompt_cfg["content"]] | |
elif prompt_cfg["type"] == "file": | |
file = Path(prompt_cfg["content"]) | |
if file.suffix == ".npy": | |
self.prompts = np.load(file.as_posix()) | |
elif file.suffix == ".txt": | |
with open(prompt_cfg["content"]) as f: | |
lines = [line.rstrip() for line in f] | |
self.prompts = lines | |
elif file.suffix == ".json": | |
with open(prompt_cfg["content"],"r") as file: | |
metadata = json.load(file) | |
if "videos_root" in prompt_cfg: | |
videos_root = Path(prompt_cfg["videos_root"]) | |
video_path = [str(videos_root / sample["page_dir"] / | |
f"{sample['videoid']}.mp4") for sample in metadata] | |
else: | |
video_path = [str(sample["page_dir"] / | |
f"{sample['videoid']}.mp4") for sample in metadata] | |
self.prompts = [sample["prompt"] for sample in metadata] | |
self.video_path = video_path | |
transformed_prompts = [] | |
for prompt in self.prompts: | |
transformed_prompts.append( | |
Annotations.clean_prompt(prompt)) | |
self.prompts = transformed_prompts | |
def __len__(self): | |
return len(self.prompts) | |
def __getitem__(self, index): | |
output = {"prompt": self.prompts[index]} | |
if hasattr(self,"video_path"): | |
output["video"] = self.video_path[index] | |
return output | |
class PromptReader(pl.LightningDataModule): | |
def __init__(self, prompt_cfg: Dict[str, str]): | |
super().__init__() | |
self.predict_dataset = CustomPromptsDataset(prompt_cfg) | |
def predict_dataloader(self) -> EVAL_DATALOADERS: | |
return torch.utils.data.DataLoader(self.predict_dataset, batch_size=1, pin_memory=False, shuffle=False, drop_last=False) | |