Spaces:
Build error
Build error
import pathlib | |
import tempfile | |
from collections import OrderedDict | |
from typing import Tuple, Union | |
import logging | |
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from timm.models.layers import DropPath, trunc_normal_ | |
from .image_encoder import build_image_encoder | |
from .text_encoder import build_text_encoder | |
from .text_encoder import build_tokenizer | |
from .templates import DEFAULT_TEMPLATES | |
logger = logging.getLogger(__name__) | |
class UniCLModel(nn.Module): | |
def __init__(self, config: dict,): | |
super().__init__() | |
self.conf_lang_encoder = config['MODEL']['TEXT_ENCODER'] | |
self.tokenizer = build_tokenizer(self.conf_lang_encoder) | |
self.text_encoder = build_text_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE']) | |
dim_projection = config['MODEL']['DIM_PROJECTION'] | |
if hasattr(self.text_encoder, 'dim_out'): | |
dim_out = self.text_encoder.dim_out | |
else: | |
with torch.no_grad(): | |
dim_out = self.text_encoder( | |
torch.zeros(1,1).type(torch.LongTensor) | |
)['last_hidden_state'].size(2) | |
self.text_projection = nn.Parameter(torch.empty(dim_out, dim_projection)) | |
self.conf_image_encoder = config['MODEL']['IMAGE_ENCODER'] | |
self.image_encoder = build_image_encoder(self.conf_image_encoder) | |
self.image_projection = nn.Parameter( | |
torch.empty(self.image_encoder.dim_out, dim_projection) | |
) | |
self.logit_scale = nn.Parameter(torch.ones([])) | |
trunc_normal_(self.text_projection, std=.02) | |
trunc_normal_(self.image_projection, std=.02) | |
def _convert_old_weights(self, model_dict): | |
model_dict_updated = {} | |
for k, v in model_dict.items(): | |
if k.startswith('visual.'): | |
model_dict_updated['image_encoder.'+k[7:]] = v | |
elif k.startswith('text.'): | |
model_dict_updated['lang_encoder.'+k[5:]] = v | |
elif k == 'vision_projection': | |
model_dict_updated['image_projection'] = v | |
elif k == 'text_projection': | |
model_dict_updated['text_projection'] = v | |
else: | |
model_dict_updated[k] = v | |
return model_dict_updated | |
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): | |
if not os.path.isfile(pretrained): | |
logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight') | |
return | |
pretrained_dict = torch.load(pretrained, map_location='cpu') | |
logger.info(f'=> Loading pretrained model {pretrained}') | |
pretrained_dict = self._convert_old_weights(pretrained_dict) | |
model_dict = self.state_dict() | |
pretrained_dict = { | |
k: v for k, v in pretrained_dict.items() | |
if k in model_dict.keys() | |
} | |
need_init_state_dict = {} | |
image_encoder_state_dict = {} | |
for k, v in pretrained_dict.items(): | |
need_init = ( | |
k.split('.')[0] in pretrained_layers | |
or pretrained_layers[0] == '*' | |
) | |
if need_init: | |
if k.startswith('image_encoder.'): | |
image_encoder_state_dict[k] = v | |
else: | |
if verbose: | |
logger.info(f'=> init {k} from {pretrained}') | |
need_init_state_dict[k] = v | |
self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose) | |
self.load_state_dict(need_init_state_dict, strict=False) | |
def no_weight_decay(self): | |
no_weight_decay = {'logit_scale'} | |
if hasattr(self.text_encoder, 'no_weight_decay'): | |
for k in self.text_encoder.no_weight_decay(): | |
no_weight_decay.add('lang_encoder.'+k) | |
if hasattr(self.image_encoder, 'no_weight_decay'): | |
for k in self.image_encoder.no_weight_decay(): | |
no_weight_decay.add('image_encoder.'+k) | |
return no_weight_decay | |
def dtype(self): | |
return self.logit_scale.dtype | |
def get_imnet_embeddings(self): | |
templates = IMAGENET_DEFAULT_TEMPLATES[:1] | |
clss_embeddings = [] | |
for clss in IMAGENET_CLASSES: | |
txts = [template.format(clss) for template in templates] | |
tokens = self.tokenizer( | |
txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' | |
) | |
tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} | |
clss_embedding = self.encode_text(tokens) | |
clss_embedding = clss_embedding.mean(dim=0) | |
clss_embedding /= clss_embedding.norm() | |
clss_embeddings.append(clss_embedding) | |
imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) | |
return imnet_text_embeddings | |
def get_text_embeddings(self, texts): | |
templates = DEFAULT_TEMPLATES[:1] | |
clss_embeddings = [] | |
for clss in texts: | |
txts = [template.format(clss) for template in templates] | |
tokens = self.tokenizer( | |
txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' | |
) | |
tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} | |
clss_embedding = self.encode_text(tokens) | |
clss_embedding = clss_embedding.mean(dim=0) | |
clss_embedding /= clss_embedding.norm() | |
clss_embeddings.append(clss_embedding) | |
imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) | |
return imnet_text_embeddings | |
def encode_image(self, image, norm=True, output_map=False): | |
x = self.image_encoder.forward_features(image, output_map=output_map) | |
if output_map: | |
x, x_map = x | |
x = x @ self.image_projection | |
if output_map: | |
x_map = self.image_projection.unsqueeze(0).transpose(1, 2) @ x_map | |
if norm: | |
x = x / x.norm(dim=-1, keepdim=True) | |
if output_map: | |
x_map = x_map / x_map.norm(dim=1, keepdim=True) | |
if output_map: | |
return x, x_map | |
else: | |
return x | |
def encode_text(self, text, norm=True): | |
x = self.text_encoder(**text) | |
x = x['last_hidden_state'] | |
if self.conf_lang_encoder['TOKENIZER'] == 'clip': | |
x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)] | |
else: | |
x = x[:, 0] | |
x = x @ self.text_projection | |
if norm: | |
x = x / x.norm(dim=-1, keepdim=True) | |
return x | |
def forward(self, image, text): | |
features_image = self.encode_image(image) | |
features_text = self.encode_text(text) | |
# cosine similarity as logits | |
T = self.logit_scale.exp() | |
return features_image, features_text, T | |
def build_unicl_model(config, **kwargs): | |
model = UniCLModel(config) | |
if config['MODEL']['PRETRAINED'] != '': | |
pretrained_path = config['MODEL']['PRETRAINED'] | |
from ..Utils.Utils import is_valid_url, download_file | |
if is_valid_url(pretrained_path): | |
with tempfile.TemporaryDirectory() as tmp_path: | |
file_local_path = pathlib.Path(tmp_path) / 'base_model.pt' | |
download_file(pretrained_path, file_local_path) | |
model.from_pretrained(str(file_local_path), config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) | |
else: | |
model.from_pretrained(pretrained_path, config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) | |
return model | |