import re import warnings from typing import Dict, Optional import torch import torch.nn as nn from transformers import AutoConfig, AutoModel, PretrainedConfig from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) _HF_ARCH_DICT = { # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 'roberta': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'hidden_size', 'heads': 'num_attention_heads', 'layers': 'num_hidden_layers', 'layer_attr': 'layer', 'token_embeddings_attr': 'embeddings', }, 'pooler': 'mean_pooler', }, # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 'xlm-roberta': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'hidden_size', 'heads': 'num_attention_heads', 'layers': 'num_hidden_layers', 'layer_attr': 'layer', 'token_embeddings_attr': 'embeddings', }, 'pooler': 'mean_pooler', }, # https://huggingface.co/docs/transformers/model_doc/bert 'bert': { 'config_names': { 'context_length': 'max_position_embeddings', 'vocab_size': 'vocab_size', 'width': 'hidden_size', 'heads': 'num_attention_heads', 'layers': 'num_hidden_layers', }, 'pooler': 'cls_pooler', }, } _POOLERS = {} def _camel2snake(s): return re.sub(r'(? 0 ): self._lora_adaptation_map = self.transformer._adaptation_map self._supports_lora = True if ( hasattr(self.transformer, '_task_instructions') and len(self.transformer._task_instructions) > 0 ): self._task_instructions = self.transformer._task_instructions self._supports_task_instructions = True self.default_instruction_task = None self.default_lora_task = None self.default_instruction = None self.default_loraid = None if default_instruction_task is not None: self.default_instruction_task = default_instruction_task self.default_instruction = self.get_instruction_from_task( default_instruction_task ) if default_lora_task is not None: self.default_lora_task = default_lora_task self.default_loraid = self.get_loraid_from_task(default_lora_task) def get_instruction_from_task(self, task: str) -> Optional[str]: if self._supports_task_instructions: if task not in self._task_instructions: raise ValueError( f'Unsupported task \'{task}\'. Choose one of the following: ' f'{", ".join(self._task_instructions)} or set to None to disable ' f'task instructions completely' ) return self._task_instructions[task] else: warnings.warn( 'Model does not support task instructions, ignoring instruction ' f"task '{task}'" ) return None def get_loraid_from_task(self, task: str) -> Optional[int]: if self._supports_lora: if task not in self._lora_adaptation_map: raise ValueError( f'Unsupported task \'{task}\'. Choose one of the following: ' f'{", ".join(self._task_instructions)} or set to None to disable ' f'the LoRA adapters completely' ) return self._lora_adaptation_map[task] else: warnings.warn( f"Model does not support LoRA adapters, ignoring LoRA task '{task}'" ) return None @torch.jit.ignore def set_grad_checkpointing(self, _=True): self.transformer.gradient_checkpointing_enable() def init_parameters(self): pass def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None): attn_mask = (x != self.config.pad_token_id).long() kwargs = {} if adapter_mask is not None: kwargs['adapter_mask'] = adapter_mask out = self.transformer(input_ids=x, attention_mask=attn_mask, **kwargs) pooled_out = self.pooler(out, attn_mask) projected = self.proj(pooled_out) seqlen = out.last_hidden_state.shape[1] tokens = ( out.last_hidden_state[ :, torch.arange(seqlen) != self.pooler.cls_token_position, : ] if isinstance(self.pooler, ClsPooler) else out.last_hidden_state ) if self.output_tokens: return projected, tokens return projected def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): if not unlocked_layers: for n, p in self.transformer.named_parameters(): p.requires_grad = ( (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False ) return encoder = ( self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer ) layer_list = getattr( encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr'] ) print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model') embeddings = getattr( self.transformer, _HF_ARCH_DICT[self.config.model_type]['config_names'][ 'token_embeddings_attr' ], ) modules = [embeddings, *layer_list][:-unlocked_layers] # freeze layers for module in modules: for n, p in module.named_parameters(): p.requires_grad = ( (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False )