|
import torch |
|
from open_clip import create_model |
|
from transformers import PretrainedConfig, PreTrainedModel, CLIPProcessor |
|
from transformers.models.clip.modeling_clip import CLIPOutput |
|
from typing import Optional, Tuple, Union |
|
|
|
class MarqoFashionCLIPConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
open_clip_model_name: str = "", |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.open_clip_model_name = open_clip_model_name |
|
|
|
|
|
class MarqoFashionCLIP(PreTrainedModel): |
|
config_class = MarqoFashionCLIPConfig |
|
|
|
def __init__(self, config: MarqoFashionCLIPConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.model = create_model(config.open_clip_model_name, output_dict=True) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def get_image_features( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
normalize: bool = False, |
|
**kwargs |
|
) -> torch.FloatTensor: |
|
|
|
with torch.inference_mode(): |
|
image_features = self.model.encode_image(pixel_values, normalize=normalize) |
|
return image_features |
|
|
|
def get_text_features( |
|
self, |
|
input_ids: torch.Tensor, |
|
normalize: bool = False, |
|
**kwargs |
|
) -> torch.FloatTensor: |
|
|
|
with torch.inference_mode(): |
|
text_features = self.model.encode_text(input_ids, normalize=normalize) |
|
return text_features |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CLIPOutput]: |
|
|
|
vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True) |
|
text_outputs = self.get_text_features(input_ids=input_ids, normalize=True) |
|
|
|
logits_per_text = text_outputs @ vision_outputs.T |
|
logits_per_image = logits_per_text.T |
|
|
|
if not return_dict: |
|
return logits_per_image, logits_per_text, text_outputs, vision_outputs |
|
|
|
return CLIPOutput( |
|
logits_per_image=logits_per_image, |
|
logits_per_text=logits_per_text, |
|
text_embeds=text_outputs, |
|
image_embeds=vision_outputs |
|
) |
|
|