import numpy as np from lib.farancia import IImage from PIL import Image from i2v_enhance import i2v_enhance_interface from dataloader.dataset_factory import SingleImageDatasetFactory from pytorch_lightning import Trainer, LightningDataModule, seed_everything import math from diffusion_trainer import streaming_svd as streaming_svd_model import torch from safetensors.torch import load_file as load_safetensors from utils.loader import download_ckpt from functools import partial from dataloader.video_data_module import VideoDataModule from pathlib import Path from pytorch_lightning.cli import LightningCLI, LightningArgumentParser from pytorch_lightning import LightningModule import sys import os from copy import deepcopy from utils.aux import ensure_annotation_class from diffusers import FluxPipeline from typing import Union class CustomCLI(LightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_argument("--image", type=Path, help="Path to the input image(s)") parser.add_argument("--output", type=Path, help="Path to the output folder") parser.add_argument("--num_frames", type=int, default=200, help="Number of frames to generate.") parser.add_argument("--out_fps", type=int, default=24, help="Framerate of the generated video.") parser.add_argument("--chunk_size", type=int, default=38, help="Chunk size used in randomized blending.") parser.add_argument("--overlap_size", type=int, default=12, help="Overlap size used in randomized blending.") parser.add_argument("--use_randomized_blending", action="store_true", help="Wether to use randomized blending.") parser.add_argument("--use_fp16", action="store_true", help="Wether to use float16 quantization.") parser.add_argument("--prompt", type=str, default = "") return parser class StreamingSVD(): def __init__(self, load_argv = True) -> None: call_fol = Path(os.getcwd()).resolve() code_fol = Path(__file__).resolve().parent code_fol = os.path.relpath(code_fol, call_fol) argv_backup = deepcopy(sys.argv) if "--use_fp16" in sys.argv: os.environ["STREAMING_USE_FP16"] = "True" sys.argv = [__file__] sys.argv.extend(self.__config_call(argv_backup[1:] if load_argv else [], code_fol)) cli = CustomCLI(LightningModule, run=False, subclass_mode_model=True, parser_kwargs={ "parser_mode": "omegaconf"}, save_config_callback=None) self.__init_models(cli) self.__init_fields(cli) sys.argv = argv_backup def __init_models(self, cli): model = cli.model trainer = cli.trainer path = download_ckpt( local_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_local, global_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_global ) if path.endswith(".safetensors"): ckpt = load_safetensors(path) else: ckpt = torch.load(path, map_location="cpu")["state_dict"] model.load_state_dict(ckpt) # load trained model trainer = cli.trainer data_module_loader = partial(VideoDataModule, workers=0) vfi = i2v_enhance_interface.vfi_init(model.vfi) enhance_pipeline, enhance_generator = i2v_enhance_interface.i2v_enhance_init( model.i2v_enhance) flux_pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) flux_pipe.enable_model_cpu_offload() # store of objects model: streaming_svd_model data_module_loader: LightningDataModule trainer: Trainer self.model = model self.vfi = vfi self.data_module_loader = data_module_loader self.enhance_pipeline = enhance_pipeline self.enhance_generator = enhance_generator self.trainer = trainer self.flux_pipe = flux_pipe def __init_fields(self, cli): self.input_path = cli.config["image"] self.output_path = cli.config["output"] self.num_frames = cli.config["num_frames"] self.fps = cli.config["out_fps"] self.use_randomized_blending = cli.config["use_randomized_blending"] self.chunk_size = cli.config["chunk_size"] self.overlap_size = cli.config["overlap_size"] self.prompt = cli.config["prompt"] def __config_call(self, config_cmds, code_fol): cmds = [cmd for cmd in config_cmds if len(cmd) > 0] cmd_init = [] cmd_init.append(f"--config") cmd_init.append(f"{code_fol}/config.yaml") if "--use_fp16" in config_cmds: cmd_init.append(f"--trainer.precision=16-true") cmd_init.extend(cmds) return cmd_init # interfaces def streaming_t2v(self, prompt, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33): image = self.text_to_image(prompt=prompt) return self.streaming_i2v(image, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=seed) def streaming_i2v(self, image, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33) -> np.array: video, scaled_outpainted_image, expanded_size = self.image_to_video( image, num_frames=(num_frames+1)//2, seed=seed) max_memory_allocated = torch.cuda.max_memory_allocated() print( f"max_memory_allocated at image_to_video: {max_memory_allocated}") video = self.enhance_video(image=IImage(scaled_outpainted_image).numpy(), video=video, chunk_size=chunk_size, overlap_size=overlap_size, use_randomized_blending=use_randomized_blending, seed=seed) video = self.interpolate_video(video, dest_num_frames=num_frames) # scale/crop back to input size if image.shape[0] == 1: image = image[0] video = IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop((0, 0, image.shape[1], image.shape[0])).numpy() print( f"max_memory_allocated at interpolate_video: {max_memory_allocated}") return video # StreamingSVD pipeline def streaming(self, image: np.ndarray): datamodule = self.data_module_loader(predict_dataset_factory=SingleImageDatasetFactory( file=image)) self.trainer.predict(model=self.model, datamodule=datamodule) video = self.trainer.generated_video expanded_size = self.trainer.expanded_size scaled_outpainted_image = self.trainer.scaled_outpainted_image return video, scaled_outpainted_image, expanded_size def image_to_video(self, image: Union[np.ndarray, str], num_frames: int, seed=33) -> tuple[np.ndarray,Image,list[int]]: seed_everything(seed) if isinstance(image, str): image = IImage.open(image).numpy() if image.shape[0] == 1 and image.ndim == 4: image = image[0] assert image.shape[-1] == 3 and image.shape[0] > 1, "Wrong image format. Assuming shape [H W C], with C = 3." assert image.dtype == "uint8", "Wrong dtype for input image. Must be uint8." # compute necessary number of chunks n_cond_frames = self.model.inference_params.num_conditional_frames n_frames_per_gen = self.model.sampler.guider.num_frames n_autoregressive_generations = math.ceil( (num_frames - n_frames_per_gen) / (n_frames_per_gen - n_cond_frames)) self.model.inference_params.n_autoregressive_generations = int( n_autoregressive_generations) print(" --- STREAMING ----- [START]") video, scaled_outpainted_image, expanded_size = self.streaming( image=image) print(f" --- STREAMING ----- [FINISHED]: {video.shape}") video = video[:num_frames] return video, scaled_outpainted_image, expanded_size def enhance_video(self, video: Union[np.ndarray, str], image: np.ndarray = None, chunk_size = 38, overlap_size=12, strength=0.97, use_randomized_blending=False, seed=33,num_frames = None): seed_everything(seed) if isinstance(video, str): video = IImage.open(video).numpy() if image is None: image = video[0] print("ATTENTION: We take first frame of previous stage as input frame for enhance. ") if num_frames is not None: video = video[:num_frames, ...] if not use_randomized_blending: chunk_size = video.shape[0] overlap_size = 0 if image.ndim == 3: image = image[None] image = [Image.fromarray( IImage(image, vmin=0, vmax=255).resize((720, 1280)).numpy()[0])] video = np.split(video, video.shape[0]) video = [Image.fromarray(frame[0]).resize((1280, 720)) for frame in video] print( f"---- ENHANCE ---- [START]. Video length = {len(video)}. Randomized Blending = {use_randomized_blending}. Chunk size = {chunk_size}. Overlap size = {overlap_size}.") video_enhanced = i2v_enhance_interface.i2v_enhance_process( image=image, video=video, pipeline=self.enhance_pipeline, generator=self.enhance_generator, chunk_size=chunk_size, overlap_size=overlap_size, strength=strength, use_randomized_blending=use_randomized_blending) video_enhanced = np.stack([np.asarray(frame) for frame in video_enhanced], axis=0) print("---- ENHANCE ---- [FINISHED].") return video_enhanced def interpolate_video(self, video: np.ndarray, dest_num_frames: int): video = np.split(video, len(video)) video = [frame[0] for frame in video] print(" ---- VFI ---- [START]") self.vfi.device() video_vfi = i2v_enhance_interface.vfi_process( video=video, vfi=self.vfi, video_len=dest_num_frames) video_vfi = np.stack([np.asarray(frame) for frame in video_vfi], axis=0) self.vfi.unload() print(f"---- VFI ---- [FINISHED]. Video length = {len(video_vfi)}") return video_vfi # T2I method def text_to_image(self, prompt, seed=33): # FLUX print("[FLUX] Generating image from text prompt") out = self.flux_pipe( prompt=prompt, guidance_scale=0, height=720, width=1280, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator( device=self.model.device).manual_seed(seed), ).images[0] print("[FLUX] Finished") return np.array(out) if __name__ == "__main__": @ensure_annotation_class def get_input_data(input_path: Path = None): if input_path.is_file(): inputs = [input_path] else: suffixes = ["*.[jJ][pP][gG]", "*.[pP][nN][gG]", "*.[jJ][pP][eE][gG]", "*.[bB][mM][pP]"] # loading png, jpg and bmp images inputs = [] for suffix in suffixes: inputs.extend(list(input_path.glob(suffix))) assert len( inputs) > 0, "No images found. Please make sure the input path is correct." img_as_np = [IImage.open(input).numpy() for input in inputs] return zip(img_as_np, inputs) streaming_svd = StreamingSVD() num_frames = streaming_svd.num_frames chunk_size = streaming_svd.chunk_size overlap_size = streaming_svd.overlap_size use_randomized_blending = streaming_svd.use_randomized_blending if not use_randomized_blending: chunk_size = (num_frames + 1)//2 overlap_size = 0 result_path = Path(streaming_svd.output_path) seed = 33 assert result_path.exists() is False or result_path.is_dir( ), "Output path must be the path to a folder." prompt = streaming_svd.prompt if len(prompt) == 0: for img, img_path in get_input_data(streaming_svd.input_path): video = streaming_svd.streaming_i2v( image=img, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) if not result_path.exists(): result_path.mkdir(parents=True) result_file = result_path / (img_path.stem+".mp4") result_file = result_file.as_posix() IImage(video, vmin=0, vmax=255).setFps( streaming_svd.fps).save(result_file) print(f"Video created at: {result_file}") else: video = streaming_svd.streaming_t2v( prompt=prompt, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) prompt_file = prompt.replace(" ", "_").replace( ".", "_").replace("/", "_").replace(":", "_") prompt_file = prompt_file[:15] if not result_path.exists(): result_path.mkdir(parents=True) result_file = result_path / (prompt_file+".mp4") result_file = result_file.as_posix() IImage(video, vmin=0, vmax=255).setFps( streaming_svd.fps).save(result_file) print(f"Video created at: {result_file}")