import torch.nn as nn from transformers import CLIPVisionModel from .xf import LayerNorm, Transformer class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class FrozenCLIPImageEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__(self, version="openai/clip-vit-large-patch14"): super().__init__() self.transformer = CLIPVisionModel.from_pretrained(version) self.final_ln = LayerNorm(1024) self.mapper = Transformer( 1, 1024, 5, 1, ) self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False for param in self.mapper.parameters(): param.requires_grad = True for param in self.final_ln.parameters(): param.requires_grad = True def forward(self, image): outputs = self.transformer(pixel_values=image) z = outputs.pooler_output z = z.unsqueeze(1) z = self.mapper(z) z = self.final_ln(z) return z def encode(self, image): if isinstance(image, list): image = image[0] return self(image) if __name__ == "__main__": from ldm.util import count_params model = FrozenCLIPImageEmbedder() count_params(model, verbose=True)