GraCo / isegm /model /modeling /clip_text_encoding.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
1.11 kB
import torch
from torch import nn
from .clip import clip
class ClipTextEncoder(nn.Module):
def __init__(self, clip_enocder_name="ViT-B/32", embedding_dim=512, out_dim=768):
super().__init__()
assert clip_enocder_name in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = clip.load(clip_enocder_name, device=self.device)
# freeze model
for _, param in self.model.named_parameters():
param.requires_grad = False
self.out_proj = nn.Linear(embedding_dim, out_dim)
nn.init.zeros_(self.out_proj.bias)
@torch.no_grad()
def forward(self, prompt):
'''
prompt: text tokens
'''
text_features = self.model.encode_text(prompt).type(torch.float32)
# norm
# text_features /= text_features.norm(dim=-1, keepdim=True) # [bs, 1024]
# proj
text_features = self.out_proj(text_features)
return text_features