feifeiobama commited on
Commit
b664a31
1 Parent(s): d36add3

Update app.py

Browse files

Change the demo to 384p

Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -10,15 +10,12 @@ from diffusers.utils import export_to_video
10
  import spaces
11
  import uuid
12
 
13
- import subprocess
14
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
15
-
16
  is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False
17
 
18
  # Constants
19
  MODEL_PATH = "pyramid-flow-model"
20
  MODEL_REPO = "rain1011/pyramid-flow-sd3"
21
- MODEL_VARIANT = "diffusion_transformer_768p"
22
  MODEL_DTYPE = "bf16"
23
 
24
  def center_crop(image, target_width, target_height):
@@ -66,19 +63,18 @@ model = load_model()
66
  # Text-to-video generation function
67
  @spaces.GPU(duration=120)
68
  def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
69
- multiplier = 0.8 if is_canonical else 2.4
70
- temp = int(duration * 0.8) # Convert seconds to temp value (assuming 24 FPS)
71
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
72
  if(image):
73
- cropped_image = center_crop(image, 1280, 720)
74
- resized_image = cropped_image.resize((1280, 720))
75
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
76
  frames = model.generate_i2v(
77
  prompt=prompt,
78
  input_image=resized_image,
79
  num_inference_steps=[10, 10, 10],
80
  temp=temp,
81
- guidance_scale=7.0,
82
  video_guidance_scale=video_guidance_scale,
83
  output_type="pil",
84
  save_memory=True,
@@ -89,8 +85,8 @@ def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guida
89
  prompt=prompt,
90
  num_inference_steps=[20, 20, 20],
91
  video_num_inference_steps=[10, 10, 10],
92
- height=768,
93
- width=1280,
94
  temp=temp,
95
  guidance_scale=guidance_scale,
96
  video_guidance_scale=video_guidance_scale,
@@ -98,14 +94,14 @@ def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guida
98
  save_memory=True,
99
  )
100
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
101
- export_to_video(frames, output_path, fps=8 if is_canonical else 24)
102
  return output_path
103
 
104
  # Gradio interface
105
  with gr.Blocks() as demo:
106
- gr.Markdown("# Pyramid Flow")
107
- gr.Markdown("Pyramid Flow is a training-efficient Autoregressive Video Generation model based on Flow Matching. It is trained only on open-source datasets within 20.7k A100 GPU hours")
108
- gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)]")
109
 
110
  with gr.Row():
111
  with gr.Column():
@@ -113,8 +109,8 @@ with gr.Blocks() as demo:
113
  i2v_image = gr.Image(type="pil", label="Input Image")
114
  t2v_prompt = gr.Textbox(label="Prompt")
115
  with gr.Accordion("Advanced settings", open=False):
116
- t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)", visible=not is_canonical)
117
- t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
118
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
119
  t2v_generate_btn = gr.Button("Generate Video")
120
  with gr.Column():
 
10
  import spaces
11
  import uuid
12
 
 
 
 
13
  is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False
14
 
15
  # Constants
16
  MODEL_PATH = "pyramid-flow-model"
17
  MODEL_REPO = "rain1011/pyramid-flow-sd3"
18
+ MODEL_VARIANT = "diffusion_transformer_384p"
19
  MODEL_DTYPE = "bf16"
20
 
21
  def center_crop(image, target_width, target_height):
 
63
  # Text-to-video generation function
64
  @spaces.GPU(duration=120)
65
  def generate_video(prompt, image=None, duration=5, guidance_scale=9, video_guidance_scale=5, progress=gr.Progress(track_tqdm=True)):
66
+ multiplier = 3
67
+ temp = int(duration * multiplier) + 1 # Convert seconds to temp value (assuming 24 FPS)
68
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
69
  if(image):
70
+ cropped_image = center_crop(image, 640, 384)
71
+ resized_image = cropped_image.resize((640, 384))
72
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
73
  frames = model.generate_i2v(
74
  prompt=prompt,
75
  input_image=resized_image,
76
  num_inference_steps=[10, 10, 10],
77
  temp=temp,
 
78
  video_guidance_scale=video_guidance_scale,
79
  output_type="pil",
80
  save_memory=True,
 
85
  prompt=prompt,
86
  num_inference_steps=[20, 20, 20],
87
  video_num_inference_steps=[10, 10, 10],
88
+ height=384,
89
+ width=640,
90
  temp=temp,
91
  guidance_scale=guidance_scale,
92
  video_guidance_scale=video_guidance_scale,
 
94
  save_memory=True,
95
  )
96
  output_path = f"{str(uuid.uuid4())}_output_video.mp4"
97
+ export_to_video(frames, output_path, fps=24)
98
  return output_path
99
 
100
  # Gradio interface
101
  with gr.Blocks() as demo:
102
+ gr.Markdown("# Pyramid Flow 384p demo")
103
+ gr.Markdown("Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours")
104
+ gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)] [[Project Page]](https://pyramid-flow.github.io)")
105
 
106
  with gr.Row():
107
  with gr.Column():
 
109
  i2v_image = gr.Image(type="pil", label="Input Image")
110
  t2v_prompt = gr.Textbox(label="Prompt")
111
  with gr.Accordion("Advanced settings", open=False):
112
+ t2v_duration = gr.Slider(minimum=1, maximum=5, value=5, step=1, label="Duration (seconds)", visible=not is_canonical)
113
+ t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=7, step=0.1, label="Guidance Scale")
114
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
115
  t2v_generate_btn = gr.Button("Generate Video")
116
  with gr.Column():