MohamedRashad's picture
Upload code
6dd488f
raw
history blame
1.07 kB
import torch
def unet_add_concat_conds(unet, new_channels=4):
with torch.no_grad():
new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
new_conv_in.bias = unet.conv_in.bias
unet.conv_in = new_conv_in
unet_original_forward = unet.forward
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
c_concat = cross_attention_kwargs.pop('concat_conds')
kwargs['cross_attention_kwargs'] = cross_attention_kwargs
c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample)
new_sample = torch.cat([sample, c_concat], dim=1)
return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
unet.forward = hooked_unet_forward
return