|
import os |
|
import sys |
|
from pathlib import Path |
|
import torch |
|
import argparse |
|
import logging |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
import json |
|
|
|
|
|
from diffusers import ( |
|
DDIMInverseScheduler, |
|
DDIMScheduler, |
|
) |
|
from diffusers.utils import load_image, export_to_video, export_to_gif |
|
|
|
|
|
from utils import ( |
|
seed_everything, |
|
load_video_frames, |
|
convert_video_to_frames, |
|
load_ddim_latents_at_T, |
|
load_ddim_latents_at_t, |
|
) |
|
from pipelines.pipeline_i2vgen_xl import I2VGenXLPipeline |
|
from pnp_utils import ( |
|
register_time, |
|
register_conv_injection, |
|
register_spatial_attention_pnp, |
|
register_temp_attention_pnp, |
|
) |
|
|
|
|
|
def init_pnp(pipe, scheduler, config): |
|
conv_injection_t = int(config.n_steps * config.pnp_f_t) |
|
spatial_attn_qk_injection_t = int(config.n_steps * config.pnp_spatial_attn_t) |
|
temp_attn_qk_injection_t = int(config.n_steps * config.pnp_temp_attn_t) |
|
conv_injection_timesteps = scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else [] |
|
spatial_attn_qk_injection_timesteps = ( |
|
scheduler.timesteps[:spatial_attn_qk_injection_t] if spatial_attn_qk_injection_t >= 0 else [] |
|
) |
|
temp_attn_qk_injection_timesteps = ( |
|
scheduler.timesteps[:temp_attn_qk_injection_t] if temp_attn_qk_injection_t >= 0 else [] |
|
) |
|
register_conv_injection(pipe, conv_injection_timesteps) |
|
register_spatial_attention_pnp(pipe, spatial_attn_qk_injection_timesteps) |
|
register_temp_attention_pnp(pipe, temp_attn_qk_injection_timesteps) |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.debug(f"conv_injection_t: {conv_injection_t}") |
|
logger.debug(f"spatial_attn_qk_injection_t: {spatial_attn_qk_injection_t}") |
|
logger.debug(f"temp_attn_qk_injection_t: {temp_attn_qk_injection_t}") |
|
logger.debug(f"conv_injection_timesteps: {conv_injection_timesteps}") |
|
logger.debug(f"spatial_attn_qk_injection_timesteps: {spatial_attn_qk_injection_timesteps}") |
|
logger.debug(f"temp_attn_qk_injection_timesteps: {temp_attn_qk_injection_timesteps}") |
|
|
|
|
|
def main(template_config, configs_list): |
|
|
|
pipe = I2VGenXLPipeline.from_pretrained( |
|
"ali-vilab/i2vgen-xl", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
pipe.to(device) |
|
|
|
|
|
ddim_scheduler = DDIMScheduler.from_pretrained( |
|
"ali-vilab/i2vgen-xl", |
|
subfolder="scheduler", |
|
) |
|
|
|
for config_entry in configs_list: |
|
if config_entry["active"] == False: |
|
logger.info(f"Skipping config_entry: {config_entry}") |
|
continue |
|
logger.info(f"Processing config_entry: {config_entry}") |
|
|
|
|
|
config = OmegaConf.merge(template_config, OmegaConf.create(config_entry)) |
|
|
|
|
|
config.video_path = os.path.join(config.video_dir, config.video_name + ".mp4") |
|
config.video_frames_path = os.path.join(config.video_dir, config.video_name) |
|
config.edited_first_frame_path = os.path.join(config.data_dir, config.edited_first_frame_path) |
|
logger.info(f"config: {OmegaConf.to_yaml(config)}") |
|
|
|
|
|
for k, v in config.items(): |
|
if "ReplaceMe" in str(v): |
|
logger.error(f"Field {k} contains 'ReplaceMe'") |
|
continue |
|
|
|
|
|
|
|
try: |
|
logger.info(f"Loading frames from: {config.video_frames_path}") |
|
_, frame_list = load_video_frames(config.video_frames_path, config.n_frames, config.image_size) |
|
except: |
|
logger.error(f"Failed to load frames from: {config.video_frames_path}") |
|
logger.info(f"Converting mp4 video to frames: {config.video_path}") |
|
frame_list = convert_video_to_frames(config.video_path, config.image_size, save_frames=True) |
|
frame_list = frame_list[: config.n_frames] |
|
logger.debug(f"len(frame_list): {len(frame_list)}") |
|
src_frame_list = frame_list |
|
src_1st_frame = src_frame_list[0] |
|
|
|
|
|
edited_1st_frame = load_image(config.edited_first_frame_path) |
|
edited_1st_frame = edited_1st_frame.resize(config.image_size, resample=Image.Resampling.LANCZOS) |
|
|
|
|
|
ddim_init_latents_t_idx = config.ddim_init_latents_t_idx |
|
ddim_scheduler.set_timesteps(config.n_steps) |
|
logger.info(f"ddim_scheduler.timesteps: {ddim_scheduler.timesteps}") |
|
ddim_latents_at_t = load_ddim_latents_at_t( |
|
ddim_scheduler.timesteps[ddim_init_latents_t_idx], ddim_latents_path=config.ddim_latents_path |
|
) |
|
logger.debug(f"ddim_scheduler.timesteps[t_idx]: {ddim_scheduler.timesteps[ddim_init_latents_t_idx]}") |
|
logger.debug(f"ddim_latents_at_t.shape: {ddim_latents_at_t.shape}") |
|
|
|
|
|
random_latents = torch.randn_like(ddim_latents_at_t) |
|
logger.info(f"Blending random_ratio (1 means random latent): {config.random_ratio}") |
|
mixed_latents = random_latents * config.random_ratio + ddim_latents_at_t * (1 - config.random_ratio) |
|
|
|
|
|
init_pnp(pipe, ddim_scheduler, config) |
|
|
|
|
|
pipe.register_modules(scheduler=ddim_scheduler) |
|
edited_video = pipe.sample_with_pnp( |
|
prompt=config.editing_prompt, |
|
image=edited_1st_frame, |
|
height=config.image_size[1], |
|
width=config.image_size[0], |
|
num_frames=config.n_frames, |
|
num_inference_steps=config.n_steps, |
|
guidance_scale=config.cfg, |
|
negative_prompt=config.editing_negative_prompt, |
|
target_fps=config.target_fps, |
|
latents=mixed_latents, |
|
generator=torch.manual_seed(config.seed), |
|
return_dict=True, |
|
ddim_init_latents_t_idx=ddim_init_latents_t_idx, |
|
ddim_inv_latents_path=config.ddim_latents_path, |
|
ddim_inv_prompt=config.ddim_inv_prompt, |
|
ddim_inv_1st_frame=src_1st_frame, |
|
).frames[0] |
|
|
|
|
|
|
|
config_suffix = ( |
|
"ddim_init_latents_t_idx_" |
|
+ str(ddim_init_latents_t_idx) |
|
+ "_nsteps_" |
|
+ str(config.n_steps) |
|
+ "_cfg_" |
|
+ str(config.cfg) |
|
+ "_pnpf" |
|
+ str(config.pnp_f_t) |
|
+ "_pnps" |
|
+ str(config.pnp_spatial_attn_t) |
|
+ "_pnpt" |
|
+ str(config.pnp_temp_attn_t) |
|
) |
|
output_dir = os.path.join(config.output_dir, config_suffix) |
|
os.makedirs(output_dir, exist_ok=True) |
|
edited_video = [frame.resize(config.image_size, resample=Image.LANCZOS) for frame in edited_video] |
|
|
|
|
|
|
|
|
|
|
|
|
|
edited_video_file_name = "video" |
|
export_to_video(edited_video, os.path.join(output_dir, f"{edited_video_file_name}.mp4"), fps=config.target_fps) |
|
export_to_gif(edited_video, os.path.join(output_dir, f"{edited_video_file_name}.gif")) |
|
logger.info(f"Saved video to: {os.path.join(output_dir, f'{edited_video_file_name}.mp4')}") |
|
logger.info(f"Saved gif to: {os.path.join(output_dir, f'{edited_video_file_name}.gif')}") |
|
for i, frame in enumerate(edited_video): |
|
frame.save(os.path.join(output_dir, f"{edited_video_file_name}_{i:05d}.png")) |
|
logger.info(f"Saved frames to: {os.path.join(output_dir, f'{edited_video_file_name}_{i:05d}.png')}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--template_config", type=str, default="./configs/group_pnp_edit/template.yaml") |
|
parser.add_argument( |
|
"--configs_json", type=str, default="./configs/group_config.json" |
|
) |
|
|
|
args = parser.parse_args() |
|
template_config = OmegaConf.load(args.template_config) |
|
|
|
|
|
logging_level = logging.DEBUG if template_config.debug else logging.INFO |
|
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"template_config: {OmegaConf.to_yaml(template_config)}") |
|
|
|
|
|
configs_json = args.configs_json |
|
assert Path(configs_json).exists() |
|
with open(configs_json, "r") as file: |
|
configs_list = json.load(file) |
|
logger.info(f"Loaded {len(configs_list)} configs from {configs_json}") |
|
|
|
|
|
device = torch.device(template_config.device) |
|
torch.set_grad_enabled(False) |
|
seed_everything(template_config.seed) |
|
main(template_config, configs_list) |
|
|