import torch from diffusers.models.embeddings import TimestepEmbedding, Timesteps def unet_add_coded_conds(unet, added_number_count=1): unet.add_time_proj = Timesteps(256, True, 0) unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280) def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs): coded_conds = added_cond_kwargs.get("coded_conds") batch_size = coded_conds.shape[0] time_embeds = unet.add_time_proj(coded_conds.flatten()) time_embeds = time_embeds.reshape((batch_size, -1)) time_embeds = time_embeds.to(emb) aug_emb = unet.add_embedding(time_embeds) return aug_emb unet.get_aug_embed = get_aug_embed unet_original_forward = unet.forward def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()} coded_conds = cross_attention_kwargs.pop('coded_conds') kwargs['cross_attention_kwargs'] = cross_attention_kwargs coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device) kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds) return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs) unet.forward = hooked_unet_forward return