multimodalart HF staff commited on
Commit
12e9e51
1 Parent(s): b664a31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -10,12 +10,15 @@ from diffusers.utils import export_to_video
10
  import spaces
11
  import uuid
12
 
 
 
 
13
  is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False
14
 
15
  # Constants
16
  MODEL_PATH = "pyramid-flow-model"
17
  MODEL_REPO = "rain1011/pyramid-flow-sd3"
18
- MODEL_VARIANT = "diffusion_transformer_384p"
19
  MODEL_DTYPE = "bf16"
20
 
21
  def center_crop(image, target_width, target_height):
@@ -63,18 +66,19 @@ model = load_model()
63
  # Text-to-video generation function
64
  @spaces.GPU(duration=120)
65
  def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
66
- multiplier = 3
67
- temp = int(duration * multiplier) + 1 # Convert seconds to temp value (assuming 24 FPS)
68
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
69
  if(image):
70
- cropped_image = center_crop(image, 640, 384)
71
- resized_image = cropped_image.resize((640, 384))
72
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
73
  frames = model.generate_i2v(
74
  prompt=prompt,
75
  input_image=resized_image,
76
  num_inference_steps=[10, 10, 10],
77
  temp=temp,
 
78
  video_guidance_scale=video_guidance_scale,
79
  output_type="pil",
80
  save_memory=True,
@@ -85,8 +89,8 @@ def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guida
85
  prompt=prompt,
86
  num_inference_steps=[20, 20, 20],
87
  video_num_inference_steps=[10, 10, 10],
88
- height=384,
89
- width=640,
90
  temp=temp,
91
  guidance_scale=guidance_scale,
92
  video_guidance_scale=video_guidance_scale,
@@ -94,14 +98,14 @@ def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guida
94
  save_memory=True,
95
  )
96
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
97
- export_to_video(frames, output_path, fps=24)
98
  return output_path
99
 
100
  # Gradio interface
101
  with gr.Blocks() as demo:
102
- gr.Markdown("# Pyramid Flow 384p demo")
103
- gr.Markdown("Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours")
104
- gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)] [[Project Page]](https://pyramid-flow.github.io)")
105
 
106
  with gr.Row():
107
  with gr.Column():
@@ -109,8 +113,8 @@ with gr.Blocks() as demo:
109
  i2v_image = gr.Image(type="pil", label="Input Image")
110
  t2v_prompt = gr.Textbox(label="Prompt")
111
  with gr.Accordion("Advanced settings", open=False):
112
- t2v_duration = gr.Slider(minimum=1, maximum=5, value=5, step=1, label="Duration (seconds)", visible=not is_canonical)
113
- t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=7, step=0.1, label="Guidance Scale")
114
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
115
  t2v_generate_btn = gr.Button("Generate Video")
116
  with gr.Column():
@@ -122,7 +126,7 @@ with gr.Blocks() as demo:
122
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
123
  </a>
124
  </p>
125
- <p>to use privately and generate videos up to 10s</p>
126
  </div>
127
  """)
128
  gr.Examples(
 
10
  import spaces
11
  import uuid
12
 
13
+ import subprocess
14
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
15
+
16
  is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False
17
 
18
  # Constants
19
  MODEL_PATH = "pyramid-flow-model"
20
  MODEL_REPO = "rain1011/pyramid-flow-sd3"
21
+ MODEL_VARIANT = "diffusion_transformer_768p"
22
  MODEL_DTYPE = "bf16"
23
 
24
  def center_crop(image, target_width, target_height):
 
66
  # Text-to-video generation function
67
  @spaces.GPU(duration=120)
68
  def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
69
+ multiplier = 0.8 if is_canonical else 2.4
70
+ temp = int(duration * 0.8) # Convert seconds to temp value (assuming 24 FPS)
71
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
72
  if(image):
73
+ cropped_image = center_crop(image, 1280, 720)
74
+ resized_image = cropped_image.resize((1280, 720))
75
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
76
  frames = model.generate_i2v(
77
  prompt=prompt,
78
  input_image=resized_image,
79
  num_inference_steps=[10, 10, 10],
80
  temp=temp,
81
+ guidance_scale=7.0,
82
  video_guidance_scale=video_guidance_scale,
83
  output_type="pil",
84
  save_memory=True,
 
89
  prompt=prompt,
90
  num_inference_steps=[20, 20, 20],
91
  video_num_inference_steps=[10, 10, 10],
92
+ height=768,
93
+ width=1280,
94
  temp=temp,
95
  guidance_scale=guidance_scale,
96
  video_guidance_scale=video_guidance_scale,
 
98
  save_memory=True,
99
  )
100
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
101
+ export_to_video(frames, output_path, fps=8 if is_canonical else 24)
102
  return output_path
103
 
104
  # Gradio interface
105
  with gr.Blocks() as demo:
106
+ gr.Markdown("# Pyramid Flow")
107
+ gr.Markdown("Pyramid Flow is a training-efficient Autoregressive Video Generation model based on Flow Matching. It is trained only on open-source datasets within 20.7k A100 GPU hours")
108
+ gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)]")
109
 
110
  with gr.Row():
111
  with gr.Column():
 
113
  i2v_image = gr.Image(type="pil", label="Input Image")
114
  t2v_prompt = gr.Textbox(label="Prompt")
115
  with gr.Accordion("Advanced settings", open=False):
116
+ t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)", visible=not is_canonical)
117
+ t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
118
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
119
  t2v_generate_btn = gr.Button("Generate Video")
120
  with gr.Column():
 
126
  <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
127
  </a>
128
  </p>
129
+ <p>to use privately and generate videos up to 10s at 24fps</p>
130
  </div>
131
  """)
132
  gr.Examples(