StreamingSVD / i2v_enhance /i2v_enhance_interface.py
lev1's picture
Initial commit
8fd2f2f
raw
history blame
4.64 kB
import torch
from i2v_enhance.pipeline_i2vgen_xl import I2VGenXLPipeline
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
import i2v_enhance.thirdparty.VFI.config as cfg
from i2v_enhance.thirdparty.VFI.Trainer import Model as VFI
from pathlib import Path
from modules.params.vfi import VFIParams
from modules.params.i2v_enhance import I2VEnhanceParams
from utils.loader import download_ckpt
def vfi_init(ckpt_cfg: VFIParams, device_id=0):
cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(F=32, depth=[
2, 2, 2, 4, 4])
vfi = VFI(-1)
ckpt_file = Path(download_ckpt(
local_path=ckpt_cfg.ckpt_path_local, global_path=ckpt_cfg.ckpt_path_global))
vfi.load_model(ckpt_file.as_posix())
vfi.eval()
vfi.device()
assert device_id == 0, "VFI on rank!=0 not implemented yet."
return vfi
def vfi_process(video, vfi, video_len):
video = video[:(video_len//2+1)]
video = [i[:, :, :3]/255. for i in video]
video = [i[:, :, ::-1] for i in video]
video = np.stack(video, axis=0)
video = rearrange(torch.from_numpy(video),
'b h w c -> b c h w').to("cuda", torch.float32)
frames = []
for i in tqdm(range(video.shape[0]-1), desc="VFI"):
I0_ = video[i:i+1, ...]
I2_ = video[i+1:i+2, ...]
frames.append((I0_[0].detach().cpu().numpy().transpose(
1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
mid = (vfi.inference(I0_, I2_, TTA=True, fast_TTA=True)[
0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
frames.append(mid[:, :, ::-1])
frames.append((video[-1].detach().cpu().numpy().transpose(1,
2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
if video_len % 2 == 0:
frames.append((video[-1].detach().cpu().numpy().transpose(1,
2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
del vfi
del video
torch.cuda.empty_cache()
video = [Image.fromarray(frame).resize((1280, 720)) for frame in frames]
del frames
return video
def i2v_enhance_init(i2vgen_cfg: I2VEnhanceParams):
generator = torch.manual_seed(8888)
try:
pipeline = I2VGenXLPipeline.from_pretrained(
i2vgen_cfg.ckpt_path_local, torch_dtype=torch.float16, variant="fp16")
except Exception as e:
pipeline = I2VGenXLPipeline.from_pretrained(
i2vgen_cfg.ckpt_path_global, torch_dtype=torch.float16, variant="fp16")
pipeline.save_pretrained(i2vgen_cfg.ckpt_path_local)
pipeline.enable_model_cpu_offload()
return pipeline, generator
def i2v_enhance_process(image, video, pipeline, generator, overlap_size, strength, chunk_size=38, use_randomized_blending=False):
prompt = "High Quality, HQ, detailed."
negative_prompt = "Distorted, blurry, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
if use_randomized_blending:
# We first need to enhance key-frames (the 1st frame of each chunk)
video_chunks = [video[i:i+chunk_size] for i in range(0, len(
video), chunk_size-overlap_size) if len(video[i:i+chunk_size]) == chunk_size]
video_short = [chunk[0] for chunk in video_chunks]
# If randomized blending then we must have a list of starting images (1 for each chunk)
image = pipeline(
prompt=prompt,
height=720,
width=1280,
image=image,
video=video_short,
strength=strength,
overlap_size=0,
chunk_size=len(video_short),
num_frames=len(video_short),
num_inference_steps=30,
decode_chunk_size=1,
negative_prompt=negative_prompt,
guidance_scale=9.0,
generator=generator,
).frames[0]
# Remove the last few frames (< chunk_size) of the video that do not fit into one chunk.
max_idx = (chunk_size - overlap_size) * \
(len(video_chunks) - 1) + chunk_size
video = video[:max_idx]
frames = pipeline(
prompt=prompt,
height=720,
width=1280,
image=image,
video=video,
strength=strength,
overlap_size=overlap_size,
chunk_size=chunk_size,
num_frames=chunk_size,
num_inference_steps=30,
decode_chunk_size=1,
negative_prompt=negative_prompt,
guidance_scale=9.0,
generator=generator,
).frames[0]
return frames