michellemoorre's picture
Initial commit
6c4dee3
raw
history blame
1.64 kB
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer
class FrozenCLIPEmbedder(nn.Module):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version).to(device)
self.device = device
self.hidden_size = self.transformer.config.hidden_size
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
).to(self.device)
with torch.cuda.amp.autocast():
outputs = self.transformer(**batch_encoding)
attn_bias = batch_encoding["attention_mask"].float()
attn_bias[attn_bias == 0] = -float("inf")
attn_bias[attn_bias == 1] = 0.0
outputs["attn_bias"] = attn_bias
return outputs
def encode(self, text):
return self(text)