Spaces:
Paused
Paused
# This code is modified from https://github.com/openai/CLIP/blob/main/clip/clip.py | |
# Modified by Xingyi Zhou | |
# The original code is under MIT license | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
from typing import Union, List | |
from collections import OrderedDict | |
import torch | |
from torch import nn | |
import torch | |
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer | |
__all__ = ["tokenize"] | |
count = 0 | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
def attention(self, x: torch.Tensor): | |
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None | |
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |
def forward(self, x: torch.Tensor): | |
x = x + self.attention(self.ln_1(x)) | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential( | |
*[ResidualAttentionBlock(width, heads, attn_mask) \ | |
for _ in range(layers)]) | |
def forward(self, x: torch.Tensor): | |
return self.resblocks(x) | |
class CLIPTEXT(nn.Module): | |
def __init__(self, | |
embed_dim=512, | |
# text | |
context_length=77, | |
vocab_size=49408, | |
transformer_width=512, | |
transformer_heads=8, | |
transformer_layers=12 | |
): | |
super().__init__() | |
self._tokenizer = _Tokenizer() | |
self.context_length = context_length | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask() | |
) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
self.ln_final = LayerNorm(transformer_width) | |
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.initialize_parameters() | |
def initialize_parameters(self): | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.normal_(self.positional_embedding, std=0.01) | |
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) | |
attn_std = self.transformer.width ** -0.5 | |
fc_std = (2 * self.transformer.width) ** -0.5 | |
for block in self.transformer.resblocks: | |
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
if self.text_projection is not None: | |
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def device(self): | |
return self.text_projection.device | |
def dtype(self): | |
return self.text_projection.dtype | |
def tokenize(self, | |
texts: Union[str, List[str]], \ | |
context_length: int = 77) -> torch.LongTensor: | |
""" | |
""" | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = self._tokenizer.encoder["<|startoftext|>"] | |
eot_token = self._tokenizer.encoder["<|endoftext|>"] | |
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] | |
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
for i, tokens in enumerate(all_tokens): | |
if len(tokens) > context_length: | |
st = torch.randint( | |
len(tokens) - context_length + 1, (1,))[0].item() | |
tokens = tokens[st: st + context_length] | |
# raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") | |
result[i, :len(tokens)] = torch.tensor(tokens) | |
return result | |
def encode_text(self, text): | |
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] | |
x = x + self.positional_embedding.type(self.dtype) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x).type(self.dtype) | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
return x | |
def forward(self, captions): | |
''' | |
captions: list of strings | |
''' | |
text = self.tokenize(captions).to(self.device) # B x L x D | |
features = self.encode_text(text) # B x D | |
return features | |
def build_text_encoder(pretrain=True): | |
text_encoder = CLIPTEXT() | |
if pretrain: | |
import clip | |
pretrained_model, _ = clip.load("ViT-B/32", device='cpu') | |
state_dict = pretrained_model.state_dict() | |
to_delete_keys = ["logit_scale", "input_resolution", \ | |
"context_length", "vocab_size"] + \ | |
[k for k in state_dict.keys() if k.startswith('visual.')] | |
for k in to_delete_keys: | |
if k in state_dict: | |
del state_dict[k] | |
print('Loading pretrained CLIP') | |
text_encoder.load_state_dict(state_dict) | |
# import pdb; pdb.set_trace() | |
return text_encoder |