Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,546 Bytes
80ccb59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch.nn as nn
from transformers import CLIPVisionModel
from .xf import LayerNorm, Transformer
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14"):
super().__init__()
self.transformer = CLIPVisionModel.from_pretrained(version)
self.final_ln = LayerNorm(1024)
self.mapper = Transformer(
1,
1024,
5,
1,
)
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
for param in self.mapper.parameters():
param.requires_grad = True
for param in self.final_ln.parameters():
param.requires_grad = True
def forward(self, image):
outputs = self.transformer(pixel_values=image)
z = outputs.pooler_output
z = z.unsqueeze(1)
z = self.mapper(z)
z = self.final_ln(z)
return z
def encode(self, image):
if isinstance(image, list):
image = image[0]
return self(image)
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPImageEmbedder()
count_params(model, verbose=True) |