|
import os |
|
import sys |
|
from pathlib import Path |
|
import torch |
|
import argparse |
|
import logging |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
import json |
|
|
|
|
|
from diffusers import ( |
|
DDIMInverseScheduler, |
|
DDIMScheduler, |
|
) |
|
from diffusers.utils import load_image, export_to_video, export_to_gif |
|
|
|
|
|
from utils import ( |
|
seed_everything, |
|
load_video_frames, |
|
convert_video_to_frames, |
|
load_ddim_latents_at_T, |
|
load_ddim_latents_at_t, |
|
) |
|
from pipelines.pipeline_i2vgen_xl import I2VGenXLPipeline |
|
|
|
|
|
def ddim_inversion(config, first_frame, frame_list, pipe: I2VGenXLPipeline, inverse_scheduler, g): |
|
pipe.scheduler = inverse_scheduler |
|
video_latents_at_0 = pipe.encode_vae_video( |
|
frame_list, |
|
device=pipe._execution_device, |
|
height=config.image_size[1], |
|
width=config.image_size[0], |
|
) |
|
ddim_latents = pipe.invert( |
|
prompt=config.prompt, |
|
image=first_frame, |
|
height=config.image_size[1], |
|
width=config.image_size[0], |
|
num_frames=config.n_frames, |
|
num_inference_steps=config.n_steps, |
|
guidance_scale=config.cfg, |
|
negative_prompt=config.negative_prompt, |
|
target_fps=config.target_fps, |
|
latents=video_latents_at_0, |
|
generator=g, |
|
return_dict=False, |
|
output_dir=config.output_dir, |
|
) |
|
logger = logging.getLogger(__name__) |
|
logger.debug(f"ddim_latents.shape: {ddim_latents.shape}") |
|
ddim_latents = ddim_latents[0] |
|
return ddim_latents |
|
|
|
|
|
def ddim_sampling( |
|
config, first_frame, ddim_latents_at_T, pipe: I2VGenXLPipeline, ddim_scheduler, ddim_init_latents_t_idx, g |
|
): |
|
pipe.scheduler = ddim_scheduler |
|
reconstructed_video = pipe( |
|
prompt=config.prompt, |
|
image=first_frame, |
|
height=config.image_size[1], |
|
width=config.image_size[0], |
|
num_frames=config.n_frames, |
|
num_inference_steps=config.n_steps, |
|
guidance_scale=config.cfg, |
|
negative_prompt=config.negative_prompt, |
|
target_fps=config.target_fps, |
|
latents=ddim_latents_at_T, |
|
generator=g, |
|
return_dict=True, |
|
ddim_init_latents_t_idx=ddim_init_latents_t_idx, |
|
).frames[0] |
|
return reconstructed_video |
|
|
|
|
|
def main(template_config, configs_list): |
|
|
|
pipe = I2VGenXLPipeline.from_pretrained( |
|
"ali-vilab/i2vgen-xl", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
pipe.to(device) |
|
g = torch.Generator(device=device) |
|
g = g.manual_seed(template_config.seed) |
|
|
|
|
|
inverse_scheduler = DDIMInverseScheduler.from_pretrained( |
|
"ali-vilab/i2vgen-xl", |
|
subfolder="scheduler", |
|
) |
|
|
|
ddim_scheduler = DDIMScheduler.from_pretrained( |
|
"ali-vilab/i2vgen-xl", |
|
subfolder="scheduler", |
|
) |
|
|
|
video_dir = template_config.video_dir |
|
assert os.path.exists(video_dir), f"video_dir: {video_dir} does not exist" |
|
|
|
for config_entry in configs_list: |
|
if config_entry["active"] == False: |
|
logger.info(f"Skipping config_entry: {config_entry}") |
|
continue |
|
logger.info(f"Processing config_entry: {config_entry}") |
|
|
|
|
|
config = OmegaConf.merge(template_config, OmegaConf.create(config_entry)) |
|
|
|
config.video_path = os.path.join(config.video_dir, config.video_name + ".mp4") |
|
config.video_frames_path = os.path.join(config.video_dir, config.video_name) |
|
|
|
|
|
if os.path.exists(config.output_dir) and not config.force_recompute_latents: |
|
logger.info(f"### Skipping !!! {config.output_dir} already exists. ") |
|
continue |
|
|
|
logger.info(f"config: {OmegaConf.to_yaml(config)}") |
|
|
|
|
|
try: |
|
logger.info(f"Loading frames from: {config.video_frames_path}") |
|
_, frame_list = load_video_frames(config.video_frames_path, config.n_frames, config.image_size) |
|
except: |
|
logger.error(f"Failed to load frames from: {config.video_frames_path}") |
|
logger.info(f"Converting mp4 video to frames: {config.video_path}") |
|
frame_list = convert_video_to_frames(config.video_path, config.image_size, save_frames=True) |
|
frame_list = frame_list[: config.n_frames] |
|
logger.debug(f"len(frame_list): {len(frame_list)}") |
|
|
|
export_to_gif( |
|
frame_list, |
|
os.path.join(config.video_frames_path, config.video_name + ".gif") |
|
) |
|
logger.info(f"Saved source video as gif to {config.video_frames_path}") |
|
first_frame = frame_list[0] |
|
|
|
|
|
if config.inverse_config.inverse_static_video: |
|
logger.info("### Inverse a static video!") |
|
frame_list = [frame_list[0]] * config.n_frames |
|
|
|
|
|
if config.inverse_config.null_image_inversion: |
|
logger.info("### Inverse a null image!") |
|
first_frame = Image.new("RGB", (config.image_size[0], config.image_size[1]), (0, 0, 0)) |
|
|
|
|
|
|
|
logger.info(f"config: {OmegaConf.to_yaml(config)}") |
|
_ddim_latents = ddim_inversion(config.inverse_config, first_frame, frame_list, pipe, inverse_scheduler, g) |
|
|
|
|
|
recon_config = config.recon_config |
|
if recon_config.enable_recon: |
|
ddim_init_latents_t_idx = recon_config.ddim_init_latents_t_idx |
|
ddim_scheduler.set_timesteps(recon_config.n_steps) |
|
logger.info(f"ddim_scheduler.timesteps: {ddim_scheduler.timesteps}") |
|
ddim_latents_path = recon_config.ddim_latents_path |
|
ddim_latents_at_t = load_ddim_latents_at_t( |
|
ddim_scheduler.timesteps[ddim_init_latents_t_idx], |
|
ddim_latents_path=ddim_latents_path, |
|
) |
|
logger.debug(f"ddim_scheduler.timesteps[t_idx]: {ddim_scheduler.timesteps[ddim_init_latents_t_idx]}") |
|
reconstructed_video = ddim_sampling( |
|
recon_config, |
|
first_frame, |
|
ddim_latents_at_t, |
|
pipe, |
|
ddim_scheduler, |
|
ddim_init_latents_t_idx, |
|
g, |
|
) |
|
|
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
|
reconstructed_video = [frame.resize((512, 512), resample=Image.LANCZOS) for frame in reconstructed_video] |
|
export_to_video( |
|
reconstructed_video, |
|
os.path.join(config.output_dir, "ddim_reconstruction.mp4"), |
|
fps=10, |
|
) |
|
export_to_gif( |
|
reconstructed_video, |
|
os.path.join(config.output_dir, "ddim_reconstruction.gif"), |
|
) |
|
logger.info(f"Saved reconstructed video to {config.output_dir}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--template_config", type=str, default="./configs/group_ddim_inversion/template.yaml") |
|
parser.add_argument("--configs_json", type=str, default="./configs/group_config.json") |
|
|
|
args = parser.parse_args() |
|
template_config = OmegaConf.load(args.template_config) |
|
|
|
|
|
logging_level = logging.DEBUG if template_config.debug else logging.INFO |
|
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"template_config: {OmegaConf.to_yaml(template_config)}") |
|
|
|
|
|
configs_json = args.configs_json |
|
assert Path(configs_json).exists() |
|
with open(configs_json, 'r') as file: |
|
configs_list = json.load(file) |
|
logger.info(f"Loaded {len(configs_list)} configs from {configs_json}") |
|
|
|
|
|
device = torch.device(template_config.device) |
|
torch.set_grad_enabled(False) |
|
seed_everything(template_config.seed) |
|
main(template_config, configs_list) |
|
|