gmastrapas's picture
fix: various fixes
cdebfc7
raw
history blame
11.3 kB
import re
import warnings
from typing import Dict, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
_HF_ARCH_DICT = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
'roberta': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
'layer_attr': 'layer',
'token_embeddings_attr': 'embeddings',
},
'pooler': 'mean_pooler',
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
'xlm-roberta': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
'layer_attr': 'layer',
'token_embeddings_attr': 'embeddings',
},
'pooler': 'mean_pooler',
},
# https://huggingface.co/docs/transformers/model_doc/bert
'bert': {
'config_names': {
'context_length': 'max_position_embeddings',
'vocab_size': 'vocab_size',
'width': 'hidden_size',
'heads': 'num_attention_heads',
'layers': 'num_hidden_layers',
},
'pooler': 'cls_pooler',
},
}
_POOLERS = {}
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
@staticmethod
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
@staticmethod
def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
masked_output = x.last_hidden_state.masked_fill(
attention_mask.unsqueeze(-1), -torch.inf
)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
def __init__(self, use_pooler_output: bool = True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, _: torch.Tensor):
if (
self.use_pooler_output
and isinstance(
x,
(
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
),
)
and (x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
config: PretrainedConfig = None,
pooler_type: str = None,
proj_type: str = None,
proj_bias: bool = False,
pretrained: bool = True,
output_tokens: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
default_instruction_task: Optional[str] = None,
default_lora_task: Optional[str] = None,
model_config_kwargs: Optional[Dict] = None,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
model_config_kwargs = model_config_kwargs or {}
if config is None:
if pretrained:
self.transformer = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
revision=revision,
add_pooling_layer=False,
code_revision=code_revision,
**model_config_kwargs,
)
self.config = self.transformer.config
else:
self.config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
code_revision=code_revision,
)
self.config.update(model_config_kwargs)
self.transformer = AutoModel.from_config(
self.config,
trust_remote_code=trust_remote_code,
add_pooling_layer=False,
code_revision=code_revision,
)
if (
hasattr(self.config, 'is_encoder_decoder')
and self.config.is_encoder_decoder
):
self.transformer = self.transformer.encoder
else:
self.config = config
self.config.update(model_config_kwargs)
self.transformer = AutoModel.from_config(
self.config,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
)
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
pooler_type = pooler_type or _HF_ARCH_DICT[self.config.model_type]['pooler']
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(
self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
)
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif (d_model != output_dim) or proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=proj_bias),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=proj_bias),
)
self._task_instructions = {}
self._lora_adaptation_map = {}
self._supports_task_instructions = False
self._supports_lora = False
if (
hasattr(self.transformer, '_adaptation_map')
and len(self.transformer._adaptation_map) > 0
):
self._lora_adaptation_map = self.transformer._adaptation_map
self._supports_lora = True
if (
hasattr(self.transformer, '_task_instructions')
and len(self.transformer._task_instructions) > 0
):
self._task_instructions = self.transformer._task_instructions
self._supports_task_instructions = True
self.default_instruction_task = None
self.default_lora_task = None
self.default_instruction = None
self.default_loraid = None
if default_instruction_task is not None:
self.default_instruction_task = default_instruction_task
self.default_instruction = self.get_instruction_from_task(
default_instruction_task
)
if default_lora_task is not None:
self.default_lora_task = default_lora_task
self.default_loraid = self.get_loraid_from_task(default_lora_task)
def get_instruction_from_task(self, task: str) -> Optional[str]:
if self._supports_task_instructions:
if task not in self._task_instructions:
raise ValueError(
f'Unsupported task \'{task}\'. Choose one of the following: '
f'{", ".join(self._task_instructions)} or set to None to disable '
f'task instructions completely'
)
return self._task_instructions[task]
else:
warnings.warn(
'Model does not support task instructions, ignoring instruction '
f"task '{task}'"
)
return None
def get_loraid_from_task(self, task: str) -> Optional[int]:
if self._supports_lora:
if task not in self._lora_adaptation_map:
raise ValueError(
f'Unsupported task \'{task}\'. Choose one of the following: '
f'{", ".join(self._task_instructions)} or set to None to disable '
f'the LoRA adapters completely'
)
return self._lora_adaptation_map[task]
else:
warnings.warn(
f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
)
return None
@torch.jit.ignore
def set_grad_checkpointing(self, _=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
attn_mask = (x != self.config.pad_token_id).long()
kwargs = {}
if adapter_mask is not None:
kwargs['adapter_mask'] = adapter_mask
out = self.transformer(input_ids=x, attention_mask=attn_mask, **kwargs)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
seqlen = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[
:, torch.arange(seqlen) != self.pooler.cls_token_position, :
]
if isinstance(self.pooler, ClsPooler)
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers:
for n, p in self.transformer.named_parameters():
p.requires_grad = (
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
)
return
encoder = (
self.transformer.encoder
if hasattr(self.transformer, 'encoder')
else self.transformer
)
layer_list = getattr(
encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
)
print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
embeddings = getattr(
self.transformer,
_HF_ARCH_DICT[self.config.model_type]['config_names'][
'token_embeddings_attr'
],
)
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (
(not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
)