#!/usr/bin/env python import os import pathlib import tempfile import cv2 import gradio as gr import torch from huggingface_hub import snapshot_download from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline DESCRIPTION = "# ModelScope-Vid2Vid-XL" if not torch.cuda.is_available(): DESCRIPTION += "\n
Running on CPU 🥶 This demo does not work on CPU.
" if torch.cuda.is_available(): model_cache_dir = os.getenv("MODEL_CACHE_DIR", "./models") model_dir = pathlib.Path(model_cache_dir) / "MS-Vid2Vid-XL" snapshot_download(repo_id="damo-vilab/MS-Vid2Vid-XL", repo_type="model", local_dir=model_dir) pipe = pipeline(task="video-to-video", model=model_dir.as_posix(), model_revision="v1.1.0", device="cuda:0") def check_input_video(video_path: str) -> None: cap = cv2.VideoCapture(video_path) n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() if n_frames != 32 or width != 448 or height != 256: raise gr.Error( f"Input video must be 32 frames of size 448x256. Your video is {n_frames} frames of size {width}x{height}." ) def video_to_video(video_path: str, text: str) -> str: check_input_video(video_path) p_input = {"video_path": video_path, "text": text} output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) pipe(p_input, output_video=output_file.name)[OutputKeys.OUTPUT_VIDEO] return output_file.name with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Group(): input_video = gr.Video(label="Input video") text_description = gr.Textbox(label="Text description") run_button = gr.Button() output_video = gr.Video(label="Output video") gr.on( triggers=[text_description.submit, run_button.click], fn=check_input_video, inputs=input_video, queue=False, api_name=False, ).success( fn=video_to_video, inputs=[input_video, text_description], outputs=output_video, api_name="run", ) if __name__ == "__main__": demo.queue(max_size=10).launch()