StreamingT2V / t2v_enhanced /inference.py
hpoghos's picture
minor
81022ab
raw
history blame
3.6 kB
# General
import os
from os.path import join as opj
import argparse
import datetime
from pathlib import Path
import torch
import gradio as gr
import tempfile
import yaml
from t2v_enhanced.model.video_ldm import VideoLDM
# Utilities
from t2v_enhanced.inference_utils import *
from t2v_enhanced.model_init import *
from t2v_enhanced.model_func import *
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt', type=str, default="A cat running on the street", help="The prompt to guide video generation.")
parser.add_argument('--image', type=str, default="", help="Path to image conditioning.")
# parser.add_argument('--video', type=str, default="", help="Path to video conditioning.")
parser.add_argument('--base_model', type=str, default="ModelscopeT2V", help="Base model to generate first chunk from", choices=["ModelscopeT2V", "AnimateDiff", "SVD"])
parser.add_argument('--num_frames', type=int, default=24, help="The number of video frames to generate.")
parser.add_argument('--negative_prompt', type=str, default="", help="The prompt to guide what to not include in video generation.")
parser.add_argument('--num_steps', type=int, default=50, help="The number of denoising steps.")
parser.add_argument('--image_guidance', type=float, default=9.0, help="The guidance scale.")
parser.add_argument('--output_dir', type=str, default="results", help="Path where to save the generated videos.")
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--seed', type=int, default=33, help="Random seed")
args = parser.parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
result_fol = Path(args.output_dir).absolute()
device = args.device
# --------------------------
# ----- Configurations -----
# --------------------------
ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute()
cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
# --------------------------
# ----- Initialization -----
# --------------------------
stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
if args.base_model == "ModelscopeT2V":
model = init_modelscope(device)
elif args.base_model == "AnimateDiff":
model = init_animatediff(device)
elif args.base_model == "SVD":
model = init_svd(device)
sdxl_model = init_sdxl(device)
inference_generator = torch.Generator(device="cuda")
# ------------------
# ----- Inputs -----
# ------------------
now = datetime.datetime.now()
name = args.prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
inference_generator = torch.Generator(device="cuda")
inference_generator.manual_seed(args.seed)
if args.base_model == "ModelscopeT2V":
short_video = ms_short_gen(args.prompt, model, inference_generator)
elif args.base_model == "AnimateDiff":
short_video = ad_short_gen(args.prompt, model, inference_generator)
elif args.base_model == "SVD":
short_video = svd_short_gen(args.image, args.prompt, model, sdxl_model, inference_generator)
n_autoreg_gen = args.num_frames // 8 - 8
stream_long_gen(args.prompt, short_video, n_autoreg_gen, args.negative_prompt, args.seed, args.num_steps, args.image_guidance, name, stream_cli, stream_model)
video2video(args.prompt, opj(result_fol, name+".mp4"), result_fol, cfg_v2v, msxl_model)