import re from typing import Dict, Optional, Tuple import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PretrainedConfig from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) """ HF architecture mapping """ _HF_ARCH_DICT = { # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 'roberta': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'hidden_size', 'heads': 'num_attention_heads', 'layers': 'num_hidden_layers', 'layer_attr': 'layer', 'token_embeddings_attr': 'embeddings', }, 'pooler': 'mean_pooler', }, # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 'xlm-roberta': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'hidden_size', 'heads': 'num_attention_heads', 'layers': 'num_hidden_layers', 'layer_attr': 'layer', 'token_embeddings_attr': 'embeddings', }, 'pooler': 'mean_pooler', }, # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 'mt5': { 'config_names': { # unlimited seqlen # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 'context_length': '', 'vocab_size': 'vocab_size', 'width': 'd_model', 'heads': 'num_heads', 'layers': 'num_layers', 'layer_attr': 'block', 'token_embeddings_attr': 'embed_tokens', }, 'pooler': 'mean_pooler', }, # https://huggingface.co/docs/transformers/model_doc/bert 'bert': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'hidden_size', 'heads': 'num_attention_heads', 'layers': 'num_hidden_layers', }, 'pooler': 'cls_pooler', }, # https://huggingface.co/docs/transformers/model_doc/m2m_100 'm2m_100': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'd_model', 'heads': 'encoder_attention_heads', 'layers': 'encoder_layers', }, 'pooler': 'cls_pooler', }, } """ Pooling functions """ _POOLERS = {} def _camel2snake(s): return re.sub(r'(? Tuple[torch.Tensor, torch.Tensor]: if self.pool_type == 'avg': pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] elif self.pool_type == 'tok': pooled, tokens = x[:, 0], x[:, 1:] else: pooled = tokens = x return pooled, tokens def forward(self, x: torch.Tensor): # returns a tuple of (final hidden states, token pooled outputs) x = self.transformer(x)[0] pooled, tokens = self._global_pool(x) projected = self.proj(pooled) return projected def lock(self, unlocked_layers: int = 0, freeze_bn_stats: bool = True): if not unlocked_layers: # full freezing for n, p in self.transformer.named_parameters(): p.requires_grad = ( (not freeze_bn_stats) if 'LayerNorm' in n.split('.') else False ) return # TODO: make it work if unlocked_layers !=0 encoder = ( self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer ) layer_list = getattr( encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr'] ) print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model') embeddings = getattr( self.transformer, _HF_ARCH_DICT[self.config.model_type]['config_names'][ 'token_embeddings_attr' ], ) modules = [embeddings, *layer_list][:-unlocked_layers] # freeze layers for module in modules: for n, p in module.named_parameters(): p.requires_grad = ( (not freeze_bn_stats) if 'LayerNorm' in n.split('.') else False ) @torch.jit.ignore def set_grad_checkpointing(self, *_, **__): self.transformer.gradient_checkpointing_enable() def init_parameters(self): pass