gmastrapas's picture
feat: jina clip v2 implementation
cd77b48
raw
history blame
9.41 kB
import re
from typing import Dict, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
_HF_ARCH_DICT = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
'roberta': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
'layer_attr': 'layer',
'token_embeddings_attr': 'embeddings',
},
'pooler': 'mean_pooler',
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
'xlm-roberta': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
'layer_attr': 'layer',
'token_embeddings_attr': 'embeddings',
},
'pooler': 'mean_pooler',
},
# https://huggingface.co/docs/transformers/model_doc/bert
'bert': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
},
'pooler': 'cls_pooler',
},
}
_POOLERS = {}
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
@staticmethod
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
@staticmethod
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
masked_output = x.last_hidden_state.masked_fill(
attention_mask.unsqueeze(-1), -torch.inf
)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
def __init__(self, use_pooler_output: bool = True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, _: torch.Tensor):
if (
self.use_pooler_output
and isinstance(
x,
(
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
),
)
and (x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj_type: str = None,
proj_bias: bool = False,
pretrained: bool = True,
output_tokens: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
model_config_kwargs: Optional[Dict] = None,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
model_config_kwargs = model_config_kwargs or {}
if config is None:
self.config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
)
self.config.update(model_config_kwargs)
create_func, model_args = (
(AutoModel.from_pretrained, model_name_or_path)
if pretrained
else (AutoModel.from_config, self.config)
)
if (
hasattr(self.config, 'is_encoder_decoder')
and self.config.is_encoder_decoder
):
self.transformer = create_func(
model_args,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
**model_config_kwargs,
)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(
model_args,
trust_remote_code=trust_remote_code,
revision=revision,
add_pooling_layer=False,
code_revision=code_revision,
**model_config_kwargs,
)
else:
self.config = config
self.config.update(model_config_kwargs)
self.transformer = AutoModel.from_config(
self.config,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
)
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler']
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(
self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
)
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif (d_model != output_dim) or proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=proj_bias),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=proj_bias),
)
self._task_instructions = {}
self._lora_adaptation_map = {}
self._supports_task_instructions = False
self._supports_lora = False
if (
hasattr(self.transformer, '_adaptation_map')
and len(self.transformer._adaptation_map) > 0
):
self._lora_adaptation_map = self.transformer._adaptation_map
self._supports_lora = True
if (
hasattr(self.transformer, '_task_instructions')
and len(self.transformer._task_instructions) > 0
):
self._task_instructions = self.transformer._task_instructions
self._supports_task_instructions = True
@torch.jit.ignore
def set_grad_checkpointing(self, _=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
attn_mask = (x != self.config.pad_token_id).long()
kwargs = {}
if adapter_mask is not None:
kwargs['adapter_mask'] = adapter_mask
out = self.transformer(input_ids=x, attention_mask=attn_mask, **kwargs)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
seqlen = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[
:, torch.arange(seqlen) != self.pooler.cls_token_position, :
]
if isinstance(self.pooler, ClsPooler)
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers:
for n, p in self.transformer.named_parameters():
p.requires_grad = (
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
)
return
encoder = (
self.transformer.encoder
if hasattr(self.transformer, 'encoder')
else self.transformer
)
layer_list = getattr(
encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
)
print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
embeddings = getattr(
self.transformer,
_HF_ARCH_DICT[self.config.model_type]['config_names'][
'token_embeddings_attr'
],
)
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
)