StreamingSVD / streaming_svd_inference.py
hpoghos's picture
Update streaming_svd_inference.py
4358fb0 verified
import numpy as np
from lib.farancia import IImage
from PIL import Image
from i2v_enhance import i2v_enhance_interface
from dataloader.dataset_factory import SingleImageDatasetFactory
from pytorch_lightning import Trainer, LightningDataModule, seed_everything
import math
from diffusion_trainer import streaming_svd as streaming_svd_model
import torch
from safetensors.torch import load_file as load_safetensors
from utils.loader import download_ckpt
from functools import partial
from dataloader.video_data_module import VideoDataModule
from pathlib import Path
from pytorch_lightning.cli import LightningCLI, LightningArgumentParser
from pytorch_lightning import LightningModule
import sys
import os
from copy import deepcopy
from utils.aux import ensure_annotation_class
from diffusers import FluxPipeline
from typing import Union
class CustomCLI(LightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_argument("--image", type=Path,
help="Path to the input image(s)")
parser.add_argument("--output", type=Path,
help="Path to the output folder")
parser.add_argument("--num_frames", type=int, default=100,
help="Number of frames to generate.")
parser.add_argument("--out_fps", type=int, default=24,
help="Framerate of the generated video.")
parser.add_argument("--chunk_size", type=int, default=38,
help="Chunk size used in randomized blending.")
parser.add_argument("--overlap_size", type=int, default=12,
help="Overlap size used in randomized blending.")
parser.add_argument("--use_randomized_blending", action="store_true",
help="Wether to use randomized blending.")
parser.add_argument("--use_fp16", action="store_true",
help="Wether to use float16 quantization.")
parser.add_argument("--prompt", type=str, default = "")
return parser
class StreamingSVD():
def __init__(self, load_argv = True) -> None:
call_fol = Path(os.getcwd()).resolve()
code_fol = Path(__file__).resolve().parent
code_fol = os.path.relpath(code_fol, call_fol)
argv_backup = deepcopy(sys.argv)
if "--use_fp16" in sys.argv:
os.environ["STREAMING_USE_FP16"] = "True"
sys.argv = [__file__]
sys.argv.extend(self.__config_call(argv_backup[1:] if load_argv else [], code_fol))
cli = CustomCLI(LightningModule, run=False, subclass_mode_model=True, parser_kwargs={
"parser_mode": "omegaconf"}, save_config_callback=None)
self.__init_models(cli)
self.__init_fields(cli)
sys.argv = argv_backup
def __init_models(self, cli):
model = cli.model
trainer = cli.trainer
path = download_ckpt(
local_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_local,
global_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_global
)
if path.endswith(".safetensors"):
ckpt = load_safetensors(path)
else:
ckpt = torch.load(path, map_location="cpu")["state_dict"]
model.load_state_dict(ckpt) # load trained model
trainer = cli.trainer
data_module_loader = partial(VideoDataModule, workers=0)
vfi = i2v_enhance_interface.vfi_init(model.vfi)
enhance_pipeline, enhance_generator = i2v_enhance_interface.i2v_enhance_init(
model.i2v_enhance)
enhance_pipeline.unet.enable_forward_chunking(chunk_size=1, dim=1)
flux_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
flux_pipe.enable_model_cpu_offload()
# store of objects
model: streaming_svd_model
data_module_loader: LightningDataModule
trainer: Trainer
self.model = model
self.vfi = vfi
self.data_module_loader = data_module_loader
self.enhance_pipeline = enhance_pipeline
self.enhance_generator = enhance_generator
self.trainer = trainer
self.flux_pipe = flux_pipe
def __init_fields(self, cli):
self.input_path = cli.config["image"]
self.output_path = cli.config["output"]
self.num_frames = cli.config["num_frames"]
self.fps = cli.config["out_fps"]
self.use_randomized_blending = cli.config["use_randomized_blending"]
self.chunk_size = cli.config["chunk_size"]
self.overlap_size = cli.config["overlap_size"]
self.prompt = cli.config["prompt"]
def __config_call(self, config_cmds, code_fol):
cmds = [cmd for cmd in config_cmds if len(cmd) > 0]
cmd_init = []
cmd_init.append(f"--config")
cmd_init.append(f"{code_fol}/config.yaml")
if "--use_fp16" in config_cmds:
cmd_init.append(f"--trainer.precision=16-true")
cmd_init.extend(cmds)
return cmd_init
# interfaces
def streaming_t2v(self, prompt, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33):
image = self.text_to_image(prompt=prompt)
return self.streaming_i2v(image, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=seed)
def streaming_i2v(self, image, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33) -> np.array:
video, scaled_outpainted_image, expanded_size = self.image_to_video(
image, num_frames=(num_frames+1)//2, seed=seed)
max_memory_allocated = torch.cuda.max_memory_allocated()
print(
f"max_memory_allocated at image_to_video: {max_memory_allocated}")
video = self.enhance_video(image=IImage(scaled_outpainted_image).numpy(), video=video, chunk_size=chunk_size, overlap_size=overlap_size,
use_randomized_blending=use_randomized_blending, seed=seed)
video = self.interpolate_video(video, dest_num_frames=num_frames)
# scale/crop back to input size
if image.shape[0] == 1:
image = image[0]
video = IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop((0, 0, image.shape[1], image.shape[0])).numpy()
print(
f"max_memory_allocated at interpolate_video: {max_memory_allocated}")
return video
# StreamingSVD pipeline
def streaming(self, image: np.ndarray):
datamodule = self.data_module_loader(predict_dataset_factory=SingleImageDatasetFactory(
file=image))
self.trainer.predict(model=self.model, datamodule=datamodule)
video = self.trainer.generated_video
expanded_size = self.trainer.expanded_size
scaled_outpainted_image = self.trainer.scaled_outpainted_image
return video, scaled_outpainted_image, expanded_size
def image_to_video(self, image: Union[np.ndarray, str], num_frames: int, seed=33) -> tuple[np.ndarray,Image,list[int]]:
seed_everything(seed)
if isinstance(image, str):
image = IImage.open(image).numpy()
if image.shape[0] == 1 and image.ndim == 4:
image = image[0]
assert image.shape[-1] == 3 and image.shape[0] > 1, "Wrong image format. Assuming shape [H W C], with C = 3."
assert image.dtype == "uint8", "Wrong dtype for input image. Must be uint8."
# compute necessary number of chunks
n_cond_frames = self.model.inference_params.num_conditional_frames
n_frames_per_gen = self.model.sampler.guider.num_frames
n_autoregressive_generations = math.ceil(
(num_frames - n_frames_per_gen) / (n_frames_per_gen - n_cond_frames))
self.model.inference_params.n_autoregressive_generations = int(
n_autoregressive_generations)
print(" --- STREAMING ----- [START]")
video, scaled_outpainted_image, expanded_size = self.streaming(
image=image)
print(f" --- STREAMING ----- [FINISHED]: {video.shape}")
video = video[:num_frames]
return video, scaled_outpainted_image, expanded_size
def enhance_video(self, video: Union[np.ndarray, str], image: np.ndarray = None, chunk_size = 38, overlap_size=12, strength=0.97, use_randomized_blending=False, seed=33,num_frames = None):
seed_everything(seed)
if isinstance(video, str):
video = IImage.open(video).numpy()
if image is None:
image = video[0]
print("ATTENTION: We take first frame of previous stage as input frame for enhance. ")
if num_frames is not None:
video = video[:num_frames, ...]
if not use_randomized_blending:
chunk_size = video.shape[0]
overlap_size = 0
if image.ndim == 3:
image = image[None]
image = [Image.fromarray(
IImage(image, vmin=0, vmax=255).resize((720, 1280)).numpy()[0])]
video = np.split(video, video.shape[0])
video = [Image.fromarray(frame[0]).resize((1280, 720))
for frame in video]
print(
f"---- ENHANCE ---- [START]. Video length = {len(video)}. Randomized Blending = {use_randomized_blending}. Chunk size = {chunk_size}. Overlap size = {overlap_size}.")
video_enhanced = i2v_enhance_interface.i2v_enhance_process(
image=image, video=video, pipeline=self.enhance_pipeline, generator=self.enhance_generator,
chunk_size=chunk_size, overlap_size=overlap_size, strength=strength, use_randomized_blending=use_randomized_blending)
video_enhanced = np.stack([np.asarray(frame)
for frame in video_enhanced], axis=0)
print("---- ENHANCE ---- [FINISHED].")
return video_enhanced
def interpolate_video(self, video: np.ndarray, dest_num_frames: int):
video = np.split(video, len(video))
video = [frame[0] for frame in video]
print(" ---- VFI ---- [START]")
self.vfi.device()
video_vfi = i2v_enhance_interface.vfi_process(
video=video, vfi=self.vfi, video_len=dest_num_frames)
video_vfi = np.stack([np.asarray(frame)
for frame in video_vfi], axis=0)
self.vfi.unload()
print(f"---- VFI ---- [FINISHED]. Video length = {len(video_vfi)}")
return video_vfi
# T2I method
def text_to_image(self, prompt, seed=33):
# FLUX
print("[FLUX] Generating image from text prompt")
out = self.flux_pipe(
prompt=prompt,
guidance_scale=0,
height=720,
width=1280,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator(
device=self.model.device).manual_seed(seed),
).images[0]
print("[FLUX] Finished")
return np.array(out)
if __name__ == "__main__":
@ensure_annotation_class
def get_input_data(input_path: Path = None):
if input_path.is_file():
inputs = [input_path]
else:
suffixes = ["*.[jJ][pP][gG]", "*.[pP][nN][gG]",
"*.[jJ][pP][eE][gG]", "*.[bB][mM][pP]"] # loading png, jpg and bmp images
inputs = []
for suffix in suffixes:
inputs.extend(list(input_path.glob(suffix)))
assert len(
inputs) > 0, "No images found. Please make sure the input path is correct."
img_as_np = [IImage.open(input).numpy() for input in inputs]
return zip(img_as_np, inputs)
streaming_svd = StreamingSVD()
num_frames = streaming_svd.num_frames
chunk_size = streaming_svd.chunk_size
overlap_size = streaming_svd.overlap_size
use_randomized_blending = streaming_svd.use_randomized_blending
if not use_randomized_blending:
chunk_size = (num_frames + 1)//2
overlap_size = 0
result_path = Path(streaming_svd.output_path)
seed = 33
assert result_path.exists() is False or result_path.is_dir(
), "Output path must be the path to a folder."
prompt = streaming_svd.prompt
if len(prompt) == 0:
for img, img_path in get_input_data(streaming_svd.input_path):
video = streaming_svd.streaming_i2v(
image=img, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33)
if not result_path.exists():
result_path.mkdir(parents=True)
result_file = result_path / (img_path.stem+".mp4")
result_file = result_file.as_posix()
IImage(video, vmin=0, vmax=255).setFps(
streaming_svd.fps).save(result_file)
print(f"Video created at: {result_file}")
else:
video = streaming_svd.streaming_t2v(
prompt=prompt, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33)
prompt_file = prompt.replace(" ", "_").replace(
".", "_").replace("/", "_").replace(":", "_")
prompt_file = prompt_file[:15]
if not result_path.exists():
result_path.mkdir(parents=True)
result_file = result_path / (prompt_file+".mp4")
result_file = result_file.as_posix()
IImage(video, vmin=0, vmax=255).setFps(
streaming_svd.fps).save(result_file)
print(f"Video created at: {result_file}")