File size: 39,348 Bytes
012b849 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 |
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
# this import has to be relative, otherwise, when setting trust_remote_code=True
# huggingface transformers won't be able to load the module correctly
from .configuration_openelm import OpenELMConfig, make_divisible
class OpenELMRMSNorm(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-6):
"""
Initialize the OpenELMRMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.num_features = num_features
def _norm(self, x: Tensor) -> Tensor:
"""
Apply the OpenELMRMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the OpenELMRMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying OpenELMRMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self) -> str:
return (
super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
)
class OpenELMPreTrainedModel(PreTrainedModel):
config_class = OpenELMConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["OpenELMDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs) -> None:
super().__init__(*inputs, **kwargs)
def _init_weights(self, module: nn.Module) -> None:
"""Initialize the weights."""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, OpenELMRMSNorm):
module.weight.data.fill_(1.0)
def _rotate_half(x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
return (x * pos_cos) + (_rotate_half(x) * pos_sin)
class OpenELMRotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings (aka RoPE) from `RoFormer <https://arxiv.org/abs/2104.09864>`_.
RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
explicit relative positional dependencies.
Args:
model_dim: The dimensionality of the model's hidden state.
max_seq_length: Maximum sequence length.
freq_constant: A constant used for computing frequencies.
"""
def __init__(
self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
) -> None:
inv_freq = 1.0 / (
freq_constant
** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
)
super().__init__()
self.model_dim = model_dim
self.freq_constant = freq_constant
self.max_seq_length = max_seq_length
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cached_cos = None
self._cached_sin = None
self._cached_seq_length = max_seq_length
self._compute_sin_cos_embeddings(max_seq_length)
def extra_repr(self) -> str:
return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
def _compute_sin_cos_embeddings(
self,
key_len: int,
key_device: torch.device = torch.device("cpu"),
key_dtype: torch.dtype = torch.float32,
) -> None:
"""
Compute sine and cos embeddings.
Args:
key_len: Number of tokens in the key embeddings in the transformer model.
device: Device where the key embeddings are stored.
key_dtype: Data type of the key embeddings.
Returns:
None
...note:
We recalculate the sine and cosine embeddings if any of the following conditions are met:
1. The number of tokens in key embeddings are greater than the cached sequence length.
2. Sine and cosine caches are empty.
3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
"""
if (
key_len > self._cached_seq_length
or self._cached_cos is None
or (self._cached_cos is not None and self._cached_cos.device != key_device)
or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
or self._cached_sin is None
or (self._cached_sin is not None and self._cached_sin.device != key_device)
or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
):
self._cached_seq_length = max(key_len, self._cached_seq_length)
# The shape of 'pos_index' is [number of key tokens]
pos_index = torch.arange(
self._cached_seq_length,
dtype=torch.float32,
device=self.inv_freq.device,
)
# The shape of 'pos_index_theta' is [number of key tokens, model dimension]
pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
# The shape of 'emb' is [number of key tokens, model dimension]
emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
# the shape of cos and sin embeddings is [number of key tokens, model_dim]
cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
# the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
self._cached_cos = cos_emb[None, None, :, :]
self._cached_sin = sin_emb[None, None, :, :]
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
The forward function of RoPE embeddings.
Args:
query: Query embeddings in the transformer model. The shape of query embeddings is
[Batch, number of query heads, number of query tokens, model dimension].
key: Key embeddings in the transformer model. The shape of key embeddings is
[Batch, number of key heads, number of key tokens, model dimension].
Returns:
A tuple containing the query and key embeddings with positional information. The shape of the returned query
and key embeddings is the same as the input query and key embeddings respectively.
...note:
The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
are casted to original input datatype.
"""
dim = key.shape[-1]
key_len = key.shape[2]
query_len = query.shape[2]
assert dim == self.model_dim
assert key.device == query.device
assert key.dtype == query.dtype
# In the context of self-attention, the lengths of keys and queries are equal.
# However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
# can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
# represent embeddings of previous tokens and the current token, while the query corresponds
# to the embedding of the current token only.
assert (
key_len >= query_len
), "Number of keys has to be greater than or equal to number of queries."
query_float = query.float()
key_float = key.float()
self._compute_sin_cos_embeddings(
key_len, key_device=key_float.device, key_dtype=key_float.dtype
)
query_float = _apply_rotary_pos_emb(
x=query_float,
pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
)
key_float = _apply_rotary_pos_emb(
x=key_float,
pos_sin=self._cached_sin[..., :key_len, :],
pos_cos=self._cached_cos[..., :key_len, :],
)
return query_float.type_as(query), key_float.type_as(key)
class OpenELMMultiHeadCausalAttention(nn.Module):
def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
super().__init__()
self.layer_idx = layer_idx
head_dim = config.head_dim
q_heads = config.num_query_heads[layer_idx]
k_heads = config.num_kv_heads[layer_idx]
v_heads = config.num_kv_heads[layer_idx]
self.qkv_proj = nn.Linear(
in_features=config.model_dim,
out_features=(q_heads + k_heads + v_heads) * head_dim,
bias=False,
)
self.pos_embedding = OpenELMRotaryEmbedding(
model_dim=config.head_dim,
max_seq_length=config.rope_max_length,
freq_constant=config.rope_freq_constant,
)
if config.normalize_qk_projections:
self.q_norm = OpenELMRMSNorm(
num_features=config.head_dim,
)
self.k_norm = OpenELMRMSNorm(
num_features=config.head_dim,
)
else:
self.q_norm = None
self.k_norm = None
self.out_proj = nn.Linear(
in_features=q_heads * head_dim,
out_features=config.model_dim,
bias=False,
)
self.head_dim = config.head_dim
self.num_q_heads = q_heads
self.num_k_heads = k_heads
self.num_v_heads = v_heads
self.transformer_dim = config.model_dim
self.num_groups = self.num_q_heads // self.num_k_heads
def extra_repr(self) -> str:
return (
super().extra_repr()
+ f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Forward pass of multi-head self-attention.
Args:
hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
past_key_value: Tensor storing the cached keys and values.
output_attentions: output attention weights.
use_cache: Specifies whether to use kv-cache for generation.
cache_position: used for updating the kv-cache.
Returns:
The output of the same shape as the input, optionally with a tensor containing cached keys and values.
"""
# scaled_dot_product_attention does not return attention weights, set output_attentions to False
output_attentions = False
batch_size, seq_length, d_model = hidden_states.size()
# [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
qkv = self.qkv_proj(hidden_states)
# [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
qkv = qkv.reshape(
batch_size,
seq_length,
self.num_q_heads + self.num_k_heads + self.num_v_heads,
self.head_dim,
)
# [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
qkv = qkv.transpose(1, 2)
# [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
queries, keys, values = qkv.split(
[self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
)
if self.q_norm is not None:
queries = self.q_norm(queries)
if self.k_norm is not None:
keys = self.k_norm(keys)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {"cache_position": cache_position}
keys, values = past_key_value.update(
keys, values, self.layer_idx, cache_kwargs
)
# Add positional embedding
queries, keys = self.pos_embedding(queries, keys)
if self.num_groups != 1:
# GQA
# [B, k_h, S, h] --> [B, q_h, S, h]
keys = keys.repeat_interleave(self.num_groups, dim=1)
# [B, v_h, S, h] --> [B, q_h, S, h]
values = values.repeat_interleave(self.num_groups, dim=1)
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
attn_output = F.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=causal_mask,
dropout_p=0,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(
batch_size, seq_length, self.num_q_heads * self.head_dim
)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class OpenELMFeedForwardNetwork(nn.Module):
def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
super().__init__()
ffn_multiplier = config.ffn_multipliers[layer_idx]
intermediate_dim = int(
make_divisible(
ffn_multiplier * config.model_dim,
divisor=config.ffn_dim_divisor,
)
)
if config.ffn_with_glu:
# FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
self.proj_1 = nn.Linear(
in_features=config.model_dim,
out_features=2 * intermediate_dim,
bias=False,
)
self.proj_2 = nn.Linear(
in_features=intermediate_dim,
out_features=config.model_dim,
bias=False,
)
self.ffn_with_glu = True
else:
# Standard FFN, as described in https://arxiv.org/abs/1706.03762
self.proj_1 = nn.Linear(
in_features=config.model_dim,
out_features=intermediate_dim,
bias=False,
)
self.proj_2 = nn.Linear(
in_features=intermediate_dim,
out_features=config.model_dim,
bias=False,
)
self.ffn_with_glu = False
self.act = ACT2FN[config.activation_fn_name]
def extra_repr(self) -> str:
return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
def forward(self, x: Tensor) -> Tensor:
"""Forward function of FFN layer.
Args:
x: Input tensor of the shape [batch size, sequence length, model dimension].
Returns:
A tensor of the same shape as the input.
"""
if self.ffn_with_glu:
y_12 = self.proj_1(x)
y_1, y_2 = y_12.chunk(2, dim=-1)
y = self.act(y_1) * y_2
return self.proj_2(y)
else:
return self.proj_2(self.act(self.proj_1(x)))
class OpenELMDecoderLayer(nn.Module):
def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
super().__init__()
self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
self.ffn_norm = OpenELMRMSNorm(
num_features=config.model_dim,
)
self.attn_norm = OpenELMRMSNorm(
num_features=config.model_dim,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.ffn_norm(hidden_states)
hidden_states = self.ffn(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class OpenELMModel(OpenELMPreTrainedModel):
config_class = OpenELMConfig
def __init__(self, config: OpenELMConfig):
super().__init__(config)
self.config = config
self.token_embeddings = nn.Embedding(
embedding_dim=config.model_dim,
num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList(
OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
for layer_idx in range(config.num_transformer_layers)
)
self.norm = OpenELMRMSNorm(num_features=config.model_dim)
if config.share_input_output_layers:
self.classifier = None
else:
self.classifier = nn.Linear(
in_features=config.model_dim,
out_features=config.vocab_size,
bias=False,
)
self.num_transformer_layers = config.num_transformer_layers
self.gradient_checkpointing = False
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
causal_mask = torch.full(
(config.max_context_length, config.max_context_length),
fill_value=True,
dtype=torch.bool,
)
self.register_buffer(
"causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
)
# Initialize weights and apply final processing
self.post_init()
self.reset_parameters(config=config)
def get_input_embeddings(self):
return self.token_embeddings
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.token_embeddings = new_embeddings
def reset_parameters(self, config: OpenELMConfig) -> None:
"""Initialize the layers in Language Model
The initialization scheme is followed, following `OPT <https://arxiv.org/pdf/2205.01068.pdf>`_.
Args:
use_megatron_std: Use standard deviation as described in Megatron-LM.
Returns:
None
"""
for module in self.modules():
if isinstance(module, nn.Linear):
std = module.in_features**-0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
std = module.embedding_dim**-0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
elif isinstance(module, OpenELMRMSNorm):
if module.weight is not None:
torch.nn.init.ones_(module.weight)
if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.zeros_(module.bias)
model_dim = config.model_dim
n_layers = config.num_transformer_layers
std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
for param_name, param in self.named_parameters():
if param_name.endswith("out_proj.weight") or param_name.endswith(
"ffn.proj_2.weight"
):
torch.nn.init.normal_(param, mean=0.0, std=std)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.token_embeddings(input_ids)
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, Cache)
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device
# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
causal_mask = torch.full(
(2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
fill_value=1,
)
self.register_buffer(
"causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
* min_dtype
)
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
:, None, None, :
].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, min_dtype
)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(
~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
).to(dtype)
return causal_mask
class OpenELMForCausalLM(OpenELMPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: OpenELMConfig):
super().__init__(config)
self.transformer = OpenELMModel(config)
self.vocab_size = config.vocab_size
if config.share_input_output_layers:
self.lm_head = None
else:
self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.transformer.token_embeddings
def set_input_embeddings(self, value):
self.transformer.token_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.transformer = decoder
def get_decoder(self):
return self.transformer
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.lm_head is None:
# shared
logits = F.linear(
hidden_states, weight=self.transformer.token_embeddings.weight
)
else:
logits = self.lm_head(hidden_states)
logits = logits[:, : self.config.vocab_size]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if (
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(
past_length,
past_length + position_ids.shape[-1],
device=position_ids.device,
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# We could use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
),
)
return reordered_past
|