Spaces:
Paused
Paused
# A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel | |
# The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range) | |
import torch | |
import types | |
import einops | |
from abc import ABCMeta | |
from transformers import CLIPVisionModelWithProjection | |
def preprocess(image): | |
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None] | |
std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None] | |
scale = 16 / min(image.shape[2], image.shape[3]) | |
image = torch.nn.functional.interpolate( | |
image, | |
size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])), | |
mode="bicubic", | |
antialias=True | |
) | |
return (image - mean) / std | |
def arbitrary_positional_encoding(p, H, W): | |
weight = p.weight | |
cls = weight[:1] | |
pos = weight[1:] | |
pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16) | |
pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest") | |
pos = einops.rearrange(pos, '1 C H W -> (H W) C') | |
weight = torch.cat([cls, pos])[None] | |
return weight | |
def improved_clipvision_embedding_forward(self, pixel_values): | |
pixel_values = pixel_values * 0.5 + 0.5 | |
pixel_values = preprocess(pixel_values) | |
batch_size = pixel_values.shape[0] | |
target_dtype = self.patch_embedding.weight.dtype | |
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) | |
B, C, H, W = patch_embeds.shape | |
patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C') | |
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | |
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | |
embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W) | |
return embeddings | |
class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta): | |
def __init__(self, config): | |
super().__init__(config) | |
self.vision_model.embeddings.forward = types.MethodType( | |
improved_clipvision_embedding_forward, | |
self.vision_model.embeddings | |
) | |