File size: 4,381 Bytes
6a87547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import snapshot_download
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import export_to_video

# Constants
MODEL_PATH = "pyramid-flow-model"
MODEL_REPO = "rain1011/pyramid-flow-sd3"
MODEL_VARIANT = "diffusion_transformer_768p"
MODEL_DTYPE = "bf16"

# Download and load the model
def load_model():
    if not os.path.exists(MODEL_PATH):
        snapshot_download(MODEL_REPO, local_dir=MODEL_PATH, local_dir_use_symlinks=False, repo_type='model')
    
    model = PyramidDiTForVideoGeneration(
        MODEL_PATH,
        MODEL_DTYPE,
        model_variant=MODEL_VARIANT,
    )
    
    model.vae.to("cuda")
    model.dit.to("cuda")
    model.text_encoder.to("cuda")
    model.vae.enable_tiling()
    
    return model

# Global model variable
model = load_model()

# Text-to-video generation function
def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
    temp = int(duration * 2.4)  # Convert seconds to temp value (assuming 24 FPS)
    torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
    
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
        frames = model.generate(
            prompt=prompt,
            num_inference_steps=[20, 20, 20],
            video_num_inference_steps=[10, 10, 10],
            height=768,
            width=1280,
            temp=temp,
            guidance_scale=guidance_scale,
            video_guidance_scale=video_guidance_scale,
            output_type="pil",
            save_memory=True,
        )
    
    output_path = "output_video.mp4"
    export_to_video(frames, output_path, fps=24)
    return output_path

# Image-to-video generation function
def generate_video_from_image(image, prompt, duration, video_guidance_scale):
    temp = int(duration * 2.4)  # Convert seconds to temp value (assuming 24 FPS)
    torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
    
    image = image.resize((1280, 768))
    
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
        frames = model.generate_i2v(
            prompt=prompt,
            input_image=image,
            num_inference_steps=[10, 10, 10],
            temp=temp,
            guidance_scale=7.0,
            video_guidance_scale=video_guidance_scale,
            output_type="pil",
            save_memory=True,
        )
    
    output_path = "output_video_i2v.mp4"
    export_to_video(frames, output_path, fps=24)
    return output_path

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Pyramid Flow Video Generation Demo")
    
    with gr.Tab("Text-to-Video"):
        with gr.Row():
            with gr.Column():
                t2v_prompt = gr.Textbox(label="Prompt")
                t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
                t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
                t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
                t2v_generate_btn = gr.Button("Generate Video")
            with gr.Column():
                t2v_output = gr.Video(label="Generated Video")
        
        t2v_generate_btn.click(
            generate_video,
            inputs=[t2v_prompt, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
            outputs=t2v_output
        )
    
    with gr.Tab("Image-to-Video"):
        with gr.Row():
            with gr.Column():
                i2v_image = gr.Image(type="pil", label="Input Image")
                i2v_prompt = gr.Textbox(label="Prompt")
                i2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
                i2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=4, step=0.1, label="Video Guidance Scale")
                i2v_generate_btn = gr.Button("Generate Video")
            with gr.Column():
                i2v_output = gr.Video(label="Generated Video")
        
        i2v_generate_btn.click(
            generate_video_from_image,
            inputs=[i2v_image, i2v_prompt, i2v_duration, i2v_video_guidance_scale],
            outputs=i2v_output
        )

demo.launch()