myownskyW7
commited on
Commit
•
8662ae5
1
Parent(s):
4ef8d94
Upload modeling_InternLM.py
Browse filesRemove the dependency of flash-attention and rotary_emb
- modeling_InternLM.py +43 -44
modeling_InternLM.py
CHANGED
@@ -2,12 +2,10 @@ import math
|
|
2 |
from typing import List, Union
|
3 |
from typing import Optional, Tuple
|
4 |
|
5 |
-
import rotary_emb
|
6 |
import torch
|
7 |
import torch.utils.checkpoint
|
8 |
import torch.utils.checkpoint
|
9 |
from einops import rearrange
|
10 |
-
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
11 |
from torch import nn
|
12 |
from torch.nn import CrossEntropyLoss
|
13 |
from transformers.activations import ACT2FN
|
@@ -23,51 +21,70 @@ logger = logging.get_logger(__name__)
|
|
23 |
_CONFIG_FOR_DOC = "InternLMXComposerConfig"
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
@staticmethod
|
31 |
-
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
32 |
"""
|
33 |
-
qkv: (
|
34 |
cos, sin: (seqlen, rotary_dim / 2)
|
35 |
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
|
|
|
|
36 |
rotary_dim must be <= headdim
|
37 |
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
38 |
"""
|
39 |
-
|
40 |
assert three == 3
|
41 |
rotary_seqlen, rotary_dim = cos.shape
|
42 |
rotary_dim *= 2
|
43 |
assert rotary_dim <= headdim
|
|
|
44 |
cos_k = cos if cos_k is None else cos_k
|
45 |
sin_k = sin if sin_k is None else sin_k
|
46 |
-
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
|
47 |
-
|
48 |
-
q1, q2 =
|
49 |
-
rotary_emb.apply_rotary(q1, q2, rearrange(cos,
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
|
|
56 |
return qkv
|
57 |
|
58 |
@staticmethod
|
59 |
def backward(ctx, dqkv):
|
60 |
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
|
|
61 |
rotary_dim = cos.shape[-1]
|
62 |
rotary_dim *= 2
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
@@ -120,23 +137,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
120 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
121 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
122 |
|
123 |
-
def forward(self,
|
124 |
-
qkv: torch.Tensor,
|
125 |
-
indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
126 |
-
self._update_cos_sin_cache(qkv, indexes)
|
127 |
-
if self.scale is None:
|
128 |
-
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes],
|
129 |
-
self._sin_cached[indexes]).to(
|
130 |
-
qkv.dtype)
|
131 |
-
else:
|
132 |
-
return apply_rotary_emb_qkv_(
|
133 |
-
qkv,
|
134 |
-
self._cos_cached[indexes],
|
135 |
-
self._sin_cached[indexes],
|
136 |
-
self._cos_k_cached[indexes],
|
137 |
-
self._sin_k_cached[indexes],
|
138 |
-
).to(qkv.dtype)
|
139 |
-
|
140 |
def eval_forward(self, qkv, seqlen_offset=0):
|
141 |
"""
|
142 |
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
@@ -157,7 +157,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
157 |
)
|
158 |
|
159 |
|
160 |
-
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
161 |
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
162 |
|
163 |
|
|
|
2 |
from typing import List, Union
|
3 |
from typing import Optional, Tuple
|
4 |
|
|
|
5 |
import torch
|
6 |
import torch.utils.checkpoint
|
7 |
import torch.utils.checkpoint
|
8 |
from einops import rearrange
|
|
|
9 |
from torch import nn
|
10 |
from torch.nn import CrossEntropyLoss
|
11 |
from transformers.activations import ACT2FN
|
|
|
21 |
_CONFIG_FOR_DOC = "InternLMXComposerConfig"
|
22 |
|
23 |
|
24 |
+
def rotary_embed(x1, x2, cos, sin, conj):
|
25 |
+
x1, x2 = x1.float(), x2.float()
|
26 |
+
if conj:
|
27 |
+
x1, x2 = x1 * cos + x2 * sin, x1 * sin + x2 * cos
|
28 |
+
else:
|
29 |
+
x1, x2 = x1 * cos - x2 * sin, x1 * sin + x2 * cos
|
30 |
+
return x1, x2
|
31 |
+
|
32 |
+
|
33 |
+
class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
|
34 |
+
|
35 |
@staticmethod
|
36 |
+
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
37 |
"""
|
38 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
39 |
cos, sin: (seqlen, rotary_dim / 2)
|
40 |
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
41 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
42 |
+
1st half and 2nd half (GPT-NeoX style).
|
43 |
rotary_dim must be <= headdim
|
44 |
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
45 |
"""
|
46 |
+
batch, seqlen, three, nheads, headdim = qkv.shape
|
47 |
assert three == 3
|
48 |
rotary_seqlen, rotary_dim = cos.shape
|
49 |
rotary_dim *= 2
|
50 |
assert rotary_dim <= headdim
|
51 |
+
assert seqlen <= rotary_seqlen
|
52 |
cos_k = cos if cos_k is None else cos_k
|
53 |
sin_k = sin if sin_k is None else sin_k
|
54 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
55 |
+
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
56 |
+
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
|
57 |
+
# rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
58 |
+
# rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
59 |
+
q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), False)
|
60 |
+
qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
|
61 |
+
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
62 |
+
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
63 |
+
# rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
64 |
+
# rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
65 |
+
k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
|
66 |
+
qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
|
67 |
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
68 |
+
ctx.interleaved = interleaved
|
69 |
return qkv
|
70 |
|
71 |
@staticmethod
|
72 |
def backward(ctx, dqkv):
|
73 |
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
74 |
+
_, seqlen, _, _, headdim = dqkv.shape
|
75 |
rotary_dim = cos.shape[-1]
|
76 |
rotary_dim *= 2
|
77 |
+
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
78 |
+
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
|
79 |
+
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
|
80 |
+
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
81 |
+
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
82 |
+
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
83 |
+
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
84 |
+
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
85 |
+
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
86 |
+
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
87 |
+
return dqkv, None, None, None, None, None
|
88 |
|
89 |
|
90 |
class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
|
137 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
138 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
def eval_forward(self, qkv, seqlen_offset=0):
|
141 |
"""
|
142 |
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
|
157 |
)
|
158 |
|
159 |
|
|
|
160 |
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
161 |
|
162 |
|