Text Generation
Transformers
PyTorch
French
pagnolxl
pagnol
custom_code
pagnol-xl / modeling_pagnolxl.py
wissamantoun's picture
Upload folder using huggingface_hub
ea4fdbf verified
raw
history blame
No virus
30.5 kB
# coding=utf-8
# TODO: Add license
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch PagnolXl model."""
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_pagnolxl import PagnolXlConfig
logger = logging.get_logger(__name__)
PAGNOLXL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"XXXX/pagnol-xl",
]
_CHECKPOINT_FOR_DOC = "XXXX/pagnol-xl"
_CONFIG_FOR_DOC = "PagnolXlConfig"
class PagnolXlEmbeddings(nn.Module):
"""Implementation of the PagnolXl Embedding layer.
Parameters
----------
vocab_size: int,
size of the vocabulary.
d_model: int,
Dimension of the hidden representations.
sigma: int, default 0.02,
standard deviation for the Gaussian initialization of the embedding weights.
"""
def __init__(self, config: PagnolXlConfig):
super().__init__()
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
return self.embedding(input_ids)
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class PagnoXlRotaryEmbeddings(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX and Falcon.
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
"""
def __init__(self, config: PagnolXlConfig):
super().__init__()
assert (
config.d_model % config.n_heads == 0
), "d_model must be divisible by n_heads. Currently d_model: {}, n_heads: {}".format(
config.d_model, config.n_heads
)
self.d_model = config.d_model
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.base = config.to_dict().get("base", 10000)
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)
)
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = -1
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None
def cos_sin(
self,
seq_len: int,
past_key_values_length: int,
device="cpu",
dtype=torch.bfloat16,
) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
return (
self.cos_cached[
:, past_key_values_length : seq_len + past_key_values_length
],
self.sin_cached[
:, past_key_values_length : seq_len + past_key_values_length
],
)
def forward(self, query, key, past_key_values_length=0):
batch, num_heads, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(
seq_len, past_key_values_length, query.device, query.dtype
)
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (
rotate_half(key) * sin
)
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
target_length, target_length+past_key_values_length]`.
"""
batch_size, target_length = input_ids_shape
mask = torch.triu(
torch.ones((target_length, target_length), dtype=torch.bool, device=device),
diagonal=1,
)
# If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
# This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
# way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
past_mask = torch.zeros(
(target_length, past_key_values_length), dtype=torch.bool, device=device
)
mask = torch.cat([past_mask, mask], dim=-1)
expanded_mask = mask[None, None, :, :].expand(
batch_size, 1, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
"""
batch_size, total_length = mask.shape
seq_length = (
total_length - past_key_values_length
if past_key_values_length is not None
else total_length
)
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, seq_length, total_length)
class PagnolXlAttention(nn.Module):
"""Implementation of Pagnol's MultiHeadAttention following `Karpathy's MinGPT <https://github.com/karpathy/minGPT>`_.
The internals are easier to modify with respect to the native Pytorch version, however it does not support
providing padding masks in the forward.
"""
def __init__(self, config: PagnolXlConfig):
super().__init__()
assert config.d_model % config.n_heads == 0
self.d_model = config.d_model
self.n_heads = config.n_heads
self.dropout = config.dropout
self.sigma = config.sigma
self.n_layers = config.n_layers
# key, query, value projections for all heads
self.key = nn.Linear(config.d_model, config.d_model)
self.query = nn.Linear(config.d_model, config.d_model)
self.value = nn.Linear(config.d_model, config.d_model)
# regularization
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
# output projection
self.proj = nn.Linear(config.d_model, config.d_model)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.n_heads = config.n_heads
self.rotary_embedding = PagnoXlRotaryEmbeddings(config)
def init_weights(self):
# Megatron params
std = self.sigma / math.sqrt(2.0 * self.n_layers)
torch.nn.init.normal_(self.key.weight, mean=0.0, std=self.sigma)
torch.nn.init.normal_(self.query.weight, mean=0.0, std=self.sigma)
torch.nn.init.normal_(self.value.weight, mean=0.0, std=self.sigma)
torch.nn.init.constant_(self.key.bias, 0.0)
torch.nn.init.constant_(self.query.bias, 0.0)
torch.nn.init.constant_(self.value.bias, 0.0)
torch.nn.init.normal_(self.proj.weight, mean=0.0, std=std)
torch.nn.init.constant_(self.proj.bias, 0.0)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
N, L, D = hidden_states.size() # Batch_size, Context_size, d_model
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
key = (
self.key(hidden_states)
.view(N, L, self.n_heads, D // self.n_heads)
.transpose(1, 2)
) # (N, nh, L, hs)
query = (
self.query(hidden_states)
.view(N, L, self.n_heads, D // self.n_heads)
.transpose(1, 2)
) # (N, nh, L, hs)
value = (
self.value(hidden_states)
.view(N, L, self.n_heads, D // self.n_heads)
.transpose(1, 2)
) # (N, nh, L, hs)
if self.rotary_embedding is not None:
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query, key = self.rotary_embedding(query, key, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache:
present = (key, value)
else:
present = None
# causal self-attention; Self-attend: (N, nh, L, hs) x (N, nh, hs, L) -> (N, nh, L, L)
attn_output = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1)))
attn_output = (
attn_output.masked_fill(attention_mask, float("-inf"))
if attention_mask is not None
else attn_output
)
attn_output = F.softmax(attn_output, dim=-1)
attn_output = self.attn_drop(attn_output)
# Mask heads if we want to
if head_mask is not None:
attn_output = attn_output * head_mask
outputs = (
attn_output @ value
) # (N, nh, L, L) x (N, nh, L, hs) -> (N, nh, L, hs)
outputs = (
outputs.transpose(1, 2).contiguous().view(N, L, D)
) # re-assemble all head outputs side by side
# output projection
outputs = self.resid_drop(self.proj(outputs))
if output_attentions:
return outputs, present, attn_output.sum(dim=1) / self.n_heads
else:
return outputs, present
class PagnolXlStandardMLP(nn.Module):
"""Implementation of Pagnol's StandardMLP"""
def __init__(self, config: PagnolXlConfig):
super().__init__()
self.config = config
self.d_model = config.d_model
self.d_feedforward = config.d_feedforward
self.n_layers = config.n_layers
self.activation = ACT2FN[config.activation_function]
self.mlp = nn.Sequential(
nn.Linear(config.d_model, config.d_feedforward, bias=True),
self.activation,
nn.Linear(config.d_feedforward, config.d_model, bias=True),
)
self.init_weights()
def init_weights(self):
std = self.config.sigma / math.sqrt(2.0 * self.n_layers)
torch.nn.init.normal_(self.mlp[0].weight, mean=0.0, std=self.config.sigma)
torch.nn.init.zeros_(self.mlp[0].bias)
torch.nn.init.normal_(self.mlp[2].weight, mean=0.0, std=std)
torch.nn.init.zeros_(self.mlp[2].bias)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.mlp(hidden_states)
class PagnolXlLayerNorm(nn.Module):
"""Implementation of Pagnol's LayerNorm"""
def __init__(self, config: PagnolXlConfig):
super().__init__()
self.config = config
self.d_model = config.d_model
self.norm = nn.LayerNorm(self.d_model, eps=config.layer_norm_epsilon)
self.init_weights()
def init_weights(self):
nn.init.ones_(self.norm.weight)
nn.init.zeros_(self.norm.bias)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.norm(hidden_states)
class PagnoXlBlock(nn.Module):
"""Transformer block containing the self-attention module and the feedforward module.
Implemented as a decoder layer of GPT-3."""
def __init__(self, config: PagnolXlConfig):
super().__init__()
self.d_model = config.d_model
self.n_layers = config.n_layers
self.self_attention = PagnolXlAttention(config)
self.attn_norm = PagnolXlLayerNorm(config)
self.attn_dropout = nn.Dropout(config.dropout)
self.mlp = PagnolXlStandardMLP(config)
self.mlp_norm = PagnolXlLayerNorm(config)
self.mlp_dropout = nn.Dropout(config.dropout)
self.init_weights()
def init_weights(self):
self.self_attention.init_weights()
self.mlp.init_weights()
def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor],
Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
]:
attn_outputs = self.attn_norm(hidden_states)
attn_outputs = self.self_attention(
attn_outputs,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
hidden_states = hidden_states + self.attn_dropout(attn_output)
feed_forward_hidden_states = self.mlp_norm(hidden_states)
feed_forward_hidden_states = self.mlp(feed_forward_hidden_states)
hidden_states = hidden_states + self.mlp_dropout(feed_forward_hidden_states)
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, attentions
class PagnolXlPreTrainedModel(PreTrainedModel):
config_class = PagnolXlConfig
base_model_prefix = "pagnolxl"
supports_gradient_checkpointing = True
_no_split_modules = ["PagnolXlBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.sigma)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.sigma)
if module.bias is not None:
module.bias.data.zero_()
# TODO: attention out_proj weights are initialized with sigma / sqrt(2.0 * n_layers)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
if isinstance(module, PagnolXlModel):
module.gradient_checkpointing = value
class PagnolXlTransformer(PagnolXlPreTrainedModel):
"""Pagnol's Transformer model"""
def __init__(self, config: PagnolXlConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[PagnoXlBlock(config) for _ in range(config.n_layers)]
)
self.gradient_checkpointing = False
self.init_weights()
def init_weights(self):
for layer in self.layers:
layer.init_weights()
@staticmethod
def _prepare_attn_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# Create a causal mask
# The attention mask we receive as input should cover the whole extended sequence, including any past
# cache, so its shape should be [batch_size, seq_length + past_key_values_length]
# The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
raise ValueError(
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}."
)
combined_attention_mask = None
device = attention_mask.device
_, seq_length = input_shape
if seq_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)
# [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
expanded_attn_mask = _expand_mask(
attention_mask, past_key_values_length=past_key_values_length
)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def forward(
self,
inputs_embeds: Optional[torch.LongTensor],
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
batch_size, seq_length, _ = inputs_embeds.shape
device = inputs_embeds.device
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layers)
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.layers))
else:
past_length = past_key_values[0][0].size(-2)
hidden_states = inputs_embeds
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length + past_length),
device=hidden_states.device,
)
else:
attention_mask = attention_mask.to(hidden_states.device)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_length,
)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
None,
causal_mask,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = layer(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class PagnolXlModel(PagnolXlPreTrainedModel):
def __init__(self, config: PagnolXlConfig):
super().__init__(config)
self.config = config
self.embedding = PagnolXlEmbeddings(config)
self.transformer = PagnolXlTransformer(config)
self.final_norm = PagnolXlLayerNorm(config)
self.projector = PagnolXlLMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embedding.embedding
def set_input_embeddings(self, value):
self.embedding.embedding = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
transformer_outputs = self.transformer(
inputs_embeds,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return transformer_outputs
class PagnolXlLMHead(nn.Module):
"""Pagnol's Language Model head Projector"""
def __init__(self, config: PagnolXlConfig):
super().__init__()
self.proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
def init_weights(self):
torch.nn.init.normal_(self.proj.weight, mean=0.0, std=self.config.sigma)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.proj(hidden_states)
class PagnolXlForCausalLM(PagnolXlPreTrainedModel):
def __init__(self, config: PagnolXlConfig):
super().__init__(config)
self.config = config
self.embedding = PagnolXlEmbeddings(config)
self.transformer = PagnolXlTransformer(config)
self.final_norm = PagnolXlLayerNorm(config)
self.projector = PagnolXlLMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embedding.embedding
def set_input_embeddings(self, value):
self.embedding.embedding = value
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
attention_mask = kwargs.get("attention_mask", None)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
transformer_outputs = self.transformer(
inputs_embeds,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.final_norm(hidden_states)
lm_logits = self.projector(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length),
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)