AngeT10 commited on
Commit
8f832c3
1 Parent(s): 83b2157

Upload run_inference.py

Browse files
Files changed (1) hide show
  1. run_inference.py +181 -0
run_inference.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+
11
+ from showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, TextToVideoIFSuperResolutionPipeline
12
+ from showone.pipelines.pipeline_t2v_base_pixel import tensor2vid
13
+ from showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond
14
+
15
+
16
+ # Base Model
17
+ # When using "showlab/show-1-base-0.0", it's advisable to increase the number of inference steps (e.g., 100)
18
+ # and opt for a larger guidance scale (e.g., 12.0) to enhance visual quality.
19
+ pretrained_model_path = "showlab/show-1-base"
20
+ pipe_base = TextToVideoIFPipeline.from_pretrained(
21
+ pretrained_model_path,
22
+ torch_dtype=torch.float16,
23
+ variant="fp16"
24
+ )
25
+ pipe_base.enable_model_cpu_offload()
26
+
27
+ # Interpolation Model
28
+ pretrained_model_path = "showlab/show-1-interpolation"
29
+ pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained(
30
+ pretrained_model_path,
31
+ torch_dtype=torch.float16,
32
+ variant="fp16"
33
+ )
34
+ pipe_interp_1.enable_model_cpu_offload()
35
+
36
+ # Super-Resolution Model 1
37
+ # Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0
38
+ pretrained_model_path = "DeepFloyd/IF-II-L-v1.0"
39
+ pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained(
40
+ pretrained_model_path,
41
+ text_encoder=None,
42
+ torch_dtype=torch.float16,
43
+ variant="fp16"
44
+ )
45
+ pipe_sr_1_image.enable_model_cpu_offload()
46
+
47
+ pretrained_model_path = "showlab/show-1-sr1"
48
+ pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained(
49
+ pretrained_model_path,
50
+ torch_dtype=torch.float16
51
+ )
52
+ pipe_sr_1_cond.enable_model_cpu_offload()
53
+
54
+ # Super-Resolution Model 2
55
+ pretrained_model_path = "showlab/show-1-sr2"
56
+ pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained(
57
+ pretrained_model_path,
58
+ torch_dtype=torch.float16
59
+ )
60
+ pipe_sr_2.enable_model_cpu_offload()
61
+ pipe_sr_2.enable_vae_slicing()
62
+
63
+
64
+ # Inference
65
+ prompt = "A burning lamborghini driving on rainbow."
66
+ output_dir = "./outputs/example"
67
+ negative_prompt = "low resolution, blur"
68
+
69
+ seed = 345
70
+ os.makedirs(output_dir, exist_ok=True)
71
+
72
+ # Text embeds
73
+ prompt_embeds, negative_embeds = pipe_base.encode_prompt(prompt)
74
+
75
+ # Keyframes generation (8x64x40, 2fps)
76
+ video_frames = pipe_base(
77
+ prompt_embeds=prompt_embeds,
78
+ negative_prompt_embeds=negative_embeds,
79
+ num_frames=8,
80
+ height=40,
81
+ width=64,
82
+ num_inference_steps=75,
83
+ guidance_scale=9.0,
84
+ generator=torch.manual_seed(seed),
85
+ output_type="pt"
86
+ ).frames
87
+
88
+ imageio.mimsave(f"{output_dir}/{prompt}_base.gif", tensor2vid(video_frames.clone()), fps=2)
89
+
90
+ # Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps)
91
+ bsz, channel, num_frames, height, width = video_frames.shape
92
+ new_num_frames = 3 * (num_frames - 1) + num_frames
93
+ new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width),
94
+ dtype=video_frames.dtype, device=video_frames.device)
95
+ new_video_frames[:, :, torch.arange(0, new_num_frames, 4), ...] = video_frames
96
+ init_noise = randn_tensor((bsz, channel, 5, height, width), dtype=video_frames.dtype,
97
+ device=video_frames.device, generator=torch.manual_seed(seed))
98
+
99
+ for i in range(num_frames - 1):
100
+ batch_i = torch.zeros((bsz, channel, 5, height, width), dtype=video_frames.dtype, device=video_frames.device)
101
+ batch_i[:, :, 0, ...] = video_frames[:, :, i, ...]
102
+ batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...]
103
+ batch_i = pipe_interp_1(
104
+ pixel_values=batch_i,
105
+ prompt_embeds=prompt_embeds,
106
+ negative_prompt_embeds=negative_embeds,
107
+ num_frames=batch_i.shape[2],
108
+ height=40,
109
+ width=64,
110
+ num_inference_steps=75,
111
+ guidance_scale=4.0,
112
+ generator=torch.manual_seed(seed),
113
+ output_type="pt",
114
+ init_noise=init_noise,
115
+ cond_interpolation=True,
116
+ ).frames
117
+
118
+ new_video_frames[:, :, i * 4:i * 4 + 5, ...] = batch_i
119
+
120
+ video_frames = new_video_frames
121
+ imageio.mimsave(f"{output_dir}/{prompt}_interp.gif", tensor2vid(video_frames.clone()), fps=8)
122
+
123
+ # Super-resolution 1 (29x64x40 -> 29x256x160)
124
+ bsz, channel, num_frames, height, width = video_frames.shape
125
+ window_size, stride = 8, 7
126
+ new_video_frames = torch.zeros(
127
+ (bsz, channel, num_frames, height * 4, width * 4),
128
+ dtype=video_frames.dtype,
129
+ device=video_frames.device)
130
+ for i in range(0, num_frames - window_size + 1, stride):
131
+ batch_i = video_frames[:, :, i:i + window_size, ...]
132
+ all_frame_cond = None
133
+
134
+ if i == 0:
135
+ first_frame_cond = pipe_sr_1_image(
136
+ image=video_frames[:, :, 0, ...],
137
+ prompt_embeds=prompt_embeds,
138
+ negative_prompt_embeds=negative_embeds,
139
+ height=height * 4,
140
+ width=width * 4,
141
+ num_inference_steps=70,
142
+ guidance_scale=4.0,
143
+ noise_level=150,
144
+ generator=torch.manual_seed(seed),
145
+ output_type="pt"
146
+ ).images
147
+ first_frame_cond = first_frame_cond.unsqueeze(2)
148
+ else:
149
+ first_frame_cond = new_video_frames[:, :, i:i + 1, ...]
150
+
151
+ batch_i = pipe_sr_1_cond(
152
+ image=batch_i,
153
+ prompt_embeds=prompt_embeds,
154
+ negative_prompt_embeds=negative_embeds,
155
+ first_frame_cond=first_frame_cond,
156
+ height=height * 4,
157
+ width=width * 4,
158
+ num_inference_steps=125,
159
+ guidance_scale=7.0,
160
+ noise_level=250,
161
+ generator=torch.manual_seed(seed),
162
+ output_type="pt"
163
+ ).frames
164
+ new_video_frames[:, :, i:i + window_size, ...] = batch_i
165
+
166
+ video_frames = new_video_frames
167
+ imageio.mimsave(f"{output_dir}/{prompt}_sr1.gif", tensor2vid(video_frames.clone()), fps=8)
168
+
169
+ # Super-resolution 2 (29x256x160 -> 29x576x320)
170
+ video_frames = [Image.fromarray(frame).resize((576, 320)) for frame in tensor2vid(video_frames.clone())]
171
+ video_frames = pipe_sr_2(
172
+ prompt,
173
+ negative_prompt=negative_prompt,
174
+ video=video_frames,
175
+ strength=0.8,
176
+ num_inference_steps=50,
177
+ generator=torch.manual_seed(seed),
178
+ output_type="pt"
179
+ ).frames
180
+
181
+ imageio.mimsave(f"{output_dir}/{prompt}.gif", tensor2vid(video_frames.clone()), fps=8)