Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper | |
from einops import rearrange, repeat | |
class StreamingWrapper(OpenAIWrapper): | |
""" | |
Modelwrapper for StreamingSVD, which holds the CAM model and the base model | |
""" | |
def __init__(self, diffusion_model, controlnet, num_frame_conditioning: int, compile_model: bool = False, pipeline_offloading: bool = False): | |
super().__init__(diffusion_model=diffusion_model, | |
compile_model=compile_model) | |
self.controlnet = controlnet | |
self.num_frame_conditioning = num_frame_conditioning | |
self.pipeline_offloading = pipeline_offloading | |
if pipeline_offloading: | |
raise NotImplementedError( | |
"Pipeline offloading for StreamingI2V not implemented yet.") | |
def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs): | |
batch_size = kwargs.pop("batch_size") | |
# We apply the controlnet model only to the control frames. | |
def reduce_to_cond_frames(input): | |
input = rearrange(input, "(B F) ... -> B F ...", B=batch_size) | |
input = input[:, :self.num_frame_conditioning] | |
return rearrange(input, "B F ... -> (B F) ...") | |
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) | |
x_ctrl = reduce_to_cond_frames(x) | |
t_ctrl = reduce_to_cond_frames(t) | |
context = c.get("crossattn", None) | |
# controlnet is not using APM so we remove potentially additional tokens | |
context_ctrl = context[:, :1] | |
context_ctrl = reduce_to_cond_frames(context_ctrl) | |
y = c.get("vector", None) | |
y_ctrl = reduce_to_cond_frames(y) | |
num_video_frames = kwargs.pop("num_video_frames") | |
image_only_indicator = kwargs.pop("image_only_indicator") | |
ctrl_img_enc_frames = repeat( | |
kwargs['ctrl_frames'], "B ... -> (2 B) ... ") | |
controlnet_cond = rearrange( | |
ctrl_img_enc_frames, "B F ... -> (B F) ...") | |
if self.diffusion_model.controlnet_mode: | |
hs_control_input, hs_control_mid = self.controlnet(x=x_ctrl, # video latent | |
timesteps=t_ctrl, # timestep | |
context=context_ctrl, # clip image conditioning | |
y=y_ctrl, # conditionigs, e.g. fps | |
controlnet_cond=controlnet_cond, # control frames | |
num_video_frames=self.num_frame_conditioning, | |
num_video_frames_conditional=self.num_frame_conditioning, | |
image_only_indicator=image_only_indicator[:, | |
:self.num_frame_conditioning] | |
) | |
else: | |
hs_control_input = None | |
hs_control_mid = None | |
kwargs["hs_control_input"] = hs_control_input | |
kwargs["hs_control_mid"] = hs_control_mid | |
out = self.diffusion_model( | |
x=x, | |
timesteps=t, | |
context=context, # must be (B F) T C | |
y=y, # must be (B F) 768 | |
num_video_frames=num_video_frames, | |
num_conditional_frames=self.num_frame_conditioning, | |
image_only_indicator=image_only_indicator, | |
hs_control_input=hs_control_input, | |
hs_control_mid=hs_control_mid, | |
) | |
return out | |