File size: 13,782 Bytes
8fd2f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879e310
8fd2f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c1dddf
8fd2f2f
 
 
 
4358fb0
8fd2f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import numpy as np
from lib.farancia import IImage
from PIL import Image
from i2v_enhance import i2v_enhance_interface
from dataloader.dataset_factory import SingleImageDatasetFactory
from pytorch_lightning import Trainer, LightningDataModule, seed_everything
import math
from diffusion_trainer import streaming_svd as streaming_svd_model
import torch
from safetensors.torch import load_file as load_safetensors
from utils.loader import download_ckpt
from functools import partial
from dataloader.video_data_module import VideoDataModule
from pathlib import Path
from pytorch_lightning.cli import LightningCLI, LightningArgumentParser
from pytorch_lightning import LightningModule
import sys
import os
from copy import deepcopy
from utils.aux import ensure_annotation_class
from diffusers import FluxPipeline
from typing import Union


class CustomCLI(LightningCLI):

    def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
        parser.add_argument("--image", type=Path,
                            help="Path to the input image(s)")
        parser.add_argument("--output", type=Path,
                            help="Path to the output folder")
        parser.add_argument("--num_frames", type=int, default=100,
                            help="Number of frames to generate.")
        parser.add_argument("--out_fps", type=int, default=24,
                            help="Framerate of the generated video.")
        parser.add_argument("--chunk_size", type=int, default=38,
                            help="Chunk size used in randomized blending.")
        parser.add_argument("--overlap_size", type=int, default=12,
                            help="Overlap size used in randomized blending.")
        parser.add_argument("--use_randomized_blending", action="store_true",
                            help="Wether to use randomized blending.")
        parser.add_argument("--use_fp16", action="store_true",
                            help="Wether to use float16 quantization.")
        parser.add_argument("--prompt", type=str, default = "")

        return parser


class StreamingSVD():

    def __init__(self, load_argv = True) -> None:

        call_fol = Path(os.getcwd()).resolve()

        code_fol = Path(__file__).resolve().parent
        code_fol = os.path.relpath(code_fol, call_fol)
        argv_backup = deepcopy(sys.argv)

        if "--use_fp16" in sys.argv:
            os.environ["STREAMING_USE_FP16"] = "True"
        sys.argv = [__file__]
        sys.argv.extend(self.__config_call(argv_backup[1:] if load_argv else [], code_fol))
        cli = CustomCLI(LightningModule, run=False, subclass_mode_model=True, parser_kwargs={
            "parser_mode": "omegaconf"}, save_config_callback=None)
        self.__init_models(cli)
        self.__init_fields(cli)

        sys.argv = argv_backup

    def __init_models(self, cli):
        model = cli.model
        trainer = cli.trainer

        path = download_ckpt(
            local_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_local,
            global_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_global
        )
        if path.endswith(".safetensors"):
            ckpt = load_safetensors(path)
        else:
            ckpt = torch.load(path, map_location="cpu")["state_dict"]

        model.load_state_dict(ckpt)  # load trained model
        trainer = cli.trainer
        data_module_loader = partial(VideoDataModule, workers=0)
        vfi = i2v_enhance_interface.vfi_init(model.vfi)

        enhance_pipeline, enhance_generator = i2v_enhance_interface.i2v_enhance_init(
            model.i2v_enhance)
        enhance_pipeline.unet.enable_forward_chunking(chunk_size=1, dim=1)
        flux_pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
        flux_pipe.enable_model_cpu_offload()

        # store of objects
        model: streaming_svd_model
        data_module_loader: LightningDataModule
        trainer: Trainer

        self.model = model
        self.vfi = vfi
        self.data_module_loader = data_module_loader
        self.enhance_pipeline = enhance_pipeline
        self.enhance_generator = enhance_generator
        self.trainer = trainer
        self.flux_pipe = flux_pipe

    def __init_fields(self, cli):
        self.input_path = cli.config["image"]
        self.output_path = cli.config["output"]
        self.num_frames = cli.config["num_frames"]
        self.fps = cli.config["out_fps"]
        self.use_randomized_blending = cli.config["use_randomized_blending"]
        self.chunk_size = cli.config["chunk_size"]
        self.overlap_size = cli.config["overlap_size"]
        self.prompt = cli.config["prompt"]

    def __config_call(self, config_cmds, code_fol):
        cmds = [cmd for cmd in config_cmds if len(cmd) > 0]
        cmd_init = []
        cmd_init.append(f"--config")
        cmd_init.append(f"{code_fol}/config.yaml")
        if "--use_fp16" in config_cmds:
            cmd_init.append(f"--trainer.precision=16-true")
        cmd_init.extend(cmds)
        return cmd_init

    # interfaces

    def streaming_t2v(self, prompt, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33):
        image = self.text_to_image(prompt=prompt)
        return self.streaming_i2v(image, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=seed)

    def streaming_i2v(self, image, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33) -> np.array:
        video, scaled_outpainted_image, expanded_size = self.image_to_video(
            image, num_frames=(num_frames+1)//2, seed=seed)
        max_memory_allocated = torch.cuda.max_memory_allocated()
        print(
            f"max_memory_allocated at image_to_video: {max_memory_allocated}")
        video = self.enhance_video(image=IImage(scaled_outpainted_image).numpy(), video=video, chunk_size=chunk_size, overlap_size=overlap_size,
                                   use_randomized_blending=use_randomized_blending, seed=seed)
        video = self.interpolate_video(video, dest_num_frames=num_frames)

        # scale/crop back to input size
        if image.shape[0] == 1:
            image = image[0]
        video = IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop((0, 0, image.shape[1], image.shape[0])).numpy()

        print(
            f"max_memory_allocated at interpolate_video: {max_memory_allocated}")
        return video

    # StreamingSVD pipeline
    def streaming(self, image: np.ndarray):

        datamodule = self.data_module_loader(predict_dataset_factory=SingleImageDatasetFactory(
            file=image))
        self.trainer.predict(model=self.model, datamodule=datamodule)
        video = self.trainer.generated_video
        expanded_size = self.trainer.expanded_size
        scaled_outpainted_image = self.trainer.scaled_outpainted_image
    
        return video, scaled_outpainted_image, expanded_size

    def image_to_video(self, image: Union[np.ndarray, str], num_frames: int, seed=33) -> tuple[np.ndarray,Image,list[int]]:
        seed_everything(seed)
        if isinstance(image, str):
            image = IImage.open(image).numpy()

        if image.shape[0] == 1 and image.ndim == 4:
            image = image[0]

        assert image.shape[-1] == 3 and image.shape[0] > 1, "Wrong image format. Assuming shape [H W C], with C = 3."
        assert image.dtype == "uint8", "Wrong dtype for input image. Must be uint8."
        # compute necessary number of chunks
        n_cond_frames = self.model.inference_params.num_conditional_frames
        n_frames_per_gen = self.model.sampler.guider.num_frames
        n_autoregressive_generations = math.ceil(
            (num_frames - n_frames_per_gen) / (n_frames_per_gen - n_cond_frames))
        self.model.inference_params.n_autoregressive_generations = int(
            n_autoregressive_generations)

        print(" --- STREAMING ----- [START]")
        video, scaled_outpainted_image, expanded_size = self.streaming(
            image=image)
        print(f" --- STREAMING ----- [FINISHED]: {video.shape}")

        video = video[:num_frames]
        return video, scaled_outpainted_image, expanded_size

    def enhance_video(self, video: Union[np.ndarray, str], image: np.ndarray = None, chunk_size = 38, overlap_size=12, strength=0.97, use_randomized_blending=False, seed=33,num_frames = None):

        seed_everything(seed)
        if isinstance(video, str):
            video = IImage.open(video).numpy()
            if image is None:
                image = video[0]
            print("ATTENTION: We take first frame of previous stage as input frame for enhance. ")

        if num_frames is not None:
            video = video[:num_frames, ...]

        if not use_randomized_blending:
            chunk_size = video.shape[0]
            overlap_size = 0
        if image.ndim == 3:
            image = image[None]
        image = [Image.fromarray(
            IImage(image, vmin=0, vmax=255).resize((720, 1280)).numpy()[0])]

        video = np.split(video, video.shape[0])
        video = [Image.fromarray(frame[0]).resize((1280, 720))
                 for frame in video]

        print(
            f"---- ENHANCE  ---- [START]. Video length = {len(video)}. Randomized Blending = {use_randomized_blending}. Chunk size = {chunk_size}. Overlap size = {overlap_size}.")
        video_enhanced = i2v_enhance_interface.i2v_enhance_process(
            image=image, video=video, pipeline=self.enhance_pipeline, generator=self.enhance_generator,
            chunk_size=chunk_size, overlap_size=overlap_size, strength=strength, use_randomized_blending=use_randomized_blending)
        video_enhanced = np.stack([np.asarray(frame)
                                   for frame in video_enhanced], axis=0)
        print("---- ENHANCE  ---- [FINISHED].")
        return video_enhanced

    def interpolate_video(self, video: np.ndarray, dest_num_frames: int):
        video = np.split(video, len(video))
        video = [frame[0] for frame in video]

        print(" ---- VFI  ---- [START]")
        self.vfi.device()
        video_vfi = i2v_enhance_interface.vfi_process(
            video=video, vfi=self.vfi, video_len=dest_num_frames)
        video_vfi = np.stack([np.asarray(frame)
                              for frame in video_vfi], axis=0)
        self.vfi.unload()
        print(f"---- VFI  ---- [FINISHED]. Video length = {len(video_vfi)}")
        return video_vfi

    # T2I method

    def text_to_image(self, prompt, seed=33):
        # FLUX
        print("[FLUX] Generating image from text prompt")
        out = self.flux_pipe(
            prompt=prompt,
            guidance_scale=0,
            height=720,
            width=1280,
            num_inference_steps=4,
            max_sequence_length=256,
            generator=torch.Generator(
                device=self.model.device).manual_seed(seed),
        ).images[0]
        print("[FLUX] Finished")
        return np.array(out)


if __name__ == "__main__":

    @ensure_annotation_class
    def get_input_data(input_path: Path = None):
        if input_path.is_file():
            inputs = [input_path]
        else:
            suffixes = ["*.[jJ][pP][gG]", "*.[pP][nN][gG]",
                        "*.[jJ][pP][eE][gG]", "*.[bB][mM][pP]"]  # loading png, jpg and bmp images
            inputs = []
            for suffix in suffixes:
                inputs.extend(list(input_path.glob(suffix)))
        assert len(
            inputs) > 0, "No images found. Please make sure the input path is correct."

        img_as_np = [IImage.open(input).numpy() for input in inputs]
        return zip(img_as_np, inputs)

    streaming_svd = StreamingSVD()
    num_frames = streaming_svd.num_frames
    chunk_size = streaming_svd.chunk_size
    overlap_size = streaming_svd.overlap_size
    use_randomized_blending = streaming_svd.use_randomized_blending
    if not use_randomized_blending:
        chunk_size = (num_frames + 1)//2
        overlap_size = 0
    result_path = Path(streaming_svd.output_path)
    seed = 33

    assert result_path.exists() is False or result_path.is_dir(
    ), "Output path must be the path to a folder."
    prompt = streaming_svd.prompt
    if len(prompt) == 0:
        for img, img_path in get_input_data(streaming_svd.input_path):
            video = streaming_svd.streaming_i2v(
                image=img, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33)
            if not result_path.exists():
                result_path.mkdir(parents=True)
            result_file = result_path / (img_path.stem+".mp4")
            result_file = result_file.as_posix()
            IImage(video, vmin=0, vmax=255).setFps(
                streaming_svd.fps).save(result_file)
            print(f"Video created at: {result_file}")
    else:
        video = streaming_svd.streaming_t2v(
            prompt=prompt, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33)
        prompt_file = prompt.replace(" ", "_").replace(
            ".", "_").replace("/", "_").replace(":", "_")
        prompt_file = prompt_file[:15]
        if not result_path.exists():
                result_path.mkdir(parents=True)
        result_file = result_path / (prompt_file+".mp4")
        result_file = result_file.as_posix()
        IImage(video, vmin=0, vmax=255).setFps(
            streaming_svd.fps).save(result_file)
        print(f"Video created at: {result_file}")