Spaces:
Starting
on
T4
Starting
on
T4
# coding=utf-8 | |
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch CPMAnt""" | |
import math | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from torch import nn | |
from torch.nn import CrossEntropyLoss | |
from ...activations import ACT2FN | |
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
from ...modeling_utils import PreTrainedModel | |
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging | |
from .configuration_cpmant import CpmAntConfig | |
logger = logging.get_logger(__name__) | |
_CHECKPOINT_FOR_DOC = "openbmb/cpm-ant-10b" | |
_CONFIG_FOR_DOC = "CpmAntConfig" | |
CPMANT_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |
"openbmb/cpm-ant-10b", | |
# See all CPMAnt models at https://huggingface.co/models?filter=cpmant | |
] | |
class CpmAntLayerNorm(nn.Module): | |
""" | |
We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details." | |
""" | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.eps = config.eps | |
self.dim_norm = config.hidden_size | |
self.weight = nn.Parameter(torch.empty(config.hidden_size)) | |
def forward(self, hidden_states: torch.Tensor): | |
""" | |
Args: | |
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) | |
""" | |
if hidden_states.size(-1) != self.dim_norm: | |
raise AssertionError("hidden_states.size(-1) != self.dim_norm") | |
old_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) | |
hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight | |
return hidden_states | |
class CpmAntAttention(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.dim_model = config.hidden_size | |
self.num_heads = config.num_attention_heads | |
self.dim_head = config.dim_head | |
self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) | |
self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) | |
self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) | |
self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False) | |
self.softmax = torch.nn.Softmax(dim=-1) | |
if config.dropout_p is not None: | |
self.dropout = torch.nn.Dropout(p=config.dropout_p) | |
else: | |
self.dropout = None | |
def forward( | |
self, | |
hidden_q: torch.Tensor, | |
hidden_kv: torch.Tensor, | |
attention_mask: torch.BoolTensor, | |
position_bias: torch.Tensor, | |
output_attentions: Optional[bool] = False, | |
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
use_cache: Optional[bool] = None, | |
): | |
""" | |
Args: | |
hidden_q (`torch.Tensor`): | |
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. | |
hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)): | |
Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)` | |
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
Avoid invalid areas to participate in the calculation of self-attention. | |
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
Provide positional information to self-attention block. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. | |
past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): | |
Cached past key and value projection states. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
""" | |
batch_size = hidden_q.size(0) | |
len_q = hidden_q.size(1) | |
len_k = hidden_kv.size(1) | |
query = self.project_q(hidden_q) | |
key = self.project_k(hidden_kv) | |
value = self.project_v(hidden_kv) | |
query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3) | |
key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) | |
value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) | |
if past_key_values is not None: | |
key = torch.cat([past_key_values[0], key], dim=-2) | |
value = torch.cat([past_key_values[1], value], dim=-2) | |
len_k = key.size(-2) | |
# (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k) | |
score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head) | |
score = score + position_bias | |
score = torch.masked_fill( | |
score, | |
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False), | |
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype), | |
) | |
score = self.softmax(score) | |
score = torch.masked_fill( | |
score, | |
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False), | |
torch.scalar_tensor(0, device=score.device, dtype=score.dtype), | |
) | |
if output_attentions: | |
attn_weights = score | |
else: | |
attn_weights = None | |
if self.dropout is not None: | |
score = self.dropout(score) | |
# (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head) | |
score = torch.matmul(score, value) | |
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3) | |
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head) | |
score = self.attention_out(score) | |
past_key_values = None | |
if use_cache: | |
past_key_values = (key, value) | |
return score, attn_weights, past_key_values | |
class CpmAntSelfAttentionBlock(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.layernorm_before_attention = CpmAntLayerNorm(config) | |
self.self_attention = CpmAntAttention(config) | |
if config.dropout_p: | |
self.dropout = torch.nn.Dropout(config.dropout_p) | |
else: | |
self.dropout = None | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor, | |
position_bias: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = False, | |
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
use_cache: Optional[bool] = None, | |
): | |
""" | |
Args: | |
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`): | |
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. | |
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
Avoid invalid areas to participate in the calculation of self-attention. | |
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`): | |
Provide positional information to self-attention block. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. | |
past_key_values (`Tuple(torch.FloatTensor)`, *optional*): | |
Cached past key and value projection states. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
""" | |
outputs = self.layernorm_before_attention(hidden_states) | |
outputs = self.self_attention( | |
outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache | |
) | |
outputs, attn_weights, current_key_value = outputs | |
if self.dropout is not None: | |
outputs = self.dropout(outputs) | |
hidden_states = hidden_states + outputs | |
return hidden_states, attn_weights, current_key_value | |
class CpmAntDenseGatedACT(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False) | |
self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False) | |
self.act = torch.nn.GELU() | |
def forward(self, hidden_states: torch.Tensor): | |
"""Transform an input tensor from one feature space to another via a nonlinear operation | |
Args: | |
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) | |
""" | |
gate_score = self.act(self.w_0(hidden_states)) | |
hidden_states = self.w_1(hidden_states) | |
hidden_states = gate_score * hidden_states | |
return hidden_states | |
class CpmAntFeedForward(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.w_in = CpmAntDenseGatedACT(config) | |
if config.dropout_p is not None: | |
self.dropout = torch.nn.Dropout(config.dropout_p) | |
else: | |
self.dropout = None | |
self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False) | |
def forward(self, hidden_states: torch.Tensor): | |
""" | |
Args: | |
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`) | |
""" | |
hidden_states = self.w_in(hidden_states) | |
if self.dropout is not None: | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.w_out(hidden_states) | |
return hidden_states | |
class CpmAntFFNBlock(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.layernorm_before_ffn = CpmAntLayerNorm(config) | |
self.ffn = CpmAntFeedForward(config) | |
if config.dropout_p: | |
self.dropout = torch.nn.Dropout(config.dropout_p) | |
else: | |
self.dropout = None | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
): | |
""" | |
Args: | |
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`): | |
Hidden states before feed forward layer. | |
""" | |
ln_outputs = self.layernorm_before_ffn(hidden_states) | |
outputs = self.ffn(ln_outputs) | |
if self.dropout is not None: | |
outputs = self.dropout(outputs) | |
hidden_states = hidden_states + outputs | |
return hidden_states | |
class CpmAntTransformerBlock(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.self_att = CpmAntSelfAttentionBlock(config) | |
self.ffn = CpmAntFFNBlock(config) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor, | |
position_bias: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = False, | |
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
use_cache: Optional[bool] = None, | |
): | |
""" | |
Args: | |
hidden_states (`torch.Tensor`): | |
Input to the layer of shape `(batch, seq_len, dim_model)` | |
attention_mask (`torch.Tensor`): | |
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)` | |
position_bias (`torch.Tensor`): | |
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)` | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. | |
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*): | |
Cached past key and value projection states | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
""" | |
hidden_states = self.self_att( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_bias=position_bias, | |
output_attentions=output_attentions, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
) | |
hidden_states, attn_weights, current_key_value = hidden_states | |
hidden_states = self.ffn(hidden_states) | |
return hidden_states, attn_weights, current_key_value | |
class CpmAntEncoder(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.num_layers = config.num_hidden_layers | |
self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)]) | |
self.output_layernorm = CpmAntLayerNorm(config) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor, | |
position_bias: torch.Tensor, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
use_cache: Optional[bool] = None, | |
): | |
""" | |
Args: | |
hidden_states (`torch.Tensor`): | |
Input to the layer of shape `(batch, seq_len, dim_model)` | |
attention_mask (`torch.Tensor`): | |
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)` | |
position_bias (`torch.Tensor`): | |
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)` | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. | |
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*): | |
Cached past key and value projection states | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
""" | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
current_key_values = () if use_cache else None | |
for i, layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
layer_outputs = layer( | |
hidden_states, | |
attention_mask, | |
position_bias, | |
output_attentions=output_attentions, | |
past_key_values=past_key_values[i] if past_key_values else None, | |
use_cache=use_cache, | |
) | |
hidden_states, attn_weights, current_key_value = layer_outputs | |
if output_attentions: | |
all_self_attns += (attn_weights,) | |
if current_key_value is not None: | |
current_key_values = current_key_values + (current_key_value,) | |
hidden_states = self.output_layernorm(hidden_states) | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
return hidden_states, current_key_values, all_hidden_states, all_self_attns | |
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt | |
class CpmAntIntermediate(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |
if isinstance(config.hidden_act, str): | |
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |
else: | |
self.intermediate_act_fn = config.hidden_act | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.intermediate_act_fn(hidden_states) | |
return hidden_states | |
class CpmAntSegmentPositionEmbedding(nn.Module): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__() | |
self.num_heads = config.num_attention_heads | |
self.num_buckets = config.position_bias_num_buckets | |
self.max_distance = config.position_bias_max_distance | |
self.num_segments = config.segment_types | |
self.relative_attention_bias = nn.Parameter( | |
torch.empty( | |
config.segment_types * config.segment_types + config.position_bias_num_buckets, | |
config.num_attention_heads, | |
) | |
) | |
def forward( | |
self, | |
key_pos: torch.Tensor, | |
query_pos: torch.Tensor, | |
key_segment: torch.Tensor, | |
query_segment: torch.Tensor, | |
): | |
with torch.no_grad(): | |
batch = key_pos.size(0) | |
keylen = key_pos.size(1) | |
querylen = query_pos.size(1) | |
if key_pos.size(0) != query_pos.size(0): | |
raise AssertionError( | |
f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!" | |
) | |
if keylen != key_segment.size(1) or querylen != query_segment.size(1): | |
raise AssertionError( | |
f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!" | |
) | |
if querylen != query_segment.size(1): | |
raise AssertionError( | |
f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.szie(1)}!" | |
) | |
key_pos = key_pos.view(batch, -1, keylen) | |
query_pos = query_pos.view(batch, querylen, -1) | |
key_segment = key_segment.view(batch, -1, keylen) | |
query_segment = query_segment.view(batch, querylen, -1) | |
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment) | |
relative_position_bucket = relative_position_bucket + self.num_buckets | |
# (batch, len_q, len_k) | |
absolute_position_bucket = self._position_bucket( | |
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :] | |
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None], | |
num_buckets=self.num_buckets, | |
max_distance=self.max_distance, | |
) | |
relative_position_bucket = torch.where( | |
(key_segment == query_segment), | |
absolute_position_bucket[None, :, :], | |
relative_position_bucket, | |
) | |
# (batch, len_q, len_k, num_heads) | |
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) | |
# (batch, num_heads, len_q, len_k) | |
embeds = embeds.permute(0, 3, 1, 2).contiguous() | |
return embeds | |
def _segment_relative_position_bucket(self, query_segment, key_segment): | |
return query_segment * self.num_segments + key_segment | |
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128): | |
relative_buckets = 0 | |
# always bidirectional in CPMAnt | |
num_buckets //= 2 | |
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets | |
relative_position = torch.abs(relative_position) | |
max_exact = num_buckets // 2 | |
is_small = relative_position < max_exact | |
relative_postion_if_large = max_exact + ( | |
torch.log(relative_position.float() / max_exact) | |
/ math.log(max_distance / max_exact) | |
* (num_buckets - max_exact) | |
).to(torch.int32) | |
relative_postion_if_large = torch.min( | |
relative_postion_if_large, | |
torch.full_like(relative_postion_if_large, num_buckets - 1), | |
) | |
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) | |
return relative_buckets | |
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt | |
class CpmAntOutput(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class CpmAntPreTrainedModel(PreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
config_class = CpmAntConfig | |
base_model_prefix = "cpmant" | |
supports_gradient_checkpointing = True | |
def _init_weights(self, module): | |
"""Initialize the weights""" | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
elif isinstance(module, CpmAntLayerNorm): | |
module.weight.data.fill_(1.0) | |
elif isinstance(module, CpmAntSegmentPositionEmbedding): | |
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if isinstance(module, CpmAntEncoder): | |
module.gradient_checkpointing = value | |
CPMANT_START_DOCSTRING = r""" | |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use | |
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and | |
behavior. | |
Parameters | |
config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the | |
Initializing with a config file does not load the weights associated with the model, only the | |
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
CPMANT_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | |
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
class CpmAntModel(CpmAntPreTrainedModel): | |
def __init__(self, config: CpmAntConfig): | |
super().__init__(config) | |
self.encoder = CpmAntEncoder(config) | |
self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size) | |
self.input_embedding = nn.Embedding( | |
config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size | |
) | |
self.position_bias = CpmAntSegmentPositionEmbedding(config) | |
self.prompt_length = config.prompt_length | |
self.vocab_size = config.vocab_size | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.input_embedding | |
def set_input_embeddings(self, embeddings, **kwargs): | |
self.input_embedding = embeddings | |
def _prepare_attention_mask(self, input_ids, span, context, length): | |
batch = input_ids.size(0) | |
seqlen = input_ids.size(1) | |
device = input_ids.device | |
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1) | |
attention_mask = context[:, None, :] | ( | |
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen) | |
) | |
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None]) | |
# mask for left padding | |
mask_1d = ( | |
torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1) | |
< length[:, None] | |
) | |
mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1) | |
attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask | |
return attention_mask | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
use_cache: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
**kwargs, | |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
# add prompts ahead | |
if input_ids.dtype != torch.int32: | |
input_ids = input_ids.to(torch.int32) | |
dtype, device = input_ids.dtype, input_ids.device | |
segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device) | |
length = (segment != 0).sum(-1).to(dtype=dtype, device=device) | |
input_ids = torch.cat( | |
( | |
torch.arange( | |
self.prompt_length * 2 + self.vocab_size, | |
self.prompt_length * 3 + self.vocab_size, | |
dtype=dtype, | |
device=device, | |
).repeat(input_ids.size(0), 1), | |
input_ids, | |
), | |
dim=1, | |
) | |
batch, seq_length = input_ids.size() | |
segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1) | |
context = torch.full((batch, seq_length), 1, dtype=dtype, device=device) | |
position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1) | |
span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) | |
if past_key_values is None: | |
past_length = 0 | |
past_key_values = tuple([None] * self.encoder.num_layers) | |
input_ids = input_ids.contiguous() | |
hidden_states = self.input_embedding(input_ids) | |
segment_states = self.segment_embedding(segment) | |
hidden_states = hidden_states + segment_states | |
else: | |
past_length = past_key_values[0][0].size(-2) | |
segment_states = self.segment_embedding(segment) | |
hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :] | |
attention_mask = self._prepare_attention_mask(input_ids, span, context, length) | |
position_bias = self.position_bias(position, position, segment, segment) | |
attention_mask = attention_mask[:, past_length:, :] | |
position_bias = position_bias[:, :, past_length:, :] | |
hidden_states = hidden_states[:, past_length:, :] | |
hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder( | |
hidden_states, | |
attention_mask, | |
position_bias, | |
output_attentions, | |
output_hidden_states, | |
past_key_values, | |
use_cache, | |
) | |
if past_length == 0: | |
hidden_states = hidden_states[:, self.prompt_length :, :] | |
# drop the prompt | |
if all_attentions is not None: | |
new_attentions = () | |
for attention in all_attentions: | |
new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],) | |
all_attentions = new_attentions | |
if all_hidden_states is not None: | |
new_hidden_states = () | |
for hidden_state in all_hidden_states: | |
new_hidden_states += (hidden_state[:, self.prompt_length :, :],) | |
all_hidden_states = new_hidden_states | |
if not return_dict: | |
return tuple( | |
v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None | |
) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=present_key_values, | |
hidden_states=all_hidden_states, | |
attentions=all_attentions, | |
) | |
class CpmAntForCausalLM(CpmAntPreTrainedModel): | |
_tied_weights_keys = ["lm_head.weight"] | |
def __init__(self, config: CpmAntConfig): | |
super().__init__(config) | |
self.cpmant = CpmAntModel(config) | |
# lm_head.weight is tied to cpmant.input_embedding.weight | |
self.lm_head = nn.Linear( | |
config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False | |
) | |
self.post_init() | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
labels: Optional[torch.Tensor] = None, | |
return_dict: Optional[bool] = None, | |
attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline | |
**kwargs, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
r""" | |
Args: | |
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the | |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. | |
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
CPMAnt will process attention mask automatically, this parameter is a dummy parameter for | |
text-generation pipeline. | |
Example: | |
Text Generation with CpmAntForCausalLM. | |
```python | |
>>> from transformers import CPMAntTokenizer, CpmAntForCausalLM | |
>>> texts = "今天天气不错," | |
>>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b") | |
>>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b") | |
>>> input_ids = tokenizer(texts, return_tensors="pt") | |
>>> outputs = model.generate(**input_ids) | |
>>> output_texts = tokenizer.batch_decode(outputs) | |
>>> print(output_texts) | |
['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的'] | |
``` | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
model_output = self.cpmant( | |
input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict | |
) | |
hidden_states = model_output.last_hidden_state if return_dict else model_output[0] | |
logits = self.lm_head(hidden_states) | |
loss = None | |
if labels is not None: | |
loss_func = CrossEntropyLoss() | |
loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + model_output[1:] | |
return ((loss,) + output) if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=model_output.past_key_values, | |
hidden_states=model_output.hidden_states, | |
attentions=model_output.attentions, | |
) | |
def get_input_embeddings(self): | |
return self.cpmant.input_embedding | |
def set_input_embeddings(self, embeddings): | |
self.cpmant.input_embedding = embeddings | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def prepare_inputs_for_generation(self, input_ids, **kwargs): | |
input_ids = input_ids.int() | |
# save the memory usage of dummy attention mask | |
if "attention_mask" in kwargs: | |
kwargs["attention_mask"] = torch.zeros(1, 1) | |
return { | |
"input_ids": input_ids, | |
"use_cache": kwargs["use_cache"], | |
"past_key_values": kwargs.get("past_key_values", None), | |
} | |
def _reorder_cache(self, past_key_values, beam_idx): | |
past_key_values = [list(each) if each is not None else each for each in past_key_values] | |
for key_value_layer in past_key_values: | |
key_value_layer[0] = key_value_layer[0][beam_idx] | |
key_value_layer[1] = key_value_layer[1][beam_idx] | |
return past_key_values | |