Spaces:
Runtime error
Runtime error
File size: 14,932 Bytes
3301107 2362b20 81022ab 2362b20 f949b3f 81022ab f949b3f 81022ab f949b3f 81022ab bde79e1 81022ab 75aaff7 81022ab 75aaff7 d67a615 75aaff7 69e088c 81022ab 75aaff7 81022ab 9cd26c4 81022ab 2362b20 d67a615 81022ab 5fe7264 81022ab 5fe7264 bde79e1 81022ab 5fe7264 81022ab d67a615 2c5b700 d67a615 2c5b700 d67a615 5fe7264 2c5b700 d67a615 81022ab 2362b20 d67a615 bde79e1 d67a615 81022ab 75aaff7 81022ab d67a615 5fe7264 d67a615 5fe7264 d67a615 7d526b1 5fe7264 d67a615 7d526b1 5fe7264 7d526b1 d67a615 7d526b1 5fe7264 7d526b1 5fe7264 7d526b1 d67a615 7d526b1 d67a615 5fe7264 7d526b1 5fe7264 7d526b1 5fe7264 81022ab 840fe70 81022ab d67a615 81022ab bde79e1 81022ab 5fe7264 81022ab d67a615 3c18edf d67a615 3c18edf 81022ab d67a615 81022ab d67a615 81022ab d67a615 7d526b1 7de406f c886731 7de406f 5fe7264 7d526b1 5fe7264 7d526b1 5fe7264 d67a615 81022ab f2c466d 81022ab 5fe7264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
import torch
# torch.jit.script = lambda f: f
# General
import os
from os.path import join as opj
import argparse
import datetime
from pathlib import Path
# import spaces
import gradio as gr
import tempfile
import yaml
from t2v_enhanced.model.video_ldm import VideoLDM
# Utilities
from t2v_enhanced.inference_utils import *
from t2v_enhanced.model_init import *
from t2v_enhanced.model_func import *
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
parser = argparse.ArgumentParser()
parser.add_argument('--public_access', action='store_true', default=True)
parser.add_argument('--where_to_log', type=str, default="gradio_output")
parser.add_argument('--device', type=str, default="cuda")
args = parser.parse_args()
default_prompt = "A man with yellow ballon head is riding a bike on the street of New York City"
Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
result_fol = Path(args.where_to_log).absolute()
device = args.device
n_devices = int(os.environ.get('NDEVICES', 4))
if n_devices == 4:
devices = [f"cuda:{idx}" for idx in range(4)]
else:
devices = ["cuda"] * 4
# --------------------------
# ----- Configurations -----
# --------------------------
cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
# --------------------------
# ----- Initialization -----
# --------------------------
ms_model = init_modelscope(devices[1])
# # zs_model = init_zeroscope(device)
ad_model = init_animatediff(devices[1])
svd_model = init_svd(devices[2])
sdxl_model = init_sdxl(devices[2])
ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute()
stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
msxl_model = init_v2v_model(cfg_v2v, devices[3])
# -------------------------
# ----- Functionality -----
# -------------------------
# @spaces.GPU(duration=120)
def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol):
now = datetime.datetime.now()
name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
if num_frames == [] or num_frames is None:
num_frames = 24
else:
num_frames = int(num_frames.split(" ")[0])
if num_frames > 56:
num_frames = 56
if prompt == "" or prompt is None:
prompt = default_prompt
n_autoreg_gen = (num_frames-8)//8
if model_name_stage1 == "ModelScopeT2V (text to video)":
inference_generator = torch.Generator(device=ms_model.device).manual_seed(seed)
short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
elif model_name_stage1 == "AnimateDiff (text to video)":
inference_generator = torch.Generator(device=ad_model.device).manual_seed(seed)
short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
elif model_name_stage1 == "SVD (image to video)":
# For cached examples
if isinstance(image, dict):
image = image["path"]
inference_generator = torch.Generator(device=svd_model.device).manual_seed(seed)
short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)
stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
video_path = opj(where_to_log, name+".mp4")
return video_path
# @spaces.GPU(duration=400)
def enhance(prompt, input_to_enhance, num_frames=None, image=None, model_name_stage1=None, model_name_stage2=None, seed=33, t=50, image_guidance=9.5, result_fol=result_fol):
if prompt == "" or prompt is None:
prompt = default_prompt
if input_to_enhance is None:
input_to_enhance = generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance)
encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
# for idx in range(4):
# print(f">>> cuda:{idx}", torch.cuda.max_memory_allocated(f"cuda:{idx}"))
return encoded_video
def change_visibility(value):
if value == "SVD (image to video)":
return gr.Image(label='Image Prompt (if not attached then SDXL will be used to generate the starting image)', show_label=True, scale=1, show_download_button=False, interactive=True, value=None)
else:
return gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False, value=None)
# [prompt_stage1, video_stage2, num_frames, image_stage1, model_name_stage1, seed, t, image_guidance]
examples_1 = [
["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.",
"__assets__/examples/t2v/1.mp4", "56 - frames", None, "ModelScopeT2V (text to video)", 33, 50, 9.0],
["People dancing in room filled with fog and colorful lights.",
"__assets__/examples/t2v/2.mp4", "56 - frames", None, "ModelScopeT2V (text to video)", 33, 50, 9.0],
["Discover the secret language of bees: delve into the complex communication system that allows bees to coordinate their actions and navigate the world.",
"__assets__/examples/t2v/3.mp4", "56 - frames", None, "AnimateDiff (text to video)", 33, 50, 9.0],
["sunset, orange sky, warm lighting, fishing boats, ocean waves seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, coastal landscape, seaside scenery.",
"__assets__/examples/t2v/4.mp4", "56 - frames", None, "AnimateDiff (text to video)", 33, 50, 9.0],
["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.",
"__assets__/examples/t2v/5.mp4", "56 - frames", None, "SVD (image to video)", 33, 50, 9.0],
["Ants, beetles and centipede nest.",
"__assets__/examples/t2v/6.mp4", "56 - frames", None, "SVD (image to video)", 33, 50, 9.0],
]
examples_2 = [
["Fishes swimming in ocean camera moving, cinematic.",
"__assets__/examples/i2v/1.mp4", "56 - frames", "__assets__/fish.jpg", "SVD (image to video)", 33, 50, 9.0],
["A squirrel on a table full of big nuts.",
"__assets__/examples/i2v/2.mp4", "56 - frames", "__assets__/squirrel.jpg", "SVD (image to video)", 33, 50, 9.0],
]
# --------------------------
# ----- Gradio-Demo UI -----
# --------------------------
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingT2V</a>
</h1>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, Humphrey Shi<sup>1,3</sup>
</h2>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
<sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
</h2>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
*Equal Contribution
</h2>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
[<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>]
[<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
[<a href="https://streamingt2v.github.io/" style="color:blue;">Project page</a>]
</h2>
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
<b>StreamingT2V</b> is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation.
It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality.
Our demonstrations include successful examples of videos up to <b>1200 frames, spanning 2 minutes</b>, and can be extended for even longer durations.
Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos.
</h2>
</div>
""")
if on_huggingspace:
gr.HTML("""
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
<br/>
<a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
</p>""")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row():
num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
with gr.Row():
prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder=f"Ex: {default_prompt}")
with gr.Row():
image_stage1 = gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False)
with gr.Column():
video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
with gr.Row():
with gr.Row():
run_button_stage1 = gr.Button("Long Video Generation (faster preview)")
with gr.Row():
run_button_stage2 = gr.Button("Long Video Generation")
with gr.Row():
with gr.Column():
with gr.Accordion('Advanced options', open=False):
model_name_stage1 = gr.Dropdown(
choices=["ModelScopeT2V (text to video)", "AnimateDiff (text to video)", "SVD (image to video)"],
label="Base Model",
value="ModelScopeT2V (text to video)"
)
model_name_stage2 = gr.Dropdown(
choices=["MS-Vid2Vid-XL"],
label="Enhancement Model",
value="MS-Vid2Vid-XL"
)
seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)
t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
image_guidance = gr.Slider(label='Image guidance scale', minimum=1, maximum=10, value=9.0, step=1.0)
with gr.Column():
with gr.Row():
video_stage2 = gr.Video(label='Long Video', show_label=True, interactive=False, height=588, show_download_button=True)
model_name_stage1.change(fn=change_visibility, inputs=[model_name_stage1], outputs=image_stage1)
inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, seed, t, image_guidance]
run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)
inputs_v2v = [prompt_stage1, video_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, seed, t, image_guidance]
inputs_examples = [prompt_stage1, video_stage2, num_frames, image_stage1, model_name_stage1, seed, t, image_guidance]
gr.HTML("""
<h2>
You can check the inference time for different number of frames
<p style=" display: inline">
<a href="https://github.com/Picsart-AI-Research/StreamingT2V/blob/main/README.md#inference-time" style="color:blue;" target="_blank">[here].</a>
</p>
</h2>
""")
gr.Examples(examples=examples_1,
inputs=inputs_examples,
# outputs=[video_stage2],
# fn=enhance,
# run_on_click=False,
# cache_examples=False,
)
gr.Examples(examples=examples_2,
inputs=inputs_examples,
# outputs=[video_stage2],
# fn=enhance,
# run_on_click=False,
# cache_examples=False,
# # preprocess=False,
# # postprocess=True,
)
run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)
'''
'''
gr.HTML(
"""
<div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
<b>Version: v1.0</b>
</h3>
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
<b>Caution</b>:
We would like the raise the awareness of users of this demo of its potential issues and concerns.
Like previous large foundation models, StreamingT2V could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingT2V can Inherit Its Imperfections.
So far, we keep all features available for research testing both to show the great potential of the StreamingT2V framework and to collect important feedback to improve the model in the future.
We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
</h3>
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
<b>Biases and content acknowledgement</b>:
Beware that StreamingT2V may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
StreamingT2V in this demo is meant only for research purposes.
</h3>
</div>
""")
if on_huggingspace:
demo.queue(max_size=10)
demo.launch(debug=True)
else:
demo.queue(api_open=False).launch(share=args.public_access) |