import torch from open_clip import create_model from transformers import PretrainedConfig, PreTrainedModel, CLIPProcessor from transformers.models.clip.modeling_clip import CLIPOutput from typing import Optional, Tuple, Union class MarqoFashionCLIPConfig(PretrainedConfig): def __init__( self, open_clip_model_name: str = "", **kwargs, ): super().__init__(**kwargs) self.open_clip_model_name = open_clip_model_name class MarqoFashionCLIP(PreTrainedModel): config_class = MarqoFashionCLIPConfig def __init__(self, config: MarqoFashionCLIPConfig): super().__init__(config) self.config = config self.model = create_model(config.open_clip_model_name, output_dict=True) self.model.to(self.device) self.model.eval() def get_image_features( self, pixel_values: torch.FloatTensor, normalize: bool = False, **kwargs ) -> torch.FloatTensor: with torch.inference_mode(): image_features = self.model.encode_image(pixel_values, normalize=normalize) return image_features def get_text_features( self, input_ids: torch.Tensor, normalize: bool = False, **kwargs ) -> torch.FloatTensor: with torch.inference_mode(): text_features = self.model.encode_text(input_ids, normalize=normalize) return text_features def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPOutput]: vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True) text_outputs = self.get_text_features(input_ids=input_ids, normalize=True) logits_per_text = text_outputs @ vision_outputs.T logits_per_image = logits_per_text.T if not return_dict: return logits_per_image, logits_per_text, text_outputs, vision_outputs return CLIPOutput( logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_outputs, image_embeds=vision_outputs )