Spaces:
Runtime error
Runtime error
import numpy as np | |
from lib.farancia import IImage | |
from PIL import Image | |
from i2v_enhance import i2v_enhance_interface | |
from dataloader.dataset_factory import SingleImageDatasetFactory | |
from pytorch_lightning import Trainer, LightningDataModule, seed_everything | |
import math | |
from diffusion_trainer import streaming_svd as streaming_svd_model | |
import torch | |
from safetensors.torch import load_file as load_safetensors | |
from utils.loader import download_ckpt | |
from functools import partial | |
from dataloader.video_data_module import VideoDataModule | |
from pathlib import Path | |
from pytorch_lightning.cli import LightningCLI, LightningArgumentParser | |
from pytorch_lightning import LightningModule | |
import sys | |
import os | |
from copy import deepcopy | |
from utils.aux import ensure_annotation_class | |
from diffusers import FluxPipeline | |
from typing import Union | |
class CustomCLI(LightningCLI): | |
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: | |
parser.add_argument("--image", type=Path, | |
help="Path to the input image(s)") | |
parser.add_argument("--output", type=Path, | |
help="Path to the output folder") | |
parser.add_argument("--num_frames", type=int, default=200, | |
help="Number of frames to generate.") | |
parser.add_argument("--out_fps", type=int, default=24, | |
help="Framerate of the generated video.") | |
parser.add_argument("--chunk_size", type=int, default=38, | |
help="Chunk size used in randomized blending.") | |
parser.add_argument("--overlap_size", type=int, default=12, | |
help="Overlap size used in randomized blending.") | |
parser.add_argument("--use_randomized_blending", action="store_true", | |
help="Wether to use randomized blending.") | |
parser.add_argument("--use_fp16", action="store_true", | |
help="Wether to use float16 quantization.") | |
parser.add_argument("--prompt", type=str, default = "") | |
return parser | |
class StreamingSVD(): | |
def __init__(self, load_argv = True) -> None: | |
call_fol = Path(os.getcwd()).resolve() | |
code_fol = Path(__file__).resolve().parent | |
code_fol = os.path.relpath(code_fol, call_fol) | |
argv_backup = deepcopy(sys.argv) | |
if "--use_fp16" in sys.argv: | |
os.environ["STREAMING_USE_FP16"] = "True" | |
sys.argv = [__file__] | |
sys.argv.extend(self.__config_call(argv_backup[1:] if load_argv else [], code_fol)) | |
cli = CustomCLI(LightningModule, run=False, subclass_mode_model=True, parser_kwargs={ | |
"parser_mode": "omegaconf"}, save_config_callback=None) | |
self.__init_models(cli) | |
self.__init_fields(cli) | |
sys.argv = argv_backup | |
def __init_models(self, cli): | |
model = cli.model | |
trainer = cli.trainer | |
path = download_ckpt( | |
local_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_local, | |
global_path=model.diff_trainer_params.streamingsvd_ckpt.ckpt_path_global | |
) | |
if path.endswith(".safetensors"): | |
ckpt = load_safetensors(path) | |
else: | |
ckpt = torch.load(path, map_location="cpu")["state_dict"] | |
model.load_state_dict(ckpt) # load trained model | |
trainer = cli.trainer | |
data_module_loader = partial(VideoDataModule, workers=0) | |
vfi = i2v_enhance_interface.vfi_init(model.vfi) | |
enhance_pipeline, enhance_generator = i2v_enhance_interface.i2v_enhance_init( | |
model.i2v_enhance) | |
flux_pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | |
flux_pipe.enable_model_cpu_offload() | |
# store of objects | |
model: streaming_svd_model | |
data_module_loader: LightningDataModule | |
trainer: Trainer | |
self.model = model | |
self.vfi = vfi | |
self.data_module_loader = data_module_loader | |
self.enhance_pipeline = enhance_pipeline | |
self.enhance_generator = enhance_generator | |
self.trainer = trainer | |
self.flux_pipe = flux_pipe | |
def __init_fields(self, cli): | |
self.input_path = cli.config["image"] | |
self.output_path = cli.config["output"] | |
self.num_frames = cli.config["num_frames"] | |
self.fps = cli.config["out_fps"] | |
self.use_randomized_blending = cli.config["use_randomized_blending"] | |
self.chunk_size = cli.config["chunk_size"] | |
self.overlap_size = cli.config["overlap_size"] | |
self.prompt = cli.config["prompt"] | |
def __config_call(self, config_cmds, code_fol): | |
cmds = [cmd for cmd in config_cmds if len(cmd) > 0] | |
cmd_init = [] | |
cmd_init.append(f"--config") | |
cmd_init.append(f"{code_fol}/config.yaml") | |
if "--use_fp16" in config_cmds: | |
cmd_init.append(f"--trainer.precision=16-true") | |
cmd_init.extend(cmds) | |
return cmd_init | |
# interfaces | |
def streaming_t2v(self, prompt, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33): | |
image = self.text_to_image(prompt=prompt) | |
return self.streaming_i2v(image, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=seed) | |
def streaming_i2v(self, image, num_frames: int, use_randomized_blending: bool = False, chunk_size: int = 38, overlap_size: int = 12, seed=33) -> np.array: | |
video, scaled_outpainted_image, expanded_size = self.image_to_video( | |
image, num_frames=(num_frames+1)//2, seed=seed) | |
max_memory_allocated = torch.cuda.max_memory_allocated() | |
print( | |
f"max_memory_allocated at image_to_video: {max_memory_allocated}") | |
video = self.enhance_video(image=IImage(scaled_outpainted_image).numpy(), video=video, chunk_size=chunk_size, overlap_size=overlap_size, | |
use_randomized_blending=use_randomized_blending, seed=seed) | |
video = self.interpolate_video(video, dest_num_frames=num_frames) | |
# scale/crop back to input size | |
if image.shape[0] == 1: | |
image = image[0] | |
video = IImage(video, vmin=0, vmax=255).resize(expanded_size[::-1]).crop((0, 0, image.shape[1], image.shape[0])).numpy() | |
print( | |
f"max_memory_allocated at interpolate_video: {max_memory_allocated}") | |
return video | |
# StreamingSVD pipeline | |
def streaming(self, image: np.ndarray): | |
datamodule = self.data_module_loader(predict_dataset_factory=SingleImageDatasetFactory( | |
file=image)) | |
self.trainer.predict(model=self.model, datamodule=datamodule) | |
video = self.trainer.generated_video | |
expanded_size = self.trainer.expanded_size | |
scaled_outpainted_image = self.trainer.scaled_outpainted_image | |
return video, scaled_outpainted_image, expanded_size | |
def image_to_video(self, image: Union[np.ndarray, str], num_frames: int, seed=33) -> tuple[np.ndarray,Image,list[int]]: | |
seed_everything(seed) | |
if isinstance(image, str): | |
image = IImage.open(image).numpy() | |
if image.shape[0] == 1 and image.ndim == 4: | |
image = image[0] | |
assert image.shape[-1] == 3 and image.shape[0] > 1, "Wrong image format. Assuming shape [H W C], with C = 3." | |
assert image.dtype == "uint8", "Wrong dtype for input image. Must be uint8." | |
# compute necessary number of chunks | |
n_cond_frames = self.model.inference_params.num_conditional_frames | |
n_frames_per_gen = self.model.sampler.guider.num_frames | |
n_autoregressive_generations = math.ceil( | |
(num_frames - n_frames_per_gen) / (n_frames_per_gen - n_cond_frames)) | |
self.model.inference_params.n_autoregressive_generations = int( | |
n_autoregressive_generations) | |
print(" --- STREAMING ----- [START]") | |
video, scaled_outpainted_image, expanded_size = self.streaming( | |
image=image) | |
print(f" --- STREAMING ----- [FINISHED]: {video.shape}") | |
video = video[:num_frames] | |
return video, scaled_outpainted_image, expanded_size | |
def enhance_video(self, video: Union[np.ndarray, str], image: np.ndarray = None, chunk_size = 38, overlap_size=12, strength=0.97, use_randomized_blending=False, seed=33,num_frames = None): | |
seed_everything(seed) | |
if isinstance(video, str): | |
video = IImage.open(video).numpy() | |
if image is None: | |
image = video[0] | |
print("ATTENTION: We take first frame of previous stage as input frame for enhance. ") | |
if num_frames is not None: | |
video = video[:num_frames, ...] | |
if not use_randomized_blending: | |
chunk_size = video.shape[0] | |
overlap_size = 0 | |
if image.ndim == 3: | |
image = image[None] | |
image = [Image.fromarray( | |
IImage(image, vmin=0, vmax=255).resize((720, 1280)).numpy()[0])] | |
video = np.split(video, video.shape[0]) | |
video = [Image.fromarray(frame[0]).resize((1280, 720)) | |
for frame in video] | |
print( | |
f"---- ENHANCE ---- [START]. Video length = {len(video)}. Randomized Blending = {use_randomized_blending}. Chunk size = {chunk_size}. Overlap size = {overlap_size}.") | |
video_enhanced = i2v_enhance_interface.i2v_enhance_process( | |
image=image, video=video, pipeline=self.enhance_pipeline, generator=self.enhance_generator, | |
chunk_size=chunk_size, overlap_size=overlap_size, strength=strength, use_randomized_blending=use_randomized_blending) | |
video_enhanced = np.stack([np.asarray(frame) | |
for frame in video_enhanced], axis=0) | |
print("---- ENHANCE ---- [FINISHED].") | |
return video_enhanced | |
def interpolate_video(self, video: np.ndarray, dest_num_frames: int): | |
video = np.split(video, len(video)) | |
video = [frame[0] for frame in video] | |
print(" ---- VFI ---- [START]") | |
self.vfi.device() | |
video_vfi = i2v_enhance_interface.vfi_process( | |
video=video, vfi=self.vfi, video_len=dest_num_frames) | |
video_vfi = np.stack([np.asarray(frame) | |
for frame in video_vfi], axis=0) | |
self.vfi.unload() | |
print(f"---- VFI ---- [FINISHED]. Video length = {len(video_vfi)}") | |
return video_vfi | |
# T2I method | |
def text_to_image(self, prompt, seed=33): | |
# FLUX | |
print("[FLUX] Generating image from text prompt") | |
out = self.flux_pipe( | |
prompt=prompt, | |
guidance_scale=0, | |
height=720, | |
width=1280, | |
num_inference_steps=4, | |
max_sequence_length=256, | |
generator=torch.Generator( | |
device=self.model.device).manual_seed(seed), | |
).images[0] | |
print("[FLUX] Finished") | |
return np.array(out) | |
if __name__ == "__main__": | |
def get_input_data(input_path: Path = None): | |
if input_path.is_file(): | |
inputs = [input_path] | |
else: | |
suffixes = ["*.[jJ][pP][gG]", "*.[pP][nN][gG]", | |
"*.[jJ][pP][eE][gG]", "*.[bB][mM][pP]"] # loading png, jpg and bmp images | |
inputs = [] | |
for suffix in suffixes: | |
inputs.extend(list(input_path.glob(suffix))) | |
assert len( | |
inputs) > 0, "No images found. Please make sure the input path is correct." | |
img_as_np = [IImage.open(input).numpy() for input in inputs] | |
return zip(img_as_np, inputs) | |
streaming_svd = StreamingSVD() | |
num_frames = streaming_svd.num_frames | |
chunk_size = streaming_svd.chunk_size | |
overlap_size = streaming_svd.overlap_size | |
use_randomized_blending = streaming_svd.use_randomized_blending | |
if not use_randomized_blending: | |
chunk_size = (num_frames + 1)//2 | |
overlap_size = 0 | |
result_path = Path(streaming_svd.output_path) | |
seed = 33 | |
assert result_path.exists() is False or result_path.is_dir( | |
), "Output path must be the path to a folder." | |
prompt = streaming_svd.prompt | |
if len(prompt) == 0: | |
for img, img_path in get_input_data(streaming_svd.input_path): | |
video = streaming_svd.streaming_i2v( | |
image=img, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) | |
if not result_path.exists(): | |
result_path.mkdir(parents=True) | |
result_file = result_path / (img_path.stem+".mp4") | |
result_file = result_file.as_posix() | |
IImage(video, vmin=0, vmax=255).setFps( | |
streaming_svd.fps).save(result_file) | |
print(f"Video created at: {result_file}") | |
else: | |
video = streaming_svd.streaming_t2v( | |
prompt=prompt, num_frames=num_frames, use_randomized_blending=use_randomized_blending, chunk_size=chunk_size, overlap_size=overlap_size, seed=33) | |
prompt_file = prompt.replace(" ", "_").replace( | |
".", "_").replace("/", "_").replace(":", "_") | |
prompt_file = prompt_file[:15] | |
if not result_path.exists(): | |
result_path.mkdir(parents=True) | |
result_file = result_path / (prompt_file+".mp4") | |
result_file = result_file.as_posix() | |
IImage(video, vmin=0, vmax=255).setFps( | |
streaming_svd.fps).save(result_file) | |
print(f"Video created at: {result_file}") | |