Spaces:
Paused
Paused
import torch | |
from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
def unet_add_coded_conds(unet, added_number_count=1): | |
unet.add_time_proj = Timesteps(256, True, 0) | |
unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280) | |
def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs): | |
coded_conds = added_cond_kwargs.get("coded_conds") | |
batch_size = coded_conds.shape[0] | |
time_embeds = unet.add_time_proj(coded_conds.flatten()) | |
time_embeds = time_embeds.reshape((batch_size, -1)) | |
time_embeds = time_embeds.to(emb) | |
aug_emb = unet.add_embedding(time_embeds) | |
return aug_emb | |
unet.get_aug_embed = get_aug_embed | |
unet_original_forward = unet.forward | |
def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): | |
cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()} | |
coded_conds = cross_attention_kwargs.pop('coded_conds') | |
kwargs['cross_attention_kwargs'] = cross_attention_kwargs | |
coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device) | |
kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds) | |
return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs) | |
unet.forward = hooked_unet_forward | |
return | |