import os os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') result_dir = os.path.join('./', 'results') os.makedirs(result_dir, exist_ok=True) import functools import os import random import gradio as gr import numpy as np import torch import wd14tagger import uuid from PIL import Image from diffusers_helper.code_cond import unet_add_coded_conds from diffusers_helper.cat_cond import unet_add_concat_conds from diffusers_helper.k_diffusion import KDiffusionSampler from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import CLIPTextModel, CLIPTokenizer from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4 import spaces # Disable gradients globally torch.set_grad_enabled(False) class ModifiedUNet(UNet2DConditionModel): @classmethod def from_config(cls, *args, **kwargs): m = super().from_config(*args, **kwargs) unet_add_concat_conds(unet=m, new_channels=4) unet_add_coded_conds(unet=m, added_number_count=1) return m model_name = 'lllyasviel/paints_undo_single_frame' tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda") vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda") # bfloat16 vae unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda") unet.set_attn_processor(AttnProcessor2_0()) vae.set_attn_processor(AttnProcessor2_0()) video_pipe = LatentVideoDiffusionPipeline.from_pretrained( 'lllyasviel/paints_undo_multi_frame', fp16=True ).to("cuda") k_sampler = KDiffusionSampler( unet=unet, timesteps=1000, linear_start=0.00085, linear_end=0.020, linear=True ) def find_best_bucket(h, w, options): min_metric = float('inf') best_bucket = None for (bucket_h, bucket_w) in options: metric = abs(h * bucket_w - w * bucket_h) if metric <= min_metric: min_metric = metric best_bucket = (bucket_h, bucket_w) return best_bucket def encode_cropped_prompt_77tokens(txt: str): cond_ids = tokenizer(txt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(device="cuda") text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state return text_cond def pytorch2numpy(imgs): results = [] for x in imgs: y = x.movedim(0, -1) y = y * 127.5 + 127.5 y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) results.append(y) return results def numpy2pytorch(imgs): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 h = h.movedim(-1, 1) return h def resize_without_crop(image, target_width, target_height): pil_image = Image.fromarray(image) resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) return np.array(resized_image) def interrogator_process(x): image_description = wd14tagger.default_interrogator(x) return image_description @spaces.GPU() def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg, progress=gr.Progress()): rng = torch.Generator(device="cuda").manual_seed(int(seed)) fg = resize_and_center_crop(input_fg, image_width, image_height) concat_conds = numpy2pytorch([fg]).clone().detach().to(device="cuda", dtype=vae.dtype) concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor conds = encode_cropped_prompt_77tokens(prompt) unconds = encode_cropped_prompt_77tokens(n_prompt) fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long) initial_latents = torch.zeros_like(concat_conds) concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype) latents = k_sampler( initial_latent=initial_latents, strength=1.0, num_inference_steps=steps, guidance_scale=cfg, batch_size=len(input_undo_steps), generator=rng, prompt_embeds=conds, negative_prompt_embeds=unconds, cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs}, same_noise_in_batch=True, progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames') ).to(vae.dtype) / vae.config.scaling_factor pixels = vae.decode(latents).sample pixels = pytorch2numpy(pixels) pixels = [fg] + pixels + [np.zeros_like(fg) + 255] return pixels def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) frames = 16 target_height, target_width = find_best_bucket( image_1.shape[0], image_1.shape[1], options=[(320, 512), (384, 448), (448, 384), (512, 320)] ) image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height) image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height) input_frames = numpy2pytorch([image_1, image_2]) input_frames = input_frames.unsqueeze(0).movedim(1, 2) positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt) negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("") input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype) positive_image_cond = video_pipe.encode_clip_vision(input_frames) positive_image_cond = video_pipe.image_projection(positive_image_cond) negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames)) negative_image_cond = video_pipe.image_projection(negative_image_cond) input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype) input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True) first_frame = input_frame_latents[:, :, 0] last_frame = input_frame_latents[:, :, 1] concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2) latents = video_pipe( batch_size=1, steps=int(steps), guidance_scale=cfg_scale, positive_text_cond=positive_text_cond, negative_text_cond=negative_text_cond, positive_image_cond=positive_image_cond, negative_image_cond=negative_image_cond, concat_cond=concat_cond, fs=fs, progress_tqdm=progress_tqdm ) video = video_pipe.decode_latents(latents, vae_hidden_states) return video, image_1, image_2 @spaces.GPU(duration=333360) def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()): result_frames = [] cropped_images = [] for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])): im1 = np.array(Image.open(im1[0])) im2 = np.array(Image.open(im2[0])) frames, im1, im2 = process_video_inner( im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3, progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})') ) result_frames.append(frames[:, :, :-1, :, :]) cropped_images.append([im1, im2]) video = torch.cat(result_frames, dim=2) video = torch.flip(video, dims=[2]) uuid_name = str(uuid.uuid4()) output_filename = os.path.join(result_dir, uuid_name + '.mp4') Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png')) video = save_bcthw_as_mp4(video, output_filename, fps=fps) video = [x.cpu().numpy() for x in video] return output_filename, video block = gr.Blocks().queue() with block: gr.HTML("