Spaces:
Runtime error
Runtime error
import torch | |
# import argparse | |
# from omegaconf import OmegaConf | |
# from models import get_models | |
# import sys | |
# import os | |
# from PIL import Image | |
# from copy import deepcopy | |
def tca_transform_model(model): | |
for down_block in model.down_blocks: | |
try: | |
for attention in down_block.attentions: | |
attention.transformer_blocks[0].tca_transform() | |
attention.transformer_blocks[0].tca_transform() | |
except: | |
continue | |
for attention in model.mid_block.attentions: | |
attention.transformer_blocks[0].tca_transform() | |
attention.transformer_blocks[0].tca_transform() | |
for up_block in model.up_blocks: | |
try: | |
for attention in up_block.attentions: | |
attention.transformer_blocks[0].tca_transform() | |
attention.transformer_blocks[0].tca_transform() | |
except: | |
continue | |
return model | |
class ImageProjModel(torch.nn.Module): | |
"""Projection Model""" | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds): | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |
def ip_transform_model(model): | |
model.image_proj_model = ImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, | |
clip_extra_context_tokens=4).to(model.device) | |
for down_block in model.down_blocks: | |
try: | |
for attention in down_block.attentions: | |
attention.transformer_blocks[0].attn2.ip_transform() | |
attention.transformer_blocks[0].attn2.ip_transform() | |
except: | |
continue | |
for attention in model.mid_block.attentions: | |
attention.transformer_blocks[0].attn2.ip_transform() | |
attention.transformer_blocks[0].attn2.ip_transform() | |
for up_block in model.up_blocks: | |
try: | |
for attention in up_block.attentions: | |
attention.transformer_blocks[0].attn2.ip_transform() | |
attention.transformer_blocks[0].attn2.ip_transform() | |
except: | |
continue | |
return model | |
def ip_scale_set(model, scale): | |
for down_block in model.down_blocks: | |
try: | |
for attention in down_block.attentions: | |
attention.transformer_blocks[0].attn2.set_scale(scale) | |
attention.transformer_blocks[0].attn2.set_scale(scale) | |
except: | |
continue | |
for attention in model.mid_block.attentions: | |
attention.transformer_blocks[0].attn2.set_scale(scale) | |
attention.transformer_blocks[0].attn2.set_scale(scale) | |
for up_block in model.up_blocks: | |
try: | |
for attention in up_block.attentions: | |
attention.transformer_blocks[0].attn2.set_scale(scale) | |
attention.transformer_blocks[0].attn2.set_scale(scale) | |
except: | |
continue | |
return model | |
def ip_train_set(model): | |
model.requires_grad_(False) | |
model.image_proj_model.requires_grad_(True) | |
for down_block in model.down_blocks: | |
try: | |
for attention in down_block.attentions: | |
attention.transformer_blocks[0].attn2.ip_train_set() | |
attention.transformer_blocks[0].attn2.ip_train_set() | |
except: | |
continue | |
for attention in model.mid_block.attentions: | |
attention.transformer_blocks[0].attn2.ip_train_set() | |
attention.transformer_blocks[0].attn2.ip_train_set() | |
for up_block in model.up_blocks: | |
try: | |
for attention in up_block.attentions: | |
attention.transformer_blocks[0].attn2.ip_train_set() | |
attention.transformer_blocks[0].attn2.ip_train_set() | |
except: | |
continue | |
return model | |