# coding=utf-8 # TODO: Add license # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch PagnolXl model.""" import math from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) from .configuration_pagnolxl import PagnolXlConfig logger = logging.get_logger(__name__) PAGNOLXL_PRETRAINED_MODEL_ARCHIVE_LIST = [ "XXXX/pagnol-xl", ] _CHECKPOINT_FOR_DOC = "XXXX/pagnol-xl" _CONFIG_FOR_DOC = "PagnolXlConfig" class PagnolXlEmbeddings(nn.Module): """Implementation of the PagnolXl Embedding layer. Parameters ---------- vocab_size: int, size of the vocabulary. d_model: int, Dimension of the hidden representations. sigma: int, default 0.02, standard deviation for the Gaussian initialization of the embedding weights. """ def __init__(self, config: PagnolXlConfig): super().__init__() self.embedding = nn.Embedding(config.vocab_size, config.d_model) def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: return self.embedding(input_ids) # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class PagnoXlRotaryEmbeddings(nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX and Falcon. This implementation is designed to operate on queries and keys that are compatible with `[batch_size, n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format). """ def __init__(self, config: PagnolXlConfig): super().__init__() assert ( config.d_model % config.n_heads == 0 ), "d_model must be divisible by n_heads. Currently d_model: {}, n_heads: {}".format( config.d_model, config.n_heads ) self.d_model = config.d_model self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads self.base = config.to_dict().get("base", 10000) inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) ) self.register_buffer("inv_freq", inv_freq) self.seq_len_cached = -1 self.cos_cached: torch.Tensor | None = None self.sin_cached: torch.Tensor | None = None def cos_sin( self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16, ) -> torch.Tensor: total_length = seq_len + past_key_values_length if total_length > self.seq_len_cached: self.seq_len_cached = total_length t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(device) if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() self.cos_cached = emb.cos()[None, :, :] self.sin_cached = emb.sin()[None, :, :] self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) return ( self.cos_cached[ :, past_key_values_length : seq_len + past_key_values_length ], self.sin_cached[ :, past_key_values_length : seq_len + past_key_values_length ], ) def forward(self, query, key, past_key_values_length=0): batch, num_heads, seq_len, head_dim = query.shape cos, sin = self.cos_sin( seq_len, past_key_values_length, query.device, query.dtype ) return (query * cos) + (rotate_half(query) * sin), (key * cos) + ( rotate_half(key) * sin ) def _make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int ) -> torch.BoolTensor: """ Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1, target_length, target_length+past_key_values_length]`. """ batch_size, target_length = input_ids_shape mask = torch.triu( torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1, ) # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op. # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later. past_mask = torch.zeros( (target_length, past_key_values_length), dtype=torch.bool, device=device ) mask = torch.cat([past_mask, mask], dim=-1) expanded_mask = mask[None, None, :, :].expand( batch_size, 1, target_length, target_length + past_key_values_length ) return expanded_mask def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor: """ Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`. """ batch_size, total_length = mask.shape seq_length = ( total_length - past_key_values_length if past_key_values_length is not None else total_length ) expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) return expanded_mask.expand(batch_size, 1, seq_length, total_length) class PagnolXlAttention(nn.Module): """Implementation of Pagnol's MultiHeadAttention following `Karpathy's MinGPT `_. The internals are easier to modify with respect to the native Pytorch version, however it does not support providing padding masks in the forward. """ def __init__(self, config: PagnolXlConfig): super().__init__() assert config.d_model % config.n_heads == 0 self.d_model = config.d_model self.n_heads = config.n_heads self.dropout = config.dropout self.sigma = config.sigma self.n_layers = config.n_layers # key, query, value projections for all heads self.key = nn.Linear(config.d_model, config.d_model) self.query = nn.Linear(config.d_model, config.d_model) self.value = nn.Linear(config.d_model, config.d_model) # regularization self.attn_drop = nn.Dropout(config.dropout) self.resid_drop = nn.Dropout(config.dropout) # output projection self.proj = nn.Linear(config.d_model, config.d_model) # causal mask to ensure that attention is only applied to the left in the input sequence self.n_heads = config.n_heads self.rotary_embedding = PagnoXlRotaryEmbeddings(config) def init_weights(self): # Megatron params std = self.sigma / math.sqrt(2.0 * self.n_layers) torch.nn.init.normal_(self.key.weight, mean=0.0, std=self.sigma) torch.nn.init.normal_(self.query.weight, mean=0.0, std=self.sigma) torch.nn.init.normal_(self.value.weight, mean=0.0, std=self.sigma) torch.nn.init.constant_(self.key.bias, 0.0) torch.nn.init.constant_(self.query.bias, 0.0) torch.nn.init.constant_(self.value.bias, 0.0) torch.nn.init.normal_(self.proj.weight, mean=0.0, std=std) torch.nn.init.constant_(self.proj.bias, 0.0) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: N, L, D = hidden_states.size() # Batch_size, Context_size, d_model # calculate query, key, values for all heads in batch and move head forward to be the batch dim key = ( self.key(hidden_states) .view(N, L, self.n_heads, D // self.n_heads) .transpose(1, 2) ) # (N, nh, L, hs) query = ( self.query(hidden_states) .view(N, L, self.n_heads, D // self.n_heads) .transpose(1, 2) ) # (N, nh, L, hs) value = ( self.value(hidden_states) .view(N, L, self.n_heads, D // self.n_heads) .transpose(1, 2) ) # (N, nh, L, hs) if self.rotary_embedding is not None: past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] query, key = self.rotary_embedding(query, key, past_kv_length) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: # - key: [batch_size * self.num_heads, kv_length, head_dim] # - value: [batch_size * self.num_heads, kv_length, head_dim] key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache: present = (key, value) else: present = None # causal self-attention; Self-attend: (N, nh, L, hs) x (N, nh, hs, L) -> (N, nh, L, L) attn_output = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1))) attn_output = ( attn_output.masked_fill(attention_mask, float("-inf")) if attention_mask is not None else attn_output ) attn_output = F.softmax(attn_output, dim=-1) attn_output = self.attn_drop(attn_output) # Mask heads if we want to if head_mask is not None: attn_output = attn_output * head_mask outputs = ( attn_output @ value ) # (N, nh, L, L) x (N, nh, L, hs) -> (N, nh, L, hs) outputs = ( outputs.transpose(1, 2).contiguous().view(N, L, D) ) # re-assemble all head outputs side by side # output projection outputs = self.resid_drop(self.proj(outputs)) if output_attentions: return outputs, present, attn_output.sum(dim=1) / self.n_heads else: return outputs, present class PagnolXlStandardMLP(nn.Module): """Implementation of Pagnol's StandardMLP""" def __init__(self, config: PagnolXlConfig): super().__init__() self.config = config self.d_model = config.d_model self.d_feedforward = config.d_feedforward self.n_layers = config.n_layers self.activation = ACT2FN[config.activation_function] self.mlp = nn.Sequential( nn.Linear(config.d_model, config.d_feedforward, bias=True), self.activation, nn.Linear(config.d_feedforward, config.d_model, bias=True), ) self.init_weights() def init_weights(self): std = self.config.sigma / math.sqrt(2.0 * self.n_layers) torch.nn.init.normal_(self.mlp[0].weight, mean=0.0, std=self.config.sigma) torch.nn.init.zeros_(self.mlp[0].bias) torch.nn.init.normal_(self.mlp[2].weight, mean=0.0, std=std) torch.nn.init.zeros_(self.mlp[2].bias) def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: return self.mlp(hidden_states) class PagnolXlLayerNorm(nn.Module): """Implementation of Pagnol's LayerNorm""" def __init__(self, config: PagnolXlConfig): super().__init__() self.config = config self.d_model = config.d_model self.norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_epsilon) self.init_weights() def init_weights(self): nn.init.ones_(self.norm.weight) nn.init.zeros_(self.norm.bias) def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: return self.norm(hidden_states) class PagnoXlBlock(nn.Module): """Transformer block containing the self-attention module and the feedforward module. Implemented as a decoder layer of GPT-3.""" def __init__(self, config: PagnolXlConfig): super().__init__() self.d_model = config.d_model self.n_layers = config.n_layers self.self_attention = PagnolXlAttention(config) self.attn_norm = PagnolXlLayerNorm(config) self.attn_dropout = nn.Dropout(config.dropout) self.mlp = PagnolXlStandardMLP(config) self.mlp_norm = PagnolXlLayerNorm(config) self.mlp_dropout = nn.Dropout(config.dropout) self.init_weights() def init_weights(self): self.self_attention.init_weights() self.mlp.init_weights() def forward( self, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], ]: attn_outputs = self.attn_norm(hidden_states) attn_outputs = self.self_attention( attn_outputs, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] hidden_states = hidden_states + self.attn_dropout(attn_output) feed_forward_hidden_states = self.mlp_norm(hidden_states) feed_forward_hidden_states = self.mlp(feed_forward_hidden_states) hidden_states = hidden_states + self.mlp_dropout(feed_forward_hidden_states) if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] return outputs # hidden_states, present, attentions class PagnolXlPreTrainedModel(PreTrainedModel): config_class = PagnolXlConfig base_model_prefix = "pagnolxl" supports_gradient_checkpointing = True _no_split_modules = ["PagnolXlBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) def _init_weights(self, module): if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.sigma) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.sigma) if module.bias is not None: module.bias.data.zero_() # TODO: attention out_proj weights are initialized with sigma / sqrt(2.0 * n_layers) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): if isinstance(module, PagnolXlModel): module.gradient_checkpointing = value class PagnolXlTransformer(PagnolXlPreTrainedModel): """Pagnol's Transformer model""" def __init__(self, config: PagnolXlConfig): super().__init__(config) self.layers = nn.ModuleList( [PagnoXlBlock(config) for _ in range(config.n_layers)] ) self.gradient_checkpointing = False self.init_weights() def init_weights(self): for layer in self.layers: layer.init_weights() @staticmethod def _prepare_attn_mask( attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int, ) -> torch.BoolTensor: # Create a causal mask # The attention mask we receive as input should cover the whole extended sequence, including any past # cache, so its shape should be [batch_size, seq_length + past_key_values_length] # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length] if input_shape[1] + past_key_values_length != attention_mask.shape[1]: raise ValueError( "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" f" {past_key_values_length}." ) combined_attention_mask = None device = attention_mask.device _, seq_length = input_shape if seq_length > 1: combined_attention_mask = _make_causal_mask( input_shape, device=device, past_key_values_length=past_key_values_length, ) # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] expanded_attn_mask = _expand_mask( attention_mask, past_key_values_length=past_key_values_length ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask ) return combined_attention_mask def forward( self, inputs_embeds: Optional[torch.LongTensor], past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) batch_size, seq_length, _ = inputs_embeds.shape device = inputs_embeds.device # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layers) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.layers)) else: past_length = past_key_values[0][0].size(-2) hidden_states = inputs_embeds if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length + past_length), device=hidden_states.device, ) else: attention_mask = attention_mask.to(hidden_states.device) causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_length, ) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, None, causal_mask, head_mask[i], use_cache, output_attentions, ) else: outputs = layer( hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + ( outputs[2 if use_cache else 1], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, presents, all_hidden_states, all_self_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class PagnolXlModel(PagnolXlPreTrainedModel): def __init__(self, config: PagnolXlConfig): super().__init__(config) self.config = config self.embedding = PagnolXlEmbeddings(config) self.transformer = PagnolXlTransformer(config) self.final_norm = PagnolXlLayerNorm(config) self.projector = PagnolXlLMHead(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embedding.embedding def set_input_embeddings(self, value): self.embedding.embedding = value def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) transformer_outputs = self.transformer( inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return transformer_outputs class PagnolXlLMHead(nn.Module): """Pagnol's Language Model head Projector""" def __init__(self, config: PagnolXlConfig): super().__init__() self.proj = nn.Linear(config.d_model, config.vocab_size, bias=False) def init_weights(self): torch.nn.init.normal_(self.proj.weight, mean=0.0, std=self.config.sigma) def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: return self.proj(hidden_states) class PagnolXlForCausalLM(PagnolXlPreTrainedModel): def __init__(self, config: PagnolXlConfig): super().__init__(config) self.config = config self.embedding = PagnolXlEmbeddings(config) self.transformer = PagnolXlTransformer(config) self.final_norm = PagnolXlLayerNorm(config) self.projector = PagnolXlLMHead(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embedding.embedding def set_input_embeddings(self, value): self.embedding.embedding = value def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[torch.Tensor] = None, **kwargs, ) -> dict: # Omit tokens covered by past_key_values if past_key_values: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] attention_mask = kwargs.get("attention_mask", None) return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) transformer_outputs = self.transformer( inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] hidden_states = self.final_norm(hidden_states) lm_logits = self.projector(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct( shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length), ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )