Spaces:
Runtime error
Runtime error
import torch | |
def apply_controlnet_advanced( | |
unet, | |
controlnet, | |
image_bchw, | |
strength, | |
start_percent, | |
end_percent, | |
positive_advanced_weighting=None, | |
negative_advanced_weighting=None, | |
advanced_frame_weighting=None, | |
advanced_sigma_weighting=None, | |
advanced_mask_weighting=None | |
): | |
""" | |
# positive_advanced_weighting or negative_advanced_weighting | |
Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks. | |
Below is an example for stronger control in middle block. | |
This is helpful for some high-res fix passes. | |
positive_advanced_weighting = { | |
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], | |
'middle': [1.0], | |
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] | |
} | |
negative_advanced_weighting = { | |
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2], | |
'middle': [1.0], | |
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2] | |
} | |
# advanced_frame_weighting | |
The advanced_frame_weighting is a weight applied to each image in a batch. | |
The length of this list must be same with batch size | |
For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0] | |
If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time. | |
# advanced_sigma_weighting | |
The advanced_sigma_weighting allows you to dynamically compute control | |
weights given diffusion timestep (sigma). | |
For example below code can softly make beginning steps stronger than ending steps. | |
sigma_max = unet.model.model_sampling.sigma_max | |
sigma_min = unet.model.model_sampling.sigma_min | |
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min) | |
# advanced_mask_weighting | |
A mask can be applied to control signals. | |
This should be a tensor with shape B 1 H W where the H and W can be arbitrary. | |
This mask will be resized automatically to match the shape of all injection layers. | |
""" | |
cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent)) | |
cnet.positive_advanced_weighting = positive_advanced_weighting | |
cnet.negative_advanced_weighting = negative_advanced_weighting | |
cnet.advanced_frame_weighting = advanced_frame_weighting | |
cnet.advanced_sigma_weighting = advanced_sigma_weighting | |
if advanced_mask_weighting is not None: | |
assert isinstance(advanced_mask_weighting, torch.Tensor) | |
B, C, H, W = advanced_mask_weighting.shape | |
assert B > 0 and C == 1 and H > 0 and W > 0 | |
cnet.advanced_mask_weighting = advanced_mask_weighting | |
m = unet.clone() | |
m.add_patched_controlnet(cnet) | |
return m | |