|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import List, Optional, Tuple, Union |
|
from io import BytesIO |
|
import requests |
|
import base64 |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as f |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers import ( |
|
AutoImageProcessor, |
|
AutoTokenizer, |
|
BatchEncoding, |
|
BatchFeature, |
|
PreTrainedModel, |
|
logging, |
|
) |
|
from transformers.models.clip.modeling_clip import ( |
|
CLIPOutput, |
|
CLIPTextModelOutput, |
|
CLIPVisionModelOutput, |
|
clip_loss, |
|
) |
|
|
|
try: |
|
from tqdm.autonotebook import trange |
|
|
|
has_tqdm = True |
|
except ImportError: |
|
has_tqdm = False |
|
|
|
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig |
|
from .eva_model import EVAVisionTransformer |
|
from .hf_model import HFTextEncoder |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
""" Jina CLIP model implementation """ |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm (with cast back to input dtype).""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
origtype = x.dtype |
|
x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
return x.to(origtype) |
|
|
|
|
|
def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder: |
|
return HFTextEncoder( |
|
model_name_or_path=config.hf_model_name_or_path, |
|
output_dim=config.embed_dim, |
|
pooler_type=config.pooler_type, |
|
proj_type=config.proj_type, |
|
proj_bias=config.proj_bias, |
|
pretrained=False, |
|
output_tokens=False, |
|
trust_remote_code=True, |
|
revision=None, |
|
model_config_kwargs=config.hf_model_config_kwargs, |
|
) |
|
|
|
|
|
def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer: |
|
norm_layer = partial(LayerNorm, eps=1e-6) |
|
|
|
if config.fused_layer_norm: |
|
try: |
|
from apex.normalization import FusedLayerNorm |
|
|
|
norm_layer = partial(FusedLayerNorm, eps=1e-6) |
|
except (ModuleNotFoundError, ImportError): |
|
logger.warning('Please install apex to use fused layer norm, ignoring') |
|
|
|
return EVAVisionTransformer( |
|
img_size=config.image_size, |
|
patch_size=config.patch_size, |
|
num_classes=config.embed_dim, |
|
use_mean_pooling=False, |
|
init_values=config.ls_init_value, |
|
patch_dropout=config.patch_dropout, |
|
embed_dim=config.width, |
|
depth=config.layers, |
|
num_heads=config.width // config.head_width, |
|
mlp_ratio=config.mlp_ratio, |
|
qkv_bias=config.qkv_bias, |
|
drop_path_rate=config.drop_path_rate, |
|
norm_layer=norm_layer, |
|
xattn=config.x_attention, |
|
rope=config.rope_embeddings, |
|
postnorm=config.post_norm, |
|
pt_hw_seq_len=config.pt_hw_seq_len, |
|
intp_freq=config.intp_freq, |
|
naiveswiglu=config.naive_swiglu, |
|
subln=config.subln, |
|
proj_type=config.proj_type, |
|
) |
|
|
|
|
|
class JinaCLIPPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for |
|
downloading and loading pretrained models. |
|
""" |
|
|
|
config_class = JinaCLIPConfig |
|
base_model_prefix = 'clip' |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, JinaCLIPModel): |
|
if isinstance(module.text_projection, nn.Linear): |
|
nn.init.normal_( |
|
module.text_projection.weight, |
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, |
|
) |
|
if isinstance(module.text_projection, nn.Linear): |
|
nn.init.normal_( |
|
module.visual_projection.weight, |
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, |
|
) |
|
if isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class JinaCLIPTextModel(JinaCLIPPreTrainedModel): |
|
config_class = JinaCLIPTextConfig |
|
|
|
def __init__(self, config: JinaCLIPTextConfig): |
|
super().__init__(config) |
|
self.text_model = _build_text_tower(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
return_dict: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids |
|
feats = self.text_model(x=x) |
|
out = CLIPTextModelOutput(text_embeds=feats) |
|
return out if return_dict else out.to_tuple() |
|
|
|
|
|
class JinaCLIPVisionModel(JinaCLIPPreTrainedModel): |
|
config_class = JinaCLIPVisionConfig |
|
main_input_name = 'pixel_values' |
|
|
|
def __init__(self, config: JinaCLIPVisionConfig): |
|
super().__init__(config) |
|
self.vision_model = _build_vision_tower(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
return_dict: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
x = ( |
|
pixel_values.pixel_values |
|
if isinstance(pixel_values, BatchFeature) |
|
else pixel_values |
|
) |
|
feats = self.vision_model(x=x) |
|
out = CLIPVisionModelOutput(image_embeds=feats) |
|
return out if return_dict else out.to_tuple() |
|
|
|
|
|
class JinaCLIPModel(JinaCLIPPreTrainedModel): |
|
config_class = JinaCLIPConfig |
|
|
|
def __init__(self, config: JinaCLIPConfig): |
|
super().__init__(config) |
|
|
|
if not isinstance(config.text_config, JinaCLIPTextConfig): |
|
raise ValueError( |
|
'Attribute config.text_config is expected to be of type ' |
|
f'JinaCLIPTextConfig but is of type {type(config.text_config)}.' |
|
) |
|
|
|
if not isinstance(config.vision_config, JinaCLIPVisionConfig): |
|
raise ValueError( |
|
'Attribute config.vision_config is expected to be of type ' |
|
f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.' |
|
) |
|
|
|
text_config = config.text_config |
|
vision_config = config.vision_config |
|
|
|
if config.use_text_flash_attn is not None: |
|
text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn |
|
if config.use_vision_xformers is not None: |
|
vision_config.x_attention = config.use_vision_xformers |
|
|
|
self.add_projections = config.add_projections |
|
self.projection_dim = config.projection_dim |
|
self.text_embed_dim = text_config.embed_dim |
|
self.vision_embed_dim = vision_config.embed_dim |
|
|
|
self.text_model = _build_text_tower(text_config) |
|
self.vision_model = _build_vision_tower(vision_config) |
|
self.logit_scale = nn.Parameter( |
|
torch.tensor(self.config.logit_scale_init_value) |
|
) |
|
|
|
if self.add_projections: |
|
self.visual_projection = nn.Linear( |
|
self.vision_embed_dim, self.projection_dim, bias=False |
|
) |
|
self.text_projection = nn.Linear( |
|
self.text_embed_dim, self.projection_dim, bias=False |
|
) |
|
else: |
|
self.visual_projection = nn.Identity() |
|
self.text_projection = nn.Identity() |
|
|
|
self.tokenizer = None |
|
self.preprocess = None |
|
self.post_init() |
|
|
|
def get_tokenizer(self): |
|
if not self.tokenizer: |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.config._name_or_path, trust_remote_code=True |
|
) |
|
return self.tokenizer |
|
|
|
def get_preprocess(self): |
|
if not self.preprocess: |
|
self.preprocess = AutoImageProcessor.from_pretrained( |
|
self.config._name_or_path, trust_remote_code=True |
|
) |
|
return self.preprocess |
|
|
|
def get_text_features( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
*_, |
|
**__, |
|
) -> torch.FloatTensor: |
|
x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids |
|
return self.text_projection(self.text_model(x=x)) |
|
|
|
def get_image_features( |
|
self, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
*_, |
|
**__, |
|
) -> torch.FloatTensor: |
|
x = ( |
|
pixel_values.pixel_values |
|
if isinstance(pixel_values, BatchFeature) |
|
else pixel_values |
|
) |
|
return self.visual_projection(self.vision_model(x=x)) |
|
|
|
@torch.inference_mode() |
|
def encode_text( |
|
self, |
|
sentences: Union[str, List[str]], |
|
batch_size: int = 32, |
|
show_progress_bar: Optional[bool] = None, |
|
convert_to_numpy: bool = True, |
|
convert_to_tensor: bool = False, |
|
device: Optional[torch.device] = None, |
|
normalize_embeddings: bool = False, |
|
**tokenizer_kwargs, |
|
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: |
|
""" |
|
Computes sentence embeddings |
|
Args: |
|
sentences(`str` or `List[str]`): |
|
Sentence or sentences to be encoded |
|
batch_size(`int`, *optional*, defaults to 32): |
|
Batch size for the computation |
|
show_progress_bar(`bool`, *optional*, defaults to None): |
|
Show a progress bar when encoding sentences. |
|
If set to None, progress bar is only shown when |
|
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`. |
|
convert_to_numpy(`bool`, *optional*, defaults to True): |
|
If true, the output is a list of numpy vectors. |
|
Else, it is a list of pytorch tensors. |
|
convert_to_tensor(`bool`, *optional*, defaults to False): |
|
If true, you get one large tensor as return. |
|
Overwrites any setting from convert_to_numpy |
|
device(`torch.device`, *optional*, defaults to None): |
|
Which torch.device to use for the computation |
|
normalize_embeddings(`bool`, *optional*, defaults to False): |
|
If set to true, returned vectors will have length 1. In that case, |
|
the faster dot-product (util.dot_score) instead of cosine similarity |
|
can be used. |
|
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}): |
|
Keyword arguments for the tokenizer |
|
Returns: |
|
By default, a list of tensors is returned. |
|
If convert_to_tensor, a stacked tensor is returned. |
|
If convert_to_numpy, a numpy matrix is returned. |
|
""" |
|
is_training = self.training |
|
self.eval() |
|
all_embeddings = [] |
|
|
|
self.tokenizer = self.get_tokenizer() |
|
|
|
if show_progress_bar is None: |
|
show_progress_bar = ( |
|
logger.getEffectiveLevel() == logging.INFO |
|
or logger.getEffectiveLevel() == logging.DEBUG |
|
) |
|
|
|
if convert_to_tensor: |
|
convert_to_numpy = False |
|
|
|
input_was_string = False |
|
if isinstance(sentences, str) or not hasattr(sentences, '__len__'): |
|
sentences = [sentences] |
|
input_was_string = True |
|
|
|
if device is not None: |
|
self.to(device) |
|
|
|
permutation = np.argsort([-len(i) for i in sentences]) |
|
inverse_permutation = np.argsort(permutation) |
|
sentences = [sentences[idx] for idx in permutation] |
|
|
|
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True) |
|
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512) |
|
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True) |
|
|
|
if has_tqdm: |
|
range_iter = trange( |
|
0, |
|
len(sentences), |
|
batch_size, |
|
desc='Encoding', |
|
disable=not show_progress_bar, |
|
) |
|
else: |
|
range_iter = range(0, len(sentences), batch_size) |
|
|
|
for i in range_iter: |
|
encoded_input = self.tokenizer( |
|
sentences[i : i + batch_size], |
|
return_tensors='pt', |
|
**tokenizer_kwargs, |
|
).to(self.device) |
|
|
|
embeddings = self.get_text_features(input_ids=encoded_input) |
|
if normalize_embeddings: |
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
if convert_to_numpy: |
|
embeddings = embeddings.cpu() |
|
all_embeddings.extend(embeddings) |
|
|
|
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation] |
|
|
|
if convert_to_tensor: |
|
all_embeddings = torch.stack(all_embeddings) |
|
elif convert_to_numpy: |
|
all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings]) |
|
|
|
if input_was_string: |
|
all_embeddings = all_embeddings[0] |
|
|
|
self.train(is_training) |
|
return all_embeddings |
|
|
|
def decode_data_image(data_image_str): |
|
header, data = data_image_str.split(',', 1) |
|
image_data = base64.b64decode(data) |
|
return Image.open(BytesIO(image_data)) |
|
|
|
@torch.inference_mode() |
|
def encode_image( |
|
self, |
|
images: Union[str, List[Union[str, "Image.Image"]]], |
|
batch_size: int = 32, |
|
show_progress_bar: Optional[bool] = None, |
|
convert_to_numpy: bool = True, |
|
convert_to_tensor: bool = False, |
|
device: Optional[torch.device] = None, |
|
normalize_embeddings: bool = False, |
|
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: |
|
""" |
|
Computes image embeddings. |
|
|
|
Args: |
|
images(`str` or `List[Union[str, Image.Image]]`): |
|
image paths, URLs, PIL images, or data:image/ strings to be encoded |
|
batch_size(`int`, *optional*, defaults to 32): |
|
Batch size for the computation |
|
show_progress_bar(`bool`, *optional*, defaults to None): |
|
Show a progress bar when encoding images. |
|
If set to None, progress bar is only shown when |
|
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`. |
|
convert_to_numpy(`bool`, *optional*, defaults to True): |
|
If true, the output is a list of numpy vectors. |
|
Else, it is a list of pytorch tensors. |
|
convert_to_tensor(`bool`, *optional*, defaults to False): |
|
If true, you get one large tensor as return. |
|
Overwrites any setting from convert_to_numpy |
|
device(`torch.device`, *optional*, defaults to None): |
|
Which torch.device to use for the computation |
|
normalize_embeddings(`bool`, *optional*, defaults to False): |
|
If set to true, returned vectors will have length 1. In that case, |
|
the faster dot-product (util.dot_score) instead of cosine similarity |
|
can be used. |
|
Returns: |
|
By default, a list of tensors is returned. |
|
If convert_to_tensor, a stacked tensor is returned. |
|
If convert_to_numpy, a numpy matrix is returned. |
|
""" |
|
|
|
is_training = self.training |
|
self.eval() |
|
|
|
self.preprocess = self.get_preprocess() |
|
all_embeddings = [] |
|
|
|
if show_progress_bar is None: |
|
show_progress_bar = ( |
|
logger.getEffectiveLevel() == logging.INFO |
|
or logger.getEffectiveLevel() == logging.DEBUG |
|
) |
|
|
|
if convert_to_tensor: |
|
convert_to_numpy = False |
|
|
|
input_was_single_img = False |
|
if isinstance(images, str) or not hasattr(images, '__len__'): |
|
images = [images] |
|
input_was_single_img = True |
|
|
|
if device is not None: |
|
self.to(device) |
|
|
|
permutation = np.argsort([-len(str(i)) for i in images]) |
|
inverse_permutation = np.argsort(permutation) |
|
images = [images[idx] for idx in permutation] |
|
|
|
if has_tqdm: |
|
range_iter = trange( |
|
0, |
|
len(images), |
|
batch_size, |
|
desc='Encoding', |
|
disable=not show_progress_bar, |
|
) |
|
else: |
|
range_iter = range(0, len(images), batch_size) |
|
|
|
from PIL import Image |
|
|
|
for i in range_iter: |
|
batch_images = images[i:i+batch_size] |
|
processed_inputs = [] |
|
|
|
for img in batch_images: |
|
if isinstance(img, str): |
|
if img.startswith('http'): |
|
response = requests.get(img) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
elif img.startswith('data:image/'): |
|
image = decode_data_image(img).convert('RGB') |
|
else: |
|
image = Image.open(img).convert('RGB') |
|
elif isinstance(img, Image.Image): |
|
image = img.convert('RGB') |
|
else: |
|
raise ValueError("Unsupported image format") |
|
|
|
processed_inputs.append(image) |
|
|
|
processed_inputs = self.preprocess(processed_inputs) |
|
processed_inputs = processed_inputs.to(self.device) |
|
embeddings = self.get_image_features(processed_inputs) |
|
|
|
if normalize_embeddings: |
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
if convert_to_numpy: |
|
embeddings = embeddings.cpu() |
|
all_embeddings.extend(embeddings) |
|
|
|
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation] |
|
|
|
if convert_to_tensor: |
|
all_embeddings = torch.stack(all_embeddings) |
|
elif convert_to_numpy: |
|
all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings]) |
|
|
|
if input_was_single_img: |
|
all_embeddings = all_embeddings[0] |
|
|
|
self.train(is_training) |
|
return all_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: Union[None, torch.Tensor, BatchEncoding] = None, |
|
pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None, |
|
return_dict: Optional[bool] = None, |
|
return_loss: Optional[bool] = None, |
|
*_, |
|
**__, |
|
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
image_embeds = self.get_image_features(pixel_values=pixel_values) |
|
text_embeds = self.get_text_features(input_ids=input_ids) |
|
|
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
|
logits_per_image = logits_per_text.t() |
|
|
|
loss = None |
|
if return_loss: |
|
loss = clip_loss(logits_per_text) |
|
|
|
if not return_dict: |
|
output = ( |
|
logits_per_image, |
|
logits_per_text, |
|
text_embeds, |
|
image_embeds, |
|
None, |
|
None, |
|
) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CLIPOutput( |
|
loss=loss, |
|
logits_per_image=logits_per_image, |
|
logits_per_text=logits_per_text, |
|
text_embeds=text_embeds, |
|
image_embeds=image_embeds, |
|
text_model_output=None, |
|
vision_model_output=None, |
|
) |
|
|