Spaces:
Running
on
Zero
Running
on
Zero
init demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- ImageConductor_app.py +586 -0
- app.py +577 -0
- configs/.DS_Store +0 -0
- configs/inference/flow_condition.yaml +18 -0
- configs/inference/image_condition.yaml +18 -0
- configs/inference/inference.yaml +22 -0
- models/.DS_Store +0 -0
- modules/__pycache__/attention.cpython-310.pyc +0 -0
- modules/__pycache__/flow_controlnet.cpython-310.pyc +0 -0
- modules/__pycache__/image_controlnet.cpython-310.pyc +0 -0
- modules/__pycache__/motion_module.cpython-310.pyc +0 -0
- modules/__pycache__/resnet.cpython-310.pyc +0 -0
- modules/__pycache__/unet.cpython-310.pyc +0 -0
- modules/__pycache__/unet_blocks.cpython-310.pyc +0 -0
- modules/attention.py +396 -0
- modules/flow_controlnet.py +591 -0
- modules/image_controlnet.py +721 -0
- modules/motion_module.py +355 -0
- modules/resnet.py +261 -0
- modules/unet.py +591 -0
- modules/unet_blocks.py +866 -0
- peft/__init__.py +98 -0
- peft/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/__pycache__/auto.cpython-310.pyc +0 -0
- peft/__pycache__/config.cpython-310.pyc +0 -0
- peft/__pycache__/import_utils.cpython-310.pyc +0 -0
- peft/__pycache__/mapping.cpython-310.pyc +0 -0
- peft/__pycache__/mixed_model.cpython-310.pyc +0 -0
- peft/__pycache__/peft_model.cpython-310.pyc +0 -0
- peft/auto.py +170 -0
- peft/config.py +270 -0
- peft/helpers.py +148 -0
- peft/import_utils.py +89 -0
- peft/mapping.py +181 -0
- peft/mixed_model.py +415 -0
- peft/peft_model.py +0 -0
- peft/py.typed +0 -0
- peft/tuners/__init__.py +35 -0
- peft/tuners/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc +0 -0
- peft/tuners/__pycache__/tuners_utils.cpython-310.pyc +0 -0
- peft/tuners/adalora/__init__.py +37 -0
- peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/config.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/layer.cpython-310.pyc +0 -0
- peft/tuners/adalora/__pycache__/model.cpython-310.pyc +0 -0
- peft/tuners/adalora/bnb.py +145 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
ImageConductor_app.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import uuid
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import json
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from torchvision import transforms
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
16 |
+
|
17 |
+
from pipelines.pipeline_imagecoductor import ImageConductorPipeline
|
18 |
+
from modules.unet import UNet3DConditionFlowModel
|
19 |
+
from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil, image2arr
|
20 |
+
from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian
|
21 |
+
from utils.lora_utils import add_LoRA_to_controlnet
|
22 |
+
from utils.visualizer import Visualizer, vis_flow_to_video
|
23 |
+
#### Description ####
|
24 |
+
title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
|
25 |
+
|
26 |
+
head = r"""
|
27 |
+
<div style="text-align: center;">
|
28 |
+
<h1>Image Conductor: Precision Control for Interactive Video Synthesis</h1>
|
29 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
30 |
+
<a href=""></a>
|
31 |
+
<a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
|
32 |
+
<a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
33 |
+
<a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
34 |
+
|
35 |
+
|
36 |
+
</div>
|
37 |
+
</br>
|
38 |
+
</div>
|
39 |
+
"""
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
descriptions = r"""
|
44 |
+
Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
|
45 |
+
🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
instructions = r"""
|
50 |
+
- ⭐️ <b>step1: </b>Upload or select one image from Example.
|
51 |
+
- ⭐️ <b>step2: </b>Click 'Add Drag' to draw some drags.
|
52 |
+
- ⭐️ <b>step3: </b>Input text prompt that complements the image (Necessary).
|
53 |
+
- ⭐️ <b>step4: </b>Select 'Drag Mode' to specify the control of camera transition or object movement.
|
54 |
+
- ⭐️ <b>step5: </b>Click 'Run' button to generate video assets.
|
55 |
+
- ⭐️ <b>others: </b>Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point.
|
56 |
+
"""
|
57 |
+
|
58 |
+
citation = r"""
|
59 |
+
If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
|
60 |
+
[![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
|
61 |
+
---
|
62 |
+
|
63 |
+
📝 **Citation**
|
64 |
+
<br>
|
65 |
+
If our work is useful for your research, please consider citing:
|
66 |
+
```bibtex
|
67 |
+
@misc{li2024imageconductor,
|
68 |
+
title={Image Conductor: Precision Control for Interactive Video Synthesis},
|
69 |
+
author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
|
70 |
+
year={2024},
|
71 |
+
eprint={2406.15339},
|
72 |
+
archivePrefix={arXiv},
|
73 |
+
primaryClass={cs.CV}
|
74 |
+
}
|
75 |
+
```
|
76 |
+
|
77 |
+
📧 **Contact**
|
78 |
+
<br>
|
79 |
+
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
|
80 |
+
|
81 |
+
# """
|
82 |
+
|
83 |
+
os.makedirs("models/personalized")
|
84 |
+
os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/flow_controlnet.ckpt -P models/')
|
85 |
+
os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/image_controlnet.ckpt -P models/')
|
86 |
+
os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/unet.ckpt -P models/')
|
87 |
+
os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/helloobjects_V12c.safetensors -P models/personalized')
|
88 |
+
os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/TUSUN.safetensors -P models/personalized')
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
# - - - - - examples - - - - - #
|
94 |
+
image_examples = [
|
95 |
+
["__asset__/images/object/turtle-1.jpg",
|
96 |
+
"a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
|
97 |
+
"object",
|
98 |
+
11318446767408804497,
|
99 |
+
"",
|
100 |
+
json.load(open("__asset__/trajs/object/turtle-1.json")),
|
101 |
+
"__asset__/images/object/turtle-1.jpg",
|
102 |
+
],
|
103 |
+
|
104 |
+
["__asset__/images/object/rose-1.jpg",
|
105 |
+
"a red rose engulfed in flames.",
|
106 |
+
"object",
|
107 |
+
6854275249656120509,
|
108 |
+
"",
|
109 |
+
json.load(open("__asset__/trajs/object/rose-1.json")),
|
110 |
+
"__asset__/images/object/rose-1.jpg",
|
111 |
+
],
|
112 |
+
|
113 |
+
["__asset__/images/object/jellyfish-1.jpg",
|
114 |
+
"intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
|
115 |
+
"object",
|
116 |
+
17966188172968903484,
|
117 |
+
"HelloObject",
|
118 |
+
json.load(open("__asset__/trajs/object/jellyfish-1.json")),
|
119 |
+
"__asset__/images/object/jellyfish-1.jpg",
|
120 |
+
],
|
121 |
+
|
122 |
+
|
123 |
+
["__asset__/images/camera/lush-1.jpg",
|
124 |
+
"detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
|
125 |
+
"camera",
|
126 |
+
7970487946960948963,
|
127 |
+
"HelloObject",
|
128 |
+
json.load(open("__asset__/trajs/camera/lush-1.json")),
|
129 |
+
"__asset__/images/camera/lush-1.jpg",
|
130 |
+
],
|
131 |
+
|
132 |
+
["__asset__/images/camera/tusun-1.jpg",
|
133 |
+
"tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
|
134 |
+
"camera",
|
135 |
+
996953226890228361,
|
136 |
+
"TUSUN",
|
137 |
+
json.load(open("__asset__/trajs/camera/tusun-1.json")),
|
138 |
+
"__asset__/images/camera/tusun-1.jpg",
|
139 |
+
],
|
140 |
+
|
141 |
+
["__asset__/images/camera/painting-1.jpg",
|
142 |
+
"A oil painting.",
|
143 |
+
"camera",
|
144 |
+
16867854766769816385,
|
145 |
+
"",
|
146 |
+
json.load(open("__asset__/trajs/camera/painting-1.json")),
|
147 |
+
"__asset__/images/camera/painting-1.jpg",
|
148 |
+
],
|
149 |
+
|
150 |
+
]
|
151 |
+
|
152 |
+
|
153 |
+
DREAM_BOOTH = {
|
154 |
+
'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
|
155 |
+
}
|
156 |
+
|
157 |
+
LORA = {
|
158 |
+
'TUSUN': 'models/personalized/TUSUN.safetensors',
|
159 |
+
}
|
160 |
+
|
161 |
+
LORA_ALPHA = {
|
162 |
+
'TUSUN': 0.6,
|
163 |
+
}
|
164 |
+
|
165 |
+
NPROMPT = {
|
166 |
+
"HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
|
167 |
+
}
|
168 |
+
|
169 |
+
output_dir = "outputs"
|
170 |
+
ensure_dirname(output_dir)
|
171 |
+
|
172 |
+
def points_to_flows(track_points, model_length, height, width):
|
173 |
+
input_drag = np.zeros((model_length - 1, height, width, 2))
|
174 |
+
for splited_track in track_points:
|
175 |
+
if len(splited_track) == 1: # stationary point
|
176 |
+
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
177 |
+
splited_track = tuple([splited_track[0], displacement_point])
|
178 |
+
# interpolate the track
|
179 |
+
splited_track = interpolate_trajectory(splited_track, model_length)
|
180 |
+
splited_track = splited_track[:model_length]
|
181 |
+
if len(splited_track) < model_length:
|
182 |
+
splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
|
183 |
+
for i in range(model_length - 1):
|
184 |
+
start_point = splited_track[i]
|
185 |
+
end_point = splited_track[i+1]
|
186 |
+
input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
|
187 |
+
input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
|
188 |
+
return input_drag
|
189 |
+
|
190 |
+
class ImageConductor:
|
191 |
+
def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
|
192 |
+
self.device = device
|
193 |
+
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
|
194 |
+
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").cuda()
|
195 |
+
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").cuda()
|
196 |
+
inference_config = OmegaConf.load("configs/inference/inference.yaml")
|
197 |
+
unet = UNet3DConditionFlowModel.from_pretrained_2d("runwayml/stable-diffusion-v1-5", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
198 |
+
|
199 |
+
self.vae = vae
|
200 |
+
|
201 |
+
### >>> Initialize UNet module >>> ###
|
202 |
+
load_model(unet, unet_path)
|
203 |
+
|
204 |
+
### >>> Initialize image controlnet module >>> ###
|
205 |
+
image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet)
|
206 |
+
load_model(image_controlnet, image_controlnet_path)
|
207 |
+
### >>> Initialize flow controlnet module >>> ###
|
208 |
+
flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet)
|
209 |
+
add_LoRA_to_controlnet(lora_rank, flow_controlnet)
|
210 |
+
load_model(flow_controlnet, flow_controlnet_path)
|
211 |
+
|
212 |
+
unet.eval().to(device)
|
213 |
+
image_controlnet.eval().to(device)
|
214 |
+
flow_controlnet.eval().to(device)
|
215 |
+
|
216 |
+
self.pipeline = ImageConductorPipeline(
|
217 |
+
unet=unet,
|
218 |
+
vae=vae,
|
219 |
+
tokenizer=tokenizer,
|
220 |
+
text_encoder=text_encoder,
|
221 |
+
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
222 |
+
image_controlnet=image_controlnet,
|
223 |
+
flow_controlnet=flow_controlnet,
|
224 |
+
).to(device)
|
225 |
+
|
226 |
+
|
227 |
+
self.height = height
|
228 |
+
self.width = width
|
229 |
+
# _, model_step, _ = split_filename(model_path)
|
230 |
+
# self.ouput_prefix = f'{model_step}_{width}X{height}'
|
231 |
+
self.model_length = model_length
|
232 |
+
|
233 |
+
blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True)
|
234 |
+
|
235 |
+
self.blur_kernel = blur_kernel
|
236 |
+
|
237 |
+
@torch.no_grad()
|
238 |
+
def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized):
|
239 |
+
|
240 |
+
|
241 |
+
original_width, original_height=384, 256
|
242 |
+
if isinstance(tracking_points, list):
|
243 |
+
input_all_points = tracking_points
|
244 |
+
else:
|
245 |
+
input_all_points = tracking_points.constructor_args['value']
|
246 |
+
|
247 |
+
|
248 |
+
resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
|
249 |
+
|
250 |
+
dir, base, ext = split_filename(first_frame_path)
|
251 |
+
id = base.split('_')[-1]
|
252 |
+
|
253 |
+
|
254 |
+
with open(f'{output_dir}/points-{id}.json', 'w') as f:
|
255 |
+
json.dump(input_all_points, f)
|
256 |
+
|
257 |
+
|
258 |
+
visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
|
259 |
+
|
260 |
+
## image condition
|
261 |
+
image_transforms = transforms.Compose([
|
262 |
+
transforms.RandomResizedCrop(
|
263 |
+
(self.height, self.width), (1.0, 1.0),
|
264 |
+
ratio=(self.width/self.height, self.width/self.height)
|
265 |
+
),
|
266 |
+
transforms.ToTensor(),
|
267 |
+
])
|
268 |
+
|
269 |
+
image_norm = lambda x: x
|
270 |
+
image_paths = [first_frame_path]
|
271 |
+
controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
|
272 |
+
controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
|
273 |
+
controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
|
274 |
+
num_controlnet_images = controlnet_images.shape[2]
|
275 |
+
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
|
276 |
+
controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
|
277 |
+
controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
|
278 |
+
|
279 |
+
# flow condition
|
280 |
+
controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
|
281 |
+
for i in range(0, self.model_length-1):
|
282 |
+
controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
|
283 |
+
controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
|
284 |
+
os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
|
285 |
+
trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
|
286 |
+
torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
|
287 |
+
controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...]
|
288 |
+
controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w")
|
289 |
+
|
290 |
+
dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
|
291 |
+
lora_model_path = LORA.get(personalized, '')
|
292 |
+
lora_alpha = LORA_ALPHA.get(personalized, 0.6)
|
293 |
+
self.pipeline = load_weights(
|
294 |
+
self.pipeline,
|
295 |
+
dreambooth_model_path = dreambooth_model_path,
|
296 |
+
lora_model_path = lora_model_path,
|
297 |
+
lora_alpha = lora_alpha,
|
298 |
+
).to(device)
|
299 |
+
|
300 |
+
if NPROMPT.get(personalized, '') != '':
|
301 |
+
negative_prompt = NPROMPT.get(personalized)
|
302 |
+
|
303 |
+
if randomize_seed:
|
304 |
+
random_seed = torch.seed()
|
305 |
+
else:
|
306 |
+
seed = int(seed)
|
307 |
+
random_seed = seed
|
308 |
+
torch.manual_seed(random_seed)
|
309 |
+
torch.cuda.manual_seed_all(random_seed)
|
310 |
+
print(f"current seed: {torch.initial_seed()}")
|
311 |
+
sample = self.pipeline(
|
312 |
+
prompt,
|
313 |
+
negative_prompt = negative_prompt,
|
314 |
+
num_inference_steps = num_inference_steps,
|
315 |
+
guidance_scale = guidance_scale,
|
316 |
+
width = self.width,
|
317 |
+
height = self.height,
|
318 |
+
video_length = self.model_length,
|
319 |
+
controlnet_images = controlnet_images, # 1 4 1 32 48
|
320 |
+
controlnet_image_index = [0],
|
321 |
+
controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
|
322 |
+
control_mode = drag_mode,
|
323 |
+
eval_mode = True,
|
324 |
+
).videos
|
325 |
+
|
326 |
+
outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
|
327 |
+
vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
|
328 |
+
torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
|
329 |
+
|
330 |
+
return visualized_drag, outputs_path
|
331 |
+
|
332 |
+
|
333 |
+
def reset_states(first_frame_path, tracking_points):
|
334 |
+
first_frame_path = gr.State()
|
335 |
+
tracking_points = gr.State([])
|
336 |
+
return None, first_frame_path, tracking_points
|
337 |
+
|
338 |
+
|
339 |
+
def preprocess_image(image):
|
340 |
+
image_pil = image2pil(image.name)
|
341 |
+
raw_w, raw_h = image_pil.size
|
342 |
+
resize_ratio = max(384/raw_w, 256/raw_h)
|
343 |
+
image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
|
344 |
+
image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
|
345 |
+
id = str(uuid.uuid4())[:4]
|
346 |
+
first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
|
347 |
+
image_pil.save(first_frame_path, quality=95)
|
348 |
+
return first_frame_path, first_frame_path, gr.State([])
|
349 |
+
|
350 |
+
|
351 |
+
def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
|
352 |
+
if drag_mode=='object':
|
353 |
+
color = (255, 0, 0, 255)
|
354 |
+
elif drag_mode=='camera':
|
355 |
+
color = (0, 0, 255, 255)
|
356 |
+
|
357 |
+
|
358 |
+
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
359 |
+
tracking_points.constructor_args['value'][-1].append(evt.index)
|
360 |
+
print(tracking_points.constructor_args)
|
361 |
+
|
362 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
363 |
+
w, h = transparent_background.size
|
364 |
+
transparent_layer = np.zeros((h, w, 4))
|
365 |
+
for track in tracking_points.constructor_args['value']:
|
366 |
+
if len(track) > 1:
|
367 |
+
for i in range(len(track)-1):
|
368 |
+
start_point = track[i]
|
369 |
+
end_point = track[i+1]
|
370 |
+
vx = end_point[0] - start_point[0]
|
371 |
+
vy = end_point[1] - start_point[1]
|
372 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
373 |
+
if i == len(track)-2:
|
374 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
|
375 |
+
else:
|
376 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
|
377 |
+
else:
|
378 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
|
379 |
+
|
380 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
381 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
382 |
+
return tracking_points, trajectory_map
|
383 |
+
|
384 |
+
|
385 |
+
def add_drag(tracking_points):
|
386 |
+
tracking_points.constructor_args['value'].append([])
|
387 |
+
print(tracking_points.constructor_args)
|
388 |
+
return tracking_points
|
389 |
+
|
390 |
+
|
391 |
+
def delete_last_drag(tracking_points, first_frame_path, drag_mode):
|
392 |
+
if drag_mode=='object':
|
393 |
+
color = (255, 0, 0, 255)
|
394 |
+
elif drag_mode=='camera':
|
395 |
+
color = (0, 0, 255, 255)
|
396 |
+
tracking_points.constructor_args['value'].pop()
|
397 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
398 |
+
w, h = transparent_background.size
|
399 |
+
transparent_layer = np.zeros((h, w, 4))
|
400 |
+
for track in tracking_points.constructor_args['value']:
|
401 |
+
if len(track) > 1:
|
402 |
+
for i in range(len(track)-1):
|
403 |
+
start_point = track[i]
|
404 |
+
end_point = track[i+1]
|
405 |
+
vx = end_point[0] - start_point[0]
|
406 |
+
vy = end_point[1] - start_point[1]
|
407 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
408 |
+
if i == len(track)-2:
|
409 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
|
410 |
+
else:
|
411 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
|
412 |
+
else:
|
413 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
|
414 |
+
|
415 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
416 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
417 |
+
return tracking_points, trajectory_map
|
418 |
+
|
419 |
+
|
420 |
+
def delete_last_step(tracking_points, first_frame_path, drag_mode):
|
421 |
+
if drag_mode=='object':
|
422 |
+
color = (255, 0, 0, 255)
|
423 |
+
elif drag_mode=='camera':
|
424 |
+
color = (0, 0, 255, 255)
|
425 |
+
tracking_points.constructor_args['value'][-1].pop()
|
426 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
427 |
+
w, h = transparent_background.size
|
428 |
+
transparent_layer = np.zeros((h, w, 4))
|
429 |
+
for track in tracking_points.constructor_args['value']:
|
430 |
+
if len(track) > 1:
|
431 |
+
for i in range(len(track)-1):
|
432 |
+
start_point = track[i]
|
433 |
+
end_point = track[i+1]
|
434 |
+
vx = end_point[0] - start_point[0]
|
435 |
+
vy = end_point[1] - start_point[1]
|
436 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
437 |
+
if i == len(track)-2:
|
438 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
|
439 |
+
else:
|
440 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
|
441 |
+
else:
|
442 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
|
443 |
+
|
444 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
445 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
446 |
+
return tracking_points, trajectory_map
|
447 |
+
|
448 |
+
|
449 |
+
block = gr.Blocks(
|
450 |
+
theme=gr.themes.Soft(
|
451 |
+
radius_size=gr.themes.sizes.radius_none,
|
452 |
+
text_size=gr.themes.sizes.text_md
|
453 |
+
)
|
454 |
+
).queue()
|
455 |
+
with block as demo:
|
456 |
+
with gr.Row():
|
457 |
+
with gr.Column():
|
458 |
+
gr.HTML(head)
|
459 |
+
|
460 |
+
gr.Markdown(descriptions)
|
461 |
+
|
462 |
+
with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
|
463 |
+
with gr.Row(equal_height=True):
|
464 |
+
gr.Markdown(instructions)
|
465 |
+
|
466 |
+
|
467 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
468 |
+
unet_path = 'models/unet.ckpt'
|
469 |
+
image_controlnet_path = 'models/image_controlnet.ckpt'
|
470 |
+
flow_controlnet_path = 'models/flow_controlnet.ckpt'
|
471 |
+
ImageConductor_net = ImageConductor(device=device,
|
472 |
+
unet_path=unet_path,
|
473 |
+
image_controlnet_path=image_controlnet_path,
|
474 |
+
flow_controlnet_path=flow_controlnet_path,
|
475 |
+
height=256,
|
476 |
+
width=384,
|
477 |
+
model_length=16
|
478 |
+
)
|
479 |
+
first_frame_path = gr.State()
|
480 |
+
tracking_points = gr.State([])
|
481 |
+
|
482 |
+
|
483 |
+
with gr.Row():
|
484 |
+
with gr.Column(scale=1):
|
485 |
+
image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
|
486 |
+
add_drag_button = gr.Button(value="Add Drag")
|
487 |
+
reset_button = gr.Button(value="Reset")
|
488 |
+
delete_last_drag_button = gr.Button(value="Delete last drag")
|
489 |
+
delete_last_step_button = gr.Button(value="Delete last step")
|
490 |
+
|
491 |
+
|
492 |
+
|
493 |
+
with gr.Column(scale=7):
|
494 |
+
with gr.Row():
|
495 |
+
with gr.Column(scale=6):
|
496 |
+
input_image = gr.Image(label=None,
|
497 |
+
interactive=True,
|
498 |
+
height=256,
|
499 |
+
width=384,)
|
500 |
+
with gr.Column(scale=6):
|
501 |
+
output_image = gr.Image(label="Motion Path",
|
502 |
+
interactive=False,
|
503 |
+
height=256,
|
504 |
+
width=384,)
|
505 |
+
with gr.Row():
|
506 |
+
with gr.Column(scale=1):
|
507 |
+
prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
|
508 |
+
negative_prompt = gr.Text(
|
509 |
+
label="Negative Prompt",
|
510 |
+
max_lines=5,
|
511 |
+
placeholder="Please input your negative prompt",
|
512 |
+
value='worst quality, low quality, letterboxed',lines=1
|
513 |
+
)
|
514 |
+
drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
|
515 |
+
run_button = gr.Button(value="Run")
|
516 |
+
|
517 |
+
with gr.Accordion("More input params", open=False, elem_id="accordion1"):
|
518 |
+
with gr.Group():
|
519 |
+
seed = gr.Textbox(
|
520 |
+
label="Seed: ", value=561793204,
|
521 |
+
)
|
522 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
523 |
+
|
524 |
+
with gr.Group():
|
525 |
+
with gr.Row():
|
526 |
+
guidance_scale = gr.Slider(
|
527 |
+
label="Guidance scale",
|
528 |
+
minimum=1,
|
529 |
+
maximum=12,
|
530 |
+
step=0.1,
|
531 |
+
value=8.5,
|
532 |
+
)
|
533 |
+
num_inference_steps = gr.Slider(
|
534 |
+
label="Number of inference steps",
|
535 |
+
minimum=1,
|
536 |
+
maximum=50,
|
537 |
+
step=1,
|
538 |
+
value=25,
|
539 |
+
)
|
540 |
+
|
541 |
+
with gr.Group():
|
542 |
+
personalized = gr.Dropdown(label="Personalized template", choices=['HelloObject', 'TUSUN'], value="")
|
543 |
+
|
544 |
+
with gr.Column(scale=7):
|
545 |
+
output_video = gr.Video(value=None,
|
546 |
+
label="Output Video",
|
547 |
+
width=384,
|
548 |
+
height=256)
|
549 |
+
|
550 |
+
|
551 |
+
with gr.Row():
|
552 |
+
def process_example(input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path):
|
553 |
+
|
554 |
+
return input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path
|
555 |
+
|
556 |
+
example = gr.Examples(
|
557 |
+
label="Input Example",
|
558 |
+
examples=image_examples,
|
559 |
+
inputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
|
560 |
+
outputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
|
561 |
+
fn=process_example,
|
562 |
+
run_on_click=True,
|
563 |
+
examples_per_page=10
|
564 |
+
)
|
565 |
+
|
566 |
+
with gr.Row():
|
567 |
+
gr.Markdown(citation)
|
568 |
+
|
569 |
+
|
570 |
+
image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
|
571 |
+
|
572 |
+
add_drag_button.click(add_drag, tracking_points, tracking_points)
|
573 |
+
|
574 |
+
delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
|
575 |
+
|
576 |
+
delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
|
577 |
+
|
578 |
+
reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points])
|
579 |
+
|
580 |
+
input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
|
581 |
+
|
582 |
+
run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode,
|
583 |
+
negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized],
|
584 |
+
[output_image, output_video])
|
585 |
+
|
586 |
+
demo.launch(server_name="0.0.0.0", debug=True, server_port=12345)
|
app.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import uuid
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import json
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from torchvision import transforms
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
16 |
+
|
17 |
+
from pipelines.pipeline_imagecoductor import ImageConductorPipeline
|
18 |
+
from modules.unet import UNet3DConditionFlowModel
|
19 |
+
from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil, image2arr
|
20 |
+
from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian
|
21 |
+
from utils.lora_utils import add_LoRA_to_controlnet
|
22 |
+
from utils.visualizer import Visualizer, vis_flow_to_video
|
23 |
+
#### Description ####
|
24 |
+
title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
|
25 |
+
|
26 |
+
head = r"""
|
27 |
+
<div style="text-align: center;">
|
28 |
+
<h1>Image Conductor: Precision Control for Interactive Video Synthesis</h1>
|
29 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
30 |
+
<a href=""></a>
|
31 |
+
<a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
|
32 |
+
<a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
33 |
+
<a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
34 |
+
|
35 |
+
|
36 |
+
</div>
|
37 |
+
</br>
|
38 |
+
</div>
|
39 |
+
"""
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
descriptions = r"""
|
44 |
+
Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
|
45 |
+
🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
instructions = r"""
|
50 |
+
- ⭐️ <b>step1: </b>Upload or select one image from Example.
|
51 |
+
- ⭐️ <b>step2: </b>Click 'Add Drag' to draw some drags.
|
52 |
+
- ⭐️ <b>step3: </b>Input text prompt that complements the image (Necessary).
|
53 |
+
- ⭐️ <b>step4: </b>Select 'Drag Mode' to specify the control of camera transition or object movement.
|
54 |
+
- ⭐️ <b>step5: </b>Click 'Run' button to generate video assets.
|
55 |
+
- ⭐️ <b>others: </b>Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point.
|
56 |
+
"""
|
57 |
+
|
58 |
+
citation = r"""
|
59 |
+
If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
|
60 |
+
[![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
|
61 |
+
---
|
62 |
+
|
63 |
+
📝 **Citation**
|
64 |
+
<br>
|
65 |
+
If our work is useful for your research, please consider citing:
|
66 |
+
```bibtex
|
67 |
+
@misc{li2024imageconductor,
|
68 |
+
title={Image Conductor: Precision Control for Interactive Video Synthesis},
|
69 |
+
author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
|
70 |
+
year={2024},
|
71 |
+
eprint={2406.15339},
|
72 |
+
archivePrefix={arXiv},
|
73 |
+
primaryClass={cs.CV}
|
74 |
+
}
|
75 |
+
```
|
76 |
+
|
77 |
+
📧 **Contact**
|
78 |
+
<br>
|
79 |
+
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
|
80 |
+
|
81 |
+
# """
|
82 |
+
|
83 |
+
|
84 |
+
# - - - - - examples - - - - - #
|
85 |
+
image_examples = [
|
86 |
+
["__asset__/images/object/turtle-1.jpg",
|
87 |
+
"a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
|
88 |
+
"object",
|
89 |
+
11318446767408804497,
|
90 |
+
"",
|
91 |
+
json.load(open("__asset__/trajs/object/turtle-1.json")),
|
92 |
+
"__asset__/images/object/turtle-1.jpg",
|
93 |
+
],
|
94 |
+
|
95 |
+
["__asset__/images/object/rose-1.jpg",
|
96 |
+
"a red rose engulfed in flames.",
|
97 |
+
"object",
|
98 |
+
6854275249656120509,
|
99 |
+
"",
|
100 |
+
json.load(open("__asset__/trajs/object/rose-1.json")),
|
101 |
+
"__asset__/images/object/rose-1.jpg",
|
102 |
+
],
|
103 |
+
|
104 |
+
["__asset__/images/object/jellyfish-1.jpg",
|
105 |
+
"intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
|
106 |
+
"object",
|
107 |
+
17966188172968903484,
|
108 |
+
"HelloObject",
|
109 |
+
json.load(open("__asset__/trajs/object/jellyfish-1.json")),
|
110 |
+
"__asset__/images/object/jellyfish-1.jpg",
|
111 |
+
],
|
112 |
+
|
113 |
+
|
114 |
+
["__asset__/images/camera/lush-1.jpg",
|
115 |
+
"detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
|
116 |
+
"camera",
|
117 |
+
7970487946960948963,
|
118 |
+
"HelloObject",
|
119 |
+
json.load(open("__asset__/trajs/camera/lush-1.json")),
|
120 |
+
"__asset__/images/camera/lush-1.jpg",
|
121 |
+
],
|
122 |
+
|
123 |
+
["__asset__/images/camera/tusun-1.jpg",
|
124 |
+
"tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
|
125 |
+
"camera",
|
126 |
+
996953226890228361,
|
127 |
+
"TUSUN",
|
128 |
+
json.load(open("__asset__/trajs/camera/tusun-1.json")),
|
129 |
+
"__asset__/images/camera/tusun-1.jpg",
|
130 |
+
],
|
131 |
+
|
132 |
+
["__asset__/images/camera/painting-1.jpg",
|
133 |
+
"A oil painting.",
|
134 |
+
"camera",
|
135 |
+
16867854766769816385,
|
136 |
+
"",
|
137 |
+
json.load(open("__asset__/trajs/camera/painting-1.json")),
|
138 |
+
"__asset__/images/camera/painting-1.jpg",
|
139 |
+
],
|
140 |
+
|
141 |
+
]
|
142 |
+
|
143 |
+
|
144 |
+
DREAM_BOOTH = {
|
145 |
+
'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
|
146 |
+
}
|
147 |
+
|
148 |
+
LORA = {
|
149 |
+
'TUSUN': 'models/personalized/TUSUN.safetensors',
|
150 |
+
}
|
151 |
+
|
152 |
+
LORA_ALPHA = {
|
153 |
+
'TUSUN': 0.6,
|
154 |
+
}
|
155 |
+
|
156 |
+
NPROMPT = {
|
157 |
+
"HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
|
158 |
+
}
|
159 |
+
|
160 |
+
output_dir = "outputs"
|
161 |
+
ensure_dirname(output_dir)
|
162 |
+
|
163 |
+
def points_to_flows(track_points, model_length, height, width):
|
164 |
+
input_drag = np.zeros((model_length - 1, height, width, 2))
|
165 |
+
for splited_track in track_points:
|
166 |
+
if len(splited_track) == 1: # stationary point
|
167 |
+
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
168 |
+
splited_track = tuple([splited_track[0], displacement_point])
|
169 |
+
# interpolate the track
|
170 |
+
splited_track = interpolate_trajectory(splited_track, model_length)
|
171 |
+
splited_track = splited_track[:model_length]
|
172 |
+
if len(splited_track) < model_length:
|
173 |
+
splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
|
174 |
+
for i in range(model_length - 1):
|
175 |
+
start_point = splited_track[i]
|
176 |
+
end_point = splited_track[i+1]
|
177 |
+
input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
|
178 |
+
input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
|
179 |
+
return input_drag
|
180 |
+
|
181 |
+
class ImageConductor:
|
182 |
+
def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
|
183 |
+
self.device = device
|
184 |
+
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
|
185 |
+
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").cuda()
|
186 |
+
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").cuda()
|
187 |
+
inference_config = OmegaConf.load("configs/inference/inference.yaml")
|
188 |
+
unet = UNet3DConditionFlowModel.from_pretrained_2d("runwayml/stable-diffusion-v1-5", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
189 |
+
|
190 |
+
self.vae = vae
|
191 |
+
|
192 |
+
### >>> Initialize UNet module >>> ###
|
193 |
+
load_model(unet, unet_path)
|
194 |
+
|
195 |
+
### >>> Initialize image controlnet module >>> ###
|
196 |
+
image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet)
|
197 |
+
load_model(image_controlnet, image_controlnet_path)
|
198 |
+
### >>> Initialize flow controlnet module >>> ###
|
199 |
+
flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet)
|
200 |
+
add_LoRA_to_controlnet(lora_rank, flow_controlnet)
|
201 |
+
load_model(flow_controlnet, flow_controlnet_path)
|
202 |
+
|
203 |
+
unet.eval().to(device)
|
204 |
+
image_controlnet.eval().to(device)
|
205 |
+
flow_controlnet.eval().to(device)
|
206 |
+
|
207 |
+
self.pipeline = ImageConductorPipeline(
|
208 |
+
unet=unet,
|
209 |
+
vae=vae,
|
210 |
+
tokenizer=tokenizer,
|
211 |
+
text_encoder=text_encoder,
|
212 |
+
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
213 |
+
image_controlnet=image_controlnet,
|
214 |
+
flow_controlnet=flow_controlnet,
|
215 |
+
).to(device)
|
216 |
+
|
217 |
+
|
218 |
+
self.height = height
|
219 |
+
self.width = width
|
220 |
+
# _, model_step, _ = split_filename(model_path)
|
221 |
+
# self.ouput_prefix = f'{model_step}_{width}X{height}'
|
222 |
+
self.model_length = model_length
|
223 |
+
|
224 |
+
blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True)
|
225 |
+
|
226 |
+
self.blur_kernel = blur_kernel
|
227 |
+
|
228 |
+
@torch.no_grad()
|
229 |
+
def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized):
|
230 |
+
|
231 |
+
|
232 |
+
original_width, original_height=384, 256
|
233 |
+
if isinstance(tracking_points, list):
|
234 |
+
input_all_points = tracking_points
|
235 |
+
else:
|
236 |
+
input_all_points = tracking_points.constructor_args['value']
|
237 |
+
|
238 |
+
|
239 |
+
resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
|
240 |
+
|
241 |
+
dir, base, ext = split_filename(first_frame_path)
|
242 |
+
id = base.split('_')[-1]
|
243 |
+
|
244 |
+
|
245 |
+
with open(f'{output_dir}/points-{id}.json', 'w') as f:
|
246 |
+
json.dump(input_all_points, f)
|
247 |
+
|
248 |
+
|
249 |
+
visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
|
250 |
+
|
251 |
+
## image condition
|
252 |
+
image_transforms = transforms.Compose([
|
253 |
+
transforms.RandomResizedCrop(
|
254 |
+
(self.height, self.width), (1.0, 1.0),
|
255 |
+
ratio=(self.width/self.height, self.width/self.height)
|
256 |
+
),
|
257 |
+
transforms.ToTensor(),
|
258 |
+
])
|
259 |
+
|
260 |
+
image_norm = lambda x: x
|
261 |
+
image_paths = [first_frame_path]
|
262 |
+
controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
|
263 |
+
controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
|
264 |
+
controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
|
265 |
+
num_controlnet_images = controlnet_images.shape[2]
|
266 |
+
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
|
267 |
+
controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
|
268 |
+
controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
|
269 |
+
|
270 |
+
# flow condition
|
271 |
+
controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
|
272 |
+
for i in range(0, self.model_length-1):
|
273 |
+
controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
|
274 |
+
controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
|
275 |
+
os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
|
276 |
+
trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
|
277 |
+
torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
|
278 |
+
controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...]
|
279 |
+
controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w")
|
280 |
+
|
281 |
+
dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
|
282 |
+
lora_model_path = LORA.get(personalized, '')
|
283 |
+
lora_alpha = LORA_ALPHA.get(personalized, 0.6)
|
284 |
+
self.pipeline = load_weights(
|
285 |
+
self.pipeline,
|
286 |
+
dreambooth_model_path = dreambooth_model_path,
|
287 |
+
lora_model_path = lora_model_path,
|
288 |
+
lora_alpha = lora_alpha,
|
289 |
+
).to(device)
|
290 |
+
|
291 |
+
if NPROMPT.get(personalized, '') != '':
|
292 |
+
negative_prompt = NPROMPT.get(personalized)
|
293 |
+
|
294 |
+
if randomize_seed:
|
295 |
+
random_seed = torch.seed()
|
296 |
+
else:
|
297 |
+
seed = int(seed)
|
298 |
+
random_seed = seed
|
299 |
+
torch.manual_seed(random_seed)
|
300 |
+
torch.cuda.manual_seed_all(random_seed)
|
301 |
+
print(f"current seed: {torch.initial_seed()}")
|
302 |
+
sample = self.pipeline(
|
303 |
+
prompt,
|
304 |
+
negative_prompt = negative_prompt,
|
305 |
+
num_inference_steps = num_inference_steps,
|
306 |
+
guidance_scale = guidance_scale,
|
307 |
+
width = self.width,
|
308 |
+
height = self.height,
|
309 |
+
video_length = self.model_length,
|
310 |
+
controlnet_images = controlnet_images, # 1 4 1 32 48
|
311 |
+
controlnet_image_index = [0],
|
312 |
+
controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
|
313 |
+
control_mode = drag_mode,
|
314 |
+
eval_mode = True,
|
315 |
+
).videos
|
316 |
+
|
317 |
+
outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
|
318 |
+
vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
|
319 |
+
torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
|
320 |
+
|
321 |
+
return visualized_drag, outputs_path
|
322 |
+
|
323 |
+
|
324 |
+
def reset_states(first_frame_path, tracking_points):
|
325 |
+
first_frame_path = gr.State()
|
326 |
+
tracking_points = gr.State([])
|
327 |
+
return None, first_frame_path, tracking_points
|
328 |
+
|
329 |
+
|
330 |
+
def preprocess_image(image):
|
331 |
+
image_pil = image2pil(image.name)
|
332 |
+
raw_w, raw_h = image_pil.size
|
333 |
+
resize_ratio = max(384/raw_w, 256/raw_h)
|
334 |
+
image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
|
335 |
+
image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
|
336 |
+
id = str(uuid.uuid4())[:4]
|
337 |
+
first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
|
338 |
+
image_pil.save(first_frame_path, quality=95)
|
339 |
+
return first_frame_path, first_frame_path, gr.State([])
|
340 |
+
|
341 |
+
|
342 |
+
def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
|
343 |
+
if drag_mode=='object':
|
344 |
+
color = (255, 0, 0, 255)
|
345 |
+
elif drag_mode=='camera':
|
346 |
+
color = (0, 0, 255, 255)
|
347 |
+
|
348 |
+
|
349 |
+
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
350 |
+
tracking_points.constructor_args['value'][-1].append(evt.index)
|
351 |
+
print(tracking_points.constructor_args)
|
352 |
+
|
353 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
354 |
+
w, h = transparent_background.size
|
355 |
+
transparent_layer = np.zeros((h, w, 4))
|
356 |
+
for track in tracking_points.constructor_args['value']:
|
357 |
+
if len(track) > 1:
|
358 |
+
for i in range(len(track)-1):
|
359 |
+
start_point = track[i]
|
360 |
+
end_point = track[i+1]
|
361 |
+
vx = end_point[0] - start_point[0]
|
362 |
+
vy = end_point[1] - start_point[1]
|
363 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
364 |
+
if i == len(track)-2:
|
365 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
|
366 |
+
else:
|
367 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
|
368 |
+
else:
|
369 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
|
370 |
+
|
371 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
372 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
373 |
+
return tracking_points, trajectory_map
|
374 |
+
|
375 |
+
|
376 |
+
def add_drag(tracking_points):
|
377 |
+
tracking_points.constructor_args['value'].append([])
|
378 |
+
print(tracking_points.constructor_args)
|
379 |
+
return tracking_points
|
380 |
+
|
381 |
+
|
382 |
+
def delete_last_drag(tracking_points, first_frame_path, drag_mode):
|
383 |
+
if drag_mode=='object':
|
384 |
+
color = (255, 0, 0, 255)
|
385 |
+
elif drag_mode=='camera':
|
386 |
+
color = (0, 0, 255, 255)
|
387 |
+
tracking_points.constructor_args['value'].pop()
|
388 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
389 |
+
w, h = transparent_background.size
|
390 |
+
transparent_layer = np.zeros((h, w, 4))
|
391 |
+
for track in tracking_points.constructor_args['value']:
|
392 |
+
if len(track) > 1:
|
393 |
+
for i in range(len(track)-1):
|
394 |
+
start_point = track[i]
|
395 |
+
end_point = track[i+1]
|
396 |
+
vx = end_point[0] - start_point[0]
|
397 |
+
vy = end_point[1] - start_point[1]
|
398 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
399 |
+
if i == len(track)-2:
|
400 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
|
401 |
+
else:
|
402 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
|
403 |
+
else:
|
404 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
|
405 |
+
|
406 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
407 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
408 |
+
return tracking_points, trajectory_map
|
409 |
+
|
410 |
+
|
411 |
+
def delete_last_step(tracking_points, first_frame_path, drag_mode):
|
412 |
+
if drag_mode=='object':
|
413 |
+
color = (255, 0, 0, 255)
|
414 |
+
elif drag_mode=='camera':
|
415 |
+
color = (0, 0, 255, 255)
|
416 |
+
tracking_points.constructor_args['value'][-1].pop()
|
417 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
418 |
+
w, h = transparent_background.size
|
419 |
+
transparent_layer = np.zeros((h, w, 4))
|
420 |
+
for track in tracking_points.constructor_args['value']:
|
421 |
+
if len(track) > 1:
|
422 |
+
for i in range(len(track)-1):
|
423 |
+
start_point = track[i]
|
424 |
+
end_point = track[i+1]
|
425 |
+
vx = end_point[0] - start_point[0]
|
426 |
+
vy = end_point[1] - start_point[1]
|
427 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
428 |
+
if i == len(track)-2:
|
429 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
|
430 |
+
else:
|
431 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
|
432 |
+
else:
|
433 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
|
434 |
+
|
435 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
436 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
437 |
+
return tracking_points, trajectory_map
|
438 |
+
|
439 |
+
|
440 |
+
block = gr.Blocks(
|
441 |
+
theme=gr.themes.Soft(
|
442 |
+
radius_size=gr.themes.sizes.radius_none,
|
443 |
+
text_size=gr.themes.sizes.text_md
|
444 |
+
)
|
445 |
+
).queue()
|
446 |
+
with block as demo:
|
447 |
+
with gr.Row():
|
448 |
+
with gr.Column():
|
449 |
+
gr.HTML(head)
|
450 |
+
|
451 |
+
gr.Markdown(descriptions)
|
452 |
+
|
453 |
+
with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
|
454 |
+
with gr.Row(equal_height=True):
|
455 |
+
gr.Markdown(instructions)
|
456 |
+
|
457 |
+
|
458 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
459 |
+
unet_path = 'models/unet.ckpt'
|
460 |
+
image_controlnet_path = 'models/image_controlnet.ckpt'
|
461 |
+
flow_controlnet_path = 'models/flow_controlnet.ckpt'
|
462 |
+
ImageConductor_net = ImageConductor(device=device,
|
463 |
+
unet_path=unet_path,
|
464 |
+
image_controlnet_path=image_controlnet_path,
|
465 |
+
flow_controlnet_path=flow_controlnet_path,
|
466 |
+
height=256,
|
467 |
+
width=384,
|
468 |
+
model_length=16
|
469 |
+
)
|
470 |
+
first_frame_path = gr.State()
|
471 |
+
tracking_points = gr.State([])
|
472 |
+
|
473 |
+
|
474 |
+
with gr.Row():
|
475 |
+
with gr.Column(scale=1):
|
476 |
+
image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
|
477 |
+
add_drag_button = gr.Button(value="Add Drag")
|
478 |
+
reset_button = gr.Button(value="Reset")
|
479 |
+
delete_last_drag_button = gr.Button(value="Delete last drag")
|
480 |
+
delete_last_step_button = gr.Button(value="Delete last step")
|
481 |
+
|
482 |
+
|
483 |
+
|
484 |
+
with gr.Column(scale=7):
|
485 |
+
with gr.Row():
|
486 |
+
with gr.Column(scale=6):
|
487 |
+
input_image = gr.Image(label=None,
|
488 |
+
interactive=True,
|
489 |
+
height=256,
|
490 |
+
width=384,)
|
491 |
+
with gr.Column(scale=6):
|
492 |
+
output_image = gr.Image(label="Motion Path",
|
493 |
+
interactive=False,
|
494 |
+
height=256,
|
495 |
+
width=384,)
|
496 |
+
with gr.Row():
|
497 |
+
with gr.Column(scale=1):
|
498 |
+
prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
|
499 |
+
negative_prompt = gr.Text(
|
500 |
+
label="Negative Prompt",
|
501 |
+
max_lines=5,
|
502 |
+
placeholder="Please input your negative prompt",
|
503 |
+
value='worst quality, low quality, letterboxed',lines=1
|
504 |
+
)
|
505 |
+
drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
|
506 |
+
run_button = gr.Button(value="Run")
|
507 |
+
|
508 |
+
with gr.Accordion("More input params", open=False, elem_id="accordion1"):
|
509 |
+
with gr.Group():
|
510 |
+
seed = gr.Textbox(
|
511 |
+
label="Seed: ", value=561793204,
|
512 |
+
)
|
513 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
514 |
+
|
515 |
+
with gr.Group():
|
516 |
+
with gr.Row():
|
517 |
+
guidance_scale = gr.Slider(
|
518 |
+
label="Guidance scale",
|
519 |
+
minimum=1,
|
520 |
+
maximum=12,
|
521 |
+
step=0.1,
|
522 |
+
value=8.5,
|
523 |
+
)
|
524 |
+
num_inference_steps = gr.Slider(
|
525 |
+
label="Number of inference steps",
|
526 |
+
minimum=1,
|
527 |
+
maximum=50,
|
528 |
+
step=1,
|
529 |
+
value=25,
|
530 |
+
)
|
531 |
+
|
532 |
+
with gr.Group():
|
533 |
+
personalized = gr.Dropdown(label="Personalized template", choices=['HelloObject', 'TUSUN'], value="")
|
534 |
+
|
535 |
+
with gr.Column(scale=7):
|
536 |
+
output_video = gr.Video(value=None,
|
537 |
+
label="Output Video",
|
538 |
+
width=384,
|
539 |
+
height=256)
|
540 |
+
|
541 |
+
|
542 |
+
with gr.Row():
|
543 |
+
def process_example(input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path):
|
544 |
+
|
545 |
+
return input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path
|
546 |
+
|
547 |
+
example = gr.Examples(
|
548 |
+
label="Input Example",
|
549 |
+
examples=image_examples,
|
550 |
+
inputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
|
551 |
+
outputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
|
552 |
+
fn=process_example,
|
553 |
+
run_on_click=True,
|
554 |
+
examples_per_page=10
|
555 |
+
)
|
556 |
+
|
557 |
+
with gr.Row():
|
558 |
+
gr.Markdown(citation)
|
559 |
+
|
560 |
+
|
561 |
+
image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
|
562 |
+
|
563 |
+
add_drag_button.click(add_drag, tracking_points, tracking_points)
|
564 |
+
|
565 |
+
delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
|
566 |
+
|
567 |
+
delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
|
568 |
+
|
569 |
+
reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points])
|
570 |
+
|
571 |
+
input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
|
572 |
+
|
573 |
+
run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode,
|
574 |
+
negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized],
|
575 |
+
[output_image, output_video])
|
576 |
+
|
577 |
+
demo.launch(server_name="0.0.0.0", debug=True, server_port=12345)
|
configs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/inference/flow_condition.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
controlnet_additional_kwargs:
|
2 |
+
set_noisy_sample_input_to_zero: true
|
3 |
+
use_simplified_condition_embedding: true
|
4 |
+
conditioning_channels: 2
|
5 |
+
concate_conditioning_mask: false
|
6 |
+
|
7 |
+
use_motion_module: true
|
8 |
+
motion_module_resolutions: [1,2,4,8]
|
9 |
+
motion_module_mid_block: false
|
10 |
+
motion_module_type: "Vanilla"
|
11 |
+
|
12 |
+
motion_module_kwargs:
|
13 |
+
num_attention_heads: 8
|
14 |
+
num_transformer_block: 1
|
15 |
+
attention_block_types: [ "Temporal_Self" ]
|
16 |
+
temporal_position_encoding: true
|
17 |
+
temporal_position_encoding_max_len: 32
|
18 |
+
temporal_attention_dim_div: 1
|
configs/inference/image_condition.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
controlnet_additional_kwargs:
|
2 |
+
set_noisy_sample_input_to_zero: true
|
3 |
+
use_simplified_condition_embedding: true
|
4 |
+
conditioning_channels: 4
|
5 |
+
concate_conditioning_mask: true
|
6 |
+
|
7 |
+
use_motion_module: true
|
8 |
+
motion_module_resolutions: [1,2,4,8]
|
9 |
+
motion_module_mid_block: false
|
10 |
+
motion_module_type: "Vanilla"
|
11 |
+
|
12 |
+
motion_module_kwargs:
|
13 |
+
num_attention_heads: 8
|
14 |
+
num_transformer_block: 1
|
15 |
+
attention_block_types: [ "Temporal_Self" ]
|
16 |
+
temporal_position_encoding: true
|
17 |
+
temporal_position_encoding_max_len: 32
|
18 |
+
temporal_attention_dim_div: 1
|
configs/inference/inference.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true
|
3 |
+
use_motion_module: true
|
4 |
+
motion_module_resolutions: [1,2,4,8]
|
5 |
+
motion_module_mid_block: false
|
6 |
+
motion_module_type: Vanilla
|
7 |
+
|
8 |
+
motion_module_kwargs:
|
9 |
+
num_attention_heads: 8
|
10 |
+
num_transformer_block: 1
|
11 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
12 |
+
temporal_position_encoding: true
|
13 |
+
temporal_position_encoding_max_len: 32
|
14 |
+
temporal_attention_dim_div: 1
|
15 |
+
zero_initialize: true
|
16 |
+
|
17 |
+
noise_scheduler_kwargs:
|
18 |
+
beta_start: 0.00085
|
19 |
+
beta_end: 0.012
|
20 |
+
beta_schedule: "linear"
|
21 |
+
steps_offset: 1
|
22 |
+
clip_sample: False
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
modules/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (6.61 kB). View file
|
|
modules/__pycache__/flow_controlnet.cpython-310.pyc
ADDED
Binary file (14.5 kB). View file
|
|
modules/__pycache__/image_controlnet.cpython-310.pyc
ADDED
Binary file (16.9 kB). View file
|
|
modules/__pycache__/motion_module.cpython-310.pyc
ADDED
Binary file (8.54 kB). View file
|
|
modules/__pycache__/resnet.cpython-310.pyc
ADDED
Binary file (5.89 kB). View file
|
|
modules/__pycache__/unet.cpython-310.pyc
ADDED
Binary file (14.3 kB). View file
|
|
modules/__pycache__/unet_blocks.cpython-310.pyc
ADDED
Binary file (13.9 kB). View file
|
|
modules/attention.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Dict, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.models import ModelMixin
|
10 |
+
from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
|
11 |
+
from diffusers.utils import BaseOutput
|
12 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Transformer3DModelOutput(BaseOutput):
|
21 |
+
sample: torch.FloatTensor
|
22 |
+
|
23 |
+
|
24 |
+
@maybe_allow_in_graph
|
25 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
26 |
+
@register_to_config
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
num_attention_heads: int = 16,
|
30 |
+
attention_head_dim: int = 88,
|
31 |
+
in_channels: Optional[int] = None,
|
32 |
+
num_layers: int = 1,
|
33 |
+
dropout: float = 0.0,
|
34 |
+
norm_num_groups: int = 32,
|
35 |
+
cross_attention_dim: Optional[int] = None,
|
36 |
+
attention_bias: bool = False,
|
37 |
+
activation_fn: str = "geglu",
|
38 |
+
num_embeds_ada_norm: Optional[int] = None,
|
39 |
+
use_linear_projection: bool = False,
|
40 |
+
only_cross_attention: bool = False,
|
41 |
+
upcast_attention: bool = False,
|
42 |
+
unet_use_cross_frame_attention=None,
|
43 |
+
unet_use_temporal_attention=None,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.use_linear_projection = use_linear_projection
|
47 |
+
self.num_attention_heads = num_attention_heads
|
48 |
+
self.attention_head_dim = attention_head_dim
|
49 |
+
inner_dim = num_attention_heads * attention_head_dim
|
50 |
+
|
51 |
+
# Define input layers
|
52 |
+
self.in_channels = in_channels
|
53 |
+
|
54 |
+
self.norm = torch.nn.GroupNorm(
|
55 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
56 |
+
)
|
57 |
+
if use_linear_projection:
|
58 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
59 |
+
else:
|
60 |
+
self.proj_in = nn.Conv2d(
|
61 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
62 |
+
)
|
63 |
+
|
64 |
+
# Define transformers blocks
|
65 |
+
self.transformer_blocks = nn.ModuleList(
|
66 |
+
[
|
67 |
+
BasicTransformerBlock(
|
68 |
+
inner_dim,
|
69 |
+
num_attention_heads,
|
70 |
+
attention_head_dim,
|
71 |
+
dropout=dropout,
|
72 |
+
cross_attention_dim=cross_attention_dim,
|
73 |
+
activation_fn=activation_fn,
|
74 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
75 |
+
attention_bias=attention_bias,
|
76 |
+
only_cross_attention=only_cross_attention,
|
77 |
+
upcast_attention=upcast_attention,
|
78 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
79 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
80 |
+
)
|
81 |
+
for d in range(num_layers)
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
# 4. Define output layers
|
86 |
+
if use_linear_projection:
|
87 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
88 |
+
else:
|
89 |
+
self.proj_out = nn.Conv2d(
|
90 |
+
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
hidden_states: torch.Tensor,
|
96 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
97 |
+
timestep: Optional[torch.LongTensor] = None,
|
98 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
99 |
+
attention_mask: Optional[torch.Tensor] = None,
|
100 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
101 |
+
return_dict: bool = True,
|
102 |
+
):
|
103 |
+
# validate input dim
|
104 |
+
if hidden_states.dim() != 5:
|
105 |
+
raise ValueError(
|
106 |
+
f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
107 |
+
)
|
108 |
+
|
109 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
110 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
111 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
112 |
+
# expects mask of shape:
|
113 |
+
# [batch, key_tokens]
|
114 |
+
# adds singleton query_tokens dimension:
|
115 |
+
# [batch, 1, key_tokens]
|
116 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
117 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
118 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
119 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
120 |
+
# assume that mask is expressed as:
|
121 |
+
# (1 = keep, 0 = discard)
|
122 |
+
# convert mask into a bias that can be added to attention scores:
|
123 |
+
# (keep = +0, discard = -10000.0)
|
124 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
125 |
+
attention_mask = attention_mask.unsqueeze(1)
|
126 |
+
|
127 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
128 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
129 |
+
encoder_attention_mask = (
|
130 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
131 |
+
) * -10000.0
|
132 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
133 |
+
|
134 |
+
# shenanigans for motion module
|
135 |
+
video_length = hidden_states.shape[2]
|
136 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
137 |
+
encoder_hidden_states = repeat(
|
138 |
+
encoder_hidden_states, "b n c -> (b f) n c", f=video_length
|
139 |
+
)
|
140 |
+
|
141 |
+
# 1. Input
|
142 |
+
batch, _, height, width = hidden_states.shape
|
143 |
+
residual = hidden_states
|
144 |
+
|
145 |
+
hidden_states = self.norm(hidden_states)
|
146 |
+
if not self.use_linear_projection:
|
147 |
+
hidden_states = self.proj_in(hidden_states)
|
148 |
+
inner_dim = hidden_states.shape[1]
|
149 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
150 |
+
batch, height * width, inner_dim
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
inner_dim = hidden_states.shape[1]
|
154 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
155 |
+
batch, height * width, inner_dim
|
156 |
+
)
|
157 |
+
hidden_states = self.proj_in(hidden_states)
|
158 |
+
|
159 |
+
# 2. Blocks
|
160 |
+
for block in self.transformer_blocks:
|
161 |
+
hidden_states = block(
|
162 |
+
hidden_states,
|
163 |
+
attention_mask=attention_mask,
|
164 |
+
encoder_hidden_states=encoder_hidden_states,
|
165 |
+
timestep=timestep,
|
166 |
+
video_length=video_length,
|
167 |
+
encoder_attention_mask=encoder_attention_mask,
|
168 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
169 |
+
)
|
170 |
+
|
171 |
+
# 3. Output
|
172 |
+
if not self.use_linear_projection:
|
173 |
+
hidden_states = (
|
174 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
175 |
+
.permute(0, 3, 1, 2)
|
176 |
+
.contiguous()
|
177 |
+
)
|
178 |
+
hidden_states = self.proj_out(hidden_states)
|
179 |
+
else:
|
180 |
+
hidden_states = self.proj_out(hidden_states)
|
181 |
+
hidden_states = (
|
182 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
183 |
+
.permute(0, 3, 1, 2)
|
184 |
+
.contiguous()
|
185 |
+
)
|
186 |
+
|
187 |
+
output = hidden_states + residual
|
188 |
+
|
189 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
190 |
+
if not return_dict:
|
191 |
+
return (output,)
|
192 |
+
|
193 |
+
return Transformer3DModelOutput(sample=output)
|
194 |
+
|
195 |
+
|
196 |
+
@maybe_allow_in_graph
|
197 |
+
class BasicTransformerBlock(nn.Module):
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
dim: int,
|
201 |
+
num_attention_heads: int,
|
202 |
+
attention_head_dim: int,
|
203 |
+
dropout: float = 0.0,
|
204 |
+
cross_attention_dim: Optional[int] = None,
|
205 |
+
activation_fn: str = "geglu",
|
206 |
+
num_embeds_ada_norm: Optional[int] = None,
|
207 |
+
attention_bias: bool = False,
|
208 |
+
only_cross_attention: bool = False,
|
209 |
+
upcast_attention: bool = False,
|
210 |
+
norm_elementwise_affine: bool = True,
|
211 |
+
unet_use_cross_frame_attention: bool = False,
|
212 |
+
unet_use_temporal_attention: bool = False,
|
213 |
+
final_dropout: bool = False,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
self.only_cross_attention = only_cross_attention
|
217 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
218 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
219 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
220 |
+
|
221 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
222 |
+
# Self-Attn / SC-Attn
|
223 |
+
if self.use_ada_layer_norm:
|
224 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
225 |
+
else:
|
226 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
227 |
+
|
228 |
+
if unet_use_cross_frame_attention:
|
229 |
+
# this isn't actually implemented anywhere in the AnimateDiff codebase or in Diffusers...
|
230 |
+
raise NotImplementedError("SC-Attn is not implemented yet.")
|
231 |
+
else:
|
232 |
+
self.attn1 = Attention(
|
233 |
+
query_dim=dim,
|
234 |
+
cross_attention_dim=(
|
235 |
+
cross_attention_dim if only_cross_attention else None
|
236 |
+
),
|
237 |
+
heads=num_attention_heads,
|
238 |
+
dim_head=attention_head_dim,
|
239 |
+
dropout=dropout,
|
240 |
+
bias=attention_bias,
|
241 |
+
upcast_attention=upcast_attention,
|
242 |
+
)
|
243 |
+
|
244 |
+
# 2. Cross-Attn
|
245 |
+
if cross_attention_dim is not None:
|
246 |
+
self.norm2 = (
|
247 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
248 |
+
if self.use_ada_layer_norm
|
249 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
250 |
+
)
|
251 |
+
self.attn2 = Attention(
|
252 |
+
query_dim=dim,
|
253 |
+
cross_attention_dim=cross_attention_dim,
|
254 |
+
heads=num_attention_heads,
|
255 |
+
dim_head=attention_head_dim,
|
256 |
+
dropout=dropout,
|
257 |
+
bias=attention_bias,
|
258 |
+
upcast_attention=upcast_attention,
|
259 |
+
) # is self-attn if encoder_hidden_states is none
|
260 |
+
else:
|
261 |
+
self.norm2 = None
|
262 |
+
self.attn2 = None
|
263 |
+
|
264 |
+
# 3. Feed-forward
|
265 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
266 |
+
self.ff = FeedForward(
|
267 |
+
dim,
|
268 |
+
dropout=dropout,
|
269 |
+
activation_fn=activation_fn,
|
270 |
+
final_dropout=final_dropout,
|
271 |
+
)
|
272 |
+
|
273 |
+
# 4. Temporal Attn
|
274 |
+
assert unet_use_temporal_attention is not None
|
275 |
+
if unet_use_temporal_attention:
|
276 |
+
self.attn_temp = Attention(
|
277 |
+
query_dim=dim,
|
278 |
+
heads=num_attention_heads,
|
279 |
+
dim_head=attention_head_dim,
|
280 |
+
dropout=dropout,
|
281 |
+
bias=attention_bias,
|
282 |
+
upcast_attention=upcast_attention,
|
283 |
+
)
|
284 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
285 |
+
if self.use_ada_layer_norm:
|
286 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
287 |
+
else:
|
288 |
+
self.norm1 = nn.LayerNorm(
|
289 |
+
dim, elementwise_affine=norm_elementwise_affine
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(
|
293 |
+
self,
|
294 |
+
hidden_states: torch.FloatTensor,
|
295 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
296 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
297 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
298 |
+
timestep: Optional[torch.LongTensor] = None,
|
299 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
300 |
+
video_length=None,
|
301 |
+
):
|
302 |
+
# SparseCausal-Attention
|
303 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
304 |
+
# 1. Self-Attention
|
305 |
+
if self.use_ada_layer_norm:
|
306 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
307 |
+
else:
|
308 |
+
norm_hidden_states = self.norm1(hidden_states)
|
309 |
+
|
310 |
+
cross_attention_kwargs = (
|
311 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
312 |
+
)
|
313 |
+
if self.unet_use_cross_frame_attention:
|
314 |
+
cross_attention_kwargs["video_length"] = video_length
|
315 |
+
|
316 |
+
attn_output = self.attn1(
|
317 |
+
norm_hidden_states,
|
318 |
+
encoder_hidden_states=(
|
319 |
+
encoder_hidden_states if self.only_cross_attention else None
|
320 |
+
),
|
321 |
+
attention_mask=attention_mask,
|
322 |
+
**cross_attention_kwargs,
|
323 |
+
)
|
324 |
+
|
325 |
+
hidden_states = attn_output + hidden_states
|
326 |
+
|
327 |
+
# 2. Cross-Attention
|
328 |
+
if self.attn2 is not None:
|
329 |
+
norm_hidden_states = (
|
330 |
+
self.norm2(hidden_states, timestep)
|
331 |
+
if self.use_ada_layer_norm
|
332 |
+
else self.norm2(hidden_states)
|
333 |
+
)
|
334 |
+
|
335 |
+
attn_output = self.attn2(
|
336 |
+
norm_hidden_states,
|
337 |
+
encoder_hidden_states=encoder_hidden_states,
|
338 |
+
attention_mask=encoder_attention_mask,
|
339 |
+
**cross_attention_kwargs,
|
340 |
+
)
|
341 |
+
hidden_states = attn_output + hidden_states
|
342 |
+
|
343 |
+
# 3. Feed-forward
|
344 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
345 |
+
|
346 |
+
# 4. Temporal-Attention
|
347 |
+
if self.unet_use_temporal_attention:
|
348 |
+
d = hidden_states.shape[1]
|
349 |
+
hidden_states = rearrange(
|
350 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
351 |
+
)
|
352 |
+
norm_hidden_states = (
|
353 |
+
self.norm_temp(hidden_states, timestep)
|
354 |
+
if self.use_ada_layer_norm
|
355 |
+
else self.norm_temp(hidden_states)
|
356 |
+
)
|
357 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
358 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
359 |
+
|
360 |
+
return hidden_states
|
361 |
+
hidden_states = attn_output + hidden_states
|
362 |
+
|
363 |
+
# 2. Cross-Attention
|
364 |
+
if self.attn2 is not None:
|
365 |
+
norm_hidden_states = (
|
366 |
+
self.norm2(hidden_states, timestep)
|
367 |
+
if self.use_ada_layer_norm
|
368 |
+
else self.norm2(hidden_states)
|
369 |
+
)
|
370 |
+
|
371 |
+
attn_output = self.attn2(
|
372 |
+
norm_hidden_states,
|
373 |
+
encoder_hidden_states=encoder_hidden_states,
|
374 |
+
attention_mask=encoder_attention_mask,
|
375 |
+
**cross_attention_kwargs,
|
376 |
+
)
|
377 |
+
hidden_states = attn_output + hidden_states
|
378 |
+
|
379 |
+
# 3. Feed-forward
|
380 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
381 |
+
|
382 |
+
# 4. Temporal-Attention
|
383 |
+
if self.unet_use_temporal_attention:
|
384 |
+
d = hidden_states.shape[1]
|
385 |
+
hidden_states = rearrange(
|
386 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
387 |
+
)
|
388 |
+
norm_hidden_states = (
|
389 |
+
self.norm_temp(hidden_states, timestep)
|
390 |
+
if self.use_ada_layer_norm
|
391 |
+
else self.norm_temp(hidden_states)
|
392 |
+
)
|
393 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
394 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
395 |
+
|
396 |
+
return hidden_states
|
modules/flow_controlnet.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Changes were made to this source code by Yuwei Guo.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from diffusers import ModelMixin
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
23 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
24 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
25 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin
|
26 |
+
from diffusers.utils import BaseOutput, logging
|
27 |
+
from einops import rearrange, repeat
|
28 |
+
from torch import nn
|
29 |
+
from torch.nn import functional as F
|
30 |
+
|
31 |
+
from .resnet import InflatedConv3d
|
32 |
+
from .unet_blocks import (
|
33 |
+
CrossAttnDownBlock3D,
|
34 |
+
DownBlock3D,
|
35 |
+
UNetMidBlock3DCrossAttn,
|
36 |
+
get_down_block,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class FlowControlNetOutput(BaseOutput):
|
45 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
46 |
+
mid_block_res_sample: torch.Tensor
|
47 |
+
|
48 |
+
|
49 |
+
class FlowControlNetConditioningEmbedding(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
conditioning_embedding_channels: int,
|
53 |
+
conditioning_channels: int = 3,
|
54 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
59 |
+
|
60 |
+
self.blocks = nn.ModuleList([])
|
61 |
+
|
62 |
+
for i in range(len(block_out_channels) - 1):
|
63 |
+
channel_in = block_out_channels[i]
|
64 |
+
channel_out = block_out_channels[i + 1]
|
65 |
+
self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
|
66 |
+
self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
67 |
+
|
68 |
+
self.conv_out = zero_module(
|
69 |
+
InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, conditioning):
|
73 |
+
embedding = self.conv_in(conditioning)
|
74 |
+
embedding = F.silu(embedding)
|
75 |
+
|
76 |
+
for block in self.blocks:
|
77 |
+
embedding = block(embedding)
|
78 |
+
embedding = F.silu(embedding)
|
79 |
+
|
80 |
+
embedding = self.conv_out(embedding)
|
81 |
+
|
82 |
+
return embedding
|
83 |
+
|
84 |
+
|
85 |
+
class FlowControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
86 |
+
_supports_gradient_checkpointing = True
|
87 |
+
|
88 |
+
@register_to_config
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
in_channels: int = 4,
|
92 |
+
conditioning_channels: int = 3,
|
93 |
+
flip_sin_to_cos: bool = True,
|
94 |
+
freq_shift: int = 0,
|
95 |
+
down_block_types: Tuple[str] = (
|
96 |
+
"CrossAttnDownBlock2D",
|
97 |
+
"CrossAttnDownBlock2D",
|
98 |
+
"CrossAttnDownBlock2D",
|
99 |
+
"DownBlock2D",
|
100 |
+
),
|
101 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
102 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
103 |
+
layers_per_block: int = 2,
|
104 |
+
downsample_padding: int = 1,
|
105 |
+
mid_block_scale_factor: float = 1,
|
106 |
+
act_fn: str = "silu",
|
107 |
+
norm_num_groups: Optional[int] = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
cross_attention_dim: int = 1280,
|
110 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
111 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
112 |
+
use_linear_projection: bool = False,
|
113 |
+
class_embed_type: Optional[str] = None,
|
114 |
+
num_class_embeds: Optional[int] = None,
|
115 |
+
upcast_attention: bool = False,
|
116 |
+
resnet_time_scale_shift: str = "default",
|
117 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
118 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
119 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
120 |
+
global_pool_conditions: bool = False,
|
121 |
+
|
122 |
+
use_motion_module = True,
|
123 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
124 |
+
motion_module_mid_block = False,
|
125 |
+
motion_module_type = "Vanilla",
|
126 |
+
motion_module_kwargs = {
|
127 |
+
"num_attention_heads": 8,
|
128 |
+
"num_transformer_block": 1,
|
129 |
+
"attention_block_types": ["Temporal_Self"],
|
130 |
+
"temporal_position_encoding": True,
|
131 |
+
"temporal_position_encoding_max_len": 32,
|
132 |
+
"temporal_attention_dim_div": 1,
|
133 |
+
"causal_temporal_attention": False,
|
134 |
+
},
|
135 |
+
|
136 |
+
concate_conditioning_mask: bool = True,
|
137 |
+
use_simplified_condition_embedding: bool = False,
|
138 |
+
|
139 |
+
set_noisy_sample_input_to_zero: bool = False,
|
140 |
+
):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
144 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
145 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
146 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
147 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
148 |
+
# which is why we correct for the naming here.
|
149 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
150 |
+
|
151 |
+
# Check inputs
|
152 |
+
if len(block_out_channels) != len(down_block_types):
|
153 |
+
raise ValueError(
|
154 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
155 |
+
)
|
156 |
+
|
157 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
158 |
+
raise ValueError(
|
159 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
160 |
+
)
|
161 |
+
|
162 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
163 |
+
raise ValueError(
|
164 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
165 |
+
)
|
166 |
+
|
167 |
+
# input
|
168 |
+
self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
|
169 |
+
|
170 |
+
conv_in_kernel = 3
|
171 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
172 |
+
self.conv_in = InflatedConv3d(
|
173 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
174 |
+
)
|
175 |
+
conditioning_channels = conditioning_channels * 8 * 8
|
176 |
+
if concate_conditioning_mask:
|
177 |
+
conditioning_channels = conditioning_channels + 1
|
178 |
+
self.concate_conditioning_mask = concate_conditioning_mask
|
179 |
+
|
180 |
+
# control net conditioning embedding
|
181 |
+
if use_simplified_condition_embedding:
|
182 |
+
self.controlnet_cond_embedding = zero_module(
|
183 |
+
InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
self.controlnet_cond_embedding = FlowControlNetConditioningEmbedding(
|
187 |
+
conditioning_embedding_channels=block_out_channels[0],
|
188 |
+
block_out_channels=conditioning_embedding_out_channels,
|
189 |
+
conditioning_channels=conditioning_channels,
|
190 |
+
)
|
191 |
+
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
192 |
+
|
193 |
+
# time
|
194 |
+
time_embed_dim = block_out_channels[0] * 4
|
195 |
+
|
196 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
197 |
+
timestep_input_dim = block_out_channels[0]
|
198 |
+
|
199 |
+
self.time_embedding = TimestepEmbedding(
|
200 |
+
timestep_input_dim,
|
201 |
+
time_embed_dim,
|
202 |
+
act_fn=act_fn,
|
203 |
+
)
|
204 |
+
|
205 |
+
# class embedding
|
206 |
+
if class_embed_type is None and num_class_embeds is not None:
|
207 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
208 |
+
elif class_embed_type == "timestep":
|
209 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
210 |
+
elif class_embed_type == "identity":
|
211 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
212 |
+
elif class_embed_type == "projection":
|
213 |
+
if projection_class_embeddings_input_dim is None:
|
214 |
+
raise ValueError(
|
215 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
216 |
+
)
|
217 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
218 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
219 |
+
# 2. it projects from an arbitrary input dimension.
|
220 |
+
#
|
221 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
222 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
223 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
224 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
225 |
+
else:
|
226 |
+
self.class_embedding = None
|
227 |
+
|
228 |
+
|
229 |
+
self.down_blocks = nn.ModuleList([])
|
230 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
231 |
+
|
232 |
+
if isinstance(only_cross_attention, bool):
|
233 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
234 |
+
|
235 |
+
if isinstance(attention_head_dim, int):
|
236 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
237 |
+
|
238 |
+
if isinstance(num_attention_heads, int):
|
239 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
240 |
+
|
241 |
+
# down
|
242 |
+
output_channel = block_out_channels[0]
|
243 |
+
|
244 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
245 |
+
controlnet_block = zero_module(controlnet_block)
|
246 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
247 |
+
|
248 |
+
for i, down_block_type in enumerate(down_block_types):
|
249 |
+
res = 2 ** i
|
250 |
+
input_channel = output_channel
|
251 |
+
output_channel = block_out_channels[i]
|
252 |
+
is_final_block = i == len(block_out_channels) - 1
|
253 |
+
|
254 |
+
down_block = get_down_block(
|
255 |
+
down_block_type,
|
256 |
+
num_layers=layers_per_block,
|
257 |
+
in_channels=input_channel,
|
258 |
+
out_channels=output_channel,
|
259 |
+
temb_channels=time_embed_dim,
|
260 |
+
add_downsample=not is_final_block,
|
261 |
+
resnet_eps=norm_eps,
|
262 |
+
resnet_act_fn=act_fn,
|
263 |
+
resnet_groups=norm_num_groups,
|
264 |
+
cross_attention_dim=cross_attention_dim,
|
265 |
+
attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
266 |
+
downsample_padding=downsample_padding,
|
267 |
+
use_linear_projection=use_linear_projection,
|
268 |
+
only_cross_attention=only_cross_attention[i],
|
269 |
+
upcast_attention=upcast_attention,
|
270 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
271 |
+
|
272 |
+
use_inflated_groupnorm=True,
|
273 |
+
|
274 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
275 |
+
motion_module_type=motion_module_type,
|
276 |
+
motion_module_kwargs=motion_module_kwargs,
|
277 |
+
)
|
278 |
+
self.down_blocks.append(down_block)
|
279 |
+
|
280 |
+
for _ in range(layers_per_block):
|
281 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
282 |
+
controlnet_block = zero_module(controlnet_block)
|
283 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
284 |
+
|
285 |
+
if not is_final_block:
|
286 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
287 |
+
controlnet_block = zero_module(controlnet_block)
|
288 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
289 |
+
|
290 |
+
# mid
|
291 |
+
mid_block_channel = block_out_channels[-1]
|
292 |
+
|
293 |
+
controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
|
294 |
+
controlnet_block = zero_module(controlnet_block)
|
295 |
+
self.controlnet_mid_block = controlnet_block
|
296 |
+
|
297 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
298 |
+
in_channels=mid_block_channel,
|
299 |
+
temb_channels=time_embed_dim,
|
300 |
+
resnet_eps=norm_eps,
|
301 |
+
resnet_act_fn=act_fn,
|
302 |
+
output_scale_factor=mid_block_scale_factor,
|
303 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
304 |
+
cross_attention_dim=cross_attention_dim,
|
305 |
+
attn_num_head_channels=num_attention_heads[-1],
|
306 |
+
resnet_groups=norm_num_groups,
|
307 |
+
use_linear_projection=use_linear_projection,
|
308 |
+
upcast_attention=upcast_attention,
|
309 |
+
|
310 |
+
use_inflated_groupnorm=True,
|
311 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
312 |
+
motion_module_type=motion_module_type,
|
313 |
+
motion_module_kwargs=motion_module_kwargs,
|
314 |
+
)
|
315 |
+
|
316 |
+
@classmethod
|
317 |
+
def from_unet(
|
318 |
+
cls,
|
319 |
+
unet: UNet2DConditionModel,
|
320 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
321 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
322 |
+
load_weights_from_unet: bool = True,
|
323 |
+
|
324 |
+
controlnet_additional_kwargs: dict = {},
|
325 |
+
):
|
326 |
+
controlnet = cls(
|
327 |
+
in_channels=unet.config.in_channels,
|
328 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
329 |
+
freq_shift=unet.config.freq_shift,
|
330 |
+
down_block_types=unet.config.down_block_types,
|
331 |
+
only_cross_attention=unet.config.only_cross_attention,
|
332 |
+
block_out_channels=unet.config.block_out_channels,
|
333 |
+
layers_per_block=unet.config.layers_per_block,
|
334 |
+
downsample_padding=unet.config.downsample_padding,
|
335 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
336 |
+
act_fn=unet.config.act_fn,
|
337 |
+
norm_num_groups=unet.config.norm_num_groups,
|
338 |
+
norm_eps=unet.config.norm_eps,
|
339 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
340 |
+
attention_head_dim=unet.config.attention_head_dim,
|
341 |
+
num_attention_heads=unet.config.num_attention_heads,
|
342 |
+
use_linear_projection=unet.config.use_linear_projection,
|
343 |
+
class_embed_type=unet.config.class_embed_type,
|
344 |
+
num_class_embeds=unet.config.num_class_embeds,
|
345 |
+
upcast_attention=unet.config.upcast_attention,
|
346 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
347 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
348 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
349 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
350 |
+
|
351 |
+
**controlnet_additional_kwargs,
|
352 |
+
)
|
353 |
+
controlnet.unshuffle = nn.PixelUnshuffle(8)
|
354 |
+
|
355 |
+
if load_weights_from_unet:
|
356 |
+
m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
|
357 |
+
assert len(u) == 0
|
358 |
+
m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
|
359 |
+
assert len(u) == 0
|
360 |
+
m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
|
361 |
+
assert len(u) == 0
|
362 |
+
|
363 |
+
if controlnet.class_embedding:
|
364 |
+
m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
|
365 |
+
assert len(u) == 0
|
366 |
+
m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
|
367 |
+
assert len(u) == 0
|
368 |
+
m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
|
369 |
+
assert len(u) == 0
|
370 |
+
|
371 |
+
|
372 |
+
return controlnet
|
373 |
+
|
374 |
+
@staticmethod
|
375 |
+
def image_layer_filter(state_dict):
|
376 |
+
new_state_dict = {}
|
377 |
+
for name, param in state_dict.items():
|
378 |
+
if "motion_modules." in name or "lora" in name: continue
|
379 |
+
new_state_dict[name] = param
|
380 |
+
return new_state_dict
|
381 |
+
|
382 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
383 |
+
def set_attention_slice(self, slice_size):
|
384 |
+
r"""
|
385 |
+
Enable sliced attention computation.
|
386 |
+
|
387 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
388 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
392 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
393 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
394 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
395 |
+
must be a multiple of `slice_size`.
|
396 |
+
"""
|
397 |
+
sliceable_head_dims = []
|
398 |
+
|
399 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
400 |
+
if hasattr(module, "set_attention_slice"):
|
401 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
402 |
+
|
403 |
+
for child in module.children():
|
404 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
405 |
+
|
406 |
+
# retrieve number of attention layers
|
407 |
+
for module in self.children():
|
408 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
409 |
+
|
410 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
411 |
+
|
412 |
+
if slice_size == "auto":
|
413 |
+
# half the attention head size is usually a good trade-off between
|
414 |
+
# speed and memory
|
415 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
416 |
+
elif slice_size == "max":
|
417 |
+
# make smallest slice possible
|
418 |
+
slice_size = num_sliceable_layers * [1]
|
419 |
+
|
420 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
421 |
+
|
422 |
+
if len(slice_size) != len(sliceable_head_dims):
|
423 |
+
raise ValueError(
|
424 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
425 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
426 |
+
)
|
427 |
+
|
428 |
+
for i in range(len(slice_size)):
|
429 |
+
size = slice_size[i]
|
430 |
+
dim = sliceable_head_dims[i]
|
431 |
+
if size is not None and size > dim:
|
432 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
433 |
+
|
434 |
+
# Recursively walk through all the children.
|
435 |
+
# Any children which exposes the set_attention_slice method
|
436 |
+
# gets the message
|
437 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
438 |
+
if hasattr(module, "set_attention_slice"):
|
439 |
+
module.set_attention_slice(slice_size.pop())
|
440 |
+
|
441 |
+
for child in module.children():
|
442 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
443 |
+
|
444 |
+
reversed_slice_size = list(reversed(slice_size))
|
445 |
+
for module in self.children():
|
446 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
447 |
+
|
448 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
449 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
450 |
+
module.gradient_checkpointing = value
|
451 |
+
|
452 |
+
def forward(
|
453 |
+
self,
|
454 |
+
sample: torch.FloatTensor,
|
455 |
+
timestep: Union[torch.Tensor, float, int],
|
456 |
+
encoder_hidden_states: torch.Tensor,
|
457 |
+
|
458 |
+
controlnet_cond: torch.FloatTensor,
|
459 |
+
conditioning_mask: Optional[torch.FloatTensor] = None,
|
460 |
+
|
461 |
+
conditioning_scale: float = 1.0,
|
462 |
+
class_labels: Optional[torch.Tensor] = None,
|
463 |
+
attention_mask: Optional[torch.Tensor] = None,
|
464 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
465 |
+
guess_mode: bool = False,
|
466 |
+
return_dict: bool = True,
|
467 |
+
) -> Union[FlowControlNetOutput, Tuple]:
|
468 |
+
# set input noise to zero
|
469 |
+
if self.set_noisy_sample_input_to_zero:
|
470 |
+
sample = torch.zeros_like(sample).to(sample.device)
|
471 |
+
|
472 |
+
# prepare attention_mask
|
473 |
+
if attention_mask is not None:
|
474 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
475 |
+
attention_mask = attention_mask.unsqueeze(1)
|
476 |
+
|
477 |
+
# 1. time
|
478 |
+
timesteps = timestep
|
479 |
+
if not torch.is_tensor(timesteps):
|
480 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
481 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
482 |
+
is_mps = sample.device.type == "mps"
|
483 |
+
if isinstance(timestep, float):
|
484 |
+
dtype = torch.float32 if is_mps else torch.float64
|
485 |
+
else:
|
486 |
+
dtype = torch.int32 if is_mps else torch.int64
|
487 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
488 |
+
elif len(timesteps.shape) == 0:
|
489 |
+
timesteps = timesteps[None].to(sample.device)
|
490 |
+
|
491 |
+
timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
|
492 |
+
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
|
493 |
+
|
494 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
495 |
+
timesteps = timesteps.expand(sample.shape[0])
|
496 |
+
|
497 |
+
t_emb = self.time_proj(timesteps)
|
498 |
+
|
499 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
500 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
501 |
+
# there might be better ways to encapsulate this.
|
502 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
503 |
+
emb = self.time_embedding(t_emb)
|
504 |
+
|
505 |
+
if self.class_embedding is not None:
|
506 |
+
if class_labels is None:
|
507 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
508 |
+
|
509 |
+
if self.config.class_embed_type == "timestep":
|
510 |
+
class_labels = self.time_proj(class_labels)
|
511 |
+
|
512 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
513 |
+
emb = emb + class_emb
|
514 |
+
|
515 |
+
# 2. pre-process
|
516 |
+
sample = self.conv_in(sample)
|
517 |
+
|
518 |
+
|
519 |
+
if self.concate_conditioning_mask:
|
520 |
+
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
|
521 |
+
controlnet_cond = self.unshuffle(controlnet_cond.permute(0,2,1,3,4))
|
522 |
+
controlnet_cond = controlnet_cond.contiguous().permute(0,2,1,3,4)
|
523 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
524 |
+
|
525 |
+
sample = sample + controlnet_cond
|
526 |
+
|
527 |
+
# 3. down
|
528 |
+
down_block_res_samples = (sample,)
|
529 |
+
for downsample_block in self.down_blocks:
|
530 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
531 |
+
sample, res_samples = downsample_block(
|
532 |
+
hidden_states=sample,
|
533 |
+
temb=emb,
|
534 |
+
encoder_hidden_states=encoder_hidden_states,
|
535 |
+
attention_mask=attention_mask,
|
536 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
537 |
+
)
|
538 |
+
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
539 |
+
|
540 |
+
down_block_res_samples += res_samples
|
541 |
+
|
542 |
+
# 4. mid
|
543 |
+
if self.mid_block is not None:
|
544 |
+
sample = self.mid_block(
|
545 |
+
sample,
|
546 |
+
emb,
|
547 |
+
encoder_hidden_states=encoder_hidden_states,
|
548 |
+
attention_mask=attention_mask,
|
549 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
550 |
+
)
|
551 |
+
|
552 |
+
# 5. controlnet blocks
|
553 |
+
controlnet_down_block_res_samples = ()
|
554 |
+
|
555 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
556 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
557 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
558 |
+
|
559 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
560 |
+
|
561 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
562 |
+
|
563 |
+
# 6. scaling
|
564 |
+
if guess_mode and not self.config.global_pool_conditions:
|
565 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
566 |
+
|
567 |
+
scales = scales * conditioning_scale
|
568 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
569 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
570 |
+
else:
|
571 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
572 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
573 |
+
|
574 |
+
if self.config.global_pool_conditions:
|
575 |
+
down_block_res_samples = [
|
576 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
577 |
+
]
|
578 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
579 |
+
|
580 |
+
if not return_dict:
|
581 |
+
return (down_block_res_samples, mid_block_res_sample)
|
582 |
+
|
583 |
+
return FlowControlNetOutput(
|
584 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
585 |
+
)
|
586 |
+
|
587 |
+
|
588 |
+
def zero_module(module):
|
589 |
+
for p in module.parameters():
|
590 |
+
nn.init.zeros_(p)
|
591 |
+
return module
|
modules/image_controlnet.py
ADDED
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Changes were made to this source code by Yuwei Guo.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from diffusers import ModelMixin
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
23 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
24 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
25 |
+
from diffusers.utils import BaseOutput, logging
|
26 |
+
from einops import rearrange, repeat
|
27 |
+
from torch import nn
|
28 |
+
from torch.nn import functional as F
|
29 |
+
|
30 |
+
from .resnet import InflatedConv3d
|
31 |
+
from .unet_blocks import (
|
32 |
+
CrossAttnDownBlock3D,
|
33 |
+
DownBlock3D,
|
34 |
+
UNetMidBlock3DCrossAttn,
|
35 |
+
get_down_block,
|
36 |
+
)
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ImageControlNetOutput(BaseOutput):
|
43 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
44 |
+
mid_block_res_sample: torch.Tensor
|
45 |
+
|
46 |
+
|
47 |
+
class ImageControlNetConditioningEmbedding(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
conditioning_embedding_channels: int,
|
51 |
+
conditioning_channels: int = 3,
|
52 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.conv_in = InflatedConv3d(
|
57 |
+
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
|
58 |
+
)
|
59 |
+
|
60 |
+
self.blocks = nn.ModuleList([])
|
61 |
+
|
62 |
+
for i in range(len(block_out_channels) - 1):
|
63 |
+
channel_in = block_out_channels[i]
|
64 |
+
channel_out = block_out_channels[i + 1]
|
65 |
+
self.blocks.append(
|
66 |
+
InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
|
67 |
+
)
|
68 |
+
self.blocks.append(
|
69 |
+
InflatedConv3d(
|
70 |
+
channel_in, channel_out, kernel_size=3, padding=1, stride=2
|
71 |
+
)
|
72 |
+
)
|
73 |
+
|
74 |
+
self.conv_out = zero_module(
|
75 |
+
InflatedConv3d(
|
76 |
+
block_out_channels[-1],
|
77 |
+
conditioning_embedding_channels,
|
78 |
+
kernel_size=3,
|
79 |
+
padding=1,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, conditioning):
|
84 |
+
embedding = self.conv_in(conditioning)
|
85 |
+
embedding = F.silu(embedding)
|
86 |
+
|
87 |
+
for block in self.blocks:
|
88 |
+
embedding = block(embedding)
|
89 |
+
embedding = F.silu(embedding)
|
90 |
+
|
91 |
+
embedding = self.conv_out(embedding)
|
92 |
+
|
93 |
+
return embedding
|
94 |
+
|
95 |
+
|
96 |
+
class ImageControlNetModel(ModelMixin, ConfigMixin):
|
97 |
+
_supports_gradient_checkpointing = True
|
98 |
+
|
99 |
+
@register_to_config
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
in_channels: int = 4,
|
103 |
+
conditioning_channels: int = 3,
|
104 |
+
flip_sin_to_cos: bool = True,
|
105 |
+
freq_shift: int = 0,
|
106 |
+
down_block_types: Tuple[str] = (
|
107 |
+
"CrossAttnDownBlock2D",
|
108 |
+
"CrossAttnDownBlock2D",
|
109 |
+
"CrossAttnDownBlock2D",
|
110 |
+
"DownBlock2D",
|
111 |
+
),
|
112 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
113 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
114 |
+
layers_per_block: int = 2,
|
115 |
+
downsample_padding: int = 1,
|
116 |
+
mid_block_scale_factor: float = 1,
|
117 |
+
act_fn: str = "silu",
|
118 |
+
norm_num_groups: Optional[int] = 32,
|
119 |
+
norm_eps: float = 1e-5,
|
120 |
+
cross_attention_dim: int = 1280,
|
121 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
122 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
123 |
+
use_linear_projection: bool = False,
|
124 |
+
class_embed_type: Optional[str] = None,
|
125 |
+
num_class_embeds: Optional[int] = None,
|
126 |
+
upcast_attention: bool = False,
|
127 |
+
resnet_time_scale_shift: str = "default",
|
128 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
129 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
130 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
131 |
+
global_pool_conditions: bool = False,
|
132 |
+
use_motion_module=True,
|
133 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
134 |
+
motion_module_mid_block=False,
|
135 |
+
motion_module_type="Vanilla",
|
136 |
+
motion_module_kwargs={
|
137 |
+
"num_attention_heads": 8,
|
138 |
+
"num_transformer_block": 1,
|
139 |
+
"attention_block_types": ["Temporal_Self"],
|
140 |
+
"temporal_position_encoding": True,
|
141 |
+
"temporal_position_encoding_max_len": 32,
|
142 |
+
"temporal_attention_dim_div": 1,
|
143 |
+
"causal_temporal_attention": False,
|
144 |
+
},
|
145 |
+
concate_conditioning_mask: bool = True,
|
146 |
+
use_simplified_condition_embedding: bool = False,
|
147 |
+
set_noisy_sample_input_to_zero: bool = False,
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
152 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
153 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
154 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
155 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
156 |
+
# which is why we correct for the naming here.
|
157 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
158 |
+
|
159 |
+
# Check inputs
|
160 |
+
if len(block_out_channels) != len(down_block_types):
|
161 |
+
raise ValueError(
|
162 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
163 |
+
)
|
164 |
+
|
165 |
+
if not isinstance(only_cross_attention, bool) and len(
|
166 |
+
only_cross_attention
|
167 |
+
) != len(down_block_types):
|
168 |
+
raise ValueError(
|
169 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
170 |
+
)
|
171 |
+
|
172 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
173 |
+
down_block_types
|
174 |
+
):
|
175 |
+
raise ValueError(
|
176 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
177 |
+
)
|
178 |
+
|
179 |
+
# input
|
180 |
+
self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
|
181 |
+
|
182 |
+
conv_in_kernel = 3
|
183 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
184 |
+
self.conv_in = InflatedConv3d(
|
185 |
+
in_channels,
|
186 |
+
block_out_channels[0],
|
187 |
+
kernel_size=conv_in_kernel,
|
188 |
+
padding=conv_in_padding,
|
189 |
+
)
|
190 |
+
|
191 |
+
if concate_conditioning_mask:
|
192 |
+
conditioning_channels = conditioning_channels + 1
|
193 |
+
self.concate_conditioning_mask = concate_conditioning_mask
|
194 |
+
|
195 |
+
# control net conditioning embedding
|
196 |
+
if use_simplified_condition_embedding:
|
197 |
+
self.controlnet_cond_embedding = zero_module(
|
198 |
+
InflatedConv3d(
|
199 |
+
conditioning_channels,
|
200 |
+
block_out_channels[0],
|
201 |
+
kernel_size=conv_in_kernel,
|
202 |
+
padding=conv_in_padding,
|
203 |
+
)
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
self.controlnet_cond_embedding = ImageControlNetConditioningEmbedding(
|
207 |
+
conditioning_embedding_channels=block_out_channels[0],
|
208 |
+
block_out_channels=conditioning_embedding_out_channels,
|
209 |
+
conditioning_channels=conditioning_channels,
|
210 |
+
)
|
211 |
+
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
212 |
+
|
213 |
+
# time
|
214 |
+
time_embed_dim = block_out_channels[0] * 4
|
215 |
+
|
216 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
217 |
+
timestep_input_dim = block_out_channels[0]
|
218 |
+
|
219 |
+
self.time_embedding = TimestepEmbedding(
|
220 |
+
timestep_input_dim,
|
221 |
+
time_embed_dim,
|
222 |
+
act_fn=act_fn,
|
223 |
+
)
|
224 |
+
|
225 |
+
# class embedding
|
226 |
+
if class_embed_type is None and num_class_embeds is not None:
|
227 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
228 |
+
elif class_embed_type == "timestep":
|
229 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
230 |
+
elif class_embed_type == "identity":
|
231 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
232 |
+
elif class_embed_type == "projection":
|
233 |
+
if projection_class_embeddings_input_dim is None:
|
234 |
+
raise ValueError(
|
235 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
236 |
+
)
|
237 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
238 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
239 |
+
# 2. it projects from an arbitrary input dimension.
|
240 |
+
#
|
241 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
242 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
243 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
244 |
+
self.class_embedding = TimestepEmbedding(
|
245 |
+
projection_class_embeddings_input_dim, time_embed_dim
|
246 |
+
)
|
247 |
+
else:
|
248 |
+
self.class_embedding = None
|
249 |
+
|
250 |
+
self.down_blocks = nn.ModuleList([])
|
251 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
252 |
+
|
253 |
+
if isinstance(only_cross_attention, bool):
|
254 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
255 |
+
|
256 |
+
if isinstance(attention_head_dim, int):
|
257 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
258 |
+
|
259 |
+
if isinstance(num_attention_heads, int):
|
260 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
261 |
+
|
262 |
+
# down
|
263 |
+
output_channel = block_out_channels[0]
|
264 |
+
|
265 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
266 |
+
controlnet_block = zero_module(controlnet_block)
|
267 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
268 |
+
|
269 |
+
for i, down_block_type in enumerate(down_block_types):
|
270 |
+
res = 2**i
|
271 |
+
input_channel = output_channel
|
272 |
+
output_channel = block_out_channels[i]
|
273 |
+
is_final_block = i == len(block_out_channels) - 1
|
274 |
+
|
275 |
+
down_block = get_down_block(
|
276 |
+
down_block_type,
|
277 |
+
num_layers=layers_per_block,
|
278 |
+
in_channels=input_channel,
|
279 |
+
out_channels=output_channel,
|
280 |
+
temb_channels=time_embed_dim,
|
281 |
+
add_downsample=not is_final_block,
|
282 |
+
resnet_eps=norm_eps,
|
283 |
+
resnet_act_fn=act_fn,
|
284 |
+
resnet_groups=norm_num_groups,
|
285 |
+
cross_attention_dim=cross_attention_dim,
|
286 |
+
attn_num_head_channels=(
|
287 |
+
attention_head_dim[i]
|
288 |
+
if attention_head_dim[i] is not None
|
289 |
+
else output_channel
|
290 |
+
),
|
291 |
+
downsample_padding=downsample_padding,
|
292 |
+
use_linear_projection=use_linear_projection,
|
293 |
+
only_cross_attention=only_cross_attention[i],
|
294 |
+
upcast_attention=upcast_attention,
|
295 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
296 |
+
use_inflated_groupnorm=True,
|
297 |
+
use_motion_module=use_motion_module
|
298 |
+
and (res in motion_module_resolutions),
|
299 |
+
motion_module_type=motion_module_type,
|
300 |
+
motion_module_kwargs=motion_module_kwargs,
|
301 |
+
)
|
302 |
+
self.down_blocks.append(down_block)
|
303 |
+
|
304 |
+
for _ in range(layers_per_block):
|
305 |
+
controlnet_block = InflatedConv3d(
|
306 |
+
output_channel, output_channel, kernel_size=1
|
307 |
+
)
|
308 |
+
controlnet_block = zero_module(controlnet_block)
|
309 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
310 |
+
|
311 |
+
if not is_final_block:
|
312 |
+
controlnet_block = InflatedConv3d(
|
313 |
+
output_channel, output_channel, kernel_size=1
|
314 |
+
)
|
315 |
+
controlnet_block = zero_module(controlnet_block)
|
316 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
317 |
+
|
318 |
+
# mid
|
319 |
+
mid_block_channel = block_out_channels[-1]
|
320 |
+
|
321 |
+
controlnet_block = InflatedConv3d(
|
322 |
+
mid_block_channel, mid_block_channel, kernel_size=1
|
323 |
+
)
|
324 |
+
controlnet_block = zero_module(controlnet_block)
|
325 |
+
self.controlnet_mid_block = controlnet_block
|
326 |
+
|
327 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
328 |
+
in_channels=mid_block_channel,
|
329 |
+
temb_channels=time_embed_dim,
|
330 |
+
resnet_eps=norm_eps,
|
331 |
+
resnet_act_fn=act_fn,
|
332 |
+
output_scale_factor=mid_block_scale_factor,
|
333 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
334 |
+
cross_attention_dim=cross_attention_dim,
|
335 |
+
attn_num_head_channels=num_attention_heads[-1],
|
336 |
+
resnet_groups=norm_num_groups,
|
337 |
+
use_linear_projection=use_linear_projection,
|
338 |
+
upcast_attention=upcast_attention,
|
339 |
+
use_inflated_groupnorm=True,
|
340 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
341 |
+
motion_module_type=motion_module_type,
|
342 |
+
motion_module_kwargs=motion_module_kwargs,
|
343 |
+
)
|
344 |
+
|
345 |
+
@classmethod
|
346 |
+
def from_unet(
|
347 |
+
cls,
|
348 |
+
unet: UNet2DConditionModel,
|
349 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
350 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
351 |
+
load_weights_from_unet: bool = True,
|
352 |
+
controlnet_additional_kwargs: dict = {},
|
353 |
+
):
|
354 |
+
controlnet = cls(
|
355 |
+
in_channels=unet.config.in_channels,
|
356 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
357 |
+
freq_shift=unet.config.freq_shift,
|
358 |
+
down_block_types=unet.config.down_block_types,
|
359 |
+
only_cross_attention=unet.config.only_cross_attention,
|
360 |
+
block_out_channels=unet.config.block_out_channels,
|
361 |
+
layers_per_block=unet.config.layers_per_block,
|
362 |
+
downsample_padding=unet.config.downsample_padding,
|
363 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
364 |
+
act_fn=unet.config.act_fn,
|
365 |
+
norm_num_groups=unet.config.norm_num_groups,
|
366 |
+
norm_eps=unet.config.norm_eps,
|
367 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
368 |
+
attention_head_dim=unet.config.attention_head_dim,
|
369 |
+
num_attention_heads=unet.config.num_attention_heads,
|
370 |
+
use_linear_projection=unet.config.use_linear_projection,
|
371 |
+
class_embed_type=unet.config.class_embed_type,
|
372 |
+
num_class_embeds=unet.config.num_class_embeds,
|
373 |
+
upcast_attention=unet.config.upcast_attention,
|
374 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
375 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
376 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
377 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
378 |
+
**controlnet_additional_kwargs,
|
379 |
+
)
|
380 |
+
|
381 |
+
if load_weights_from_unet:
|
382 |
+
m, u = controlnet.conv_in.load_state_dict(
|
383 |
+
cls.image_layer_filter(unet.conv_in.state_dict()), strict=False
|
384 |
+
)
|
385 |
+
assert len(u) == 0
|
386 |
+
m, u = controlnet.time_proj.load_state_dict(
|
387 |
+
cls.image_layer_filter(unet.time_proj.state_dict()), strict=False
|
388 |
+
)
|
389 |
+
assert len(u) == 0
|
390 |
+
m, u = controlnet.time_embedding.load_state_dict(
|
391 |
+
cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False
|
392 |
+
)
|
393 |
+
assert len(u) == 0
|
394 |
+
|
395 |
+
if controlnet.class_embedding:
|
396 |
+
m, u = controlnet.class_embedding.load_state_dict(
|
397 |
+
cls.image_layer_filter(unet.class_embedding.state_dict()),
|
398 |
+
strict=False,
|
399 |
+
)
|
400 |
+
assert len(u) == 0
|
401 |
+
m, u = controlnet.down_blocks.load_state_dict(
|
402 |
+
cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False
|
403 |
+
)
|
404 |
+
assert len(u) == 0
|
405 |
+
m, u = controlnet.mid_block.load_state_dict(
|
406 |
+
cls.image_layer_filter(unet.mid_block.state_dict()), strict=False
|
407 |
+
)
|
408 |
+
assert len(u) == 0
|
409 |
+
|
410 |
+
return controlnet
|
411 |
+
|
412 |
+
@staticmethod
|
413 |
+
def image_layer_filter(state_dict):
|
414 |
+
new_state_dict = {}
|
415 |
+
for name, param in state_dict.items():
|
416 |
+
if "motion_modules." in name or "lora" in name:
|
417 |
+
continue
|
418 |
+
new_state_dict[name] = param
|
419 |
+
return new_state_dict
|
420 |
+
|
421 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
422 |
+
def set_attention_slice(self, slice_size):
|
423 |
+
r"""
|
424 |
+
Enable sliced attention computation.
|
425 |
+
|
426 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
427 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
431 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
432 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
433 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
434 |
+
must be a multiple of `slice_size`.
|
435 |
+
"""
|
436 |
+
sliceable_head_dims = []
|
437 |
+
|
438 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
439 |
+
if hasattr(module, "set_attention_slice"):
|
440 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
441 |
+
|
442 |
+
for child in module.children():
|
443 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
444 |
+
|
445 |
+
# retrieve number of attention layers
|
446 |
+
for module in self.children():
|
447 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
448 |
+
|
449 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
450 |
+
|
451 |
+
if slice_size == "auto":
|
452 |
+
# half the attention head size is usually a good trade-off between
|
453 |
+
# speed and memory
|
454 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
455 |
+
elif slice_size == "max":
|
456 |
+
# make smallest slice possible
|
457 |
+
slice_size = num_sliceable_layers * [1]
|
458 |
+
|
459 |
+
slice_size = (
|
460 |
+
num_sliceable_layers * [slice_size]
|
461 |
+
if not isinstance(slice_size, list)
|
462 |
+
else slice_size
|
463 |
+
)
|
464 |
+
|
465 |
+
if len(slice_size) != len(sliceable_head_dims):
|
466 |
+
raise ValueError(
|
467 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
468 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
469 |
+
)
|
470 |
+
|
471 |
+
for i in range(len(slice_size)):
|
472 |
+
size = slice_size[i]
|
473 |
+
dim = sliceable_head_dims[i]
|
474 |
+
if size is not None and size > dim:
|
475 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
476 |
+
|
477 |
+
# Recursively walk through all the children.
|
478 |
+
# Any children which exposes the set_attention_slice method
|
479 |
+
# gets the message
|
480 |
+
def fn_recursive_set_attention_slice(
|
481 |
+
module: torch.nn.Module, slice_size: List[int]
|
482 |
+
):
|
483 |
+
if hasattr(module, "set_attention_slice"):
|
484 |
+
module.set_attention_slice(slice_size.pop())
|
485 |
+
|
486 |
+
for child in module.children():
|
487 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
488 |
+
|
489 |
+
reversed_slice_size = list(reversed(slice_size))
|
490 |
+
for module in self.children():
|
491 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
492 |
+
|
493 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
494 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
495 |
+
module.gradient_checkpointing = value
|
496 |
+
|
497 |
+
def forward(
|
498 |
+
self,
|
499 |
+
sample: torch.FloatTensor,
|
500 |
+
timestep: Union[torch.Tensor, float, int],
|
501 |
+
encoder_hidden_states: torch.Tensor,
|
502 |
+
controlnet_cond: torch.FloatTensor,
|
503 |
+
conditioning_mask: Optional[torch.FloatTensor] = None,
|
504 |
+
conditioning_scale: float = 1.0,
|
505 |
+
class_labels: Optional[torch.Tensor] = None,
|
506 |
+
attention_mask: Optional[torch.Tensor] = None,
|
507 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
508 |
+
guess_mode: bool = False,
|
509 |
+
return_dict: bool = True,
|
510 |
+
) -> Union[ImageControlNetOutput, Tuple]:
|
511 |
+
|
512 |
+
# set input noise to zero
|
513 |
+
if self.set_noisy_sample_input_to_zero:
|
514 |
+
sample = torch.zeros_like(sample).to(sample.device)
|
515 |
+
|
516 |
+
# prepare attention_mask
|
517 |
+
if attention_mask is not None:
|
518 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
519 |
+
attention_mask = attention_mask.unsqueeze(1)
|
520 |
+
|
521 |
+
# 1. time
|
522 |
+
timesteps = timestep
|
523 |
+
if not torch.is_tensor(timesteps):
|
524 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
525 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
526 |
+
is_mps = sample.device.type == "mps"
|
527 |
+
if isinstance(timestep, float):
|
528 |
+
dtype = torch.float32 if is_mps else torch.float64
|
529 |
+
else:
|
530 |
+
dtype = torch.int32 if is_mps else torch.int64
|
531 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
532 |
+
elif len(timesteps.shape) == 0:
|
533 |
+
timesteps = timesteps[None].to(sample.device)
|
534 |
+
|
535 |
+
timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
|
536 |
+
encoder_hidden_states = encoder_hidden_states.repeat(
|
537 |
+
sample.shape[0] // encoder_hidden_states.shape[0], 1, 1
|
538 |
+
)
|
539 |
+
|
540 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
541 |
+
timesteps = timesteps.expand(sample.shape[0])
|
542 |
+
|
543 |
+
t_emb = self.time_proj(timesteps)
|
544 |
+
|
545 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
546 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
547 |
+
# there might be better ways to encapsulate this.
|
548 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
549 |
+
emb = self.time_embedding(t_emb)
|
550 |
+
|
551 |
+
if self.class_embedding is not None:
|
552 |
+
if class_labels is None:
|
553 |
+
raise ValueError(
|
554 |
+
"class_labels should be provided when num_class_embeds > 0"
|
555 |
+
)
|
556 |
+
|
557 |
+
if self.config.class_embed_type == "timestep":
|
558 |
+
class_labels = self.time_proj(class_labels)
|
559 |
+
|
560 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
561 |
+
emb = emb + class_emb
|
562 |
+
|
563 |
+
# 2. pre-process
|
564 |
+
sample = self.conv_in(sample)
|
565 |
+
|
566 |
+
if self.concate_conditioning_mask:
|
567 |
+
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
|
568 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
569 |
+
|
570 |
+
sample = sample + controlnet_cond
|
571 |
+
|
572 |
+
# 3. down
|
573 |
+
down_block_res_samples = (sample,)
|
574 |
+
for downsample_block in self.down_blocks:
|
575 |
+
if (
|
576 |
+
hasattr(downsample_block, "has_cross_attention")
|
577 |
+
and downsample_block.has_cross_attention
|
578 |
+
):
|
579 |
+
sample, res_samples = downsample_block(
|
580 |
+
hidden_states=sample,
|
581 |
+
temb=emb,
|
582 |
+
encoder_hidden_states=encoder_hidden_states,
|
583 |
+
attention_mask=attention_mask,
|
584 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
588 |
+
|
589 |
+
down_block_res_samples += res_samples
|
590 |
+
|
591 |
+
# 4. mid
|
592 |
+
if self.mid_block is not None:
|
593 |
+
sample = self.mid_block(
|
594 |
+
sample,
|
595 |
+
emb,
|
596 |
+
encoder_hidden_states=encoder_hidden_states,
|
597 |
+
attention_mask=attention_mask,
|
598 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
599 |
+
)
|
600 |
+
|
601 |
+
# 5. controlnet blocks
|
602 |
+
controlnet_down_block_res_samples = ()
|
603 |
+
|
604 |
+
for down_block_res_sample, controlnet_block in zip(
|
605 |
+
down_block_res_samples, self.controlnet_down_blocks
|
606 |
+
):
|
607 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
608 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
|
609 |
+
down_block_res_sample,
|
610 |
+
)
|
611 |
+
|
612 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
613 |
+
|
614 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
615 |
+
|
616 |
+
# 6. scaling
|
617 |
+
if guess_mode and not self.config.global_pool_conditions:
|
618 |
+
scales = torch.logspace(
|
619 |
+
-1, 0, len(down_block_res_samples) + 1, device=sample.device
|
620 |
+
) # 0.1 to 1.0
|
621 |
+
|
622 |
+
scales = scales * conditioning_scale
|
623 |
+
down_block_res_samples = [
|
624 |
+
sample * scale for sample, scale in zip(down_block_res_samples, scales)
|
625 |
+
]
|
626 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
627 |
+
else:
|
628 |
+
down_block_res_samples = [
|
629 |
+
sample * conditioning_scale for sample in down_block_res_samples
|
630 |
+
]
|
631 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
632 |
+
|
633 |
+
if self.config.global_pool_conditions:
|
634 |
+
down_block_res_samples = [
|
635 |
+
torch.mean(sample, dim=(2, 3), keepdim=True)
|
636 |
+
for sample in down_block_res_samples
|
637 |
+
]
|
638 |
+
mid_block_res_sample = torch.mean(
|
639 |
+
mid_block_res_sample, dim=(2, 3), keepdim=True
|
640 |
+
)
|
641 |
+
|
642 |
+
if not return_dict:
|
643 |
+
return (down_block_res_samples, mid_block_res_sample)
|
644 |
+
|
645 |
+
return ImageControlNetOutput(
|
646 |
+
down_block_res_samples=down_block_res_samples,
|
647 |
+
mid_block_res_sample=mid_block_res_sample,
|
648 |
+
)
|
649 |
+
|
650 |
+
@property
|
651 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
652 |
+
r"""
|
653 |
+
Returns:
|
654 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
655 |
+
indexed by its weight name.
|
656 |
+
"""
|
657 |
+
# set recursively
|
658 |
+
processors = {}
|
659 |
+
|
660 |
+
def fn_recursive_add_processors(
|
661 |
+
name: str,
|
662 |
+
module: torch.nn.Module,
|
663 |
+
processors: Dict[str, AttentionProcessor],
|
664 |
+
):
|
665 |
+
if hasattr(module, "set_processor"):
|
666 |
+
processors[f"{name}.processor"] = module.processor
|
667 |
+
|
668 |
+
for sub_name, child in module.named_children():
|
669 |
+
if "temporal_transformer" not in sub_name:
|
670 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
671 |
+
|
672 |
+
return processors
|
673 |
+
|
674 |
+
for name, module in self.named_children():
|
675 |
+
if "temporal_transformer" not in name:
|
676 |
+
fn_recursive_add_processors(name, module, processors)
|
677 |
+
|
678 |
+
return processors
|
679 |
+
|
680 |
+
def set_attn_processor(
|
681 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
682 |
+
):
|
683 |
+
r"""
|
684 |
+
Sets the attention processor to use to compute attention.
|
685 |
+
Parameters:
|
686 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
687 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
688 |
+
for **all** `Attention` layers.
|
689 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
690 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
691 |
+
"""
|
692 |
+
count = len(self.attn_processors.keys())
|
693 |
+
|
694 |
+
if isinstance(processor, dict) and len(processor) != count:
|
695 |
+
raise ValueError(
|
696 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
697 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
698 |
+
)
|
699 |
+
|
700 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
701 |
+
if hasattr(module, "set_processor"):
|
702 |
+
if not isinstance(processor, dict):
|
703 |
+
module.set_processor(processor)
|
704 |
+
else:
|
705 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
706 |
+
|
707 |
+
for sub_name, child in module.named_children():
|
708 |
+
if "temporal_transformer" not in sub_name:
|
709 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
710 |
+
|
711 |
+
for name, module in self.named_children():
|
712 |
+
if "temporal_transformer" not in name:
|
713 |
+
fn_recursive_attn_processor(name, module, processor)
|
714 |
+
|
715 |
+
|
716 |
+
def zero_module(module):
|
717 |
+
for p in module.parameters():
|
718 |
+
nn.init.zeros_(p)
|
719 |
+
return module
|
720 |
+
|
721 |
+
|
modules/motion_module.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from diffusers.models.attention import Attention, FeedForward
|
8 |
+
from diffusers.utils import BaseOutput
|
9 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from torch import Tensor, nn
|
12 |
+
|
13 |
+
|
14 |
+
def zero_module(module):
|
15 |
+
# Zero out the parameters of a module and return it.
|
16 |
+
for p in module.parameters():
|
17 |
+
p.detach().zero_()
|
18 |
+
return module
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
23 |
+
sample: torch.FloatTensor
|
24 |
+
|
25 |
+
|
26 |
+
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
27 |
+
if motion_module_type == "Vanilla":
|
28 |
+
return VanillaTemporalModule(
|
29 |
+
in_channels=in_channels,
|
30 |
+
**motion_module_kwargs,
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
raise ValueError
|
34 |
+
|
35 |
+
|
36 |
+
class VanillaTemporalModule(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
in_channels,
|
40 |
+
num_attention_heads=8,
|
41 |
+
num_transformer_block=2,
|
42 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
43 |
+
cross_frame_attention_mode=None,
|
44 |
+
temporal_position_encoding=False,
|
45 |
+
temporal_position_encoding_max_len=24,
|
46 |
+
temporal_attention_dim_div=1,
|
47 |
+
zero_initialize=True,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
52 |
+
in_channels=in_channels,
|
53 |
+
num_attention_heads=num_attention_heads,
|
54 |
+
attention_head_dim=in_channels
|
55 |
+
// num_attention_heads
|
56 |
+
// temporal_attention_dim_div,
|
57 |
+
num_layers=num_transformer_block,
|
58 |
+
attention_block_types=attention_block_types,
|
59 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
60 |
+
temporal_position_encoding=temporal_position_encoding,
|
61 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
62 |
+
)
|
63 |
+
|
64 |
+
if zero_initialize:
|
65 |
+
self.temporal_transformer.proj_out = zero_module(
|
66 |
+
self.temporal_transformer.proj_out
|
67 |
+
)
|
68 |
+
self.skip_temporal_layers = False # Whether to skip temporal layer
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self,
|
72 |
+
input_tensor,
|
73 |
+
temb,
|
74 |
+
encoder_hidden_states,
|
75 |
+
attention_mask=None,
|
76 |
+
anchor_frame_idx=None,
|
77 |
+
):
|
78 |
+
if self.skip_temporal_layers is True:
|
79 |
+
return input_tensor
|
80 |
+
|
81 |
+
hidden_states = input_tensor
|
82 |
+
hidden_states = self.temporal_transformer(
|
83 |
+
hidden_states, encoder_hidden_states, attention_mask
|
84 |
+
)
|
85 |
+
|
86 |
+
output = hidden_states
|
87 |
+
return output
|
88 |
+
|
89 |
+
|
90 |
+
@maybe_allow_in_graph
|
91 |
+
class TemporalTransformer3DModel(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
in_channels,
|
95 |
+
num_attention_heads,
|
96 |
+
attention_head_dim,
|
97 |
+
num_layers,
|
98 |
+
attention_block_types=(
|
99 |
+
"Temporal_Self",
|
100 |
+
"Temporal_Self",
|
101 |
+
),
|
102 |
+
dropout=0.0,
|
103 |
+
norm_num_groups=32,
|
104 |
+
cross_attention_dim=768,
|
105 |
+
activation_fn="geglu",
|
106 |
+
attention_bias=False,
|
107 |
+
upcast_attention=False,
|
108 |
+
cross_frame_attention_mode=None,
|
109 |
+
temporal_position_encoding=False,
|
110 |
+
temporal_position_encoding_max_len=24,
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
inner_dim = num_attention_heads * attention_head_dim
|
115 |
+
|
116 |
+
self.norm = torch.nn.GroupNorm(
|
117 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
118 |
+
)
|
119 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
120 |
+
|
121 |
+
self.transformer_blocks = nn.ModuleList(
|
122 |
+
[
|
123 |
+
TemporalTransformerBlock(
|
124 |
+
dim=inner_dim,
|
125 |
+
num_attention_heads=num_attention_heads,
|
126 |
+
attention_head_dim=attention_head_dim,
|
127 |
+
attention_block_types=attention_block_types,
|
128 |
+
dropout=dropout,
|
129 |
+
norm_num_groups=norm_num_groups,
|
130 |
+
cross_attention_dim=cross_attention_dim,
|
131 |
+
activation_fn=activation_fn,
|
132 |
+
attention_bias=attention_bias,
|
133 |
+
upcast_attention=upcast_attention,
|
134 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
135 |
+
temporal_position_encoding=temporal_position_encoding,
|
136 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
137 |
+
)
|
138 |
+
for d in range(num_layers)
|
139 |
+
]
|
140 |
+
)
|
141 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
142 |
+
|
143 |
+
def forward(
|
144 |
+
self,
|
145 |
+
hidden_states: Tensor,
|
146 |
+
encoder_hidden_states: Optional[Tensor] = None,
|
147 |
+
attention_mask: Optional[Tensor] = None,
|
148 |
+
):
|
149 |
+
assert (
|
150 |
+
hidden_states.dim() == 5
|
151 |
+
), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
152 |
+
video_length = hidden_states.shape[2]
|
153 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
154 |
+
|
155 |
+
batch, channel, height, weight = hidden_states.shape
|
156 |
+
residual = hidden_states
|
157 |
+
|
158 |
+
hidden_states = self.norm(hidden_states)
|
159 |
+
inner_dim = hidden_states.shape[1]
|
160 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
161 |
+
batch, height * weight, inner_dim
|
162 |
+
)
|
163 |
+
hidden_states = self.proj_in(hidden_states)
|
164 |
+
|
165 |
+
# Transformer Blocks
|
166 |
+
for block in self.transformer_blocks:
|
167 |
+
hidden_states = block(
|
168 |
+
hidden_states,
|
169 |
+
encoder_hidden_states=encoder_hidden_states,
|
170 |
+
video_length=video_length,
|
171 |
+
)
|
172 |
+
|
173 |
+
# output
|
174 |
+
hidden_states = self.proj_out(hidden_states)
|
175 |
+
hidden_states = (
|
176 |
+
hidden_states.reshape(batch, height, weight, inner_dim)
|
177 |
+
.permute(0, 3, 1, 2)
|
178 |
+
.contiguous()
|
179 |
+
)
|
180 |
+
|
181 |
+
output = hidden_states + residual
|
182 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
183 |
+
|
184 |
+
return output
|
185 |
+
|
186 |
+
|
187 |
+
@maybe_allow_in_graph
|
188 |
+
class TemporalTransformerBlock(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
dim: int,
|
192 |
+
num_attention_heads: int,
|
193 |
+
attention_head_dim: int,
|
194 |
+
attention_block_types=(
|
195 |
+
"Temporal_Self",
|
196 |
+
"Temporal_Self",
|
197 |
+
),
|
198 |
+
dropout=0.0,
|
199 |
+
norm_num_groups: int = 32,
|
200 |
+
cross_attention_dim: int = 768,
|
201 |
+
activation_fn: str = "geglu",
|
202 |
+
attention_bias: bool = False,
|
203 |
+
upcast_attention: bool = False,
|
204 |
+
cross_frame_attention_mode=None,
|
205 |
+
temporal_position_encoding: bool = False,
|
206 |
+
temporal_position_encoding_max_len: int = 24,
|
207 |
+
):
|
208 |
+
super().__init__()
|
209 |
+
|
210 |
+
attention_blocks = []
|
211 |
+
norms = []
|
212 |
+
|
213 |
+
for block_name in attention_block_types:
|
214 |
+
attention_blocks.append(
|
215 |
+
VersatileAttention(
|
216 |
+
attention_mode=block_name.split("_")[0],
|
217 |
+
cross_attention_dim=(
|
218 |
+
cross_attention_dim if block_name.endswith("_Cross") else None
|
219 |
+
),
|
220 |
+
query_dim=dim,
|
221 |
+
heads=num_attention_heads,
|
222 |
+
dim_head=attention_head_dim,
|
223 |
+
dropout=dropout,
|
224 |
+
bias=attention_bias,
|
225 |
+
upcast_attention=upcast_attention,
|
226 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
227 |
+
temporal_position_encoding=temporal_position_encoding,
|
228 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
229 |
+
)
|
230 |
+
)
|
231 |
+
norms.append(nn.LayerNorm(dim))
|
232 |
+
|
233 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
234 |
+
self.norms = nn.ModuleList(norms)
|
235 |
+
|
236 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
237 |
+
self.ff_norm = nn.LayerNorm(dim)
|
238 |
+
|
239 |
+
def forward(
|
240 |
+
self,
|
241 |
+
hidden_states,
|
242 |
+
encoder_hidden_states=None,
|
243 |
+
attention_mask=None,
|
244 |
+
video_length=None,
|
245 |
+
):
|
246 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
247 |
+
norm_hidden_states = norm(hidden_states)
|
248 |
+
hidden_states = (
|
249 |
+
attention_block(
|
250 |
+
norm_hidden_states,
|
251 |
+
encoder_hidden_states=(
|
252 |
+
encoder_hidden_states
|
253 |
+
if attention_block.is_cross_attention
|
254 |
+
else None
|
255 |
+
),
|
256 |
+
video_length=video_length,
|
257 |
+
)
|
258 |
+
+ hidden_states
|
259 |
+
)
|
260 |
+
|
261 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
262 |
+
|
263 |
+
output = hidden_states
|
264 |
+
return output
|
265 |
+
|
266 |
+
|
267 |
+
class PositionalEncoding(nn.Module):
|
268 |
+
def __init__(self, d_model, dropout: float = 0.0, max_len: int = 24):
|
269 |
+
super().__init__()
|
270 |
+
self.dropout: nn.Module = nn.Dropout(p=dropout)
|
271 |
+
position = torch.arange(max_len).unsqueeze(1)
|
272 |
+
div_term = torch.exp(
|
273 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
274 |
+
)
|
275 |
+
pe: Tensor = torch.zeros(1, max_len, d_model)
|
276 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
277 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
278 |
+
self.register_buffer("pe", pe)
|
279 |
+
|
280 |
+
def forward(self, x: Tensor):
|
281 |
+
x = x + self.pe[:, : x.size(1)]
|
282 |
+
return self.dropout(x)
|
283 |
+
|
284 |
+
|
285 |
+
@maybe_allow_in_graph
|
286 |
+
class VersatileAttention(Attention):
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
attention_mode: str = None,
|
290 |
+
cross_frame_attention_mode: Optional[str] = None,
|
291 |
+
temporal_position_encoding: bool = False,
|
292 |
+
temporal_position_encoding_max_len: int = 24,
|
293 |
+
*args,
|
294 |
+
**kwargs,
|
295 |
+
):
|
296 |
+
super().__init__(*args, **kwargs)
|
297 |
+
if attention_mode.lower() != "temporal":
|
298 |
+
raise ValueError(f"Attention mode {attention_mode} is not supported.")
|
299 |
+
|
300 |
+
self.attention_mode = attention_mode
|
301 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
302 |
+
|
303 |
+
self.pos_encoder = (
|
304 |
+
PositionalEncoding(
|
305 |
+
kwargs["query_dim"],
|
306 |
+
dropout=0.0,
|
307 |
+
max_len=temporal_position_encoding_max_len,
|
308 |
+
)
|
309 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
310 |
+
else None
|
311 |
+
)
|
312 |
+
|
313 |
+
def extra_repr(self):
|
314 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
315 |
+
|
316 |
+
def forward(
|
317 |
+
self,
|
318 |
+
hidden_states: Tensor,
|
319 |
+
encoder_hidden_states=None,
|
320 |
+
attention_mask=None,
|
321 |
+
video_length=None,
|
322 |
+
):
|
323 |
+
if self.attention_mode == "Temporal":
|
324 |
+
d = hidden_states.shape[1]
|
325 |
+
hidden_states = rearrange(
|
326 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
327 |
+
)
|
328 |
+
|
329 |
+
if self.pos_encoder is not None:
|
330 |
+
hidden_states = self.pos_encoder(hidden_states)
|
331 |
+
|
332 |
+
encoder_hidden_states = (
|
333 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
334 |
+
if encoder_hidden_states is not None
|
335 |
+
else encoder_hidden_states
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
raise NotImplementedError
|
339 |
+
|
340 |
+
# attention processor makes this easy so that's nice
|
341 |
+
hidden_states = self.processor(
|
342 |
+
self, hidden_states, encoder_hidden_states, attention_mask
|
343 |
+
)
|
344 |
+
|
345 |
+
if self.attention_mode == "Temporal":
|
346 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
347 |
+
|
348 |
+
return hidden_states
|
349 |
+
|
350 |
+
def set_use_memory_efficient_attention_xformers(
|
351 |
+
self,
|
352 |
+
use_memory_efficient_attention_xformers: bool,
|
353 |
+
attention_op: Optional[Callable] = None,
|
354 |
+
):
|
355 |
+
return None
|
modules/resnet.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch import Tensor, nn
|
9 |
+
|
10 |
+
|
11 |
+
class InflatedConv3d(nn.Conv2d):
|
12 |
+
def forward(self, x: Tensor) -> Tensor:
|
13 |
+
ori_dim = x.ndim
|
14 |
+
if ori_dim == 5:
|
15 |
+
frames = x.shape[2]
|
16 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
17 |
+
x = F.conv2d(
|
18 |
+
x,
|
19 |
+
self.weight,
|
20 |
+
self.bias,
|
21 |
+
self.stride,
|
22 |
+
self.padding,
|
23 |
+
self.dilation,
|
24 |
+
self.groups,
|
25 |
+
)
|
26 |
+
if ori_dim == 5:
|
27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=frames)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
32 |
+
def forward(self, x):
|
33 |
+
video_length = x.shape[2]
|
34 |
+
|
35 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
36 |
+
x = super().forward(x)
|
37 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
38 |
+
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Upsample3D(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
channels: int,
|
46 |
+
use_conv: bool = False,
|
47 |
+
use_conv_transpose: bool = False,
|
48 |
+
out_channels: Optional[int] = None,
|
49 |
+
name="conv",
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
self.channels = channels
|
53 |
+
self.out_channels = out_channels or channels
|
54 |
+
self.use_conv = use_conv
|
55 |
+
self.use_conv_transpose = use_conv_transpose
|
56 |
+
self.name = name
|
57 |
+
|
58 |
+
if use_conv_transpose:
|
59 |
+
raise NotImplementedError
|
60 |
+
elif use_conv:
|
61 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
62 |
+
|
63 |
+
def forward(self, hidden_states: Tensor, output_size=None):
|
64 |
+
assert hidden_states.shape[1] == self.channels
|
65 |
+
|
66 |
+
if self.use_conv_transpose:
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
70 |
+
dtype = hidden_states.dtype
|
71 |
+
if dtype == torch.bfloat16:
|
72 |
+
hidden_states = hidden_states.to(torch.float32)
|
73 |
+
|
74 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
75 |
+
if hidden_states.shape[0] >= 64:
|
76 |
+
hidden_states = hidden_states.contiguous()
|
77 |
+
|
78 |
+
# if `output_size` is passed we force the interpolation output
|
79 |
+
# size and do not make use of `scale_factor=2`
|
80 |
+
if output_size is None:
|
81 |
+
hidden_states = F.interpolate(
|
82 |
+
hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
hidden_states = F.interpolate(
|
86 |
+
hidden_states, size=output_size, mode="nearest"
|
87 |
+
)
|
88 |
+
|
89 |
+
# If the input is bfloat16, we cast back to bfloat16
|
90 |
+
if dtype == torch.bfloat16:
|
91 |
+
hidden_states = hidden_states.to(dtype)
|
92 |
+
|
93 |
+
hidden_states = self.conv(hidden_states)
|
94 |
+
|
95 |
+
return hidden_states
|
96 |
+
|
97 |
+
|
98 |
+
class Downsample3D(nn.Module):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
channels: int,
|
102 |
+
use_conv: bool = False,
|
103 |
+
out_channels: Optional[int] = None,
|
104 |
+
padding: int = 1,
|
105 |
+
name="conv",
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
self.channels = channels
|
109 |
+
self.out_channels = out_channels or channels
|
110 |
+
self.use_conv = use_conv
|
111 |
+
self.padding = padding
|
112 |
+
stride = 2
|
113 |
+
self.name = name
|
114 |
+
|
115 |
+
if use_conv:
|
116 |
+
self.conv = InflatedConv3d(
|
117 |
+
self.channels, self.out_channels, 3, stride=stride, padding=padding
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
def forward(self, hidden_states):
|
123 |
+
assert hidden_states.shape[1] == self.channels
|
124 |
+
if self.use_conv and self.padding == 0:
|
125 |
+
raise NotImplementedError
|
126 |
+
|
127 |
+
assert hidden_states.shape[1] == self.channels
|
128 |
+
hidden_states = self.conv(hidden_states)
|
129 |
+
|
130 |
+
return hidden_states
|
131 |
+
|
132 |
+
|
133 |
+
class ResnetBlock3D(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
*,
|
137 |
+
in_channels,
|
138 |
+
out_channels=None,
|
139 |
+
conv_shortcut=False,
|
140 |
+
dropout=0.0,
|
141 |
+
temb_channels=512,
|
142 |
+
groups=32,
|
143 |
+
groups_out=None,
|
144 |
+
pre_norm=True,
|
145 |
+
eps=1e-6,
|
146 |
+
non_linearity="swish",
|
147 |
+
time_embedding_norm="default",
|
148 |
+
output_scale_factor=1.0,
|
149 |
+
use_in_shortcut=None,
|
150 |
+
use_inflated_groupnorm=None,
|
151 |
+
):
|
152 |
+
super().__init__()
|
153 |
+
self.pre_norm = pre_norm
|
154 |
+
self.pre_norm = True
|
155 |
+
self.in_channels = in_channels
|
156 |
+
out_channels = in_channels if out_channels is None else out_channels
|
157 |
+
self.out_channels = out_channels
|
158 |
+
self.use_conv_shortcut = conv_shortcut
|
159 |
+
self.time_embedding_norm = time_embedding_norm
|
160 |
+
self.output_scale_factor = output_scale_factor
|
161 |
+
|
162 |
+
if groups_out is None:
|
163 |
+
groups_out = groups
|
164 |
+
|
165 |
+
assert use_inflated_groupnorm != None
|
166 |
+
if use_inflated_groupnorm:
|
167 |
+
self.norm1 = InflatedGroupNorm(
|
168 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
self.norm1 = nn.GroupNorm(
|
172 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.conv1 = InflatedConv3d(
|
176 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
177 |
+
)
|
178 |
+
|
179 |
+
if temb_channels is not None:
|
180 |
+
if self.time_embedding_norm == "default":
|
181 |
+
time_emb_proj_out_channels = out_channels
|
182 |
+
elif self.time_embedding_norm == "scale_shift":
|
183 |
+
time_emb_proj_out_channels = out_channels * 2
|
184 |
+
else:
|
185 |
+
raise ValueError(
|
186 |
+
f"unknown time_embedding_norm : {self.time_embedding_norm} "
|
187 |
+
)
|
188 |
+
|
189 |
+
self.time_emb_proj = nn.Linear(temb_channels, time_emb_proj_out_channels)
|
190 |
+
else:
|
191 |
+
self.time_emb_proj = None
|
192 |
+
|
193 |
+
if use_inflated_groupnorm:
|
194 |
+
self.norm2 = InflatedGroupNorm(
|
195 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
self.norm2 = nn.GroupNorm(
|
199 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
|
200 |
+
)
|
201 |
+
|
202 |
+
self.dropout = nn.Dropout(dropout)
|
203 |
+
self.conv2 = InflatedConv3d(
|
204 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
205 |
+
)
|
206 |
+
|
207 |
+
if non_linearity == "swish":
|
208 |
+
self.nonlinearity = lambda x: F.silu(x)
|
209 |
+
elif non_linearity == "mish":
|
210 |
+
self.nonlinearity = Mish()
|
211 |
+
elif non_linearity == "silu":
|
212 |
+
self.nonlinearity = nn.SiLU()
|
213 |
+
|
214 |
+
self.use_in_shortcut = (
|
215 |
+
self.in_channels != self.out_channels
|
216 |
+
if use_in_shortcut is None
|
217 |
+
else use_in_shortcut
|
218 |
+
)
|
219 |
+
|
220 |
+
self.conv_shortcut = None
|
221 |
+
if self.use_in_shortcut:
|
222 |
+
self.conv_shortcut = InflatedConv3d(
|
223 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
224 |
+
)
|
225 |
+
|
226 |
+
def forward(self, input_tensor, temb):
|
227 |
+
hidden_states = input_tensor
|
228 |
+
|
229 |
+
hidden_states = self.norm1(hidden_states)
|
230 |
+
hidden_states = self.nonlinearity(hidden_states)
|
231 |
+
|
232 |
+
hidden_states = self.conv1(hidden_states)
|
233 |
+
|
234 |
+
if temb is not None:
|
235 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
236 |
+
|
237 |
+
if temb is not None and self.time_embedding_norm == "default":
|
238 |
+
hidden_states = hidden_states + temb
|
239 |
+
|
240 |
+
hidden_states = self.norm2(hidden_states)
|
241 |
+
|
242 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
243 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
244 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
245 |
+
|
246 |
+
hidden_states = self.nonlinearity(hidden_states)
|
247 |
+
|
248 |
+
hidden_states = self.dropout(hidden_states)
|
249 |
+
hidden_states = self.conv2(hidden_states)
|
250 |
+
|
251 |
+
if self.conv_shortcut is not None:
|
252 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
253 |
+
|
254 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
255 |
+
|
256 |
+
return output_tensor
|
257 |
+
|
258 |
+
|
259 |
+
class Mish(nn.Module):
|
260 |
+
def forward(self, hidden_states):
|
261 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
modules/unet.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from os import PathLike
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
13 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin
|
14 |
+
from diffusers.models import ModelMixin
|
15 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
16 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
17 |
+
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
18 |
+
from safetensors.torch import load_file
|
19 |
+
from torch import Tensor, nn
|
20 |
+
|
21 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
22 |
+
from .unet_blocks import (
|
23 |
+
CrossAttnDownBlock3D,
|
24 |
+
CrossAttnUpBlock3D,
|
25 |
+
DownBlock3D,
|
26 |
+
UNetMidBlock3DCrossAttn,
|
27 |
+
UpBlock3D,
|
28 |
+
get_down_block,
|
29 |
+
get_up_block,
|
30 |
+
)
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class UNet3DConditionFlowModelOutput(BaseOutput):
|
37 |
+
sample: torch.FloatTensor
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
class UNet3DConditionFlowModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
42 |
+
_supports_gradient_checkpointing = True
|
43 |
+
|
44 |
+
@register_to_config
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
sample_size: Optional[int] = None,
|
48 |
+
in_channels: int = 4,
|
49 |
+
out_channels: int = 4,
|
50 |
+
center_input_sample: bool = False,
|
51 |
+
flip_sin_to_cos: bool = True,
|
52 |
+
freq_shift: int = 0,
|
53 |
+
down_block_types: Tuple[str] = (
|
54 |
+
"CrossAttnDownBlock3D",
|
55 |
+
"CrossAttnDownBlock3D",
|
56 |
+
"CrossAttnDownBlock3D",
|
57 |
+
"DownBlock3D",
|
58 |
+
),
|
59 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
60 |
+
up_block_types: Tuple[str] = (
|
61 |
+
"UpBlock3D",
|
62 |
+
"CrossAttnUpBlock3D",
|
63 |
+
"CrossAttnUpBlock3D",
|
64 |
+
"CrossAttnUpBlock3D"
|
65 |
+
),
|
66 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
67 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
68 |
+
layers_per_block: int = 2,
|
69 |
+
downsample_padding: int = 1,
|
70 |
+
mid_block_scale_factor: float = 1,
|
71 |
+
act_fn: str = "silu",
|
72 |
+
norm_num_groups: int = 32,
|
73 |
+
norm_eps: float = 1e-5,
|
74 |
+
cross_attention_dim: int = 1280,
|
75 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
76 |
+
dual_cross_attention: bool = False,
|
77 |
+
use_linear_projection: bool = False,
|
78 |
+
class_embed_type: Optional[str] = None,
|
79 |
+
num_class_embeds: Optional[int] = None,
|
80 |
+
upcast_attention: bool = False,
|
81 |
+
resnet_time_scale_shift: str = "default",
|
82 |
+
|
83 |
+
use_inflated_groupnorm=False,
|
84 |
+
|
85 |
+
# Additional
|
86 |
+
use_motion_module = False,
|
87 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
88 |
+
motion_module_mid_block = False,
|
89 |
+
motion_module_decoder_only = False,
|
90 |
+
motion_module_type = None,
|
91 |
+
motion_module_kwargs = {},
|
92 |
+
unet_use_cross_frame_attention = False,
|
93 |
+
unet_use_temporal_attention = False,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.sample_size = sample_size
|
98 |
+
time_embed_dim = block_out_channels[0] * 4
|
99 |
+
|
100 |
+
# input
|
101 |
+
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
102 |
+
|
103 |
+
# time
|
104 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
105 |
+
timestep_input_dim = block_out_channels[0]
|
106 |
+
|
107 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
108 |
+
|
109 |
+
# class embedding
|
110 |
+
if class_embed_type is None and num_class_embeds is not None:
|
111 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
112 |
+
elif class_embed_type == "timestep":
|
113 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
114 |
+
elif class_embed_type == "identity":
|
115 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
116 |
+
else:
|
117 |
+
self.class_embedding = None
|
118 |
+
|
119 |
+
self.down_blocks = nn.ModuleList([])
|
120 |
+
self.mid_block = None
|
121 |
+
self.up_blocks = nn.ModuleList([])
|
122 |
+
|
123 |
+
if isinstance(only_cross_attention, bool):
|
124 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
125 |
+
|
126 |
+
if isinstance(attention_head_dim, int):
|
127 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
128 |
+
|
129 |
+
# down
|
130 |
+
output_channel = block_out_channels[0]
|
131 |
+
for i, down_block_type in enumerate(down_block_types):
|
132 |
+
res = 2 ** i
|
133 |
+
input_channel = output_channel
|
134 |
+
output_channel = block_out_channels[i]
|
135 |
+
is_final_block = i == len(block_out_channels) - 1
|
136 |
+
|
137 |
+
down_block = get_down_block(
|
138 |
+
down_block_type,
|
139 |
+
num_layers=layers_per_block,
|
140 |
+
in_channels=input_channel,
|
141 |
+
out_channels=output_channel,
|
142 |
+
temb_channels=time_embed_dim,
|
143 |
+
add_downsample=not is_final_block,
|
144 |
+
resnet_eps=norm_eps,
|
145 |
+
resnet_act_fn=act_fn,
|
146 |
+
resnet_groups=norm_num_groups,
|
147 |
+
cross_attention_dim=cross_attention_dim,
|
148 |
+
attn_num_head_channels=attention_head_dim[i],
|
149 |
+
downsample_padding=downsample_padding,
|
150 |
+
dual_cross_attention=dual_cross_attention,
|
151 |
+
use_linear_projection=use_linear_projection,
|
152 |
+
only_cross_attention=only_cross_attention[i],
|
153 |
+
upcast_attention=upcast_attention,
|
154 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
155 |
+
|
156 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
157 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
158 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
159 |
+
|
160 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
161 |
+
motion_module_type=motion_module_type,
|
162 |
+
motion_module_kwargs=motion_module_kwargs,
|
163 |
+
)
|
164 |
+
self.down_blocks.append(down_block)
|
165 |
+
|
166 |
+
# mid
|
167 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
168 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
169 |
+
in_channels=block_out_channels[-1],
|
170 |
+
temb_channels=time_embed_dim,
|
171 |
+
resnet_eps=norm_eps,
|
172 |
+
resnet_act_fn=act_fn,
|
173 |
+
output_scale_factor=mid_block_scale_factor,
|
174 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
175 |
+
cross_attention_dim=cross_attention_dim,
|
176 |
+
attn_num_head_channels=attention_head_dim[-1],
|
177 |
+
resnet_groups=norm_num_groups,
|
178 |
+
dual_cross_attention=dual_cross_attention,
|
179 |
+
use_linear_projection=use_linear_projection,
|
180 |
+
upcast_attention=upcast_attention,
|
181 |
+
|
182 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
183 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
184 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
185 |
+
|
186 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
187 |
+
motion_module_type=motion_module_type,
|
188 |
+
motion_module_kwargs=motion_module_kwargs,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
192 |
+
|
193 |
+
# count how many layers upsample the videos
|
194 |
+
self.num_upsamplers = 0
|
195 |
+
|
196 |
+
# up
|
197 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
198 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
199 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
200 |
+
output_channel = reversed_block_out_channels[0]
|
201 |
+
for i, up_block_type in enumerate(up_block_types):
|
202 |
+
res = 2 ** (3 - i)
|
203 |
+
is_final_block = i == len(block_out_channels) - 1
|
204 |
+
|
205 |
+
prev_output_channel = output_channel
|
206 |
+
output_channel = reversed_block_out_channels[i]
|
207 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
208 |
+
|
209 |
+
# add upsample block for all BUT final layer
|
210 |
+
if not is_final_block:
|
211 |
+
add_upsample = True
|
212 |
+
self.num_upsamplers += 1
|
213 |
+
else:
|
214 |
+
add_upsample = False
|
215 |
+
|
216 |
+
up_block = get_up_block(
|
217 |
+
up_block_type,
|
218 |
+
num_layers=layers_per_block + 1,
|
219 |
+
in_channels=input_channel,
|
220 |
+
out_channels=output_channel,
|
221 |
+
prev_output_channel=prev_output_channel,
|
222 |
+
temb_channels=time_embed_dim,
|
223 |
+
add_upsample=add_upsample,
|
224 |
+
resnet_eps=norm_eps,
|
225 |
+
resnet_act_fn=act_fn,
|
226 |
+
resnet_groups=norm_num_groups,
|
227 |
+
cross_attention_dim=cross_attention_dim,
|
228 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
229 |
+
dual_cross_attention=dual_cross_attention,
|
230 |
+
use_linear_projection=use_linear_projection,
|
231 |
+
only_cross_attention=only_cross_attention[i],
|
232 |
+
upcast_attention=upcast_attention,
|
233 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
234 |
+
|
235 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
236 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
237 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
238 |
+
|
239 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
240 |
+
motion_module_type=motion_module_type,
|
241 |
+
motion_module_kwargs=motion_module_kwargs,
|
242 |
+
)
|
243 |
+
self.up_blocks.append(up_block)
|
244 |
+
prev_output_channel = output_channel
|
245 |
+
|
246 |
+
# out
|
247 |
+
if use_inflated_groupnorm:
|
248 |
+
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
249 |
+
else:
|
250 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
251 |
+
self.conv_act = nn.SiLU()
|
252 |
+
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
253 |
+
|
254 |
+
|
255 |
+
def set_attention_slice(self, slice_size):
|
256 |
+
r"""
|
257 |
+
Enable sliced attention computation.
|
258 |
+
|
259 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
260 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
264 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
265 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
266 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
267 |
+
must be a multiple of `slice_size`.
|
268 |
+
"""
|
269 |
+
sliceable_head_dims = []
|
270 |
+
|
271 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
272 |
+
if hasattr(module, "set_attention_slice"):
|
273 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
274 |
+
|
275 |
+
for child in module.children():
|
276 |
+
fn_recursive_retrieve_slicable_dims(child)
|
277 |
+
|
278 |
+
# retrieve number of attention layers
|
279 |
+
for module in self.children():
|
280 |
+
fn_recursive_retrieve_slicable_dims(module)
|
281 |
+
|
282 |
+
num_slicable_layers = len(sliceable_head_dims)
|
283 |
+
|
284 |
+
if slice_size == "auto":
|
285 |
+
# half the attention head size is usually a good trade-off between
|
286 |
+
# speed and memory
|
287 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
288 |
+
elif slice_size == "max":
|
289 |
+
# make smallest slice possible
|
290 |
+
slice_size = num_slicable_layers * [1]
|
291 |
+
|
292 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
293 |
+
|
294 |
+
if len(slice_size) != len(sliceable_head_dims):
|
295 |
+
raise ValueError(
|
296 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
297 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
298 |
+
)
|
299 |
+
|
300 |
+
for i in range(len(slice_size)):
|
301 |
+
size = slice_size[i]
|
302 |
+
dim = sliceable_head_dims[i]
|
303 |
+
if size is not None and size > dim:
|
304 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
305 |
+
|
306 |
+
# Recursively walk through all the children.
|
307 |
+
# Any children which exposes the set_attention_slice method
|
308 |
+
# gets the message
|
309 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
310 |
+
if hasattr(module, "set_attention_slice"):
|
311 |
+
module.set_attention_slice(slice_size.pop())
|
312 |
+
|
313 |
+
for child in module.children():
|
314 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
315 |
+
|
316 |
+
reversed_slice_size = list(reversed(slice_size))
|
317 |
+
for module in self.children():
|
318 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
319 |
+
|
320 |
+
|
321 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
322 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
323 |
+
module.gradient_checkpointing = value
|
324 |
+
|
325 |
+
|
326 |
+
def get_image_controlnet(self, controlnet_noisy_latents, timesteps,
|
327 |
+
encoder_hidden_states=None,
|
328 |
+
controlnet_cond=None,
|
329 |
+
conditioning_mask=None,
|
330 |
+
conditioning_scale=None,
|
331 |
+
guess_mode=False,
|
332 |
+
return_dict=False,):
|
333 |
+
down_block_additional_residuals, mid_block_additional_residual = self.image_controlnet(
|
334 |
+
controlnet_noisy_latents, timesteps,
|
335 |
+
encoder_hidden_states=encoder_hidden_states,
|
336 |
+
controlnet_cond=controlnet_cond,
|
337 |
+
conditioning_mask=conditioning_mask,
|
338 |
+
conditioning_scale=conditioning_scale,
|
339 |
+
guess_mode=guess_mode,
|
340 |
+
return_dict=return_dict,
|
341 |
+
)
|
342 |
+
return down_block_additional_residuals, mid_block_additional_residual
|
343 |
+
|
344 |
+
|
345 |
+
def get_flow_controlnet(self, controlnet_noisy_latents, timesteps,
|
346 |
+
encoder_hidden_states=None,
|
347 |
+
controlnet_cond=None,
|
348 |
+
conditioning_mask=None,
|
349 |
+
conditioning_scale=None,
|
350 |
+
guess_mode=False,
|
351 |
+
return_dict=False,):
|
352 |
+
down_block_additional_residuals, mid_block_additional_residual = self.omcm_controlnet(
|
353 |
+
controlnet_noisy_latents, timesteps,
|
354 |
+
encoder_hidden_states=encoder_hidden_states,
|
355 |
+
controlnet_cond=controlnet_cond,
|
356 |
+
conditioning_mask=conditioning_mask,
|
357 |
+
conditioning_scale=conditioning_scale,
|
358 |
+
guess_mode=guess_mode,
|
359 |
+
return_dict=return_dict,
|
360 |
+
)
|
361 |
+
return down_block_additional_residuals, mid_block_additional_residual
|
362 |
+
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
sample: torch.FloatTensor,
|
367 |
+
timestep: Union[torch.Tensor, float, int],
|
368 |
+
encoder_hidden_states: torch.Tensor,
|
369 |
+
class_labels: Optional[torch.Tensor] = None,
|
370 |
+
attention_mask: Optional[torch.Tensor] = None,
|
371 |
+
|
372 |
+
# support image controlnet
|
373 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
374 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
375 |
+
|
376 |
+
# support flow controlnet
|
377 |
+
flow_down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
378 |
+
flow_mid_block_additional_residual: Optional[torch.Tensor] = None,
|
379 |
+
|
380 |
+
return_dict: bool = True,
|
381 |
+
) -> Union[UNet3DConditionFlowModelOutput, Tuple]:
|
382 |
+
r"""
|
383 |
+
Args:
|
384 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
385 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
386 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
387 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
388 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
392 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
393 |
+
returning a tuple, the first element is the sample tensor.
|
394 |
+
"""
|
395 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
396 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
397 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
398 |
+
# on the fly if necessary.
|
399 |
+
|
400 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
401 |
+
|
402 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
403 |
+
forward_upsample_size = False
|
404 |
+
upsample_size = None
|
405 |
+
|
406 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
407 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
408 |
+
forward_upsample_size = True
|
409 |
+
|
410 |
+
# prepare attention_mask
|
411 |
+
if attention_mask is not None:
|
412 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
413 |
+
attention_mask = attention_mask.unsqueeze(1)
|
414 |
+
|
415 |
+
# center input if necessary
|
416 |
+
if self.config.center_input_sample:
|
417 |
+
sample = 2 * sample - 1.0
|
418 |
+
|
419 |
+
# time
|
420 |
+
timesteps = timestep
|
421 |
+
if not torch.is_tensor(timesteps):
|
422 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
423 |
+
is_mps = sample.device.type == "mps"
|
424 |
+
if isinstance(timestep, float):
|
425 |
+
dtype = torch.float32 if is_mps else torch.float64
|
426 |
+
else:
|
427 |
+
dtype = torch.int32 if is_mps else torch.int64
|
428 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
429 |
+
elif len(timesteps.shape) == 0:
|
430 |
+
timesteps = timesteps[None].to(sample.device)
|
431 |
+
|
432 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
433 |
+
timesteps = timesteps.expand(sample.shape[0])
|
434 |
+
|
435 |
+
t_emb = self.time_proj(timesteps)
|
436 |
+
|
437 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
438 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
439 |
+
# there might be better ways to encapsulate this.
|
440 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
441 |
+
emb = self.time_embedding(t_emb)
|
442 |
+
|
443 |
+
if self.class_embedding is not None:
|
444 |
+
if class_labels is None:
|
445 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
446 |
+
|
447 |
+
if self.config.class_embed_type == "timestep":
|
448 |
+
class_labels = self.time_proj(class_labels)
|
449 |
+
|
450 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
451 |
+
emb = emb + class_emb
|
452 |
+
|
453 |
+
# pre-process
|
454 |
+
sample = self.conv_in(sample)
|
455 |
+
|
456 |
+
# down
|
457 |
+
down_block_res_samples = (sample,)
|
458 |
+
for downsample_block in self.down_blocks:
|
459 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
460 |
+
sample, res_samples = downsample_block(
|
461 |
+
hidden_states=sample,
|
462 |
+
temb=emb,
|
463 |
+
encoder_hidden_states=encoder_hidden_states,
|
464 |
+
attention_mask=attention_mask,
|
465 |
+
)
|
466 |
+
else:
|
467 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
|
468 |
+
|
469 |
+
down_block_res_samples += res_samples
|
470 |
+
|
471 |
+
# support controlnet
|
472 |
+
# image controlnet
|
473 |
+
down_block_res_samples = list(down_block_res_samples)
|
474 |
+
if down_block_additional_residuals is not None:
|
475 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
476 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
477 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
478 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
479 |
+
|
480 |
+
# flow controlnet
|
481 |
+
if flow_down_block_additional_residuals is not None:
|
482 |
+
for i, down_block_additional_residual in enumerate(flow_down_block_additional_residuals):
|
483 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
484 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
485 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
486 |
+
|
487 |
+
|
488 |
+
# mid
|
489 |
+
sample = self.mid_block(
|
490 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
491 |
+
)
|
492 |
+
|
493 |
+
# support controlnet
|
494 |
+
# image controlnet
|
495 |
+
if mid_block_additional_residual is not None:
|
496 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
497 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
498 |
+
sample = sample + mid_block_additional_residual
|
499 |
+
|
500 |
+
# flow controlnet
|
501 |
+
if flow_mid_block_additional_residual is not None:
|
502 |
+
if flow_mid_block_additional_residual.dim() == 4: # boardcast
|
503 |
+
flow_mid_block_additional_residual = flow_mid_block_additional_residual.unsqueeze(2)
|
504 |
+
sample = sample + flow_mid_block_additional_residual
|
505 |
+
|
506 |
+
|
507 |
+
# up
|
508 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
509 |
+
is_final_block = i == len(self.up_blocks) - 1
|
510 |
+
|
511 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
512 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
513 |
+
|
514 |
+
# if we have not reached the final block and need to forward the
|
515 |
+
# upsample size, we do it here
|
516 |
+
if not is_final_block and forward_upsample_size:
|
517 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
518 |
+
|
519 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
520 |
+
sample = upsample_block(
|
521 |
+
hidden_states=sample,
|
522 |
+
temb=emb,
|
523 |
+
res_hidden_states_tuple=res_samples,
|
524 |
+
encoder_hidden_states=encoder_hidden_states,
|
525 |
+
upsample_size=upsample_size,
|
526 |
+
attention_mask=attention_mask,
|
527 |
+
)
|
528 |
+
else:
|
529 |
+
sample = upsample_block(
|
530 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
|
531 |
+
)
|
532 |
+
|
533 |
+
# post-process
|
534 |
+
sample = self.conv_norm_out(sample)
|
535 |
+
sample = self.conv_act(sample)
|
536 |
+
sample = self.conv_out(sample)
|
537 |
+
|
538 |
+
if not return_dict:
|
539 |
+
return (sample,)
|
540 |
+
|
541 |
+
return UNet3DConditionFlowModelOutput(sample=sample)
|
542 |
+
|
543 |
+
@classmethod
|
544 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
545 |
+
if subfolder is not None:
|
546 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
547 |
+
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
548 |
+
|
549 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
550 |
+
if not os.path.isfile(config_file):
|
551 |
+
raise RuntimeError(f"{config_file} does not exist")
|
552 |
+
with open(config_file, "r") as f:
|
553 |
+
config = json.load(f)
|
554 |
+
config["_class_name"] = cls.__name__
|
555 |
+
config["down_block_types"] = [
|
556 |
+
"CrossAttnDownBlock3D",
|
557 |
+
"CrossAttnDownBlock3D",
|
558 |
+
"CrossAttnDownBlock3D",
|
559 |
+
"DownBlock3D"
|
560 |
+
]
|
561 |
+
config["up_block_types"] = [
|
562 |
+
"UpBlock3D",
|
563 |
+
"CrossAttnUpBlock3D",
|
564 |
+
"CrossAttnUpBlock3D",
|
565 |
+
"CrossAttnUpBlock3D"
|
566 |
+
]
|
567 |
+
|
568 |
+
from diffusers.utils import WEIGHTS_NAME
|
569 |
+
model = cls.from_config(config, **unet_additional_kwargs)
|
570 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
571 |
+
if not os.path.isfile(model_file):
|
572 |
+
raise RuntimeError(f"{model_file} does not exist")
|
573 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
574 |
+
|
575 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
576 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
577 |
+
|
578 |
+
motion_params = [p.numel() if "motion_modules." in n else 0 for n,p in model.named_parameters()]
|
579 |
+
motion_name = [n for n in model.state_dict().keys() if "motion_modules." in n]
|
580 |
+
|
581 |
+
print(f"### Motion Module Parameters: {sum(motion_params) / 1e6} M")
|
582 |
+
print(f"### Motion Module keys: {len(motion_name)}")
|
583 |
+
|
584 |
+
unnorlmal = []
|
585 |
+
for n in m:
|
586 |
+
if n not in motion_name:
|
587 |
+
unnorlmal.append(n)
|
588 |
+
|
589 |
+
return model
|
590 |
+
|
591 |
+
'motion_modules.' in 'up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe'
|
modules/unet_blocks.py
ADDED
@@ -0,0 +1,866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from .attention import Transformer3DModel
|
9 |
+
from .motion_module import get_motion_module
|
10 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
11 |
+
|
12 |
+
|
13 |
+
def get_down_block(
|
14 |
+
down_block_type,
|
15 |
+
num_layers,
|
16 |
+
in_channels,
|
17 |
+
out_channels,
|
18 |
+
temb_channels,
|
19 |
+
add_downsample,
|
20 |
+
resnet_eps,
|
21 |
+
resnet_act_fn,
|
22 |
+
attn_num_head_channels,
|
23 |
+
resnet_groups=None,
|
24 |
+
cross_attention_dim=None,
|
25 |
+
downsample_padding=None,
|
26 |
+
dual_cross_attention=False,
|
27 |
+
use_linear_projection=False,
|
28 |
+
only_cross_attention=False,
|
29 |
+
upcast_attention=False,
|
30 |
+
resnet_time_scale_shift="default",
|
31 |
+
unet_use_cross_frame_attention=False,
|
32 |
+
unet_use_temporal_attention=False,
|
33 |
+
use_inflated_groupnorm=False,
|
34 |
+
use_motion_module=None,
|
35 |
+
motion_module_type=None,
|
36 |
+
motion_module_kwargs=None,
|
37 |
+
):
|
38 |
+
down_block_type = (
|
39 |
+
down_block_type[7:]
|
40 |
+
if down_block_type.startswith("UNetRes")
|
41 |
+
else down_block_type
|
42 |
+
)
|
43 |
+
if down_block_type == "DownBlock3D":
|
44 |
+
return DownBlock3D(
|
45 |
+
num_layers=num_layers,
|
46 |
+
in_channels=in_channels,
|
47 |
+
out_channels=out_channels,
|
48 |
+
temb_channels=temb_channels,
|
49 |
+
add_downsample=add_downsample,
|
50 |
+
resnet_eps=resnet_eps,
|
51 |
+
resnet_act_fn=resnet_act_fn,
|
52 |
+
resnet_groups=resnet_groups,
|
53 |
+
downsample_padding=downsample_padding,
|
54 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
55 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
56 |
+
use_motion_module=use_motion_module,
|
57 |
+
motion_module_type=motion_module_type,
|
58 |
+
motion_module_kwargs=motion_module_kwargs,
|
59 |
+
)
|
60 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
61 |
+
if cross_attention_dim is None:
|
62 |
+
raise ValueError(
|
63 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock3D"
|
64 |
+
)
|
65 |
+
return CrossAttnDownBlock3D(
|
66 |
+
num_layers=num_layers,
|
67 |
+
in_channels=in_channels,
|
68 |
+
out_channels=out_channels,
|
69 |
+
temb_channels=temb_channels,
|
70 |
+
add_downsample=add_downsample,
|
71 |
+
resnet_eps=resnet_eps,
|
72 |
+
resnet_act_fn=resnet_act_fn,
|
73 |
+
resnet_groups=resnet_groups,
|
74 |
+
downsample_padding=downsample_padding,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
attn_num_head_channels=attn_num_head_channels,
|
77 |
+
dual_cross_attention=dual_cross_attention,
|
78 |
+
use_linear_projection=use_linear_projection,
|
79 |
+
only_cross_attention=only_cross_attention,
|
80 |
+
upcast_attention=upcast_attention,
|
81 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
82 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
83 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
84 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
85 |
+
use_motion_module=use_motion_module,
|
86 |
+
motion_module_type=motion_module_type,
|
87 |
+
motion_module_kwargs=motion_module_kwargs,
|
88 |
+
)
|
89 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
90 |
+
|
91 |
+
|
92 |
+
def get_up_block(
|
93 |
+
up_block_type,
|
94 |
+
num_layers,
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
prev_output_channel,
|
98 |
+
temb_channels,
|
99 |
+
add_upsample,
|
100 |
+
resnet_eps,
|
101 |
+
resnet_act_fn,
|
102 |
+
attn_num_head_channels,
|
103 |
+
resnet_groups=None,
|
104 |
+
cross_attention_dim=None,
|
105 |
+
dual_cross_attention=False,
|
106 |
+
use_linear_projection=False,
|
107 |
+
only_cross_attention=False,
|
108 |
+
upcast_attention=False,
|
109 |
+
resnet_time_scale_shift="default",
|
110 |
+
unet_use_cross_frame_attention=False,
|
111 |
+
unet_use_temporal_attention=False,
|
112 |
+
use_inflated_groupnorm=False,
|
113 |
+
use_motion_module=None,
|
114 |
+
motion_module_type=None,
|
115 |
+
motion_module_kwargs=None,
|
116 |
+
):
|
117 |
+
up_block_type = (
|
118 |
+
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
119 |
+
)
|
120 |
+
if up_block_type == "UpBlock3D":
|
121 |
+
return UpBlock3D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
prev_output_channel=prev_output_channel,
|
126 |
+
temb_channels=temb_channels,
|
127 |
+
add_upsample=add_upsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
133 |
+
use_motion_module=use_motion_module,
|
134 |
+
motion_module_type=motion_module_type,
|
135 |
+
motion_module_kwargs=motion_module_kwargs,
|
136 |
+
)
|
137 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
138 |
+
if cross_attention_dim is None:
|
139 |
+
raise ValueError(
|
140 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock3D"
|
141 |
+
)
|
142 |
+
return CrossAttnUpBlock3D(
|
143 |
+
num_layers=num_layers,
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
prev_output_channel=prev_output_channel,
|
147 |
+
temb_channels=temb_channels,
|
148 |
+
add_upsample=add_upsample,
|
149 |
+
resnet_eps=resnet_eps,
|
150 |
+
resnet_act_fn=resnet_act_fn,
|
151 |
+
resnet_groups=resnet_groups,
|
152 |
+
cross_attention_dim=cross_attention_dim,
|
153 |
+
attn_num_head_channels=attn_num_head_channels,
|
154 |
+
dual_cross_attention=dual_cross_attention,
|
155 |
+
use_linear_projection=use_linear_projection,
|
156 |
+
only_cross_attention=only_cross_attention,
|
157 |
+
upcast_attention=upcast_attention,
|
158 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
159 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
160 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
161 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
162 |
+
use_motion_module=use_motion_module,
|
163 |
+
motion_module_type=motion_module_type,
|
164 |
+
motion_module_kwargs=motion_module_kwargs,
|
165 |
+
)
|
166 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
167 |
+
|
168 |
+
|
169 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
in_channels: int,
|
174 |
+
temb_channels: int,
|
175 |
+
dropout: float = 0.0,
|
176 |
+
num_layers: int = 1,
|
177 |
+
resnet_eps: float = 1e-6,
|
178 |
+
resnet_time_scale_shift: str = "default",
|
179 |
+
resnet_act_fn: str = "swish",
|
180 |
+
resnet_groups: int = 32,
|
181 |
+
resnet_pre_norm: bool = True,
|
182 |
+
attn_num_head_channels=1,
|
183 |
+
output_scale_factor=1.0,
|
184 |
+
cross_attention_dim=1280,
|
185 |
+
dual_cross_attention=False,
|
186 |
+
use_linear_projection=False,
|
187 |
+
upcast_attention=False,
|
188 |
+
unet_use_cross_frame_attention=False,
|
189 |
+
unet_use_temporal_attention=False,
|
190 |
+
use_inflated_groupnorm=False,
|
191 |
+
use_motion_module=None,
|
192 |
+
motion_module_type=None,
|
193 |
+
motion_module_kwargs=None,
|
194 |
+
):
|
195 |
+
super().__init__()
|
196 |
+
|
197 |
+
self.has_cross_attention = True
|
198 |
+
self.attn_num_head_channels = attn_num_head_channels
|
199 |
+
resnet_groups = (
|
200 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
201 |
+
)
|
202 |
+
|
203 |
+
# there is always at least one resnet
|
204 |
+
resnets = [
|
205 |
+
ResnetBlock3D(
|
206 |
+
in_channels=in_channels,
|
207 |
+
out_channels=in_channels,
|
208 |
+
temb_channels=temb_channels,
|
209 |
+
eps=resnet_eps,
|
210 |
+
groups=resnet_groups,
|
211 |
+
dropout=dropout,
|
212 |
+
time_embedding_norm=resnet_time_scale_shift,
|
213 |
+
non_linearity=resnet_act_fn,
|
214 |
+
output_scale_factor=output_scale_factor,
|
215 |
+
pre_norm=resnet_pre_norm,
|
216 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
217 |
+
)
|
218 |
+
]
|
219 |
+
attentions = []
|
220 |
+
motion_modules = []
|
221 |
+
|
222 |
+
for _ in range(num_layers):
|
223 |
+
if dual_cross_attention:
|
224 |
+
raise NotImplementedError
|
225 |
+
attentions.append(
|
226 |
+
Transformer3DModel(
|
227 |
+
attn_num_head_channels,
|
228 |
+
in_channels // attn_num_head_channels,
|
229 |
+
in_channels=in_channels,
|
230 |
+
num_layers=1,
|
231 |
+
cross_attention_dim=cross_attention_dim,
|
232 |
+
norm_num_groups=resnet_groups,
|
233 |
+
use_linear_projection=use_linear_projection,
|
234 |
+
upcast_attention=upcast_attention,
|
235 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
236 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
237 |
+
)
|
238 |
+
)
|
239 |
+
motion_modules.append(
|
240 |
+
get_motion_module(
|
241 |
+
in_channels=in_channels,
|
242 |
+
motion_module_type=motion_module_type,
|
243 |
+
motion_module_kwargs=motion_module_kwargs,
|
244 |
+
)
|
245 |
+
if use_motion_module
|
246 |
+
else None
|
247 |
+
)
|
248 |
+
resnets.append(
|
249 |
+
ResnetBlock3D(
|
250 |
+
in_channels=in_channels,
|
251 |
+
out_channels=in_channels,
|
252 |
+
temb_channels=temb_channels,
|
253 |
+
eps=resnet_eps,
|
254 |
+
groups=resnet_groups,
|
255 |
+
dropout=dropout,
|
256 |
+
time_embedding_norm=resnet_time_scale_shift,
|
257 |
+
non_linearity=resnet_act_fn,
|
258 |
+
output_scale_factor=output_scale_factor,
|
259 |
+
pre_norm=resnet_pre_norm,
|
260 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
261 |
+
)
|
262 |
+
)
|
263 |
+
|
264 |
+
self.attentions = nn.ModuleList(attentions)
|
265 |
+
self.resnets = nn.ModuleList(resnets)
|
266 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
hidden_states: torch.FloatTensor,
|
271 |
+
temb: Optional[torch.FloatTensor] = None,
|
272 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
273 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
274 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
275 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
276 |
+
) -> torch.FloatTensor:
|
277 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
278 |
+
for attn, resnet, motion_module in zip(
|
279 |
+
self.attentions, self.resnets[1:], self.motion_modules
|
280 |
+
):
|
281 |
+
hidden_states = attn(
|
282 |
+
hidden_states,
|
283 |
+
encoder_hidden_states=encoder_hidden_states,
|
284 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
285 |
+
attention_mask=attention_mask,
|
286 |
+
encoder_attention_mask=encoder_attention_mask,
|
287 |
+
return_dict=False,
|
288 |
+
)[0]
|
289 |
+
if motion_module is not None:
|
290 |
+
hidden_states = motion_module(
|
291 |
+
hidden_states,
|
292 |
+
temb,
|
293 |
+
encoder_hidden_states=encoder_hidden_states,
|
294 |
+
)
|
295 |
+
hidden_states = resnet(hidden_states, temb)
|
296 |
+
|
297 |
+
return hidden_states
|
298 |
+
|
299 |
+
|
300 |
+
class CrossAttnDownBlock3D(nn.Module):
|
301 |
+
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
in_channels: int,
|
305 |
+
out_channels: int,
|
306 |
+
temb_channels: int,
|
307 |
+
dropout: float = 0.0,
|
308 |
+
num_layers: int = 1,
|
309 |
+
transformer_layers_per_block: int = 1,
|
310 |
+
resnet_eps: float = 1e-6,
|
311 |
+
resnet_time_scale_shift: str = "default",
|
312 |
+
resnet_act_fn: str = "swish",
|
313 |
+
resnet_groups: int = 32,
|
314 |
+
resnet_pre_norm: bool = True,
|
315 |
+
attn_num_head_channels=1,
|
316 |
+
cross_attention_dim=1280,
|
317 |
+
output_scale_factor=1.0,
|
318 |
+
downsample_padding=1,
|
319 |
+
add_downsample=True,
|
320 |
+
dual_cross_attention=False,
|
321 |
+
use_linear_projection=False,
|
322 |
+
only_cross_attention=False,
|
323 |
+
upcast_attention=False,
|
324 |
+
unet_use_cross_frame_attention=False,
|
325 |
+
unet_use_temporal_attention=False,
|
326 |
+
use_inflated_groupnorm=False,
|
327 |
+
use_motion_module=None,
|
328 |
+
motion_module_type=None,
|
329 |
+
motion_module_kwargs=None,
|
330 |
+
):
|
331 |
+
super().__init__()
|
332 |
+
resnets = []
|
333 |
+
attentions = []
|
334 |
+
motion_modules = []
|
335 |
+
|
336 |
+
self.has_cross_attention = True
|
337 |
+
self.attn_num_head_channels = attn_num_head_channels
|
338 |
+
|
339 |
+
for i in range(num_layers):
|
340 |
+
in_channels = in_channels if i == 0 else out_channels
|
341 |
+
resnets.append(
|
342 |
+
ResnetBlock3D(
|
343 |
+
in_channels=in_channels,
|
344 |
+
out_channels=out_channels,
|
345 |
+
temb_channels=temb_channels,
|
346 |
+
eps=resnet_eps,
|
347 |
+
groups=resnet_groups,
|
348 |
+
dropout=dropout,
|
349 |
+
time_embedding_norm=resnet_time_scale_shift,
|
350 |
+
non_linearity=resnet_act_fn,
|
351 |
+
output_scale_factor=output_scale_factor,
|
352 |
+
pre_norm=resnet_pre_norm,
|
353 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
354 |
+
)
|
355 |
+
)
|
356 |
+
if dual_cross_attention:
|
357 |
+
raise NotImplementedError
|
358 |
+
attentions.append(
|
359 |
+
Transformer3DModel(
|
360 |
+
num_attention_heads=attn_num_head_channels,
|
361 |
+
attention_head_dim=out_channels // attn_num_head_channels,
|
362 |
+
in_channels=out_channels,
|
363 |
+
num_layers=transformer_layers_per_block,
|
364 |
+
cross_attention_dim=cross_attention_dim,
|
365 |
+
norm_num_groups=resnet_groups,
|
366 |
+
use_linear_projection=use_linear_projection,
|
367 |
+
only_cross_attention=only_cross_attention,
|
368 |
+
upcast_attention=upcast_attention,
|
369 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
370 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
371 |
+
)
|
372 |
+
)
|
373 |
+
motion_modules.append(
|
374 |
+
get_motion_module(
|
375 |
+
in_channels=out_channels,
|
376 |
+
motion_module_type=motion_module_type,
|
377 |
+
motion_module_kwargs=motion_module_kwargs,
|
378 |
+
)
|
379 |
+
if use_motion_module
|
380 |
+
else None
|
381 |
+
)
|
382 |
+
|
383 |
+
self.attentions = nn.ModuleList(attentions)
|
384 |
+
self.resnets = nn.ModuleList(resnets)
|
385 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
386 |
+
|
387 |
+
if add_downsample:
|
388 |
+
self.downsamplers = nn.ModuleList(
|
389 |
+
[
|
390 |
+
Downsample3D(
|
391 |
+
out_channels,
|
392 |
+
use_conv=True,
|
393 |
+
out_channels=out_channels,
|
394 |
+
padding=downsample_padding,
|
395 |
+
name="op",
|
396 |
+
)
|
397 |
+
]
|
398 |
+
)
|
399 |
+
else:
|
400 |
+
self.downsamplers = None
|
401 |
+
|
402 |
+
self.gradient_checkpointing = False
|
403 |
+
|
404 |
+
def forward(
|
405 |
+
self,
|
406 |
+
hidden_states: torch.FloatTensor,
|
407 |
+
temb: Optional[torch.FloatTensor] = None,
|
408 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
409 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
410 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
411 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
412 |
+
) -> torch.FloatTensor:
|
413 |
+
output_states = ()
|
414 |
+
|
415 |
+
for resnet, attn, motion_module in zip(
|
416 |
+
self.resnets, self.attentions, self.motion_modules
|
417 |
+
):
|
418 |
+
if self.training and self.gradient_checkpointing:
|
419 |
+
|
420 |
+
def create_custom_forward(module, return_dict=None):
|
421 |
+
def custom_forward(*inputs):
|
422 |
+
if return_dict is not None:
|
423 |
+
return module(*inputs, return_dict=return_dict)
|
424 |
+
else:
|
425 |
+
return module(*inputs)
|
426 |
+
|
427 |
+
return custom_forward
|
428 |
+
|
429 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
430 |
+
create_custom_forward(resnet), hidden_states, temb
|
431 |
+
)
|
432 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
433 |
+
create_custom_forward(attn, return_dict=False),
|
434 |
+
hidden_states,
|
435 |
+
encoder_hidden_states,
|
436 |
+
)[0]
|
437 |
+
if motion_module is not None:
|
438 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
439 |
+
create_custom_forward(motion_module),
|
440 |
+
hidden_states.requires_grad_(),
|
441 |
+
temb,
|
442 |
+
encoder_hidden_states,
|
443 |
+
)
|
444 |
+
|
445 |
+
else:
|
446 |
+
hidden_states = resnet(hidden_states, temb)
|
447 |
+
hidden_states = attn(
|
448 |
+
hidden_states,
|
449 |
+
encoder_hidden_states=encoder_hidden_states,
|
450 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
451 |
+
attention_mask=attention_mask,
|
452 |
+
encoder_attention_mask=encoder_attention_mask,
|
453 |
+
return_dict=False,
|
454 |
+
)[0]
|
455 |
+
# add motion module
|
456 |
+
hidden_states = (
|
457 |
+
motion_module(
|
458 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states
|
459 |
+
)
|
460 |
+
if motion_module is not None
|
461 |
+
else hidden_states
|
462 |
+
)
|
463 |
+
|
464 |
+
output_states = output_states + (hidden_states,)
|
465 |
+
|
466 |
+
if self.downsamplers is not None:
|
467 |
+
for downsampler in self.downsamplers:
|
468 |
+
hidden_states = downsampler(hidden_states)
|
469 |
+
|
470 |
+
output_states = output_states + (hidden_states,)
|
471 |
+
|
472 |
+
return hidden_states, output_states
|
473 |
+
|
474 |
+
|
475 |
+
class DownBlock3D(nn.Module):
|
476 |
+
def __init__(
|
477 |
+
self,
|
478 |
+
in_channels: int,
|
479 |
+
out_channels: int,
|
480 |
+
temb_channels: int,
|
481 |
+
dropout: float = 0.0,
|
482 |
+
num_layers: int = 1,
|
483 |
+
resnet_eps: float = 1e-6,
|
484 |
+
resnet_time_scale_shift: str = "default",
|
485 |
+
resnet_act_fn: str = "swish",
|
486 |
+
resnet_groups: int = 32,
|
487 |
+
resnet_pre_norm: bool = True,
|
488 |
+
output_scale_factor=1.0,
|
489 |
+
add_downsample=True,
|
490 |
+
downsample_padding=1,
|
491 |
+
use_inflated_groupnorm=None,
|
492 |
+
use_motion_module=None,
|
493 |
+
motion_module_type=None,
|
494 |
+
motion_module_kwargs=None,
|
495 |
+
):
|
496 |
+
super().__init__()
|
497 |
+
resnets = []
|
498 |
+
motion_modules = []
|
499 |
+
|
500 |
+
for i in range(num_layers):
|
501 |
+
in_channels = in_channels if i == 0 else out_channels
|
502 |
+
resnets.append(
|
503 |
+
ResnetBlock3D(
|
504 |
+
in_channels=in_channels,
|
505 |
+
out_channels=out_channels,
|
506 |
+
temb_channels=temb_channels,
|
507 |
+
eps=resnet_eps,
|
508 |
+
groups=resnet_groups,
|
509 |
+
dropout=dropout,
|
510 |
+
time_embedding_norm=resnet_time_scale_shift,
|
511 |
+
non_linearity=resnet_act_fn,
|
512 |
+
output_scale_factor=output_scale_factor,
|
513 |
+
pre_norm=resnet_pre_norm,
|
514 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
515 |
+
)
|
516 |
+
)
|
517 |
+
motion_modules.append(
|
518 |
+
get_motion_module(
|
519 |
+
in_channels=out_channels,
|
520 |
+
motion_module_type=motion_module_type,
|
521 |
+
motion_module_kwargs=motion_module_kwargs,
|
522 |
+
)
|
523 |
+
if use_motion_module
|
524 |
+
else None
|
525 |
+
)
|
526 |
+
|
527 |
+
self.resnets = nn.ModuleList(resnets)
|
528 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
529 |
+
|
530 |
+
if add_downsample:
|
531 |
+
self.downsamplers = nn.ModuleList(
|
532 |
+
[
|
533 |
+
Downsample3D(
|
534 |
+
out_channels,
|
535 |
+
use_conv=True,
|
536 |
+
out_channels=out_channels,
|
537 |
+
padding=downsample_padding,
|
538 |
+
name="op",
|
539 |
+
)
|
540 |
+
]
|
541 |
+
)
|
542 |
+
else:
|
543 |
+
self.downsamplers = None
|
544 |
+
|
545 |
+
self.gradient_checkpointing = False
|
546 |
+
|
547 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
548 |
+
output_states = ()
|
549 |
+
|
550 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
551 |
+
if self.training and self.gradient_checkpointing:
|
552 |
+
|
553 |
+
def create_custom_forward(module):
|
554 |
+
def custom_forward(*inputs):
|
555 |
+
return module(*inputs)
|
556 |
+
|
557 |
+
return custom_forward
|
558 |
+
|
559 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
560 |
+
create_custom_forward(resnet), hidden_states, temb
|
561 |
+
)
|
562 |
+
if motion_module is not None:
|
563 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
564 |
+
create_custom_forward(motion_module),
|
565 |
+
hidden_states.requires_grad_(),
|
566 |
+
temb,
|
567 |
+
encoder_hidden_states,
|
568 |
+
)
|
569 |
+
else:
|
570 |
+
hidden_states = resnet(hidden_states, temb)
|
571 |
+
|
572 |
+
# add motion module
|
573 |
+
if motion_module:
|
574 |
+
hidden_states = motion_module(
|
575 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states
|
576 |
+
)
|
577 |
+
|
578 |
+
output_states = output_states + (hidden_states,)
|
579 |
+
|
580 |
+
if self.downsamplers is not None:
|
581 |
+
for downsampler in self.downsamplers:
|
582 |
+
hidden_states = downsampler(hidden_states)
|
583 |
+
|
584 |
+
output_states = output_states + (hidden_states,)
|
585 |
+
|
586 |
+
return hidden_states, output_states
|
587 |
+
|
588 |
+
|
589 |
+
class CrossAttnUpBlock3D(nn.Module):
|
590 |
+
|
591 |
+
def __init__(
|
592 |
+
self,
|
593 |
+
in_channels: int,
|
594 |
+
out_channels: int,
|
595 |
+
prev_output_channel: int,
|
596 |
+
temb_channels: int,
|
597 |
+
dropout: float = 0.0,
|
598 |
+
num_layers: int = 1,
|
599 |
+
transformer_layers_per_block: int = 1,
|
600 |
+
resnet_eps: float = 1e-6,
|
601 |
+
resnet_time_scale_shift: str = "default",
|
602 |
+
resnet_act_fn: str = "swish",
|
603 |
+
resnet_groups: int = 32,
|
604 |
+
resnet_pre_norm: bool = True,
|
605 |
+
attn_num_head_channels=1,
|
606 |
+
cross_attention_dim=1280,
|
607 |
+
output_scale_factor=1.0,
|
608 |
+
add_upsample=True,
|
609 |
+
dual_cross_attention=False,
|
610 |
+
use_linear_projection=False,
|
611 |
+
only_cross_attention=False,
|
612 |
+
upcast_attention=False,
|
613 |
+
unet_use_cross_frame_attention=False,
|
614 |
+
unet_use_temporal_attention=False,
|
615 |
+
use_inflated_groupnorm=False,
|
616 |
+
use_motion_module=None,
|
617 |
+
motion_module_type=None,
|
618 |
+
motion_module_kwargs=None,
|
619 |
+
):
|
620 |
+
super().__init__()
|
621 |
+
resnets = []
|
622 |
+
attentions = []
|
623 |
+
motion_modules = []
|
624 |
+
|
625 |
+
self.has_cross_attention = True
|
626 |
+
self.attn_num_head_channels = attn_num_head_channels
|
627 |
+
|
628 |
+
for i in range(num_layers):
|
629 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
630 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
631 |
+
|
632 |
+
resnets.append(
|
633 |
+
ResnetBlock3D(
|
634 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
635 |
+
out_channels=out_channels,
|
636 |
+
temb_channels=temb_channels,
|
637 |
+
eps=resnet_eps,
|
638 |
+
groups=resnet_groups,
|
639 |
+
dropout=dropout,
|
640 |
+
time_embedding_norm=resnet_time_scale_shift,
|
641 |
+
non_linearity=resnet_act_fn,
|
642 |
+
output_scale_factor=output_scale_factor,
|
643 |
+
pre_norm=resnet_pre_norm,
|
644 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
645 |
+
)
|
646 |
+
)
|
647 |
+
if dual_cross_attention:
|
648 |
+
raise NotImplementedError
|
649 |
+
attentions.append(
|
650 |
+
Transformer3DModel(
|
651 |
+
attn_num_head_channels,
|
652 |
+
out_channels // attn_num_head_channels,
|
653 |
+
in_channels=out_channels,
|
654 |
+
num_layers=transformer_layers_per_block,
|
655 |
+
cross_attention_dim=cross_attention_dim,
|
656 |
+
norm_num_groups=resnet_groups,
|
657 |
+
use_linear_projection=use_linear_projection,
|
658 |
+
only_cross_attention=only_cross_attention,
|
659 |
+
upcast_attention=upcast_attention,
|
660 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
661 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
662 |
+
)
|
663 |
+
)
|
664 |
+
motion_modules.append(
|
665 |
+
get_motion_module(
|
666 |
+
in_channels=out_channels,
|
667 |
+
motion_module_type=motion_module_type,
|
668 |
+
motion_module_kwargs=motion_module_kwargs,
|
669 |
+
)
|
670 |
+
if use_motion_module
|
671 |
+
else None
|
672 |
+
)
|
673 |
+
|
674 |
+
self.attentions = nn.ModuleList(attentions)
|
675 |
+
self.resnets = nn.ModuleList(resnets)
|
676 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
677 |
+
|
678 |
+
if add_upsample:
|
679 |
+
self.upsamplers = nn.ModuleList(
|
680 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
|
681 |
+
)
|
682 |
+
else:
|
683 |
+
self.upsamplers = None
|
684 |
+
|
685 |
+
self.gradient_checkpointing = False
|
686 |
+
|
687 |
+
def forward(
|
688 |
+
self,
|
689 |
+
hidden_states: torch.FloatTensor,
|
690 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
691 |
+
temb: Optional[torch.FloatTensor] = None,
|
692 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
693 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
694 |
+
upsample_size: Optional[int] = None,
|
695 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
696 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
697 |
+
):
|
698 |
+
for resnet, attn, motion_module in zip(
|
699 |
+
self.resnets, self.attentions, self.motion_modules
|
700 |
+
):
|
701 |
+
# pop res hidden states
|
702 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
703 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
704 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
705 |
+
|
706 |
+
if self.training and self.gradient_checkpointing:
|
707 |
+
|
708 |
+
def create_custom_forward(module, return_dict=None):
|
709 |
+
def custom_forward(*inputs):
|
710 |
+
if return_dict is not None:
|
711 |
+
return module(*inputs, return_dict=return_dict)
|
712 |
+
else:
|
713 |
+
return module(*inputs)
|
714 |
+
|
715 |
+
return custom_forward
|
716 |
+
|
717 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
718 |
+
create_custom_forward(resnet), hidden_states, temb
|
719 |
+
)
|
720 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
721 |
+
create_custom_forward(attn, return_dict=False),
|
722 |
+
hidden_states,
|
723 |
+
encoder_hidden_states,
|
724 |
+
)[0]
|
725 |
+
if motion_module is not None:
|
726 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
727 |
+
create_custom_forward(motion_module),
|
728 |
+
hidden_states.requires_grad_(),
|
729 |
+
temb,
|
730 |
+
encoder_hidden_states,
|
731 |
+
)
|
732 |
+
|
733 |
+
else:
|
734 |
+
hidden_states = resnet(hidden_states, temb)
|
735 |
+
hidden_states = attn(
|
736 |
+
hidden_states,
|
737 |
+
encoder_hidden_states=encoder_hidden_states,
|
738 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
739 |
+
attention_mask=attention_mask,
|
740 |
+
encoder_attention_mask=encoder_attention_mask,
|
741 |
+
return_dict=False,
|
742 |
+
)[0]
|
743 |
+
|
744 |
+
# add motion module
|
745 |
+
if motion_module:
|
746 |
+
hidden_states = motion_module(
|
747 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states
|
748 |
+
)
|
749 |
+
|
750 |
+
if self.upsamplers is not None:
|
751 |
+
for upsampler in self.upsamplers:
|
752 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
753 |
+
|
754 |
+
return hidden_states
|
755 |
+
|
756 |
+
|
757 |
+
class UpBlock3D(nn.Module):
|
758 |
+
def __init__(
|
759 |
+
self,
|
760 |
+
in_channels: int,
|
761 |
+
prev_output_channel: int,
|
762 |
+
out_channels: int,
|
763 |
+
temb_channels: int,
|
764 |
+
dropout: float = 0.0,
|
765 |
+
num_layers: int = 1,
|
766 |
+
resnet_eps: float = 1e-6,
|
767 |
+
resnet_time_scale_shift: str = "default",
|
768 |
+
resnet_act_fn: str = "swish",
|
769 |
+
resnet_groups: int = 32,
|
770 |
+
resnet_pre_norm: bool = True,
|
771 |
+
output_scale_factor=1.0,
|
772 |
+
add_upsample=True,
|
773 |
+
use_inflated_groupnorm=None,
|
774 |
+
use_motion_module=None,
|
775 |
+
motion_module_type=None,
|
776 |
+
motion_module_kwargs=None,
|
777 |
+
):
|
778 |
+
super().__init__()
|
779 |
+
resnets = []
|
780 |
+
motion_modules = []
|
781 |
+
|
782 |
+
for i in range(num_layers):
|
783 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
784 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
785 |
+
|
786 |
+
resnets.append(
|
787 |
+
ResnetBlock3D(
|
788 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
789 |
+
out_channels=out_channels,
|
790 |
+
temb_channels=temb_channels,
|
791 |
+
eps=resnet_eps,
|
792 |
+
groups=resnet_groups,
|
793 |
+
dropout=dropout,
|
794 |
+
time_embedding_norm=resnet_time_scale_shift,
|
795 |
+
non_linearity=resnet_act_fn,
|
796 |
+
output_scale_factor=output_scale_factor,
|
797 |
+
pre_norm=resnet_pre_norm,
|
798 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
799 |
+
)
|
800 |
+
)
|
801 |
+
motion_modules.append(
|
802 |
+
get_motion_module(
|
803 |
+
in_channels=out_channels,
|
804 |
+
motion_module_type=motion_module_type,
|
805 |
+
motion_module_kwargs=motion_module_kwargs,
|
806 |
+
)
|
807 |
+
if use_motion_module
|
808 |
+
else None
|
809 |
+
)
|
810 |
+
|
811 |
+
self.resnets = nn.ModuleList(resnets)
|
812 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
813 |
+
|
814 |
+
if add_upsample:
|
815 |
+
self.upsamplers = nn.ModuleList(
|
816 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
|
817 |
+
)
|
818 |
+
else:
|
819 |
+
self.upsamplers = None
|
820 |
+
|
821 |
+
self.gradient_checkpointing = False
|
822 |
+
|
823 |
+
def forward(
|
824 |
+
self,
|
825 |
+
hidden_states,
|
826 |
+
res_hidden_states_tuple,
|
827 |
+
temb=None,
|
828 |
+
upsample_size=None,
|
829 |
+
encoder_hidden_states=None,
|
830 |
+
):
|
831 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
832 |
+
# pop res hidden states
|
833 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
834 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
835 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
836 |
+
|
837 |
+
if self.training and self.gradient_checkpointing:
|
838 |
+
|
839 |
+
def create_custom_forward(module):
|
840 |
+
def custom_forward(*inputs):
|
841 |
+
return module(*inputs)
|
842 |
+
|
843 |
+
return custom_forward
|
844 |
+
|
845 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
846 |
+
create_custom_forward(resnet), hidden_states, temb
|
847 |
+
)
|
848 |
+
if motion_module is not None:
|
849 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
850 |
+
create_custom_forward(motion_module),
|
851 |
+
hidden_states.requires_grad_(),
|
852 |
+
temb,
|
853 |
+
encoder_hidden_states,
|
854 |
+
)
|
855 |
+
else:
|
856 |
+
hidden_states = resnet(hidden_states, temb)
|
857 |
+
if motion_module:
|
858 |
+
hidden_states = motion_module(
|
859 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states
|
860 |
+
)
|
861 |
+
|
862 |
+
if self.upsamplers is not None:
|
863 |
+
for upsampler in self.upsamplers:
|
864 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
865 |
+
|
866 |
+
return hidden_states
|
peft/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
3 |
+
# module, but to preserve other warnings. So, don't check this module at all.
|
4 |
+
|
5 |
+
# coding=utf-8
|
6 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
__version__ = "0.11.1"
|
21 |
+
|
22 |
+
from .auto import (
|
23 |
+
AutoPeftModel,
|
24 |
+
AutoPeftModelForCausalLM,
|
25 |
+
AutoPeftModelForSequenceClassification,
|
26 |
+
AutoPeftModelForSeq2SeqLM,
|
27 |
+
AutoPeftModelForTokenClassification,
|
28 |
+
AutoPeftModelForQuestionAnswering,
|
29 |
+
AutoPeftModelForFeatureExtraction,
|
30 |
+
)
|
31 |
+
from .mapping import (
|
32 |
+
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
|
33 |
+
PEFT_TYPE_TO_CONFIG_MAPPING,
|
34 |
+
get_peft_config,
|
35 |
+
get_peft_model,
|
36 |
+
inject_adapter_in_model,
|
37 |
+
)
|
38 |
+
from .mixed_model import PeftMixedModel
|
39 |
+
from .peft_model import (
|
40 |
+
PeftModel,
|
41 |
+
PeftModelForCausalLM,
|
42 |
+
PeftModelForSeq2SeqLM,
|
43 |
+
PeftModelForSequenceClassification,
|
44 |
+
PeftModelForTokenClassification,
|
45 |
+
PeftModelForQuestionAnswering,
|
46 |
+
PeftModelForFeatureExtraction,
|
47 |
+
get_layer_status,
|
48 |
+
get_model_status,
|
49 |
+
)
|
50 |
+
from .tuners import (
|
51 |
+
AdaptionPromptConfig,
|
52 |
+
AdaptionPromptModel,
|
53 |
+
LoraConfig,
|
54 |
+
LoftQConfig,
|
55 |
+
LoraModel,
|
56 |
+
LoHaConfig,
|
57 |
+
LoHaModel,
|
58 |
+
LoKrConfig,
|
59 |
+
LoKrModel,
|
60 |
+
IA3Config,
|
61 |
+
IA3Model,
|
62 |
+
AdaLoraConfig,
|
63 |
+
AdaLoraModel,
|
64 |
+
BOFTConfig,
|
65 |
+
BOFTModel,
|
66 |
+
PrefixEncoder,
|
67 |
+
PrefixTuningConfig,
|
68 |
+
PromptEmbedding,
|
69 |
+
PromptEncoder,
|
70 |
+
PromptEncoderConfig,
|
71 |
+
PromptEncoderReparameterizationType,
|
72 |
+
PromptTuningConfig,
|
73 |
+
PromptTuningInit,
|
74 |
+
MultitaskPromptTuningConfig,
|
75 |
+
MultitaskPromptTuningInit,
|
76 |
+
OFTConfig,
|
77 |
+
OFTModel,
|
78 |
+
PolyConfig,
|
79 |
+
PolyModel,
|
80 |
+
LNTuningConfig,
|
81 |
+
LNTuningModel,
|
82 |
+
VeraConfig,
|
83 |
+
VeraModel,
|
84 |
+
)
|
85 |
+
from .utils import (
|
86 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
87 |
+
PeftType,
|
88 |
+
TaskType,
|
89 |
+
bloom_model_postprocess_past_key_value,
|
90 |
+
get_peft_model_state_dict,
|
91 |
+
prepare_model_for_kbit_training,
|
92 |
+
replace_lora_weights_loftq,
|
93 |
+
set_peft_model_state_dict,
|
94 |
+
shift_tokens_right,
|
95 |
+
load_peft_weights,
|
96 |
+
cast_mixed_precision_params,
|
97 |
+
)
|
98 |
+
from .config import PeftConfig, PromptLearningConfig
|
peft/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.37 kB). View file
|
|
peft/__pycache__/auto.cpython-310.pyc
ADDED
Binary file (4.84 kB). View file
|
|
peft/__pycache__/config.cpython-310.pyc
ADDED
Binary file (8.79 kB). View file
|
|
peft/__pycache__/import_utils.cpython-310.pyc
ADDED
Binary file (2.18 kB). View file
|
|
peft/__pycache__/mapping.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
peft/__pycache__/mixed_model.cpython-310.pyc
ADDED
Binary file (14.8 kB). View file
|
|
peft/__pycache__/peft_model.cpython-310.pyc
ADDED
Binary file (71.9 kB). View file
|
|
peft/auto.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import importlib
|
18 |
+
import os
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
from transformers import (
|
22 |
+
AutoModel,
|
23 |
+
AutoModelForCausalLM,
|
24 |
+
AutoModelForQuestionAnswering,
|
25 |
+
AutoModelForSeq2SeqLM,
|
26 |
+
AutoModelForSequenceClassification,
|
27 |
+
AutoModelForTokenClassification,
|
28 |
+
AutoTokenizer,
|
29 |
+
)
|
30 |
+
|
31 |
+
from .config import PeftConfig
|
32 |
+
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
|
33 |
+
from .peft_model import (
|
34 |
+
PeftModel,
|
35 |
+
PeftModelForCausalLM,
|
36 |
+
PeftModelForFeatureExtraction,
|
37 |
+
PeftModelForQuestionAnswering,
|
38 |
+
PeftModelForSeq2SeqLM,
|
39 |
+
PeftModelForSequenceClassification,
|
40 |
+
PeftModelForTokenClassification,
|
41 |
+
)
|
42 |
+
from .utils.constants import TOKENIZER_CONFIG_NAME
|
43 |
+
from .utils.other import check_file_exists_on_hf_hub
|
44 |
+
|
45 |
+
|
46 |
+
class _BaseAutoPeftModel:
|
47 |
+
_target_class = None
|
48 |
+
_target_peft_class = None
|
49 |
+
|
50 |
+
def __init__(self, *args, **kwargs):
|
51 |
+
# For consistency with transformers: https://github.com/huggingface/transformers/blob/91d7df58b6537d385e90578dac40204cb550f706/src/transformers/models/auto/auto_factory.py#L400
|
52 |
+
raise EnvironmentError( # noqa: UP024
|
53 |
+
f"{self.__class__.__name__} is designed to be instantiated "
|
54 |
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
55 |
+
f"`{self.__class__.__name__}.from_config(config)` methods."
|
56 |
+
)
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def from_pretrained(
|
60 |
+
cls,
|
61 |
+
pretrained_model_name_or_path,
|
62 |
+
adapter_name: str = "default",
|
63 |
+
is_trainable: bool = False,
|
64 |
+
config: Optional[PeftConfig] = None,
|
65 |
+
**kwargs,
|
66 |
+
):
|
67 |
+
r"""
|
68 |
+
A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs
|
69 |
+
are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and
|
70 |
+
the config object init.
|
71 |
+
"""
|
72 |
+
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
73 |
+
base_model_path = peft_config.base_model_name_or_path
|
74 |
+
|
75 |
+
task_type = getattr(peft_config, "task_type", None)
|
76 |
+
|
77 |
+
if cls._target_class is not None:
|
78 |
+
target_class = cls._target_class
|
79 |
+
elif cls._target_class is None and task_type is not None:
|
80 |
+
# this is only in the case where we use `AutoPeftModel`
|
81 |
+
raise ValueError(
|
82 |
+
"Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)"
|
83 |
+
)
|
84 |
+
|
85 |
+
if task_type is not None:
|
86 |
+
expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type]
|
87 |
+
if cls._target_peft_class.__name__ != expected_target_class.__name__:
|
88 |
+
raise ValueError(
|
89 |
+
f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }"
|
90 |
+
" make sure that you are loading the correct model for your task type."
|
91 |
+
)
|
92 |
+
elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None:
|
93 |
+
auto_mapping = getattr(peft_config, "auto_mapping", None)
|
94 |
+
base_model_class = auto_mapping["base_model_class"]
|
95 |
+
parent_library_name = auto_mapping["parent_library"]
|
96 |
+
|
97 |
+
parent_library = importlib.import_module(parent_library_name)
|
98 |
+
target_class = getattr(parent_library, base_model_class)
|
99 |
+
else:
|
100 |
+
raise ValueError(
|
101 |
+
"Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type."
|
102 |
+
)
|
103 |
+
|
104 |
+
base_model = target_class.from_pretrained(base_model_path, **kwargs)
|
105 |
+
|
106 |
+
tokenizer_exists = False
|
107 |
+
if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)):
|
108 |
+
tokenizer_exists = True
|
109 |
+
else:
|
110 |
+
token = kwargs.get("token", None)
|
111 |
+
if token is None:
|
112 |
+
token = kwargs.get("use_auth_token", None)
|
113 |
+
|
114 |
+
tokenizer_exists = check_file_exists_on_hf_hub(
|
115 |
+
repo_id=pretrained_model_name_or_path,
|
116 |
+
filename=TOKENIZER_CONFIG_NAME,
|
117 |
+
revision=kwargs.get("revision", None),
|
118 |
+
repo_type=kwargs.get("repo_type", None),
|
119 |
+
token=token,
|
120 |
+
)
|
121 |
+
|
122 |
+
if tokenizer_exists:
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
124 |
+
pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False)
|
125 |
+
)
|
126 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
127 |
+
|
128 |
+
return cls._target_peft_class.from_pretrained(
|
129 |
+
base_model,
|
130 |
+
pretrained_model_name_or_path,
|
131 |
+
adapter_name=adapter_name,
|
132 |
+
is_trainable=is_trainable,
|
133 |
+
config=config,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
class AutoPeftModel(_BaseAutoPeftModel):
|
139 |
+
_target_class = None
|
140 |
+
_target_peft_class = PeftModel
|
141 |
+
|
142 |
+
|
143 |
+
class AutoPeftModelForCausalLM(_BaseAutoPeftModel):
|
144 |
+
_target_class = AutoModelForCausalLM
|
145 |
+
_target_peft_class = PeftModelForCausalLM
|
146 |
+
|
147 |
+
|
148 |
+
class AutoPeftModelForSeq2SeqLM(_BaseAutoPeftModel):
|
149 |
+
_target_class = AutoModelForSeq2SeqLM
|
150 |
+
_target_peft_class = PeftModelForSeq2SeqLM
|
151 |
+
|
152 |
+
|
153 |
+
class AutoPeftModelForSequenceClassification(_BaseAutoPeftModel):
|
154 |
+
_target_class = AutoModelForSequenceClassification
|
155 |
+
_target_peft_class = PeftModelForSequenceClassification
|
156 |
+
|
157 |
+
|
158 |
+
class AutoPeftModelForTokenClassification(_BaseAutoPeftModel):
|
159 |
+
_target_class = AutoModelForTokenClassification
|
160 |
+
_target_peft_class = PeftModelForTokenClassification
|
161 |
+
|
162 |
+
|
163 |
+
class AutoPeftModelForQuestionAnswering(_BaseAutoPeftModel):
|
164 |
+
_target_class = AutoModelForQuestionAnswering
|
165 |
+
_target_peft_class = PeftModelForQuestionAnswering
|
166 |
+
|
167 |
+
|
168 |
+
class AutoPeftModelForFeatureExtraction(_BaseAutoPeftModel):
|
169 |
+
_target_class = AutoModel
|
170 |
+
_target_peft_class = PeftModelForFeatureExtraction
|
peft/config.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import inspect
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from dataclasses import asdict, dataclass, field
|
18 |
+
from typing import Dict, Optional, Union
|
19 |
+
|
20 |
+
from huggingface_hub import hf_hub_download
|
21 |
+
from transformers.utils import PushToHubMixin
|
22 |
+
|
23 |
+
from .utils import CONFIG_NAME, PeftType, TaskType
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class PeftConfigMixin(PushToHubMixin):
|
28 |
+
r"""
|
29 |
+
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
|
30 |
+
PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to
|
31 |
+
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
|
32 |
+
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
36 |
+
"""
|
37 |
+
|
38 |
+
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
|
39 |
+
auto_mapping: Optional[dict] = field(
|
40 |
+
default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
|
41 |
+
)
|
42 |
+
|
43 |
+
def to_dict(self) -> Dict:
|
44 |
+
r"""
|
45 |
+
Returns the configuration for your adapter model as a dictionary.
|
46 |
+
"""
|
47 |
+
return asdict(self)
|
48 |
+
|
49 |
+
def save_pretrained(self, save_directory: str, **kwargs) -> None:
|
50 |
+
r"""
|
51 |
+
This method saves the configuration of your adapter model in a directory.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
save_directory (`str`):
|
55 |
+
The directory where the configuration will be saved.
|
56 |
+
kwargs (additional keyword arguments, *optional*):
|
57 |
+
Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`]
|
58 |
+
method.
|
59 |
+
"""
|
60 |
+
if os.path.isfile(save_directory):
|
61 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
62 |
+
|
63 |
+
os.makedirs(save_directory, exist_ok=True)
|
64 |
+
auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)
|
65 |
+
|
66 |
+
output_dict = asdict(self)
|
67 |
+
# converting set type to list
|
68 |
+
for key, value in output_dict.items():
|
69 |
+
if isinstance(value, set):
|
70 |
+
output_dict[key] = list(value)
|
71 |
+
|
72 |
+
output_path = os.path.join(save_directory, CONFIG_NAME)
|
73 |
+
|
74 |
+
# Add auto mapping details for custom models.
|
75 |
+
if auto_mapping_dict is not None:
|
76 |
+
output_dict["auto_mapping"] = auto_mapping_dict
|
77 |
+
|
78 |
+
# save it
|
79 |
+
with open(output_path, "w") as writer:
|
80 |
+
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_peft_type(cls, **kwargs):
|
84 |
+
r"""
|
85 |
+
This method loads the configuration of your adapter model from a set of kwargs.
|
86 |
+
|
87 |
+
The appropriate configuration type is determined by the `peft_type` argument. If `peft_type` is not provided,
|
88 |
+
the calling class type is instantiated.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
kwargs (configuration keyword arguments):
|
92 |
+
Keyword arguments passed along to the configuration initialization.
|
93 |
+
"""
|
94 |
+
# Avoid circular dependency .. TODO: fix this with a larger refactor
|
95 |
+
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
96 |
+
|
97 |
+
# TODO: this hack is needed to fix the following issue (on commit 702f937):
|
98 |
+
# if someone saves a default config and loads it back with `PeftConfig` class it yields to
|
99 |
+
# not loading the correct config class.
|
100 |
+
|
101 |
+
# from peft import AdaLoraConfig, PeftConfig
|
102 |
+
# peft_config = AdaLoraConfig()
|
103 |
+
# print(peft_config)
|
104 |
+
# >>> AdaLoraConfig(peft_type=<PeftType.ADALORA: 'ADALORA'>, auto_mapping=None, base_model_name_or_path=None,
|
105 |
+
# revision=None, task_type=None, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, ...
|
106 |
+
#
|
107 |
+
# peft_config.save_pretrained("./test_config")
|
108 |
+
# peft_config = PeftConfig.from_pretrained("./test_config")
|
109 |
+
# print(peft_config)
|
110 |
+
# >>> PeftConfig(peft_type='ADALORA', auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=None, inference_mode=False)
|
111 |
+
|
112 |
+
if "peft_type" in kwargs:
|
113 |
+
peft_type = kwargs["peft_type"]
|
114 |
+
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
|
115 |
+
else:
|
116 |
+
config_cls = cls
|
117 |
+
|
118 |
+
return config_cls(**kwargs)
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
|
122 |
+
r"""
|
123 |
+
This method loads the configuration of your adapter model from a directory.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
pretrained_model_name_or_path (`str`):
|
127 |
+
The directory or the Hub repository id where the configuration is saved.
|
128 |
+
kwargs (additional keyword arguments, *optional*):
|
129 |
+
Additional keyword arguments passed along to the child class initialization.
|
130 |
+
"""
|
131 |
+
path = (
|
132 |
+
os.path.join(pretrained_model_name_or_path, subfolder)
|
133 |
+
if subfolder is not None
|
134 |
+
else pretrained_model_name_or_path
|
135 |
+
)
|
136 |
+
|
137 |
+
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
|
138 |
+
|
139 |
+
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
140 |
+
config_file = os.path.join(path, CONFIG_NAME)
|
141 |
+
else:
|
142 |
+
try:
|
143 |
+
config_file = hf_hub_download(
|
144 |
+
pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
|
145 |
+
)
|
146 |
+
except Exception as exc:
|
147 |
+
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") from exc
|
148 |
+
|
149 |
+
loaded_attributes = cls.from_json_file(config_file)
|
150 |
+
kwargs = {**class_kwargs, **loaded_attributes}
|
151 |
+
return cls.from_peft_type(**kwargs)
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def from_json_file(cls, path_json_file: str, **kwargs):
|
155 |
+
r"""
|
156 |
+
Loads a configuration file from a json file.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
path_json_file (`str`):
|
160 |
+
The path to the json file.
|
161 |
+
"""
|
162 |
+
with open(path_json_file) as file:
|
163 |
+
json_object = json.load(file)
|
164 |
+
|
165 |
+
return json_object
|
166 |
+
|
167 |
+
@classmethod
|
168 |
+
def _split_kwargs(cls, kwargs):
|
169 |
+
hf_hub_download_kwargs = {}
|
170 |
+
class_kwargs = {}
|
171 |
+
other_kwargs = {}
|
172 |
+
|
173 |
+
for key, value in kwargs.items():
|
174 |
+
if key in inspect.signature(hf_hub_download).parameters:
|
175 |
+
hf_hub_download_kwargs[key] = value
|
176 |
+
elif key in list(cls.__annotations__):
|
177 |
+
class_kwargs[key] = value
|
178 |
+
else:
|
179 |
+
other_kwargs[key] = value
|
180 |
+
|
181 |
+
return hf_hub_download_kwargs, class_kwargs, other_kwargs
|
182 |
+
|
183 |
+
@classmethod
|
184 |
+
def _get_peft_type(
|
185 |
+
cls,
|
186 |
+
model_id: str,
|
187 |
+
**hf_hub_download_kwargs,
|
188 |
+
):
|
189 |
+
subfolder = hf_hub_download_kwargs.get("subfolder", None)
|
190 |
+
|
191 |
+
path = os.path.join(model_id, subfolder) if subfolder is not None else model_id
|
192 |
+
|
193 |
+
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
194 |
+
config_file = os.path.join(path, CONFIG_NAME)
|
195 |
+
else:
|
196 |
+
try:
|
197 |
+
config_file = hf_hub_download(
|
198 |
+
model_id,
|
199 |
+
CONFIG_NAME,
|
200 |
+
**hf_hub_download_kwargs,
|
201 |
+
)
|
202 |
+
except Exception:
|
203 |
+
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'")
|
204 |
+
|
205 |
+
loaded_attributes = cls.from_json_file(config_file)
|
206 |
+
return loaded_attributes["peft_type"]
|
207 |
+
|
208 |
+
@property
|
209 |
+
def is_prompt_learning(self) -> bool:
|
210 |
+
r"""
|
211 |
+
Utility method to check if the configuration is for prompt learning.
|
212 |
+
"""
|
213 |
+
return False
|
214 |
+
|
215 |
+
@property
|
216 |
+
def is_adaption_prompt(self) -> bool:
|
217 |
+
"""Return True if this is an adaption prompt config."""
|
218 |
+
return False
|
219 |
+
|
220 |
+
|
221 |
+
@dataclass
|
222 |
+
class PeftConfig(PeftConfigMixin):
|
223 |
+
"""
|
224 |
+
This is the base configuration class to store the configuration of a [`PeftModel`].
|
225 |
+
|
226 |
+
Args:
|
227 |
+
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
228 |
+
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
|
229 |
+
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
|
230 |
+
"""
|
231 |
+
|
232 |
+
base_model_name_or_path: Optional[str] = field(
|
233 |
+
default=None, metadata={"help": "The name of the base model to use."}
|
234 |
+
)
|
235 |
+
revision: Optional[str] = field(default=None, metadata={"help": "The specific model version to use."})
|
236 |
+
peft_type: Optional[Union[str, PeftType]] = field(default=None, metadata={"help": "Peft type"})
|
237 |
+
task_type: Optional[Union[str, TaskType]] = field(default=None, metadata={"help": "Task type"})
|
238 |
+
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
|
239 |
+
|
240 |
+
|
241 |
+
@dataclass
|
242 |
+
class PromptLearningConfig(PeftConfig):
|
243 |
+
"""
|
244 |
+
This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or
|
245 |
+
[`PromptTuning`].
|
246 |
+
|
247 |
+
Args:
|
248 |
+
num_virtual_tokens (`int`): The number of virtual tokens to use.
|
249 |
+
token_dim (`int`): The hidden embedding dimension of the base transformer model.
|
250 |
+
num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
|
251 |
+
num_attention_heads (`int`): The number of attention heads in the base transformer model.
|
252 |
+
num_layers (`int`): The number of layers in the base transformer model.
|
253 |
+
"""
|
254 |
+
|
255 |
+
num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
|
256 |
+
token_dim: int = field(
|
257 |
+
default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
|
258 |
+
)
|
259 |
+
num_transformer_submodules: Optional[int] = field(
|
260 |
+
default=None, metadata={"help": "Number of transformer submodules"}
|
261 |
+
)
|
262 |
+
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
|
263 |
+
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
|
264 |
+
|
265 |
+
@property
|
266 |
+
def is_prompt_learning(self) -> bool:
|
267 |
+
r"""
|
268 |
+
Utility method to check if the configuration is for prompt learning.
|
269 |
+
"""
|
270 |
+
return True
|
peft/helpers.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from copy import deepcopy
|
17 |
+
from functools import update_wrapper
|
18 |
+
from types import MethodType
|
19 |
+
|
20 |
+
from .peft_model import PeftConfig, PeftModel
|
21 |
+
|
22 |
+
|
23 |
+
def update_forward_signature(model: PeftModel) -> None:
|
24 |
+
"""
|
25 |
+
Updates the forward signature of the PeftModel to include parents class signature
|
26 |
+
model (`PeftModel`): Peft model to update the forward signature
|
27 |
+
|
28 |
+
Example:
|
29 |
+
|
30 |
+
```python
|
31 |
+
>>> from transformers import WhisperForConditionalGeneration
|
32 |
+
>>> from peft import get_peft_model, LoraConfig, update_forward_signature
|
33 |
+
|
34 |
+
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
35 |
+
>>> peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"])
|
36 |
+
|
37 |
+
>>> peft_model = get_peft_model(model, peft_config)
|
38 |
+
>>> update_forward_signature(peft_model)
|
39 |
+
```
|
40 |
+
"""
|
41 |
+
|
42 |
+
# Only update signature when the current forward signature only has *args and **kwargs
|
43 |
+
current_signature = inspect.signature(model.forward)
|
44 |
+
if (
|
45 |
+
len(current_signature.parameters) == 2
|
46 |
+
and "args" in current_signature.parameters
|
47 |
+
and "kwargs" in current_signature.parameters
|
48 |
+
):
|
49 |
+
forward = deepcopy(model.forward.__func__)
|
50 |
+
update_wrapper(
|
51 |
+
forward, type(model.get_base_model()).forward, assigned=("__doc__", "__name__", "__annotations__")
|
52 |
+
)
|
53 |
+
model.forward = MethodType(forward, model)
|
54 |
+
|
55 |
+
|
56 |
+
def update_generate_signature(model: PeftModel) -> None:
|
57 |
+
"""
|
58 |
+
Updates the generate signature of a PeftModel with overriding generate to include parents class signature
|
59 |
+
model (`PeftModel`): Peft model to update the generate signature
|
60 |
+
|
61 |
+
Example:
|
62 |
+
|
63 |
+
```python
|
64 |
+
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
65 |
+
>>> from peft import get_peft_model, LoraConfig, TaskType, update_generate_signature
|
66 |
+
|
67 |
+
>>> model_name_or_path = "bigscience/mt0-large"
|
68 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
69 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
70 |
+
|
71 |
+
>>> peft_config = LoraConfig(
|
72 |
+
... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
73 |
+
... )
|
74 |
+
>>> peft_model = get_peft_model(model, peft_config)
|
75 |
+
>>> update_generate_signature(peft_model)
|
76 |
+
>>> help(peft_model.generate)
|
77 |
+
```
|
78 |
+
"""
|
79 |
+
if not hasattr(model, "generate"):
|
80 |
+
return
|
81 |
+
current_signature = inspect.signature(model.generate)
|
82 |
+
if (
|
83 |
+
len(current_signature.parameters) == 2
|
84 |
+
and "args" in current_signature.parameters
|
85 |
+
and "kwargs" in current_signature.parameters
|
86 |
+
) or (len(current_signature.parameters) == 1 and "kwargs" in current_signature.parameters):
|
87 |
+
generate = deepcopy(model.generate.__func__)
|
88 |
+
update_wrapper(
|
89 |
+
generate,
|
90 |
+
type(model.get_base_model()).generate,
|
91 |
+
assigned=("__doc__", "__name__", "__annotations__"),
|
92 |
+
)
|
93 |
+
model.generate = MethodType(generate, model)
|
94 |
+
|
95 |
+
|
96 |
+
def update_signature(model: PeftModel, method: str = "all") -> None:
|
97 |
+
"""
|
98 |
+
Updates the signature of a PeftModel include parents class signature for forward or generate method
|
99 |
+
model (`PeftModel`): Peft model to update generate or forward signature method (`str`): method to update
|
100 |
+
signature choose one of "forward", "generate", "all"
|
101 |
+
|
102 |
+
Example:
|
103 |
+
```python
|
104 |
+
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
105 |
+
>>> from peft import get_peft_model, LoraConfig, TaskType, update_signature
|
106 |
+
|
107 |
+
>>> model_name_or_path = "bigscience/mt0-large"
|
108 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
109 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
110 |
+
|
111 |
+
>>> peft_config = LoraConfig(
|
112 |
+
... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
113 |
+
... )
|
114 |
+
>>> peft_model = get_peft_model(model, peft_config)
|
115 |
+
>>> update_signature(peft_model)
|
116 |
+
>>> help(peft_model.generate)
|
117 |
+
```
|
118 |
+
"""
|
119 |
+
if method == "forward":
|
120 |
+
update_forward_signature(model)
|
121 |
+
elif method == "generate":
|
122 |
+
update_generate_signature(model)
|
123 |
+
elif method == "all":
|
124 |
+
update_forward_signature(model)
|
125 |
+
update_generate_signature(model)
|
126 |
+
else:
|
127 |
+
raise ValueError(f"method {method} is not supported please choose one of ['forward', 'generate', 'all']")
|
128 |
+
|
129 |
+
|
130 |
+
def check_if_peft_model(model_name_or_path: str) -> bool:
|
131 |
+
"""
|
132 |
+
Check if the model is a PEFT model.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
model_name_or_path (`str`):
|
136 |
+
Model id to check, can be local or on the Hugging Face Hub.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
`bool`: True if the model is a PEFT model, False otherwise.
|
140 |
+
"""
|
141 |
+
is_peft_model = True
|
142 |
+
try:
|
143 |
+
PeftConfig.from_pretrained(model_name_or_path)
|
144 |
+
except Exception:
|
145 |
+
# allow broad exceptions so that this works even if new exceptions are added on HF Hub side
|
146 |
+
is_peft_model = False
|
147 |
+
|
148 |
+
return is_peft_model
|
peft/import_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
import importlib.metadata as importlib_metadata
|
16 |
+
from functools import lru_cache
|
17 |
+
|
18 |
+
import packaging.version
|
19 |
+
|
20 |
+
|
21 |
+
@lru_cache
|
22 |
+
def is_bnb_available() -> bool:
|
23 |
+
return importlib.util.find_spec("bitsandbytes") is not None
|
24 |
+
|
25 |
+
|
26 |
+
@lru_cache
|
27 |
+
def is_bnb_4bit_available() -> bool:
|
28 |
+
if not is_bnb_available():
|
29 |
+
return False
|
30 |
+
|
31 |
+
import bitsandbytes as bnb
|
32 |
+
|
33 |
+
return hasattr(bnb.nn, "Linear4bit")
|
34 |
+
|
35 |
+
|
36 |
+
@lru_cache
|
37 |
+
def is_auto_gptq_available():
|
38 |
+
if importlib.util.find_spec("auto_gptq") is not None:
|
39 |
+
AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.5.0")
|
40 |
+
version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
|
41 |
+
if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq:
|
42 |
+
return True
|
43 |
+
else:
|
44 |
+
raise ImportError(
|
45 |
+
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, "
|
46 |
+
f"but only versions above {AUTOGPTQ_MINIMUM_VERSION} are supported"
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
@lru_cache
|
51 |
+
def is_optimum_available() -> bool:
|
52 |
+
return importlib.util.find_spec("optimum") is not None
|
53 |
+
|
54 |
+
|
55 |
+
@lru_cache
|
56 |
+
def is_torch_tpu_available(check_device=True):
|
57 |
+
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
58 |
+
if importlib.util.find_spec("torch_xla") is not None:
|
59 |
+
if check_device:
|
60 |
+
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
61 |
+
try:
|
62 |
+
import torch_xla.core.xla_model as xm
|
63 |
+
|
64 |
+
_ = xm.xla_device()
|
65 |
+
return True
|
66 |
+
except RuntimeError:
|
67 |
+
return False
|
68 |
+
return True
|
69 |
+
return False
|
70 |
+
|
71 |
+
|
72 |
+
@lru_cache
|
73 |
+
def is_aqlm_available():
|
74 |
+
return importlib.util.find_spec("aqlm") is not None
|
75 |
+
|
76 |
+
|
77 |
+
@lru_cache
|
78 |
+
def is_auto_awq_available():
|
79 |
+
return importlib.util.find_spec("awq") is not None
|
80 |
+
|
81 |
+
|
82 |
+
@lru_cache
|
83 |
+
def is_eetq_available():
|
84 |
+
return importlib.util.find_spec("eetq") is not None
|
85 |
+
|
86 |
+
|
87 |
+
@lru_cache
|
88 |
+
def is_hqq_available():
|
89 |
+
return importlib.util.find_spec("hqq") is not None
|
peft/mapping.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
from typing import TYPE_CHECKING, Any
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from .config import PeftConfig
|
22 |
+
from .mixed_model import PeftMixedModel
|
23 |
+
from .peft_model import (
|
24 |
+
PeftModel,
|
25 |
+
PeftModelForCausalLM,
|
26 |
+
PeftModelForFeatureExtraction,
|
27 |
+
PeftModelForQuestionAnswering,
|
28 |
+
PeftModelForSeq2SeqLM,
|
29 |
+
PeftModelForSequenceClassification,
|
30 |
+
PeftModelForTokenClassification,
|
31 |
+
)
|
32 |
+
from .tuners import (
|
33 |
+
AdaLoraConfig,
|
34 |
+
AdaLoraModel,
|
35 |
+
AdaptionPromptConfig,
|
36 |
+
BOFTConfig,
|
37 |
+
BOFTModel,
|
38 |
+
IA3Config,
|
39 |
+
IA3Model,
|
40 |
+
LNTuningConfig,
|
41 |
+
LNTuningModel,
|
42 |
+
LoHaConfig,
|
43 |
+
LoHaModel,
|
44 |
+
LoKrConfig,
|
45 |
+
LoKrModel,
|
46 |
+
LoraConfig,
|
47 |
+
LoraModel,
|
48 |
+
MultitaskPromptTuningConfig,
|
49 |
+
OFTConfig,
|
50 |
+
OFTModel,
|
51 |
+
PolyConfig,
|
52 |
+
PolyModel,
|
53 |
+
PrefixTuningConfig,
|
54 |
+
PromptEncoderConfig,
|
55 |
+
PromptTuningConfig,
|
56 |
+
VeraConfig,
|
57 |
+
VeraModel,
|
58 |
+
)
|
59 |
+
from .tuners.tuners_utils import BaseTuner as _BaseTuner
|
60 |
+
from .utils import _prepare_prompt_learning_config
|
61 |
+
|
62 |
+
|
63 |
+
if TYPE_CHECKING:
|
64 |
+
from transformers import PreTrainedModel
|
65 |
+
|
66 |
+
|
67 |
+
MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
|
68 |
+
"SEQ_CLS": PeftModelForSequenceClassification,
|
69 |
+
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
|
70 |
+
"CAUSAL_LM": PeftModelForCausalLM,
|
71 |
+
"TOKEN_CLS": PeftModelForTokenClassification,
|
72 |
+
"QUESTION_ANS": PeftModelForQuestionAnswering,
|
73 |
+
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
|
74 |
+
}
|
75 |
+
|
76 |
+
PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
|
77 |
+
"ADAPTION_PROMPT": AdaptionPromptConfig,
|
78 |
+
"PROMPT_TUNING": PromptTuningConfig,
|
79 |
+
"PREFIX_TUNING": PrefixTuningConfig,
|
80 |
+
"P_TUNING": PromptEncoderConfig,
|
81 |
+
"LORA": LoraConfig,
|
82 |
+
"LOHA": LoHaConfig,
|
83 |
+
"LOKR": LoKrConfig,
|
84 |
+
"ADALORA": AdaLoraConfig,
|
85 |
+
"BOFT": BOFTConfig,
|
86 |
+
"IA3": IA3Config,
|
87 |
+
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
|
88 |
+
"OFT": OFTConfig,
|
89 |
+
"POLY": PolyConfig,
|
90 |
+
"LN_TUNING": LNTuningConfig,
|
91 |
+
"VERA": VeraConfig,
|
92 |
+
}
|
93 |
+
|
94 |
+
PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
|
95 |
+
"LORA": LoraModel,
|
96 |
+
"LOHA": LoHaModel,
|
97 |
+
"LOKR": LoKrModel,
|
98 |
+
"ADALORA": AdaLoraModel,
|
99 |
+
"BOFT": BOFTModel,
|
100 |
+
"IA3": IA3Model,
|
101 |
+
"OFT": OFTModel,
|
102 |
+
"POLY": PolyModel,
|
103 |
+
"LN_TUNING": LNTuningModel,
|
104 |
+
"VERA": VeraModel,
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
|
109 |
+
"""
|
110 |
+
Returns a Peft config object from a dictionary.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
|
114 |
+
"""
|
115 |
+
|
116 |
+
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
|
117 |
+
|
118 |
+
|
119 |
+
def get_peft_model(
|
120 |
+
model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
|
121 |
+
) -> PeftModel | PeftMixedModel:
|
122 |
+
"""
|
123 |
+
Returns a Peft model object from a model and a config.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
model ([`transformers.PreTrainedModel`]):
|
127 |
+
Model to be wrapped.
|
128 |
+
peft_config ([`PeftConfig`]):
|
129 |
+
Configuration object containing the parameters of the Peft model.
|
130 |
+
adapter_name (`str`, `optional`, defaults to `"default"`):
|
131 |
+
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
132 |
+
mixed (`bool`, `optional`, defaults to `False`):
|
133 |
+
Whether to allow mixing different (compatible) adapter types.
|
134 |
+
"""
|
135 |
+
model_config = getattr(model, "config", {"model_type": "custom"})
|
136 |
+
if hasattr(model_config, "to_dict"):
|
137 |
+
model_config = model_config.to_dict()
|
138 |
+
|
139 |
+
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
|
140 |
+
|
141 |
+
if mixed:
|
142 |
+
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
|
143 |
+
|
144 |
+
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
|
145 |
+
return PeftModel(model, peft_config, adapter_name=adapter_name)
|
146 |
+
|
147 |
+
if peft_config.is_prompt_learning:
|
148 |
+
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
|
149 |
+
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
|
150 |
+
|
151 |
+
|
152 |
+
def inject_adapter_in_model(
|
153 |
+
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default"
|
154 |
+
) -> torch.nn.Module:
|
155 |
+
r"""
|
156 |
+
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
|
157 |
+
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
|
158 |
+
calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
peft_config (`PeftConfig`):
|
162 |
+
Configuration object containing the parameters of the Peft model.
|
163 |
+
model (`torch.nn.Module`):
|
164 |
+
The input model where the adapter will be injected.
|
165 |
+
adapter_name (`str`, `optional`, defaults to `"default"`):
|
166 |
+
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
167 |
+
"""
|
168 |
+
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
|
169 |
+
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
|
170 |
+
|
171 |
+
if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys():
|
172 |
+
raise ValueError(
|
173 |
+
f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`."
|
174 |
+
)
|
175 |
+
|
176 |
+
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
|
177 |
+
|
178 |
+
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
|
179 |
+
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name)
|
180 |
+
|
181 |
+
return peft_model.model
|
peft/mixed_model.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import os
|
18 |
+
from contextlib import contextmanager
|
19 |
+
from typing import Any, Optional, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from accelerate.hooks import remove_hook_from_submodules
|
23 |
+
from torch import nn
|
24 |
+
from transformers.utils import PushToHubMixin
|
25 |
+
|
26 |
+
from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES
|
27 |
+
|
28 |
+
from .config import PeftConfig
|
29 |
+
from .peft_model import PeftModel
|
30 |
+
from .tuners import (
|
31 |
+
AdaLoraModel,
|
32 |
+
IA3Model,
|
33 |
+
LoHaModel,
|
34 |
+
LoKrModel,
|
35 |
+
LoraModel,
|
36 |
+
MixedModel,
|
37 |
+
OFTModel,
|
38 |
+
)
|
39 |
+
from .utils import PeftType, _set_adapter, _set_trainable
|
40 |
+
|
41 |
+
|
42 |
+
PEFT_TYPE_TO_MODEL_MAPPING = {
|
43 |
+
PeftType.LORA: LoraModel,
|
44 |
+
PeftType.LOHA: LoHaModel,
|
45 |
+
PeftType.LOKR: LoKrModel,
|
46 |
+
PeftType.ADALORA: AdaLoraModel,
|
47 |
+
PeftType.IA3: IA3Model,
|
48 |
+
PeftType.OFT: OFTModel,
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None:
|
53 |
+
r"""
|
54 |
+
Prepares the model for gradient checkpointing if necessary
|
55 |
+
"""
|
56 |
+
# Note: same as PeftModel._prepare_model_for_gradient_checkpointing
|
57 |
+
if not getattr(model, "is_gradient_checkpointing", True):
|
58 |
+
return model
|
59 |
+
|
60 |
+
if not (
|
61 |
+
getattr(model, "is_loaded_in_8bit", False)
|
62 |
+
or getattr(model, "is_loaded_in_4bit", False)
|
63 |
+
or getattr(model, "is_quantized", False)
|
64 |
+
):
|
65 |
+
if hasattr(model, "enable_input_require_grads"):
|
66 |
+
model.enable_input_require_grads()
|
67 |
+
elif hasattr(model, "get_input_embeddings"):
|
68 |
+
|
69 |
+
def make_inputs_require_grad(module, input, output):
|
70 |
+
output.requires_grad_(True)
|
71 |
+
|
72 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
73 |
+
|
74 |
+
|
75 |
+
def _check_config_compatible(peft_config: PeftConfig) -> None:
|
76 |
+
if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES:
|
77 |
+
raise ValueError(
|
78 |
+
f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. "
|
79 |
+
f"Compatible types are: {COMPATIBLE_TUNER_TYPES}"
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class PeftMixedModel(PushToHubMixin, torch.nn.Module):
|
84 |
+
"""
|
85 |
+
PeftMixedModel for loading mixing different types of adapters for inference.
|
86 |
+
|
87 |
+
This class does not support loading/saving, and it shouldn't usually be initialized directly. Instead, use
|
88 |
+
`get_peft_model` with the argument `mixed=True`.
|
89 |
+
|
90 |
+
<Tip>
|
91 |
+
|
92 |
+
Read the [Mixed adapter types](https://huggingface.co/docs/peft/en/developer_guides/mixed_models) guide to learn
|
93 |
+
more about using different adapter types.
|
94 |
+
|
95 |
+
</Tip>
|
96 |
+
|
97 |
+
Example:
|
98 |
+
|
99 |
+
```py
|
100 |
+
>>> from peft import get_peft_model
|
101 |
+
|
102 |
+
>>> base_model = ... # load the base model, e.g. from transformers
|
103 |
+
>>> peft_model = PeftMixedModel.from_pretrained(base_model, path_to_adapter1, "adapter1").eval()
|
104 |
+
>>> peft_model.load_adapter(path_to_adapter2, "adapter2")
|
105 |
+
>>> peft_model.set_adapter(["adapter1", "adapter2"]) # activate both adapters
|
106 |
+
>>> peft_model(data) # forward pass using both adapters
|
107 |
+
```
|
108 |
+
|
109 |
+
Args:
|
110 |
+
model (`torch.nn.Module`):
|
111 |
+
The model to be tuned.
|
112 |
+
config (`PeftConfig`):
|
113 |
+
The config of the model to be tuned. The adapter type must be compatible.
|
114 |
+
adapter_name (`str`, `optional`, defaults to `"default"`):
|
115 |
+
The name of the first adapter.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
119 |
+
super().__init__()
|
120 |
+
_check_config_compatible(peft_config)
|
121 |
+
_prepare_model_for_gradient_checkpointing(model)
|
122 |
+
self.modules_to_save = None
|
123 |
+
self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name)
|
124 |
+
self.set_modules_to_save(peft_config, adapter_name)
|
125 |
+
|
126 |
+
self.config = getattr(model, "config", {"model_type": "custom"})
|
127 |
+
|
128 |
+
# the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
|
129 |
+
# numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
|
130 |
+
# behavior we disable that in this line.
|
131 |
+
if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
|
132 |
+
self.base_model.config.pretraining_tp = 1
|
133 |
+
|
134 |
+
@property
|
135 |
+
def peft_config(self) -> dict[str, PeftConfig]:
|
136 |
+
return self.base_model.peft_config
|
137 |
+
|
138 |
+
@property
|
139 |
+
def active_adapter(self) -> str:
|
140 |
+
return self.base_model.active_adapter
|
141 |
+
|
142 |
+
@property
|
143 |
+
def active_adapters(self) -> list[str]:
|
144 |
+
return self.base_model.active_adapters
|
145 |
+
|
146 |
+
def get_nb_trainable_parameters(self):
|
147 |
+
r"""
|
148 |
+
Returns the number of trainable parameters and number of all parameters in the model.
|
149 |
+
"""
|
150 |
+
# note: same as PeftModel.get_nb_trainable_parameters
|
151 |
+
trainable_params = 0
|
152 |
+
all_param = 0
|
153 |
+
for _, param in self.named_parameters():
|
154 |
+
num_params = param.numel()
|
155 |
+
# if using DS Zero 3 and the weights are initialized empty
|
156 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
157 |
+
num_params = param.ds_numel
|
158 |
+
|
159 |
+
# Due to the design of 4bit linear layers from bitsandbytes
|
160 |
+
# one needs to multiply the number of parameters by 2 to get
|
161 |
+
# the correct number of parameters
|
162 |
+
if param.__class__.__name__ == "Params4bit":
|
163 |
+
num_params = num_params * 2
|
164 |
+
|
165 |
+
all_param += num_params
|
166 |
+
if param.requires_grad:
|
167 |
+
trainable_params += num_params
|
168 |
+
|
169 |
+
return trainable_params, all_param
|
170 |
+
|
171 |
+
def print_trainable_parameters(self):
|
172 |
+
"""
|
173 |
+
Prints the number of trainable parameters in the model.
|
174 |
+
|
175 |
+
Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from
|
176 |
+
num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns
|
177 |
+
(trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model.
|
178 |
+
For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for
|
179 |
+
prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number
|
180 |
+
of trainable parameters of the backbone transformer model which can be different.
|
181 |
+
"""
|
182 |
+
# note: same as PeftModel.print_trainable_parameters
|
183 |
+
trainable_params, all_param = self.get_nb_trainable_parameters()
|
184 |
+
|
185 |
+
print(
|
186 |
+
f"trainable params: {trainable_params:,d} || "
|
187 |
+
f"all params: {all_param:,d} || "
|
188 |
+
f"trainable%: {100 * trainable_params / all_param:.4f}"
|
189 |
+
)
|
190 |
+
|
191 |
+
def __getattr__(self, name: str):
|
192 |
+
"""Forward missing attributes to the wrapped module."""
|
193 |
+
try:
|
194 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
195 |
+
except AttributeError:
|
196 |
+
return getattr(self.base_model, name)
|
197 |
+
|
198 |
+
def forward(self, *args: Any, **kwargs: Any):
|
199 |
+
"""
|
200 |
+
Forward pass of the model.
|
201 |
+
"""
|
202 |
+
return self.base_model(*args, **kwargs)
|
203 |
+
|
204 |
+
def generate(self, *args: Any, **kwargs: Any):
|
205 |
+
"""
|
206 |
+
Generate output.
|
207 |
+
"""
|
208 |
+
return self.base_model.generate(*args, **kwargs)
|
209 |
+
|
210 |
+
@contextmanager
|
211 |
+
def disable_adapter(self):
|
212 |
+
"""
|
213 |
+
Disables the adapter module.
|
214 |
+
"""
|
215 |
+
try:
|
216 |
+
self.base_model.disable_adapter_layers()
|
217 |
+
yield
|
218 |
+
finally:
|
219 |
+
self.base_model.enable_adapter_layers()
|
220 |
+
|
221 |
+
def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
|
222 |
+
_check_config_compatible(peft_config)
|
223 |
+
|
224 |
+
try:
|
225 |
+
self.peft_config[adapter_name] = peft_config
|
226 |
+
self.base_model.inject_adapter(self, adapter_name)
|
227 |
+
except Exception: # something went wrong, roll back
|
228 |
+
if adapter_name in self.peft_config:
|
229 |
+
del self.peft_config[adapter_name]
|
230 |
+
raise
|
231 |
+
|
232 |
+
self.set_modules_to_save(peft_config, adapter_name)
|
233 |
+
|
234 |
+
def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None:
|
235 |
+
if (modules_to_save := getattr(peft_config, "modules_to_save", None)) is None:
|
236 |
+
return
|
237 |
+
|
238 |
+
if self.modules_to_save is None:
|
239 |
+
self.modules_to_save = set(modules_to_save)
|
240 |
+
else:
|
241 |
+
self.modules_to_save.update(modules_to_save)
|
242 |
+
_set_trainable(self, adapter_name)
|
243 |
+
|
244 |
+
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None:
|
245 |
+
"""
|
246 |
+
Sets the active adapter(s) for the model.
|
247 |
+
|
248 |
+
Note that the order in which the adapters are applied during the forward pass may not be the same as the order
|
249 |
+
in which they are passed to this function. Instead, the order during the forward pass is determined by the
|
250 |
+
order in which the adapters were loaded into the model. The active adapters only determine which adapters are
|
251 |
+
active during the forward pass, but not the order in which they are applied.
|
252 |
+
|
253 |
+
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
|
254 |
+
not desired, use the following code.
|
255 |
+
|
256 |
+
```py
|
257 |
+
>>> for name, param in model_peft.named_parameters():
|
258 |
+
... if ...: # some check on name (ex. if 'lora' in name)
|
259 |
+
... param.requires_grad = False
|
260 |
+
```
|
261 |
+
|
262 |
+
Args:
|
263 |
+
adapter_name (`str` or `List[str]`):
|
264 |
+
The name of the adapter(s) to be activated.
|
265 |
+
"""
|
266 |
+
if isinstance(adapter_name, str):
|
267 |
+
adapter_name = [adapter_name]
|
268 |
+
|
269 |
+
mismatched = set(adapter_name) - set(self.peft_config.keys())
|
270 |
+
if mismatched:
|
271 |
+
raise ValueError(
|
272 |
+
f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
|
273 |
+
)
|
274 |
+
|
275 |
+
self.base_model.set_adapter(adapter_name)
|
276 |
+
_set_adapter(self, adapter_name)
|
277 |
+
|
278 |
+
def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None:
|
279 |
+
if isinstance(adapter_name, str):
|
280 |
+
adapter_name = [adapter_name]
|
281 |
+
|
282 |
+
mismatched = set(adapter_name) - set(self.peft_config.keys())
|
283 |
+
if mismatched:
|
284 |
+
raise ValueError(
|
285 |
+
f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
|
286 |
+
)
|
287 |
+
|
288 |
+
self.base_model.delete_adapter(adapter_name)
|
289 |
+
|
290 |
+
def merge_and_unload(self, *args: Any, **kwargs: Any):
|
291 |
+
r"""
|
292 |
+
This method merges the adapter layers into the base model. This is needed if someone wants to use the base
|
293 |
+
model as a standalone model.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
progressbar (`bool`):
|
297 |
+
whether to show a progressbar indicating the unload and merge process
|
298 |
+
safe_merge (`bool`):
|
299 |
+
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
300 |
+
weights
|
301 |
+
adapter_names (`List[str]`, *optional*):
|
302 |
+
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
303 |
+
to `None`.
|
304 |
+
"""
|
305 |
+
return self.base_model.merge_and_unload(*args, **kwargs)
|
306 |
+
|
307 |
+
def unload(self, *args: Any, **kwargs: Any):
|
308 |
+
"""
|
309 |
+
Gets back the base model by removing all the adapter modules without merging. This gives back the original base
|
310 |
+
model.
|
311 |
+
"""
|
312 |
+
return self.base_model.unload(*args, **kwargs)
|
313 |
+
|
314 |
+
def get_layer_status(self):
|
315 |
+
raise TypeError(f"get_layer_status is not supported for {self.__class__.__name__}.")
|
316 |
+
|
317 |
+
def get_model_status(self):
|
318 |
+
raise TypeError(f"get_model_status is not supported for {self.__class__.__name__}.")
|
319 |
+
|
320 |
+
@classmethod
|
321 |
+
def _split_kwargs(cls, kwargs: dict[str, Any]):
|
322 |
+
return PeftModel._split_kwargs(kwargs)
|
323 |
+
|
324 |
+
def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any):
|
325 |
+
output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs)
|
326 |
+
# TODO: not quite clear why this is necessary but tests fail without it
|
327 |
+
self.set_adapter(self.active_adapters)
|
328 |
+
return output
|
329 |
+
|
330 |
+
def create_or_update_model_card(self, output_dir: str):
|
331 |
+
raise NotImplementedError(f"Model card creation is not supported for {self.__class__.__name__} (yet).")
|
332 |
+
|
333 |
+
def save_pretrained(
|
334 |
+
self,
|
335 |
+
save_directory: str,
|
336 |
+
safe_serialization: bool = False,
|
337 |
+
selected_adapters: Optional[list[str]] = None,
|
338 |
+
**kwargs: Any,
|
339 |
+
):
|
340 |
+
raise NotImplementedError(f"Saving is not supported for {self.__class__.__name__} (yet).")
|
341 |
+
|
342 |
+
@classmethod
|
343 |
+
def from_pretrained(
|
344 |
+
cls,
|
345 |
+
model: nn.Module,
|
346 |
+
model_id: str | os.PathLike,
|
347 |
+
adapter_name: str = "default",
|
348 |
+
is_trainable: bool = False,
|
349 |
+
config: Optional[PeftConfig] = None,
|
350 |
+
**kwargs: Any,
|
351 |
+
):
|
352 |
+
r"""
|
353 |
+
Instantiate a PEFT mixed model from a pretrained model and loaded PEFT weights.
|
354 |
+
|
355 |
+
Note that the passed `model` may be modified inplace.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
model (`nn.Module`):
|
359 |
+
The model to be adapted.
|
360 |
+
model_id (`str` or `os.PathLike`):
|
361 |
+
The name of the PEFT configuration to use. Can be either:
|
362 |
+
- A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face
|
363 |
+
Hub.
|
364 |
+
- A path to a directory containing a PEFT configuration file saved using the `save_pretrained`
|
365 |
+
method (`./my_peft_config_directory/`).
|
366 |
+
adapter_name (`str`, *optional*, defaults to `"default"`):
|
367 |
+
The name of the adapter to be loaded. This is useful for loading multiple adapters.
|
368 |
+
is_trainable (`bool`, *optional*, defaults to `False`):
|
369 |
+
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for
|
370 |
+
inference
|
371 |
+
config ([`~peft.PeftConfig`], *optional*):
|
372 |
+
The configuration object to use instead of an automatically loaded configuration. This configuration
|
373 |
+
object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
|
374 |
+
loaded before calling `from_pretrained`.
|
375 |
+
kwargs: (`optional`):
|
376 |
+
Additional keyword arguments passed along to the specific PEFT configuration class.
|
377 |
+
"""
|
378 |
+
# note: adapted from PeftModel.from_pretrained
|
379 |
+
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
380 |
+
|
381 |
+
# load the config
|
382 |
+
if config is None:
|
383 |
+
config = PEFT_TYPE_TO_CONFIG_MAPPING[
|
384 |
+
PeftConfig._get_peft_type(
|
385 |
+
model_id,
|
386 |
+
subfolder=kwargs.get("subfolder", None),
|
387 |
+
revision=kwargs.get("revision", None),
|
388 |
+
cache_dir=kwargs.get("cache_dir", None),
|
389 |
+
use_auth_token=kwargs.get("use_auth_token", None),
|
390 |
+
)
|
391 |
+
].from_pretrained(model_id, **kwargs)
|
392 |
+
elif isinstance(config, PeftConfig):
|
393 |
+
config.inference_mode = not is_trainable
|
394 |
+
else:
|
395 |
+
raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")
|
396 |
+
|
397 |
+
# note: this is different from PeftModel.from_pretrained
|
398 |
+
if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING:
|
399 |
+
raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.")
|
400 |
+
|
401 |
+
if (getattr(model, "hf_device_map", None) is not None) and len(
|
402 |
+
set(model.hf_device_map.values()).intersection({"cpu", "disk"})
|
403 |
+
) > 0:
|
404 |
+
remove_hook_from_submodules(model)
|
405 |
+
|
406 |
+
if config.is_prompt_learning and is_trainable:
|
407 |
+
# note: should not be possible to reach, but just in case
|
408 |
+
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
|
409 |
+
else:
|
410 |
+
config.inference_mode = not is_trainable
|
411 |
+
|
412 |
+
# note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel
|
413 |
+
model = cls(model, config, adapter_name)
|
414 |
+
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
|
415 |
+
return model
|
peft/peft_model.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
peft/py.typed
ADDED
File without changes
|
peft/tuners/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
3 |
+
# module, but to preserve other warnings. So, don't check this module at all
|
4 |
+
|
5 |
+
# coding=utf-8
|
6 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
|
21 |
+
from .lora import LoraConfig, LoraModel, LoftQConfig
|
22 |
+
from .loha import LoHaConfig, LoHaModel
|
23 |
+
from .lokr import LoKrConfig, LoKrModel
|
24 |
+
from .ia3 import IA3Config, IA3Model
|
25 |
+
from .adalora import AdaLoraConfig, AdaLoraModel
|
26 |
+
from .boft import BOFTConfig, BOFTModel
|
27 |
+
from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
|
28 |
+
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
|
29 |
+
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
|
30 |
+
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
|
31 |
+
from .oft import OFTConfig, OFTModel
|
32 |
+
from .mixed import MixedModel
|
33 |
+
from .poly import PolyConfig, PolyModel
|
34 |
+
from .ln_tuning import LNTuningConfig, LNTuningModel
|
35 |
+
from .vera import VeraConfig, VeraModel
|
peft/tuners/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.38 kB). View file
|
|
peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc
ADDED
Binary file (14.3 kB). View file
|
|
peft/tuners/__pycache__/tuners_utils.cpython-310.pyc
ADDED
Binary file (27.1 kB). View file
|
|
peft/tuners/adalora/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
16 |
+
|
17 |
+
from .config import AdaLoraConfig
|
18 |
+
from .gptq import SVDQuantLinear
|
19 |
+
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
|
20 |
+
from .model import AdaLoraModel
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"]
|
24 |
+
|
25 |
+
|
26 |
+
def __getattr__(name):
|
27 |
+
if (name == "SVDLinear8bitLt") and is_bnb_available():
|
28 |
+
from .bnb import SVDLinear8bitLt
|
29 |
+
|
30 |
+
return SVDLinear8bitLt
|
31 |
+
|
32 |
+
if (name == "SVDLinear4bit") and is_bnb_4bit_available():
|
33 |
+
from .bnb import SVDLinear4bit
|
34 |
+
|
35 |
+
return SVDLinear4bit
|
36 |
+
|
37 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (867 Bytes). View file
|
|
peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc
ADDED
Binary file (3.18 kB). View file
|
|
peft/tuners/adalora/__pycache__/config.cpython-310.pyc
ADDED
Binary file (2.85 kB). View file
|
|
peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
peft/tuners/adalora/__pycache__/layer.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
peft/tuners/adalora/__pycache__/model.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
peft/tuners/adalora/bnb.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
20 |
+
|
21 |
+
from .layer import AdaLoraLayer
|
22 |
+
|
23 |
+
|
24 |
+
if is_bnb_available():
|
25 |
+
|
26 |
+
class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
|
27 |
+
# Low-rank matrix for SVD-based adaptation
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
base_layer: torch.nn.Module,
|
31 |
+
adapter_name: str,
|
32 |
+
r: int = 0,
|
33 |
+
lora_alpha: int = 1,
|
34 |
+
lora_dropout: float = 0.0,
|
35 |
+
init_lora_weights: bool = True,
|
36 |
+
**kwargs,
|
37 |
+
) -> None:
|
38 |
+
super().__init__()
|
39 |
+
AdaLoraLayer.__init__(self, base_layer)
|
40 |
+
# Freezing the pre-trained weight matrix
|
41 |
+
self.get_base_layer().weight.requires_grad = False
|
42 |
+
|
43 |
+
self._active_adapter = adapter_name
|
44 |
+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
47 |
+
# note: no check for self.merged because merging is not supported (yet)
|
48 |
+
result = self.base_layer(x)
|
49 |
+
|
50 |
+
if self.disable_adapters:
|
51 |
+
return result
|
52 |
+
|
53 |
+
for active_adapter in self.active_adapters:
|
54 |
+
if active_adapter not in self.lora_A.keys():
|
55 |
+
continue
|
56 |
+
requires_conversion = not torch.is_autocast_enabled()
|
57 |
+
if requires_conversion:
|
58 |
+
expected_dtype = result.dtype
|
59 |
+
if x.dtype != torch.float32:
|
60 |
+
x = x.float()
|
61 |
+
|
62 |
+
lora_A = self.lora_A[active_adapter]
|
63 |
+
lora_B = self.lora_B[active_adapter]
|
64 |
+
lora_E = self.lora_E[active_adapter]
|
65 |
+
dropout = self.lora_dropout[active_adapter]
|
66 |
+
scaling = self.scaling[active_adapter]
|
67 |
+
ranknum = self.ranknum[active_adapter] + 1e-5
|
68 |
+
|
69 |
+
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
|
70 |
+
if requires_conversion:
|
71 |
+
output = output.to(expected_dtype)
|
72 |
+
output = output * scaling / ranknum
|
73 |
+
# inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
|
74 |
+
result = result + output
|
75 |
+
return result
|
76 |
+
|
77 |
+
def __repr__(self) -> str:
|
78 |
+
rep = super().__repr__()
|
79 |
+
return "adalora." + rep
|
80 |
+
|
81 |
+
|
82 |
+
if is_bnb_4bit_available():
|
83 |
+
|
84 |
+
class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
|
85 |
+
# Low-rank matrix for SVD-based adaptation
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
base_layer: torch.nn.Module,
|
89 |
+
adapter_name: str,
|
90 |
+
r: int = 0,
|
91 |
+
lora_alpha: int = 1,
|
92 |
+
lora_dropout: float = 0.0,
|
93 |
+
init_lora_weights: bool = True,
|
94 |
+
**kwargs,
|
95 |
+
) -> None:
|
96 |
+
super().__init__()
|
97 |
+
AdaLoraLayer.__init__(self, base_layer)
|
98 |
+
# Freezing the pre-trained weight matrix
|
99 |
+
self.get_base_layer().weight.requires_grad = False
|
100 |
+
|
101 |
+
self._active_adapter = adapter_name
|
102 |
+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
105 |
+
# note: no check for self.merged because merging is not supported (yet)
|
106 |
+
result = self.base_layer(x, *args, **kwargs)
|
107 |
+
|
108 |
+
if self.disable_adapters:
|
109 |
+
return result
|
110 |
+
|
111 |
+
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
112 |
+
# The reason is that in some cases, an error can occur that backprop
|
113 |
+
# does not work on a manipulated view. This issue may be solved with
|
114 |
+
# newer PyTorch versions but this would need extensive testing to be
|
115 |
+
# sure.
|
116 |
+
result = result.clone()
|
117 |
+
|
118 |
+
for active_adapter in self.active_adapters:
|
119 |
+
if active_adapter not in self.lora_A.keys():
|
120 |
+
continue
|
121 |
+
|
122 |
+
lora_A = self.lora_A[active_adapter]
|
123 |
+
lora_B = self.lora_B[active_adapter]
|
124 |
+
lora_E = self.lora_E[active_adapter]
|
125 |
+
dropout = self.lora_dropout[active_adapter]
|
126 |
+
scaling = self.scaling[active_adapter]
|
127 |
+
ranknum = self.ranknum[active_adapter] + 1e-5
|
128 |
+
|
129 |
+
requires_conversion = not torch.is_autocast_enabled()
|
130 |
+
if requires_conversion:
|
131 |
+
expected_dtype = result.dtype
|
132 |
+
compute_dtype = lora_A.dtype
|
133 |
+
if x.dtype != compute_dtype:
|
134 |
+
x = x.to(compute_dtype)
|
135 |
+
|
136 |
+
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
|
137 |
+
if requires_conversion:
|
138 |
+
output = output.to(expected_dtype)
|
139 |
+
output = output * scaling / ranknum
|
140 |
+
result += output
|
141 |
+
return result
|
142 |
+
|
143 |
+
def __repr__(self) -> str:
|
144 |
+
rep = super().__repr__()
|
145 |
+
return "adalora." + rep
|