File size: 14,932 Bytes
3301107
2362b20
81022ab
 
 
 
 
 
2362b20
f949b3f
81022ab
 
 
f949b3f
81022ab
 
 
 
f949b3f
81022ab
 
 
 
 
 
 
bde79e1
81022ab
 
 
 
75aaff7
 
 
 
 
81022ab
 
 
 
 
 
 
 
 
75aaff7
d67a615
75aaff7
 
 
69e088c
 
81022ab
75aaff7
81022ab
9cd26c4
81022ab
 
 
 
 
2362b20
d67a615
81022ab
 
 
 
5fe7264
81022ab
 
5fe7264
 
bde79e1
 
 
81022ab
5fe7264
81022ab
d67a615
2c5b700
d67a615
 
2c5b700
d67a615
 
5fe7264
 
 
2c5b700
d67a615
 
 
81022ab
 
 
2362b20
d67a615
bde79e1
 
 
d67a615
 
81022ab
75aaff7
 
81022ab
 
d67a615
 
5fe7264
d67a615
5fe7264
d67a615
 
7d526b1
5fe7264
d67a615
7d526b1
5fe7264
7d526b1
d67a615
7d526b1
5fe7264
7d526b1
5fe7264
7d526b1
d67a615
7d526b1
d67a615
 
5fe7264
 
7d526b1
5fe7264
7d526b1
5fe7264
81022ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840fe70
81022ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d67a615
81022ab
bde79e1
81022ab
5fe7264
81022ab
 
 
d67a615
3c18edf
d67a615
3c18edf
81022ab
 
 
 
 
d67a615
 
 
81022ab
 
d67a615
 
 
81022ab
 
 
 
 
 
 
 
d67a615
 
 
 
 
 
 
 
 
7d526b1
 
7de406f
 
 
 
c886731
7de406f
 
 
 
5fe7264
7d526b1
 
 
 
 
5fe7264
 
 
7d526b1
 
 
 
 
 
 
5fe7264
 
d67a615
 
81022ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2c466d
81022ab
 
5fe7264
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import torch
# torch.jit.script = lambda f: f
# General
import os
from os.path import join as opj
import argparse
import datetime
from pathlib import Path
# import spaces
import gradio as gr
import tempfile
import yaml
from t2v_enhanced.model.video_ldm import VideoLDM

# Utilities
from t2v_enhanced.inference_utils import *
from t2v_enhanced.model_init import *
from t2v_enhanced.model_func import *


on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
parser = argparse.ArgumentParser()
parser.add_argument('--public_access', action='store_true', default=True)
parser.add_argument('--where_to_log', type=str, default="gradio_output")
parser.add_argument('--device', type=str, default="cuda")
args = parser.parse_args()
default_prompt = "A man with yellow ballon head is riding a bike on the street of New York City"

Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
result_fol = Path(args.where_to_log).absolute()
device = args.device
n_devices = int(os.environ.get('NDEVICES', 4))
if n_devices == 4:
    devices = [f"cuda:{idx}" for idx in range(4)]
else:
    devices = ["cuda"] * 4
# --------------------------
# ----- Configurations -----
# --------------------------
cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}


# --------------------------
# ----- Initialization -----
# --------------------------
ms_model = init_modelscope(devices[1])
# # zs_model = init_zeroscope(device)
ad_model = init_animatediff(devices[1])
svd_model = init_svd(devices[2])
sdxl_model = init_sdxl(devices[2])

ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute()
stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
msxl_model = init_v2v_model(cfg_v2v, devices[3])




# -------------------------
# ----- Functionality -----
# -------------------------
# @spaces.GPU(duration=120)
def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol):
    now = datetime.datetime.now()
    name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")

    if num_frames == [] or num_frames is None:
        num_frames = 24
    else:
        num_frames = int(num_frames.split(" ")[0])
        if num_frames > 56:
            num_frames = 56
    
    if prompt == "" or prompt is None:
        prompt = default_prompt

    n_autoreg_gen = (num_frames-8)//8

    if model_name_stage1 == "ModelScopeT2V (text to video)":
        inference_generator = torch.Generator(device=ms_model.device).manual_seed(seed)
        short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
    elif model_name_stage1 == "AnimateDiff (text to video)":
        inference_generator = torch.Generator(device=ad_model.device).manual_seed(seed)
        short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
    elif model_name_stage1 == "SVD (image to video)":
        # For cached examples
        if isinstance(image, dict):
            image = image["path"]
        inference_generator = torch.Generator(device=svd_model.device).manual_seed(seed)
        short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)

    stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
    video_path = opj(where_to_log, name+".mp4")
    return video_path

# @spaces.GPU(duration=400)
def enhance(prompt, input_to_enhance, num_frames=None, image=None, model_name_stage1=None, model_name_stage2=None, seed=33, t=50, image_guidance=9.5, result_fol=result_fol):
    if prompt == "" or prompt is None:
        prompt = default_prompt
    
    if input_to_enhance is None:
        input_to_enhance = generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance)
    encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
    # for idx in range(4):
    #     print(f">>> cuda:{idx}", torch.cuda.max_memory_allocated(f"cuda:{idx}"))
    return encoded_video

def change_visibility(value):
    if value == "SVD (image to video)":
        return gr.Image(label='Image Prompt (if not attached then SDXL will be used to generate the starting image)', show_label=True, scale=1, show_download_button=False, interactive=True, value=None)
    else:
        return gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False, value=None)


# [prompt_stage1, video_stage2, num_frames, image_stage1, model_name_stage1, seed, t, image_guidance]
examples_1 = [
        ["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.",
            "__assets__/examples/t2v/1.mp4", "56 - frames", None, "ModelScopeT2V (text to video)", 33, 50, 9.0],
        ["People dancing in room filled with fog and colorful lights.",
            "__assets__/examples/t2v/2.mp4", "56 - frames", None, "ModelScopeT2V (text to video)", 33, 50, 9.0],
        ["Discover the secret language of bees: delve into the complex communication system that allows bees to coordinate their actions and navigate the world.",
            "__assets__/examples/t2v/3.mp4", "56 - frames", None, "AnimateDiff (text to video)", 33, 50, 9.0],
        ["sunset, orange sky, warm lighting, fishing boats, ocean waves seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, coastal landscape, seaside scenery.",
            "__assets__/examples/t2v/4.mp4", "56 - frames", None, "AnimateDiff (text to video)", 33, 50, 9.0],
        ["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.",
            "__assets__/examples/t2v/5.mp4", "56 - frames", None, "SVD (image to video)", 33, 50, 9.0],
        ["Ants, beetles and centipede nest.",
            "__assets__/examples/t2v/6.mp4", "56 - frames", None, "SVD (image to video)", 33, 50, 9.0],
        ]

examples_2 = [
        ["Fishes swimming in ocean camera moving, cinematic.",
            "__assets__/examples/i2v/1.mp4", "56 - frames", "__assets__/fish.jpg", "SVD (image to video)", 33, 50, 9.0],
        ["A squirrel on a table full of big nuts.",
            "__assets__/examples/i2v/2.mp4", "56 - frames", "__assets__/squirrel.jpg", "SVD (image to video)", 33, 50, 9.0],
        ]

# --------------------------
# ----- Gradio-Demo UI -----
# --------------------------
with gr.Blocks() as demo:
    gr.HTML(
        """
        <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
        <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
            <a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingT2V</a> 
        </h1>
        <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
        Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, Humphrey Shi<sup>1,3</sup>
        </h2>
        <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
        <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
        </h2>
        <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
        *Equal Contribution
        </h2>
        <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
        [<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>] 
        [<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
        [<a href="https://streamingt2v.github.io/" style="color:blue;">Project page</a>]
        </h2>
        <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
        <b>StreamingT2V</b> is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation. 
        It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality. 
        Our demonstrations include successful examples of videos up to <b>1200 frames, spanning 2 minutes</b>, and can be extended for even longer durations. 
        Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos.
        </h2>
        </div>
        """)

    if on_huggingspace:
        gr.HTML("""
        <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
        <br/>
        <a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
        <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
        </p>""")

    with gr.Row():
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
                    with gr.Row():
                        prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder=f"Ex: {default_prompt}")
                    with gr.Row():
                        image_stage1 = gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False)
                with gr.Column():
                    video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
            with gr.Row():
                with gr.Row():
                    run_button_stage1 = gr.Button("Long Video Generation (faster preview)")
                with gr.Row():
                    run_button_stage2 = gr.Button("Long Video Generation")

            with gr.Row():
                with gr.Column():
                    with gr.Accordion('Advanced options', open=False):
                        model_name_stage1 = gr.Dropdown(
                            choices=["ModelScopeT2V (text to video)", "AnimateDiff (text to video)", "SVD (image to video)"],
                            label="Base Model",
                            value="ModelScopeT2V (text to video)"
                        )
                        model_name_stage2 = gr.Dropdown(
                            choices=["MS-Vid2Vid-XL"],
                            label="Enhancement Model",
                            value="MS-Vid2Vid-XL"
                        )
                        seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)

                        t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
                        image_guidance = gr.Slider(label='Image guidance scale', minimum=1, maximum=10, value=9.0, step=1.0)

        with gr.Column():
            with gr.Row():
                video_stage2 = gr.Video(label='Long Video', show_label=True, interactive=False, height=588, show_download_button=True)

    model_name_stage1.change(fn=change_visibility, inputs=[model_name_stage1], outputs=image_stage1)

    inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, seed, t, image_guidance]
    run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)

    inputs_v2v = [prompt_stage1, video_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, seed, t, image_guidance]

    inputs_examples = [prompt_stage1, video_stage2, num_frames, image_stage1, model_name_stage1, seed, t, image_guidance]

    gr.HTML("""
        <h2>
            You can check the inference time for different number of frames
            <p style=" display: inline">
                <a href="https://github.com/Picsart-AI-Research/StreamingT2V/blob/main/README.md#inference-time" style="color:blue;" target="_blank">[here].</a> 
            </p>
        </h2>
        """)

    gr.Examples(examples=examples_1,
                inputs=inputs_examples,
                # outputs=[video_stage2],
                # fn=enhance,
                # run_on_click=False,
                # cache_examples=False,
                )

    gr.Examples(examples=examples_2,
                inputs=inputs_examples,
                # outputs=[video_stage2],
                # fn=enhance,
                # run_on_click=False,
                # cache_examples=False,
                # # preprocess=False,
                # # postprocess=True,
                )
    
    run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)

    '''
    '''
    gr.HTML(
        """
        <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
        <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
        <b>Version: v1.0</b>
        </h3>
        <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
        <b>Caution</b>: 
        We would like the raise the awareness of users of this demo of its potential issues and concerns.
        Like previous large foundation models, StreamingT2V could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingT2V can Inherit Its Imperfections.
        So far, we keep all features available for research testing both to show the great potential of the StreamingT2V framework and to collect important feedback to improve the model in the future.
        We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
        </h3>
        <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
        <b>Biases and content acknowledgement</b>:
        Beware that StreamingT2V may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. 
        StreamingT2V in this demo is meant only for research purposes.
        </h3>
        </div>
        """)


if on_huggingspace:
    demo.queue(max_size=10)
    demo.launch(debug=True)
else:
    demo.queue(api_open=False).launch(share=args.public_access)