from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import CLIPTextModel, CLIPTokenizer class FrozenCLIPEmbedder(nn.Module): """Uses the CLIP transformer encoder for text (from huggingface)""" def __init__( self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, ): # clip-vit-base-patch32 super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version).to(device) self.device = device self.hidden_size = self.transformer.config.hidden_size self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ).to(self.device) with torch.cuda.amp.autocast(): outputs = self.transformer(**batch_encoding) attn_bias = batch_encoding["attention_mask"].float() attn_bias[attn_bias == 0] = -float("inf") attn_bias[attn_bias == 1] = 0.0 outputs["attn_bias"] = attn_bias return outputs def encode(self, text): return self(text)