diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bab002189355d42f5aba56499d615fa23e43e745 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright 2023 MagicAnimate Team All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md index 689d44137e986820d38ecc07d357da5c865a53b2..851dd02426b3fa1a28bb50d1d644f75b222cac6d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,128 @@ ---- -title: Magicanimate -emoji: 📊 -colorFrom: yellow -colorTo: pink -sdk: gradio -sdk_version: 4.7.1 -app_file: app.py -pinned: false -license: bsd-3-clause ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + + +

+ +

MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model

+

+ Zhongcong Xu + · + Jianfeng Zhang + · + Jun Hao Liew + · + Hanshu Yan + · + Jia-Wei Liu + · + Chenxu Zhang + · + Jiashi Feng + · + Mike Zheng Shou +
+
+ Paper PDF + Project Page + +
+ National University of Singapore   |   ByteDance +

+ + + + + + + + + + +
+ + + +
+ + + +
+ +## 📢 News +* **[2023.12.4]** Release inference code and gradio demo. We are working to improve MagicAnimate, stay tuned! +* **[2023.11.23]** Release MagicAnimate paper and project page. + +## 🏃‍♂️ Getting Started +Please download the pretrained base models for [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [MSE-finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse). + +Download our MagicAnimate [checkpints](https://huggingface.co/zcxu-eric/MagicAnimate). + +**Place them as following:** +```bash +magic-animate +|----pretrained_models + |----MagicAnimate + |----appearance_encoder + |----diffusion_pytorch_model.safetensors + |----config.json + |----densepose_controlnet + |----diffusion_pytorch_model.safetensors + |----config.json + |----temporal_attention + |----temporal_attention.ckpt + |----sd-vae-ft-mse + |----... + |----stable-diffusion-v1-5 + |----... +|----... +``` + +## ⚒️ Installation +prerequisites: `python>=3.8`, `CUDA>=11.3`, and `ffmpeg`. + +Install with `conda`: +```bash +conda env create -f environment.yml +conda activate manimate +``` +or `pip`: +```bash +pip3 install -r requirements.txt +``` + +## 💃 Inference +Run inference on single GPU: +```bash +bash scripts/animate.sh +``` +Run inference with multiple GPUs: +```bash +bash scripts/animate_dist.sh +``` + +## 🎨 Gradio Demo + +#### Online Gradio Demo: +Try our [online gradio demo]() quickly. + +#### Local Gradio Demo: +Launch local gradio demo on single GPU: +```bash +python3 -m demo.gradio_animate +``` +Launch local gradio demo if you have multiple GPUs: +```bash +python3 -m demo.gradio_animate_dist +``` +Then open gradio demo in local browser. + +## 🎓 Citation +If you find this codebase useful for your research, please use the following entry. +```BibTeX +@inproceedings{xu2023magicanimate, + author = {Xu, Zhongcong and Zhang, Jianfeng and Liew, Jun Hao and Yan, Hanshu and Liu, Jia-Wei and Zhang, Chenxu and Feng, Jiashi and Shou, Mike Zheng}, + title = {MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model}, + booktitle = {arXiv}, + year = {2023} +} +``` + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..01b8fd5a9e113f6a78257a6d34501dfa4a266c6f --- /dev/null +++ b/app.py @@ -0,0 +1,107 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Copyright (2023) MagicAnimate Authors +# +# ByteDance, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from ByteDance or +# its affiliates is strictly prohibited. +import argparse +import imageio +import numpy as np +import gradio as gr +from PIL import Image +from subprocess import PIPE, run + +from demo.animate import MagicAnimate + +for command in [ + 'mkdir ./pretrained_models && cd pretrained_models', + 'git lfs clone https://huggingface.co/zcxu-eric/MagicAnimate', + 'git lfs clone https://huggingface.co/runwayml/stable-diffusion-v1-5', + 'git lfs clone https://huggingface.co/stabilityai/sd-vae-ft-mse', + 'cd ..', +]: + run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True) + +animator = MagicAnimate() + +def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale): + return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale) + +with gr.Blocks() as demo: + + gr.HTML( + """ +
+

+ MagicAnimate: Temporally Consistent Human Image Animation +

+
+

+ Project page | + GitHub | + arXiv +

+
+ """) + animation = gr.Video(format="mp4", label="Animation Results", autoplay=True) + + with gr.Row(): + reference_image = gr.Image(label="Reference Image") + motion_sequence = gr.Video(format="mp4", label="Motion Sequence") + + with gr.Column(): + random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1") + sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25") + guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5") + submit = gr.Button("Animate") + + def read_video(video): + size = int(size) + reader = imageio.get_reader(video) + fps = reader.get_meta_data()['fps'] + assert fps == 25.0, f'Expected video fps: 25, but {fps} fps found' + return video + + def read_image(image, size=512): + return np.array(Image.fromarray(image).resize((size, size))) + + # when user uploads a new video + motion_sequence.upload( + read_video, + motion_sequence, + motion_sequence + ) + # when `first_frame` is updated + reference_image.upload( + read_image, + reference_image, + reference_image + ) + # when the `submit` button is clicked + submit.click( + animate, + [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], + animation + ) + + # Examples + gr.Markdown("## Examples") + gr.Examples( + examples=[ + ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"], + ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"], + ["inputs/applications/source_image/0002.png", "inputs/applications/driving/densepose/demo4.mp4"], + ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"], + ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"], + ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"], + ], + inputs=[reference_image, motion_sequence], + outputs=animation, + ) + + +demo.launch(share=True) \ No newline at end of file diff --git a/configs/inference/inference.yaml b/configs/inference/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86f377721c0637372229c024e2b5127bc6ca1604 --- /dev/null +++ b/configs/inference/inference.yaml @@ -0,0 +1,26 @@ +unet_additional_kwargs: + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 24 + temporal_attention_dim_div: 1 + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" diff --git a/configs/prompts/animation.yaml b/configs/prompts/animation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4c15903c4a396cc9e8e2fa59e62bd7d721539f3 --- /dev/null +++ b/configs/prompts/animation.yaml @@ -0,0 +1,42 @@ +pretrained_model_path: "pretrained_models/stable-diffusion-v1-5" +pretrained_vae_path: "pretrained_models/sd-vae-ft-mse" +pretrained_controlnet_path: "pretrained_models/MagicAnimate/densepose_controlnet" +pretrained_appearance_encoder_path: "pretrained_models/MagicAnimate/appearance_encoder" +pretrained_unet_path: "" + +motion_module: "pretrained_models/MagicAnimate/temporal_attention/temporal_attention.ckpt" + +savename: null + +fusion_blocks: "midup" + +seed: [1] +steps: 25 +guidance_scale: 7.5 + +source_image: + - "inputs/applications/source_image/monalisa.png" + - "inputs/applications/source_image/0002.png" + - "inputs/applications/source_image/demo4.png" + - "inputs/applications/source_image/dalle2.jpeg" + - "inputs/applications/source_image/dalle8.jpeg" + - "inputs/applications/source_image/multi1_source.png" +video_path: + - "inputs/applications/driving/densepose/running.mp4" + - "inputs/applications/driving/densepose/demo4.mp4" + - "inputs/applications/driving/densepose/demo4.mp4" + - "inputs/applications/driving/densepose/running2.mp4" + - "inputs/applications/driving/densepose/dancing2.mp4" + - "inputs/applications/driving/densepose/multi_dancing.mp4" + +inference_config: "configs/inference/inference.yaml" +size: 512 +L: 16 +S: 1 +I: 0 +clip: 0 +offset: 0 +max_length: null +video_type: "condition" +invert_video: false +save_individual_videos: false diff --git a/demo/animate.py b/demo/animate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71f1940212bcc4901828644be613a8d7f0525c8 --- /dev/null +++ b/demo/animate.py @@ -0,0 +1,195 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Copyright (2023) MagicAnimate Authors +# +# ByteDance, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from ByteDance or +# its affiliates is strictly prohibited. +import argparse +import argparse +import datetime +import inspect +import os +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +from collections import OrderedDict + +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler + +from tqdm import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from magicanimate.models.unet_controlnet import UNet3DConditionModel +from magicanimate.models.controlnet import ControlNetModel +from magicanimate.models.appearance_encoder import AppearanceEncoderModel +from magicanimate.models.mutual_self_attention import ReferenceAttentionControl +from magicanimate.pipelines.pipeline_animation import AnimationPipeline +from magicanimate.utils.util import save_videos_grid +from accelerate.utils import set_seed + +from magicanimate.utils.videoreader import VideoReader + +from einops import rearrange, repeat + +import csv, pdb, glob +from safetensors import safe_open +import math +from pathlib import Path + +class MagicAnimate(): + def __init__(self, config="configs/prompts/animation.yaml") -> None: + print("Initializing MagicAnimate Pipeline...") + *_, func_args = inspect.getargvalues(inspect.currentframe()) + func_args = dict(func_args) + + config = OmegaConf.load(config) + + inference_config = OmegaConf.load(config.inference_config) + + motion_module = config.motion_module + + ### >>> create animation pipeline >>> ### + tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") + if config.pretrained_unet_path: + unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + else: + unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda() + self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) + self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) + if config.pretrained_vae_path is not None: + vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) + else: + vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + + ### Load controlnet + controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) + + vae.to(torch.float16) + unet.to(torch.float16) + text_encoder.to(torch.float16) + controlnet.to(torch.float16) + self.appearance_encoder.to(torch.float16) + + unet.enable_xformers_memory_efficient_attention() + self.appearance_encoder.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + + self.pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + # NOTE: UniPCMultistepScheduler + ).to("cuda") + + # 1. unet ckpt + # 1.1 motion module + motion_module_state_dict = torch.load(motion_module, map_location="cpu") + if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) + motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict + try: + # extra steps for self-trained models + state_dict = OrderedDict() + for key in motion_module_state_dict.keys(): + if key.startswith("module."): + _key = key.split("module.")[-1] + state_dict[_key] = motion_module_state_dict[key] + else: + state_dict[key] = motion_module_state_dict[key] + motion_module_state_dict = state_dict + del state_dict + missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + assert len(unexpected) == 0 + except: + _tmp_ = OrderedDict() + for key in motion_module_state_dict.keys(): + if "motion_modules" in key: + if key.startswith("unet."): + _key = key.split('unet.')[-1] + _tmp_[_key] = motion_module_state_dict[key] + else: + _tmp_[key] = motion_module_state_dict[key] + missing, unexpected = unet.load_state_dict(_tmp_, strict=False) + assert len(unexpected) == 0 + del _tmp_ + del motion_module_state_dict + + self.pipeline.to("cuda") + self.L = config.L + + print("Initialization Done!") + + def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512): + prompt = n_prompt = "" + random_seed = int(random_seed) + step = int(step) + guidance_scale = float(guidance_scale) + samples_per_video = [] + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + set_seed(random_seed) + else: + torch.seed() + + if motion_sequence.endswith('.mp4'): + control = VideoReader(motion_sequence).read() + if control[0].shape[0] != size: + control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] + control = np.array(control) + + if source_image.shape[0] != size: + source_image = np.array(Image.fromarray(source_image).resize((size, size))) + H, W, C = source_image.shape + + init_latents = None + original_length = control.shape[0] + if control.shape[0] % self.L > 0: + control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') + generator = torch.Generator(device=torch.device("cuda:0")) + generator.manual_seed(torch.initial_seed()) + sample = self.pipeline( + prompt, + negative_prompt = n_prompt, + num_inference_steps = step, + guidance_scale = guidance_scale, + width = W, + height = H, + video_length = len(control), + controlnet_condition = control, + init_latents = init_latents, + generator = generator, + appearance_encoder = self.appearance_encoder, + reference_control_writer = self.reference_control_writer, + reference_control_reader = self.reference_control_reader, + source_image = source_image, + ).videos + + source_images = np.array([source_image] * original_length) + source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 + samples_per_video.append(source_images) + + control = control / 255.0 + control = rearrange(control, "t h w c -> 1 c t h w") + control = torch.from_numpy(control) + samples_per_video.append(control[:, :, :original_length]) + + samples_per_video.append(sample[:, :, :original_length]) + + samples_per_video = torch.cat(samples_per_video) + + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + savedir = f"demo/outputs" + animation_path = f"{savedir}/{time_str}.mp4" + + os.makedirs(savedir, exist_ok=True) + save_videos_grid(samples_per_video, animation_path) + + return animation_path + \ No newline at end of file diff --git a/inputs/applications/driving/densepose/.nfs006c000000039d6800000023 b/inputs/applications/driving/densepose/.nfs006c000000039d6800000023 new file mode 100644 index 0000000000000000000000000000000000000000..b8949299d2c18a8eece9e9687f3a2b9230fcedae Binary files /dev/null and b/inputs/applications/driving/densepose/.nfs006c000000039d6800000023 differ diff --git a/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 b/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 new file mode 100644 index 0000000000000000000000000000000000000000..4c83ce4f197a580ef873be6b54ef320f0f228fb6 Binary files /dev/null and b/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 differ diff --git a/inputs/applications/driving/densepose/dancing2.mp4 b/inputs/applications/driving/densepose/dancing2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4af7ef94f3fa7791ff95a96c8850607cb070f978 Binary files /dev/null and b/inputs/applications/driving/densepose/dancing2.mp4 differ diff --git a/inputs/applications/driving/densepose/demo4.mp4 b/inputs/applications/driving/densepose/demo4.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..740d9c08967e243fd7a93d213ffd3d7e328c90a9 Binary files /dev/null and b/inputs/applications/driving/densepose/demo4.mp4 differ diff --git a/inputs/applications/driving/densepose/multi_dancing.mp4 b/inputs/applications/driving/densepose/multi_dancing.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7a51a8337a8c1fbde56548b21c81ba94958891a5 Binary files /dev/null and b/inputs/applications/driving/densepose/multi_dancing.mp4 differ diff --git a/inputs/applications/driving/densepose/running.mp4 b/inputs/applications/driving/densepose/running.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1189d254f34f82a7ab8bd7f96afb3b713c7f3547 Binary files /dev/null and b/inputs/applications/driving/densepose/running.mp4 differ diff --git a/inputs/applications/driving/densepose/running2.mp4 b/inputs/applications/driving/densepose/running2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a1ccccf5101089eeab3e671a9a0fa5315c889c9 Binary files /dev/null and b/inputs/applications/driving/densepose/running2.mp4 differ diff --git a/inputs/applications/source_image/0002.png b/inputs/applications/source_image/0002.png new file mode 100644 index 0000000000000000000000000000000000000000..fd236a8ba11aa5698a264f7b1a96df2ca47ac2fb Binary files /dev/null and b/inputs/applications/source_image/0002.png differ diff --git a/inputs/applications/source_image/dalle2.jpeg b/inputs/applications/source_image/dalle2.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..9dd175f1739865343c7cc0f5e50f88079f4cb064 Binary files /dev/null and b/inputs/applications/source_image/dalle2.jpeg differ diff --git a/inputs/applications/source_image/dalle8.jpeg b/inputs/applications/source_image/dalle8.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..04574c3fbc7956124c3639e092e3489089a2bd63 Binary files /dev/null and b/inputs/applications/source_image/dalle8.jpeg differ diff --git a/inputs/applications/source_image/demo4.png b/inputs/applications/source_image/demo4.png new file mode 100644 index 0000000000000000000000000000000000000000..110684f9baac2931306524f20439eb7e181d3dc7 Binary files /dev/null and b/inputs/applications/source_image/demo4.png differ diff --git a/inputs/applications/source_image/monalisa.png b/inputs/applications/source_image/monalisa.png new file mode 100644 index 0000000000000000000000000000000000000000..15eeb1c8b40499b44852bbdaca5ce7e8e7ac9db0 Binary files /dev/null and b/inputs/applications/source_image/monalisa.png differ diff --git a/inputs/applications/source_image/multi1_source.png b/inputs/applications/source_image/multi1_source.png new file mode 100644 index 0000000000000000000000000000000000000000..5791d81c7bf90ab56f187f65664cb2e1a023f996 Binary files /dev/null and b/inputs/applications/source_image/multi1_source.png differ diff --git a/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc b/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8912a67e911ee76135846f91566e2b0d69ba0810 Binary files /dev/null and b/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/attention.cpython-38.pyc b/magicanimate/models/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94f6f4c520a51b2c9ae95e33d482ddf2779c26ee Binary files /dev/null and b/magicanimate/models/__pycache__/attention.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/controlnet.cpython-38.pyc b/magicanimate/models/__pycache__/controlnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b8fc075e92b8d4c2aa21632aff66abe7125aa4c Binary files /dev/null and b/magicanimate/models/__pycache__/controlnet.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/embeddings.cpython-38.pyc b/magicanimate/models/__pycache__/embeddings.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f81d7c99111b9f5d8898740a43a337ec90807f8 Binary files /dev/null and b/magicanimate/models/__pycache__/embeddings.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/motion_module.cpython-38.pyc b/magicanimate/models/__pycache__/motion_module.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b49ee55eece55056b428c26781fa7f9433b0eb Binary files /dev/null and b/magicanimate/models/__pycache__/motion_module.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc b/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b632ff58ffa5eb480b07606db27c052da06ae38 Binary files /dev/null and b/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc b/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736e3e2834636c6c9f4b82dc0fa563febbf6c610 Binary files /dev/null and b/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/resnet.cpython-38.pyc b/magicanimate/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6379480913b4d142976095bba360b3963eff28b4 Binary files /dev/null and b/magicanimate/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc b/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d17a09a13e7ce520932d7fbc369f219fc740a469 Binary files /dev/null and b/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc b/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5e37e70c10eb4ba8b69aa6ef899c73ab4e1d514 Binary files /dev/null and b/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc differ diff --git a/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc b/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9f78a15cd972abe1a57ef7fea51af166f01420c Binary files /dev/null and b/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc differ diff --git a/magicanimate/models/appearance_encoder.py b/magicanimate/models/appearance_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..81736db540c7ddd029b639e5978fc5b63762013a --- /dev/null +++ b/magicanimate/models/appearance_encoder.py @@ -0,0 +1,1066 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.lora import LoRALinearLayer +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + PositionNet, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Identity(torch.nn.Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + def __init__(self, scale=None, *args, **kwargs) -> None: + super(Identity, self).__init__() + + def forward(self, input, *args, **kwargs): + return input + + + +class _LoRACompatibleLinear(nn.Module): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self): + pass + + def _unfuse_lora(self): + pass + + def forward(self, hidden_states, scale=None, lora_scale: int = 1): + return hidden_states + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()]) + self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None + self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity() + self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity() + self.up_blocks[3].attentions[2].proj_out = Identity() + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = PositionNet( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_block_additional_residuals) > 0 + and sample.shape == down_block_additional_residuals[0].shape + ): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) \ No newline at end of file diff --git a/magicanimate/models/attention.py b/magicanimate/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..796d861a6d6727aff179fcb452ca351d87148a7f --- /dev/null +++ b/magicanimate/models/attention.py @@ -0,0 +1,320 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from diffusers.models.attention import Attention as CrossAttention + +from einops import rearrange, repeat + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + # JH: need not repeat when a list of prompts are given + if encoder_hidden_states.shape[0] != hidden_states.shape[0]: + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + + unet_use_cross_frame_attention = None, + unet_use_temporal_attention = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + # SC-Attn + assert unet_use_cross_frame_attention is not None + if unet_use_cross_frame_attention: + self.attn1 = SparseCausalAttention2D( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + else: + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + self.use_ada_layer_norm_zero = False + + # Temp-Attn + assert unet_use_temporal_attention is not None + if unet_use_temporal_attention: + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + # if self.only_cross_attention: + # hidden_states = ( + # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + # ) + # else: + # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # pdb.set_trace() + if self.unet_use_cross_frame_attention: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/magicanimate/models/controlnet.py b/magicanimate/models/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8a74923e91e9211e59fa22d168983d5db95c0f7d --- /dev/null +++ b/magicanimate/models/controlnet.py @@ -0,0 +1,578 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from .embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from diffusers.models.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + ): + super().__init__() + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=attention_head_dim[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=attention_head_dim[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + ): + r""" + Instantiate Controlnet class from UNet2DConditionModel. + + Parameters: + unet (`UNet2DConditionModel`): + UNet model which weights are copied to the ControlNet. Note that all configuration options are also + copied where applicable. + """ + controlnet = cls( + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + # @property + # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # def attn_processors(self) -> Dict[str, AttentionProcessor]: + # r""" + # Returns: + # `dict` of attention processors: A dictionary containing all attention processors used in the model with + # indexed by its weight name. + # """ + # # set recursively + # processors = {} + + # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + # if hasattr(module, "set_processor"): + # processors[f"{name}.processor"] = module.processor + + # for sub_name, child in module.named_children(): + # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + # return processors + + # for name, module in self.named_children(): + # fn_recursive_add_processors(name, module, processors) + + # return processors + + # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + # r""" + # Parameters: + # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + # The instantiated processor class or a dictionary of processor classes that will be set as the processor + # of **all** `Attention` layers. + # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + # """ + # count = len(self.attn_processors.keys()) + + # if isinstance(processor, dict) and len(processor) != count: + # raise ValueError( + # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + # f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + # ) + + # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + # if hasattr(module, "set_processor"): + # if not isinstance(processor, dict): + # module.set_processor(processor) + # else: + # module.set_processor(processor.pop(f"{name}.processor")) + + # for sub_name, child in module.named_children(): + # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + # for name, module in self.named_children(): + # fn_recursive_attn_processor(name, module, processor) + + # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # def set_default_attn_processor(self): + # """ + # Disables custom attention processors and sets the default attention implementation. + # """ + # self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module \ No newline at end of file diff --git a/magicanimate/models/embeddings.py b/magicanimate/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..25fb3b2115a3559c2dc994820ee8297f96dbc146 --- /dev/null +++ b/magicanimate/models/embeddings.py @@ -0,0 +1,385 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional + +import numpy as np +import torch +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + return latent + self.pos_embed + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + elif act_fn == "gelu": + self.act = nn.GELU() + else: + raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + elif post_act_fn == "silu": + self.post_act = nn.SiLU() + elif post_act_fn == "mish": + self.post_act = nn.Mish() + elif post_act_fn == "gelu": + self.post_act = nn.GELU() + else: + raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + self.weight = self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning \ No newline at end of file diff --git a/magicanimate/models/motion_module.py b/magicanimate/models/motion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..b272b69d4fe57c7baf18574a300bab51c7bebb83 --- /dev/null +++ b/magicanimate/models/motion_module.py @@ -0,0 +1,334 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/guoyww/AnimateDiff +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward +from magicanimate.models.orig_attention import CrossAttention + +from einops import rearrange, repeat +import math + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def get_motion_module( + in_channels, + motion_module_type: str, + motion_module_kwargs: dict +): + if motion_module_type == "Vanilla": + return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", "Temporal_Self" ), + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + temporal_attention_dim_div = 1, + zero_initialize = True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): + hidden_states = input_tensor + hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + + num_layers, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, + video_length=video_length, + ) + hidden_states + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 24 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class VersatileAttention(CrossAttention): + def __init__( + self, + attention_mode = None, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_position_encoding_max_len + ) if (temporal_position_encoding and attention_mode == "Temporal") else None + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states + else: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/magicanimate/models/mutual_self_attention.py b/magicanimate/models/mutual_self_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..688332e6f6790d683d2fadbe544175d657cba65a --- /dev/null +++ b/magicanimate/models/mutual_self_attention.py @@ -0,0 +1,642 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Copyright (2023) MagicAnimate Authors +# +# ByteDance, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from ByteDance or +# its affiliates is strictly prohibited. + +import torch +import torch.nn.functional as F + +from einops import rearrange +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from diffusers.models.attention import BasicTransformerBlock +from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from .stable_diffusion_controlnet_reference import torch_dfs + + +class AttentionBase: + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + def after_step(self): + pass + + def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): + out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + # after step + self.after_step() + return out + + def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): + out = torch.einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) + return out + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + +class MutualSelfAttentionControl(AttentionBase): + + def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'): + """ + Mutual self-attention control for Stable-Diffusion MODEl + Args: + total_steps: the total number of steps + """ + super().__init__() + self.total_steps = total_steps + self.hijack = hijack_init_state + self.with_negative_guidance = with_negative_guidance + + # alpha: mutual self attention intensity + # TODO: make alpha learnable + self.alpha = appearance_control_alpha + self.GLOBAL_ATTN_QUEUE = [] + assert mode in ['enqueue', 'dequeue'] + MODE = mode + + def attn_batch(self, q, k, v, num_heads, **kwargs): + """ + Performing attention for a batch of queries, keys, and values + """ + b = q.shape[0] // num_heads + q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) + k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) + v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) + + sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") + attn = sim.softmax(-1) + out = torch.einsum("h i j, h j d -> h i d", attn, v) + out = rearrange(out, "h (b n) d -> b n (h d)", b=b) + return out + + def mutual_self_attn(self, q, k, v, num_heads, **kwargs): + q_tgt, q_src = q.chunk(2) + k_tgt, k_src = k.chunk(2) + v_tgt, v_src = v.chunk(2) + + # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \ + # self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha) + out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs) + out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs) + out = torch.cat([out_tgt, out_src], dim=0) + return out + + def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): + if self.MODE == 'dequeue' and len(self.kv_queue) > 0: + k_src, v_src = self.kv_queue.pop(0) + out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs) + return out + else: + self.kv_queue.append([k.clone(), v.clone()]) + return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) + + def get_queue(self): + return self.GLOBAL_ATTN_QUEUE + + def set_queue(self, attn_queue): + self.GLOBAL_ATTN_QUEUE = attn_queue + + def clear_queue(self): + self.GLOBAL_ATTN_QUEUE = [] + + def to(self, dtype): + self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE] + + def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): + """ + Attention forward function + """ + return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) + + +class ReferenceAttentionControl(): + + def __init__(self, + unet, + mode="write", + do_classifier_free_guidance=False, + attention_auto_machine_weight = float('inf'), + gn_auto_machine_weight = 1.0, + style_fidelity = 1.0, + reference_attn=True, + reference_adain=False, + fusion_blocks="midup", + batch_size=1, + ) -> None: + # 10. Modify self attention and group norm + self.unet = unet + assert mode in ["read", "write"] + assert fusion_blocks in ["midup", "full"] + self.reference_attn = reference_attn + self.reference_adain = reference_adain + self.fusion_blocks = fusion_blocks + self.register_reference_hooks( + mode, + do_classifier_free_guidance, + attention_auto_machine_weight, + gn_auto_machine_weight, + style_fidelity, + reference_attn, + reference_adain, + fusion_blocks, + batch_size=batch_size, + ) + + def register_reference_hooks( + self, + mode, + do_classifier_free_guidance, + attention_auto_machine_weight, + gn_auto_machine_weight, + style_fidelity, + reference_attn, + reference_adain, + dtype=torch.float16, + batch_size=1, + num_images_per_prompt=1, + device=torch.device("cpu"), + fusion_blocks='midup', + ): + MODE = mode + do_classifier_free_guidance = do_classifier_free_guidance + attention_auto_machine_weight = attention_auto_machine_weight + gn_auto_machine_weight = gn_auto_machine_weight + style_fidelity = style_fidelity + reference_attn = reference_attn + reference_adain = reference_adain + fusion_blocks = fusion_blocks + num_images_per_prompt = num_images_per_prompt + dtype=dtype + if do_classifier_free_guidance: + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16) + .to(device) + .bool() + ) + else: + uc_mask = ( + torch.Tensor([0] * batch_size * num_images_per_prompt * 2) + .to(device) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + video_length=None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank] + hidden_states_uc = self.attn1(norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + attention_mask=attention_mask) + hidden_states + hidden_states_c = hidden_states_uc.clone() + _uc_mask = uc_mask.clone() + if do_classifier_free_guidance: + if hidden_states.shape[0] != _uc_mask.shape[0]: + _uc_mask = ( + torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2)) + .to(device) + .bool() + ) + hidden_states_c[_uc_mask] = self.attn1( + norm_hidden_states[_uc_mask], + encoder_hidden_states=norm_hidden_states[_uc_mask], + attention_mask=attention_mask, + ) + hidden_states[_uc_mask] + hidden_states = hidden_states_c.clone() + + self.bank.clear() + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if self.reference_attn: + if self.fusion_blocks == "midup": + attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] + elif self.fusion_blocks == "full": + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if self.reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + def update(self, writer, dtype=torch.float16): + if self.reference_attn: + if self.fusion_blocks == "midup": + reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)] + writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)] + elif self.fusion_blocks == "full": + reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)] + writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)] + reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + for r, w in zip(reader_attn_modules, writer_attn_modules): + r.bank = [v.clone().to(dtype) for v in w.bank] + # w.bank.clear() + if self.reference_adain: + reader_gn_modules = [self.unet.mid_block] + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + reader_gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + reader_gn_modules.append(module) + + writer_gn_modules = [writer.unet.mid_block] + + down_blocks = writer.unet.down_blocks + for w, module in enumerate(down_blocks): + writer_gn_modules.append(module) + + up_blocks = writer.unet.up_blocks + for w, module in enumerate(up_blocks): + writer_gn_modules.append(module) + + for r, w in zip(reader_gn_modules, writer_gn_modules): + if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list): + r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank] + r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank] + else: + r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank] + r.var_bank = [v.clone().to(dtype) for v in w.var_bank] + + def clear(self): + if self.reference_attn: + if self.fusion_blocks == "midup": + reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] + elif self.fusion_blocks == "full": + reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] + reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + for r in reader_attn_modules: + r.bank.clear() + if self.reference_adain: + reader_gn_modules = [self.unet.mid_block] + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + reader_gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + reader_gn_modules.append(module) + + for r in reader_gn_modules: + r.mean_bank.clear() + r.var_bank.clear() + \ No newline at end of file diff --git a/magicanimate/models/orig_attention.py b/magicanimate/models/orig_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3eba094cd37665839911535d5fb7bb76a0cb18 --- /dev/null +++ b/magicanimate/models/orig_attention.py @@ -0,0 +1,988 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + norm_num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + self._use_memory_efficient_attention_xformers = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + scale = 1 / math.sqrt(self.channels / self.num_heads) + + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + + if self._use_memory_efficient_attention_xformers: + # Memory efficient attention + hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + hidden_states = hidden_states.to(query_proj.dtype) + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + hidden_states = torch.bmm(attention_probs, value_proj) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) \ No newline at end of file diff --git a/magicanimate/models/resnet.py b/magicanimate/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f1231ee0668f1989ec9075b8f3c741cdc34816f5 --- /dev/null +++ b/magicanimate/models/resnet.py @@ -0,0 +1,212 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/guoyww/AnimateDiff + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/magicanimate/models/stable_diffusion_controlnet_reference.py b/magicanimate/models/stable_diffusion_controlnet_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..436e74cce0b7705f72db463376a384c895bcc0ac --- /dev/null +++ b/magicanimate/models/stable_diffusion_controlnet_reference.py @@ -0,0 +1,840 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers import StableDiffusionControlNetPipeline +from diffusers.models import ControlNetModel +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import logging +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> from diffusers import UniPCMultistepScheduler + >>> from diffusers.utils import load_image + + >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + + >>> # get canny image + >>> image = cv2.Canny(np.array(input_image), 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + controlnet=controlnet, + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) + + >>> result_img = pipe(ref_image=input_image, + prompt="1girl", + image=canny_image, + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] + + >>> result_img.show() + ``` +""" + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline): + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + refimage = refimage.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + ref_image_latents = [ + self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + ref_image_latents = torch.cat(ref_image_latents, dim=0) + else: + ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) + ref_image_latents = self.vae.config.scaling_factor * ref_image_latents + + # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method + if ref_image_latents.shape[0] < batch_size: + if not batch_size % ref_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) + + ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents + + # aligning device to prevent device errors when concating it with the latent model input + ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + return ref_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + ref_image (`torch.FloatTensor`, `PIL.Image.Image`): + The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can + also be accepted as an image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + attention_auto_machine_weight (`float`): + Weight of using reference query for self attention's context. + If attention_auto_machine_weight=1.0, use reference query for all self attention's context. + gn_auto_machine_weight (`float`): + Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. + style_fidelity (`float`): + style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, + elif style_fidelity=0.0, prompt more important, else balanced. + reference_attn (`bool`): + Whether to use reference query for self attention's context. + reference_adain (`bool`): + Whether to use reference adain. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Preprocess reference image + ref_image = self.prepare_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare reference latent variables + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Modify self attention and group norm + MODE = "write" + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) + .type_as(ref_image_latents) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.detach().clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + if attention_auto_machine_weight > self.attn_weight: + attn_output_uc = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + # attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_c = attn_output_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + attn_output_c[uc_mask] = self.attn1( + norm_hidden_states[uc_mask], + encoder_hidden_states=norm_hidden_states[uc_mask], + **cross_attention_kwargs, + ) + attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc + self.bank.clear() + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if reference_attn: + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # ref only part + noise = randn_tensor( + ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + self.unet( + ref_xt, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/magicanimate/models/unet.py b/magicanimate/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..09e5e11fdb9c985fd8627ce90d56791d0d8f469c --- /dev/null +++ b/magicanimate/models/unet.py @@ -0,0 +1,508 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/guoyww/AnimateDiff + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import json +import pdb + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + + # Additional + use_motion_module = False, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type = None, + motion_module_kwargs = {}, + unet_use_cross_frame_attention = None, + unet_use_temporal_attention = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") + + params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] + print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") + + return model diff --git a/magicanimate/models/unet_3d_blocks.py b/magicanimate/models/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..55cf1559598a34dcfcc48756a71549f5b0fd9d7b --- /dev/null +++ b/magicanimate/models/unet_3d_blocks.py @@ -0,0 +1,751 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/guoyww/AnimateDiff + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn + +from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from .motion_module import get_motion_module + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/magicanimate/models/unet_controlnet.py b/magicanimate/models/unet_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0ccd9cad7d1f9e02aed3592c07a064d2a04dbd69 --- /dev/null +++ b/magicanimate/models/unet_controlnet.py @@ -0,0 +1,525 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from magicanimate.models.unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + + # Additional + use_motion_module = False, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type = None, + motion_module_kwargs = {}, + unet_use_cross_frame_attention = None, + unet_use_temporal_attention = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + # for controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + # config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") + + params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] + print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") + + return model diff --git a/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc b/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9739b1d9a8220f44aadf90537dc32314d648738 Binary files /dev/null and b/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc differ diff --git a/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc b/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08d058d7ca8b2a8532ad5e97e109d7190c16127e Binary files /dev/null and b/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc differ diff --git a/magicanimate/pipelines/__pycache__/context.cpython-38.pyc b/magicanimate/pipelines/__pycache__/context.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b37b304d6644940cda2985ef107a6f6aece61bf3 Binary files /dev/null and b/magicanimate/pipelines/__pycache__/context.cpython-38.pyc differ diff --git a/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc b/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae41088706b703811c27a47becd7d34519045571 Binary files /dev/null and b/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc differ diff --git a/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc b/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b79af89fea8aa02000c36379ff57c831f174390 Binary files /dev/null and b/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc differ diff --git a/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc b/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..016270e040ee9631dbbc932680ce7b5c5c868e91 Binary files /dev/null and b/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc differ diff --git a/magicanimate/pipelines/animation.py b/magicanimate/pipelines/animation.py new file mode 100644 index 0000000000000000000000000000000000000000..899583ed4ac71dac05e1971553696d3ccc8cde81 --- /dev/null +++ b/magicanimate/pipelines/animation.py @@ -0,0 +1,282 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Copyright (2023) MagicAnimate Authors +# +# ByteDance, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from ByteDance or +# its affiliates is strictly prohibited. +import argparse +import datetime +import inspect +import os +import random +import numpy as np + +from PIL import Image +from omegaconf import OmegaConf +from collections import OrderedDict + +import torch +import torch.distributed as dist + +from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler + +from tqdm import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from magicanimate.models.unet_controlnet import UNet3DConditionModel +from magicanimate.models.controlnet import ControlNetModel +from magicanimate.models.appearance_encoder import AppearanceEncoderModel +from magicanimate.models.mutual_self_attention import ReferenceAttentionControl +from magicanimate.pipelines.pipeline_animation import AnimationPipeline +from magicanimate.utils.util import save_videos_grid +from magicanimate.utils.dist_tools import distributed_init +from accelerate.utils import set_seed + +from magicanimate.utils.videoreader import VideoReader + +from einops import rearrange + +from pathlib import Path + + +def main(args): + + *_, func_args = inspect.getargvalues(inspect.currentframe()) + func_args = dict(func_args) + + config = OmegaConf.load(args.config) + + # Initialize distributed training + device = torch.device(f"cuda:{args.rank}") + dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} + + if config.savename is None: + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + savedir = f"samples/{Path(args.config).stem}-{time_str}" + else: + savedir = f"samples/{config.savename}" + + if args.dist: + dist.broadcast_object_list([savedir], 0) + dist.barrier() + + if args.rank == 0: + os.makedirs(savedir, exist_ok=True) + + inference_config = OmegaConf.load(config.inference_config) + + motion_module = config.motion_module + + ### >>> create animation pipeline >>> ### + tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") + if config.pretrained_unet_path: + unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + else: + unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) + reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) + reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) + if config.pretrained_vae_path is not None: + vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) + else: + vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + + ### Load controlnet + controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) + + unet.enable_xformers_memory_efficient_attention() + appearance_encoder.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + + vae.to(torch.float16) + unet.to(torch.float16) + text_encoder.to(torch.float16) + appearance_encoder.to(torch.float16) + controlnet.to(torch.float16) + + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + # NOTE: UniPCMultistepScheduler + ) + + # 1. unet ckpt + # 1.1 motion module + motion_module_state_dict = torch.load(motion_module, map_location="cpu") + if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) + motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict + try: + # extra steps for self-trained models + state_dict = OrderedDict() + for key in motion_module_state_dict.keys(): + if key.startswith("module."): + _key = key.split("module.")[-1] + state_dict[_key] = motion_module_state_dict[key] + else: + state_dict[key] = motion_module_state_dict[key] + motion_module_state_dict = state_dict + del state_dict + missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + assert len(unexpected) == 0 + except: + _tmp_ = OrderedDict() + for key in motion_module_state_dict.keys(): + if "motion_modules" in key: + if key.startswith("unet."): + _key = key.split('unet.')[-1] + _tmp_[_key] = motion_module_state_dict[key] + else: + _tmp_[key] = motion_module_state_dict[key] + missing, unexpected = unet.load_state_dict(_tmp_, strict=False) + assert len(unexpected) == 0 + del _tmp_ + del motion_module_state_dict + + pipeline.to(device) + ### <<< create validation pipeline <<< ### + + random_seeds = config.get("seed", [-1]) + random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds + + # input test videos (either source video/ conditions) + + test_videos = config.video_path + source_images = config.source_image + num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) + + # read size, step from yaml file + sizes = [config.size] * len(test_videos) + steps = [config.S] * len(test_videos) + + config.random_seed = [] + prompt = n_prompt = "" + for idx, (source_image, test_video, random_seed, size, step) in tqdm( + enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), + total=len(test_videos), + disable=(args.rank!=0) + ): + samples_per_video = [] + samples_per_clip = [] + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + set_seed(random_seed) + else: + torch.seed() + config.random_seed.append(torch.initial_seed()) + + if test_video.endswith('.mp4'): + control = VideoReader(test_video).read() + if control[0].shape[0] != size: + control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] + if config.max_length is not None: + control = control[config.offset: (config.offset+config.max_length)] + control = np.array(control) + + if source_image.endswith(".mp4"): + source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) + else: + source_image = np.array(Image.open(source_image).resize((size, size))) + H, W, C = source_image.shape + + print(f"current seed: {torch.initial_seed()}") + init_latents = None + + # print(f"sampling {prompt} ...") + original_length = control.shape[0] + if control.shape[0] % config.L > 0: + control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') + generator = torch.Generator(device=torch.device("cuda:0")) + generator.manual_seed(torch.initial_seed()) + sample = pipeline( + prompt, + negative_prompt = n_prompt, + num_inference_steps = config.steps, + guidance_scale = config.guidance_scale, + width = W, + height = H, + video_length = len(control), + controlnet_condition = control, + init_latents = init_latents, + generator = generator, + num_actual_inference_steps = num_actual_inference_steps, + appearance_encoder = appearance_encoder, + reference_control_writer = reference_control_writer, + reference_control_reader = reference_control_reader, + source_image = source_image, + **dist_kwargs, + ).videos + + if args.rank == 0: + source_images = np.array([source_image] * original_length) + source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 + samples_per_video.append(source_images) + + control = control / 255.0 + control = rearrange(control, "t h w c -> 1 c t h w") + control = torch.from_numpy(control) + samples_per_video.append(control[:, :, :original_length]) + + samples_per_video.append(sample[:, :, :original_length]) + + samples_per_video = torch.cat(samples_per_video) + + video_name = os.path.basename(test_video)[:-4] + source_name = os.path.basename(config.source_image[idx]).split(".")[0] + save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") + save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") + + if config.save_individual_videos: + save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") + save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") + + if args.dist: + dist.barrier() + + if args.rank == 0: + OmegaConf.save(config, f"{savedir}/config.yaml") + + +def distributed_main(device_id, args): + args.rank = device_id + args.device_id = device_id + if torch.cuda.is_available(): + torch.cuda.set_device(args.device_id) + torch.cuda.init() + distributed_init(args) + main(args) + + +def run(args): + + if args.dist: + args.world_size = max(1, torch.cuda.device_count()) + assert args.world_size <= torch.cuda.device_count() + + if args.world_size > 0 and torch.cuda.device_count() > 1: + port = random.randint(10000, 20000) + args.init_method = f"tcp://localhost:{port}" + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args,), + nprocs=args.world_size, + ) + else: + main(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--dist", action="store_true", required=False) + parser.add_argument("--rank", type=int, default=0, required=False) + parser.add_argument("--world_size", type=int, default=1, required=False) + + args = parser.parse_args() + run(args) diff --git a/magicanimate/pipelines/context.py b/magicanimate/pipelines/context.py new file mode 100644 index 0000000000000000000000000000000000000000..59fb8864ad14844e7c018234f9030f12e801707b --- /dev/null +++ b/magicanimate/pipelines/context.py @@ -0,0 +1,76 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main +import numpy as np +from typing import Callable, Optional, List + + +def ordered_halving(val): + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + + return as_int / (1 << 64) + + +def uniform( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + if num_frames <= context_size: + yield list(range(num_frames)) + return + + context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] + + +def get_context_scheduler(name: str) -> Callable: + if name == "uniform": + return uniform + else: + raise ValueError(f"Unknown context_overlap policy {name}") + + +def get_total_steps( + scheduler, + timesteps: List[int], + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + return sum( + len( + list( + scheduler( + i, + num_steps, + num_frames, + context_size, + context_stride, + context_overlap, + ) + ) + ) + for i in range(len(timesteps)) + ) diff --git a/magicanimate/pipelines/pipeline_animation.py b/magicanimate/pipelines/pipeline_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..08a77bacc4f11f66bee1703e76cb1df21368a234 --- /dev/null +++ b/magicanimate/pipelines/pipeline_animation.py @@ -0,0 +1,799 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TODO: +1. support multi-controlnet +2. [DONE] support DDIM inversion +3. support Prompt-to-prompt +""" + +import inspect, math +from typing import Callable, List, Optional, Union +from dataclasses import dataclass +from PIL import Image +import numpy as np +import torch +import torch.distributed as dist +from tqdm import tqdm +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput + +from einops import rearrange + +from magicanimate.models.unet_controlnet import UNet3DConditionModel +from magicanimate.models.controlnet import ControlNetModel +from magicanimate.models.mutual_self_attention import ReferenceAttentionControl +from magicanimate.pipelines.context import ( + get_context_scheduler, + get_total_steps +) +from magicanimate.utils.util import get_tensor_interpolation_method + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class AnimationPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class AnimationPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + controlnet: ControlNetModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents, rank, decoder_consistency=None): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0]), disable=(rank!=0)): + if decoder_consistency is not None: + video.append(decoder_consistency(latents[frame_idx:frame_idx+1])) + else: + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, clip_length=16): + shape = (batch_size, num_channels_latents, clip_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + + latents = latents.repeat(1, 1, video_length//clip_length, 1, 1) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_condition(self, condition, num_videos_per_prompt, device, dtype, do_classifier_free_guidance): + # prepare conditions for controlnet + condition = torch.from_numpy(condition.copy()).to(device=device, dtype=dtype) / 255.0 + condition = torch.stack([condition for _ in range(num_videos_per_prompt)], dim=0) + condition = rearrange(condition, 'b f h w c -> (b f) c h w').clone() + if do_classifier_free_guidance: + condition = torch.cat([condition] * 2) + return condition + + def next_step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + eta=0., + verbose=False + ): + """ + Inverse sampling for DDIM Inversion + """ + if verbose: + print("timestep: ", timestep) + next_step = timestep + timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output + x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir + return x_next, pred_x0 + + @torch.no_grad() + def images2latents(self, images, dtype): + """ + Convert RGB image to VAE latents + """ + device = self._execution_device + images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1 + images = rearrange(images, "f h w c -> f c h w").to(device) + latents = [] + for frame_idx in range(images.shape[0]): + latents.append(self.vae.encode(images[frame_idx:frame_idx+1])['latent_dist'].mean * 0.18215) + latents = torch.cat(latents) + return latents + + @torch.no_grad() + def invert( + self, + image: torch.Tensor, + prompt, + num_inference_steps=20, + num_actual_inference_steps=10, + eta=0.0, + return_intermediates=False, + **kwargs): + """ + Adapted from: https://github.com/Yujun-Shi/DragDiffusion/blob/main/drag_pipeline.py#L440 + invert a real image into noise map with determinisc DDIM inversion + """ + device = self._execution_device + batch_size = image.shape[0] + if isinstance(prompt, list): + if batch_size == 1: + image = image.expand(len(prompt), -1, -1, -1) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] + print("input text embeddings :", text_embeddings.shape) + # define initial latents + latents = self.images2latents(image) + + print("latents shape: ", latents.shape) + # interative sampling + self.scheduler.set_timesteps(num_inference_steps) + print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + latents_list = [latents] + pred_x0_list = [latents] + for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): + + if num_actual_inference_steps is not None and i >= num_actual_inference_steps: + continue + model_inputs = latents + + # predict the noise + # NOTE: the u-net here is UNet3D, therefore the model_inputs need to be of shape (b c f h w) + model_inputs = rearrange(model_inputs, "f c h w -> 1 c f h w") + noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample + noise_pred = rearrange(noise_pred, "b c f h w -> (b f) c h w") + + # compute the previous noise sample x_t-1 -> x_t + latents, pred_x0 = self.next_step(noise_pred, t, latents) + latents_list.append(latents) + pred_x0_list.append(pred_x0) + + if return_intermediates: + # return the intermediate laters during inversion + return latents, latents_list + return latents + + def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, device ): + if interpolation_factor < 2: + return latents + + new_latents = torch.zeros( + (latents.shape[0],latents.shape[1],((latents.shape[2]-1) * interpolation_factor)+1, latents.shape[3],latents.shape[4]), + device=latents.device, + dtype=latents.dtype, + ) + + org_video_length = latents.shape[2] + rate = [i/interpolation_factor for i in range(interpolation_factor)][1:] + + new_index = 0 + + v0 = None + v1 = None + + for i0,i1 in zip( range( org_video_length ),range( org_video_length )[1:] ): + v0 = latents[:,:,i0,:,:] + v1 = latents[:,:,i1,:,:] + + new_latents[:,:,new_index,:,:] = v0 + new_index += 1 + + for f in rate: + v = get_tensor_interpolation_method()(v0.to(device=device),v1.to(device=device),f) + new_latents[:,:,new_index,:,:] = v.to(latents.device) + new_index += 1 + + new_latents[:,:,new_index,:,:] = v1 + new_index += 1 + + return new_latents + + def select_controlnet_res_samples(self, controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f): + _down_block_res_samples = [] + _mid_block_res_sample = [] + for i in np.concatenate(np.array(context)): + _down_block_res_samples.append(controlnet_res_samples_cache_dict[i][0]) + _mid_block_res_sample.append(controlnet_res_samples_cache_dict[i][1]) + down_block_res_samples = [[] for _ in range(len(controlnet_res_samples_cache_dict[i][0]))] + for res_t in _down_block_res_samples: + for i, res in enumerate(res_t): + down_block_res_samples[i].append(res) + down_block_res_samples = [torch.cat(res) for res in down_block_res_samples] + mid_block_res_sample = torch.cat(_mid_block_res_sample) + + # reshape controlnet output to match the unet3d inputs + b = b // 2 if do_classifier_free_guidance else b + _down_block_res_samples = [] + for sample in down_block_res_samples: + sample = rearrange(sample, '(b f) c h w -> b c f h w', b=b, f=f) + if do_classifier_free_guidance: + sample = sample.repeat(2, 1, 1, 1, 1) + _down_block_res_samples.append(sample) + down_block_res_samples = _down_block_res_samples + mid_block_res_sample = rearrange(mid_block_res_sample, '(b f) c h w -> b c f h w', b=b, f=f) + if do_classifier_free_guidance: + mid_block_res_sample = mid_block_res_sample.repeat(2, 1, 1, 1, 1) + + return down_block_res_samples, mid_block_res_sample + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + controlnet_condition: list = None, + controlnet_conditioning_scale: float = 1.0, + context_frames: int = 16, + context_stride: int = 1, + context_overlap: int = 4, + context_batch_size: int = 1, + context_schedule: str = "uniform", + init_latents: Optional[torch.FloatTensor] = None, + num_actual_inference_steps: Optional[int] = None, + appearance_encoder = None, + reference_control_writer = None, + reference_control_reader = None, + source_image: str = None, + decoder_consistency = None, + **kwargs, + ): + """ + New args: + - controlnet_condition : condition map (e.g., depth, canny, keypoints) for controlnet + - controlnet_conditioning_scale : conditioning scale for controlnet + - init_latents : initial latents to begin with (used along with invert()) + - num_actual_inference_steps : number of actual inference steps (while total steps is num_inference_steps) + """ + controlnet = self.controlnet + + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + # batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 + if latents is not None: + batch_size = latents.shape[0] + if isinstance(prompt, list): + batch_size = len(prompt) + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size + if negative_prompt is not None: + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + text_embeddings = torch.cat([text_embeddings] * context_batch_size) + + reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', batch_size=context_batch_size) + reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', batch_size=context_batch_size) + + is_dist_initialized = kwargs.get("dist", False) + rank = kwargs.get("rank", 0) + world_size = kwargs.get("world_size", 1) + + # Prepare video + assert num_videos_per_prompt == 1 # FIXME: verify if num_videos_per_prompt > 1 works + assert batch_size == 1 # FIXME: verify if batch_size > 1 works + control = self.prepare_condition( + condition=controlnet_condition, + device=device, + dtype=controlnet.dtype, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + controlnet_uncond_images, controlnet_cond_images = control.chunk(2) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + if init_latents is not None: + latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length) + else: + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Prepare text embeddings for controlnet + controlnet_text_embeddings = text_embeddings.repeat_interleave(video_length, 0) + _, controlnet_text_embeddings_c = controlnet_text_embeddings.chunk(2) + + controlnet_res_samples_cache_dict = {i:None for i in range(video_length)} + + # For img2img setting + if num_actual_inference_steps is None: + num_actual_inference_steps = num_inference_steps + + if isinstance(source_image, str): + ref_image_latents = self.images2latents(np.array(Image.open(source_image).resize((width, height)))[None, :], latents_dtype).cuda() + elif isinstance(source_image, np.ndarray): + ref_image_latents = self.images2latents(source_image[None, :], latents_dtype).cuda() + + context_scheduler = get_context_scheduler(context_schedule) + + # Denoising loop + for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank!=0)): + if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps: + continue + + noise_pred = torch.zeros( + (latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:]), + device=latents.device, + dtype=latents.dtype, + ) + counter = torch.zeros( + (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype + ) + + appearance_encoder( + ref_image_latents.repeat(context_batch_size * (2 if do_classifier_free_guidance else 1), 1, 1, 1), + t, + encoder_hidden_states=text_embeddings, + return_dict=False, + ) + + context_queue = list(context_scheduler( + 0, num_inference_steps, latents.shape[2], context_frames, context_stride, 0 + )) + num_context_batches = math.ceil(len(context_queue) / context_batch_size) + for i in range(num_context_batches): + context = context_queue[i*context_batch_size: (i+1)*context_batch_size] + # expand the latents if we are doing classifier free guidance + controlnet_latent_input = ( + torch.cat([latents[:, :, c] for c in context]) + .to(device) + ) + controlnet_latent_input = self.scheduler.scale_model_input(controlnet_latent_input, t) + + # prepare inputs for controlnet + b, c, f, h, w = controlnet_latent_input.shape + controlnet_latent_input = rearrange(controlnet_latent_input, "b c f h w -> (b f) c h w") + + # controlnet inference + down_block_res_samples, mid_block_res_sample = self.controlnet( + controlnet_latent_input, + t, + encoder_hidden_states=torch.cat([controlnet_text_embeddings_c[c] for c in context]), + controlnet_cond=torch.cat([controlnet_cond_images[c] for c in context]), + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + for j, k in enumerate(np.concatenate(np.array(context))): + controlnet_res_samples_cache_dict[k] = ([sample[j:j+1] for sample in down_block_res_samples], mid_block_res_sample[j:j+1]) + + context_queue = list(context_scheduler( + 0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap + )) + + num_context_batches = math.ceil(len(context_queue) / context_batch_size) + global_context = [] + for i in range(num_context_batches): + global_context.append(context_queue[i*context_batch_size: (i+1)*context_batch_size]) + + for context in global_context[rank::world_size]: + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents[:, :, c] for c in context]) + .to(device) + .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1) + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + b, c, f, h, w = latent_model_input.shape + down_block_res_samples, mid_block_res_sample = self.select_controlnet_res_samples( + controlnet_res_samples_cache_dict, + context, + do_classifier_free_guidance, + b, f + ) + + reference_control_reader.update(reference_control_writer) + + # predict the noise residual + pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings[:b], + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + reference_control_reader.clear() + + pred_uc, pred_c = pred.chunk(2) + pred = torch.cat([pred_uc.unsqueeze(0), pred_c.unsqueeze(0)]) + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + pred[:, j] + counter[:, :, c] = counter[:, :, c] + 1 + + if is_dist_initialized: + noise_pred_gathered = [torch.zeros_like(noise_pred) for _ in range(world_size)] + if rank == 0: + dist.gather(tensor=noise_pred, gather_list=noise_pred_gathered, dst=0) + else: + dist.gather(tensor=noise_pred, gather_list=[], dst=0) + dist.barrier() + + if rank == 0: + for k in range(1, world_size): + for context in global_context[k::world_size]: + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_gathered[k][:, :, c] + counter[:, :, c] = counter[:, :, c] + 1 + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if is_dist_initialized: + dist.broadcast(latents, 0) + dist.barrier() + + reference_control_writer.clear() + + interpolation_factor = 1 + latents = self.interpolate_latents(latents, interpolation_factor, device) + # Post-processing + video = self.decode_latents(latents, rank, decoder_consistency=decoder_consistency) + + if is_dist_initialized: + dist.barrier() + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return AnimationPipelineOutput(videos=video) diff --git a/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc b/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc362cbaa73edc4be0ab32379c983a089477b803 Binary files /dev/null and b/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc differ diff --git a/magicanimate/utils/__pycache__/util.cpython-38.pyc b/magicanimate/utils/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203fd0d0bf60eef6d6657a785b6208b3476b376c Binary files /dev/null and b/magicanimate/utils/__pycache__/util.cpython-38.pyc differ diff --git a/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc b/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b9d98e5208b4bee650f5fa487628acd82be67e Binary files /dev/null and b/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc differ diff --git a/magicanimate/utils/dist_tools.py b/magicanimate/utils/dist_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..0761c6d197a14b0d77d167d0e27bde3f4cc2ec6b --- /dev/null +++ b/magicanimate/utils/dist_tools.py @@ -0,0 +1,105 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Copyright (2023) MagicAnimate Authors +# +# ByteDance, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from ByteDance or +# its affiliates is strictly prohibited. +import os +import socket +import warnings +import torch +from torch import distributed as dist + + +def distributed_init(args): + + if dist.is_initialized(): + warnings.warn("Distributed is already initialized, cannot initialize twice!") + args.rank = dist.get_rank() + else: + print( + f"Distributed Init (Rank {args.rank}): " + f"{args.init_method}" + ) + dist.init_process_group( + backend='nccl', + init_method=args.init_method, + world_size=args.world_size, + rank=args.rank, + ) + print( + f"Initialized Host {socket.gethostname()} as Rank " + f"{args.rank}" + ) + + if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ: + # Set for onboxdataloader support + split = args.init_method.split("//") + assert len(split) == 2, ( + "host url for distributed should be split by '//' " + + "into exactly two elements" + ) + + split = split[1].split(":") + assert ( + len(split) == 2 + ), "host url should be of the form :" + os.environ["MASTER_ADDR"] = split[0] + os.environ["MASTER_PORT"] = split[1] + + # perform a dummy all-reduce to initialize the NCCL communicator + dist.all_reduce(torch.zeros(1).cuda()) + + suppress_output(is_master()) + args.rank = dist.get_rank() + return args.rank + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_nccl_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_master(): + return get_rank() == 0 + + +def synchronize(): + if dist.is_initialized(): + dist.barrier() + + +def suppress_output(is_master): + """Suppress printing on the current device. Force printing with `force=True`.""" + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + import warnings + + builtin_warn = warnings.warn + + def warn(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_warn(*args, **kwargs) + + # Log warnings only once + warnings.warn = warn + warnings.simplefilter("once", UserWarning) \ No newline at end of file diff --git a/magicanimate/utils/util.py b/magicanimate/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..76f2245cf13367694ad00faf3be721de70ff57bd --- /dev/null +++ b/magicanimate/utils/util.py @@ -0,0 +1,138 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Adapted from https://github.com/guoyww/AnimateDiff +import os +import imageio +import numpy as np + +import torch +import torchvision + +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + +def save_images_grid(images: torch.Tensor, path: str): + assert images.shape[2] == 1 # no time dimension + images = images.squeeze(2) + grid = torchvision.utils.make_grid(images) + grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8) + os.makedirs(os.path.dirname(path), exist_ok=True) + Image.fromarray(grid).save(path) + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents + + +def video2images(path, step=4, length=16, start=0): + reader = imageio.get_reader(path) + frames = [] + for frame in reader: + frames.append(np.array(frame)) + frames = frames[start::step][:length] + return frames + + +def images2video(video, path, fps=8): + imageio.mimsave(path, video, fps=fps) + return + + +tensor_interpolation = None + +def get_tensor_interpolation_method(): + return tensor_interpolation + +def set_tensor_interpolation_method(is_slerp): + global tensor_interpolation + tensor_interpolation = slerp if is_slerp else linear + +def linear(v1, v2, t): + return (1.0 - t) * v1 + t * v2 + +def slerp( + v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 +) -> torch.Tensor: + u0 = v0 / v0.norm() + u1 = v1 / v1.norm() + dot = (u0 * u1).sum() + if dot.abs() > DOT_THRESHOLD: + #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') + return (1.0 - t) * v0 + t * v1 + omega = dot.acos() + return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() \ No newline at end of file diff --git a/magicanimate/utils/videoreader.py b/magicanimate/utils/videoreader.py new file mode 100644 index 0000000000000000000000000000000000000000..e26a37e01a1deb223e0ca8d24beb306d1703d01e --- /dev/null +++ b/magicanimate/utils/videoreader.py @@ -0,0 +1,157 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +# Copyright 2022 ByteDance and/or its affiliates. +# +# Copyright (2022) PV3D Authors +# +# ByteDance, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from ByteDance or +# its affiliates is strictly prohibited. +import av, gc +import torch +import warnings +import numpy as np + + +_CALLED_TIMES = 0 +_GC_COLLECTION_INTERVAL = 20 + + +# remove warnings +av.logging.set_level(av.logging.ERROR) + + +class VideoReader(): + """ + Simple wrapper around PyAV that exposes a few useful functions for + dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries. + Acknowledgement: Codes are borrowed from Bruno Korbar + """ + def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False): + """ + Arguments: + video_path (str): path or byte of the video to be loaded + """ + self.container = av.open(video) + self.num_frames = num_frames + self.bi_frame = bi_frame + + self.resampler = None + if audio_resample_rate is not None: + self.resampler = av.AudioResampler(rate=audio_resample_rate) + + if self.container.streams.video: + # enable multi-threaded video decoding + if decode_lossy: + warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) + self.container.streams.video[0].thread_type = 'AUTO' + self.video_stream = self.container.streams.video[0] + else: + self.video_stream = None + + self.fps = self._get_video_frame_rate() + + def seek(self, pts, backward=True, any_frame=False): + stream = self.video_stream + self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream) + + def _occasional_gc(self): + # there are a lot of reference cycles in PyAV, so need to manually call + # the garbage collector from time to time + global _CALLED_TIMES, _GC_COLLECTION_INTERVAL + _CALLED_TIMES += 1 + if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: + gc.collect() + + def _read_video(self, offset): + self._occasional_gc() + + pts = self.container.duration * offset + time_ = pts / float(av.time_base) + self.container.seek(int(pts)) + + video_frames = [] + count = 0 + for _, frame in enumerate(self._iter_frames()): + if frame.pts * frame.time_base >= time_: + video_frames.append(frame) + if count >= self.num_frames - 1: + break + count += 1 + return video_frames + + def _iter_frames(self): + for packet in self.container.demux(self.video_stream): + for frame in packet.decode(): + yield frame + + def _compute_video_stats(self): + if self.video_stream is None or self.container is None: + return 0 + num_of_frames = self.container.streams.video[0].frames + if num_of_frames == 0: + num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base) + self.seek(0, backward=False) + count = 0 + time_base = 512 + for p in self.container.decode(video=0): + count = count + 1 + if count == 1: + start_pts = p.pts + elif count == 2: + time_base = p.pts - start_pts + break + return start_pts, time_base, num_of_frames + + def _get_video_frame_rate(self): + return float(self.container.streams.video[0].guessed_rate) + + def sample(self, debug=False): + + if self.container is None: + raise RuntimeError('video stream not found') + sample = dict() + _, _, total_num_frames = self._compute_video_stats() + offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item() + video_frames = self._read_video(offset/total_num_frames) + video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) + sample["frames"] = video_frames + sample["frame_idx"] = [offset] + + if self.bi_frame: + frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)] + frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)] + frames.sort() + video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]]) + Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)] + sample["frames"] = video_frames + sample["real_t"] = torch.tensor(Ts, dtype=torch.float32) + sample["frame_idx"] = [offset+min(frames), offset+max(frames)] + return sample + + return sample + + def read_frames(self, frame_indices): + self.num_frames = frame_indices[1] - frame_indices[0] + video_frames = self._read_video(frame_indices[0]/self.get_num_frames()) + video_frames = np.array([ + np.uint8(video_frames[0].to_rgb().to_ndarray()), + np.uint8(video_frames[-1].to_rgb().to_ndarray()) + ]) + return video_frames + + def read(self): + video_frames = self._read_video(0) + video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) + return video_frames + + def get_num_frames(self): + _, _, total_num_frames = self._compute_video_stats() + return total_num_frames \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..20645e641240cb419f5fc66c14c1447e91daf669 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..82f59c87c395ea9454a2578f8b5e84cd361f4342 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,127 @@ +absl-py==1.4.0 +accelerate==0.22.0 +aiofiles==23.2.1 +aiohttp==3.8.5 +aiosignal==1.3.1 +altair==5.0.1 +annotated-types==0.5.0 +antlr4-python3-runtime==4.9.3 +anyio==3.7.1 +async-timeout==4.0.3 +attrs==23.1.0 +cachetools==5.3.1 +certifi==2023.7.22 +charset-normalizer==3.2.0 +click==8.1.7 +cmake==3.27.2 +contourpy==1.1.0 +cycler==0.11.0 +datasets==2.14.4 +dill==0.3.7 +einops==0.6.1 +exceptiongroup==1.1.3 +fastapi==0.103.0 +ffmpy==0.3.1 +filelock==3.12.2 +fonttools==4.42.1 +frozenlist==1.4.0 +fsspec==2023.6.0 +google-auth==2.22.0 +google-auth-oauthlib==1.0.0 +gradio==3.41.2 +gradio-client==0.5.0 +grpcio==1.57.0 +h11==0.14.0 +httpcore==0.17.3 +httpx==0.24.1 +huggingface-hub==0.16.4 +idna==3.4 +importlib-metadata==6.8.0 +importlib-resources==6.0.1 +jinja2==3.1.2 +joblib==1.3.2 +jsonschema==4.19.0 +jsonschema-specifications==2023.7.1 +kiwisolver==1.4.5 +lightning-utilities==0.9.0 +lit==16.0.6 +markdown==3.4.4 +markupsafe==2.1.3 +matplotlib==3.7.2 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.15 +networkx==3.1 +numpy==1.24.4 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +oauthlib==3.2.2 +omegaconf==2.3.0 +opencv-python==4.8.0.76 +orjson==3.9.5 +pandas==2.0.3 +pillow==9.5.0 +pkgutil-resolve-name==1.3.10 +protobuf==4.24.2 +psutil==5.9.5 +pyarrow==13.0.0 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pydantic==2.3.0 +pydantic-core==2.6.3 +pydub==0.25.1 +pyparsing==3.0.9 +python-multipart==0.0.6 +pytorch-lightning==2.0.7 +pytz==2023.3 +pyyaml==6.0.1 +referencing==0.30.2 +regex==2023.8.8 +requests==2.31.0 +requests-oauthlib==1.3.1 +rpds-py==0.9.2 +rsa==4.9 +safetensors==0.3.3 +semantic-version==2.10.0 +sniffio==1.3.0 +starlette==0.27.0 +sympy==1.12 +tensorboard==2.14.0 +tensorboard-data-server==0.7.1 +tokenizers==0.13.3 +toolz==0.12.0 +torchmetrics==1.1.0 +tqdm==4.66.1 +transformers==4.32.0 +triton==2.0.0 +tzdata==2023.3 +urllib3==1.26.16 +uvicorn==0.23.2 +websockets==11.0.3 +werkzeug==2.3.7 +xxhash==3.3.0 +yarl==1.9.2 +zipp==3.16.2 +decord +imageio==2.9.0 +imageio-ffmpeg==0.4.3 +timm +scipy +scikit-image +av +imgaug +lpips +ffmpeg-python +torch==2.0.1 +torchvision==0.15.2 +xformers==0.0.22 +diffusers==0.21.4 diff --git a/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4 b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a9f0602367e1a8732fb0be34b8ab48ce22f0b769 Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4 differ diff --git a/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4 b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8c03c1eb2db19845476cc0ea99502a6ea5a582f5 Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4 differ diff --git a/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4 b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8f84ada3cbf1efa43b23138505387260ad642c83 Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4 differ diff --git a/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4 b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..462f7a7a5675834ed62e590b5102b3e78b563bee Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4 differ diff --git a/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4 b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4080a0155411670dc3839a7b8fd76517e3212e49 Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4 differ diff --git a/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4 b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..665b46f9551b4f4c8cb6253662ae7b5546d11f51 Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4 differ diff --git a/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4 b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..98e8cf53ecfefda176bb6ce86bab59fa484f141f Binary files /dev/null and b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4 differ diff --git a/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4 b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5029c554a8aaa1e92d2a41143ae6bd2b2288dbc0 Binary files /dev/null and b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4 differ