rlawjdghek's picture
stableviton
80ccb59
raw
history blame
1.55 kB
import torch.nn as nn
from transformers import CLIPVisionModel
from .xf import LayerNorm, Transformer
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14"):
super().__init__()
self.transformer = CLIPVisionModel.from_pretrained(version)
self.final_ln = LayerNorm(1024)
self.mapper = Transformer(
1,
1024,
5,
1,
)
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
for param in self.mapper.parameters():
param.requires_grad = True
for param in self.final_ln.parameters():
param.requires_grad = True
def forward(self, image):
outputs = self.transformer(pixel_values=image)
z = outputs.pooler_output
z = z.unsqueeze(1)
z = self.mapper(z)
z = self.final_ln(z)
return z
def encode(self, image):
if isinstance(image, list):
image = image[0]
return self(image)
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPImageEmbedder()
count_params(model, verbose=True)