File size: 3,765 Bytes
8fd2f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

import torch
from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
from einops import rearrange, repeat


class StreamingWrapper(OpenAIWrapper):
    """
    Modelwrapper for StreamingSVD, which holds the CAM model and the base model

    """

    def __init__(self, diffusion_model, controlnet, num_frame_conditioning: int, compile_model: bool = False, pipeline_offloading: bool = False):
        super().__init__(diffusion_model=diffusion_model,
                         compile_model=compile_model)
        self.controlnet = controlnet
        self.num_frame_conditioning = num_frame_conditioning
        self.pipeline_offloading = pipeline_offloading
        if pipeline_offloading:
            raise NotImplementedError(
                "Pipeline offloading for StreamingI2V not implemented yet.")

    def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs):

        batch_size = kwargs.pop("batch_size")

        # We apply the controlnet model only to the control frames.
        def reduce_to_cond_frames(input):
            input = rearrange(input, "(B F) ... -> B F ...", B=batch_size)
            input = input[:, :self.num_frame_conditioning]
            return rearrange(input, "B F ... -> (B F) ...")

        x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
        x_ctrl = reduce_to_cond_frames(x)
        t_ctrl = reduce_to_cond_frames(t)

        context = c.get("crossattn", None)
        # controlnet is not using APM so we remove potentially additional tokens
        context_ctrl = context[:, :1]
        context_ctrl = reduce_to_cond_frames(context_ctrl)
        y = c.get("vector", None)
        y_ctrl = reduce_to_cond_frames(y)
        num_video_frames = kwargs.pop("num_video_frames")
        image_only_indicator = kwargs.pop("image_only_indicator")
        ctrl_img_enc_frames = repeat(
            kwargs['ctrl_frames'], "B ... -> (2 B) ... ")
        controlnet_cond = rearrange(
            ctrl_img_enc_frames, "B F ... -> (B F) ...")

        if self.diffusion_model.controlnet_mode:
            hs_control_input, hs_control_mid = self.controlnet(x=x_ctrl,  # video latent
                                                               timesteps=t_ctrl,  # timestep
                                                               context=context_ctrl,  # clip image conditioning
                                                               y=y_ctrl,  # conditionigs, e.g. fps
                                                               controlnet_cond=controlnet_cond,  # control frames
                                                               num_video_frames=self.num_frame_conditioning,
                                                               num_video_frames_conditional=self.num_frame_conditioning,
                                                               image_only_indicator=image_only_indicator[:,
                                                                                                         :self.num_frame_conditioning]
                                                               )
        else:
            hs_control_input = None
            hs_control_mid = None
        kwargs["hs_control_input"] = hs_control_input
        kwargs["hs_control_mid"] = hs_control_mid

        out = self.diffusion_model(
            x=x,
            timesteps=t,
            context=context,  # must be (B F) T C
            y=y,  # must be (B F) 768
            num_video_frames=num_video_frames,
            num_conditional_frames=self.num_frame_conditioning,
            image_only_indicator=image_only_indicator,
            hs_control_input=hs_control_input,
            hs_control_mid=hs_control_mid,
        )
        return out