File size: 11,739 Bytes
3a83cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
Orginally Taken verbatim from xformers library
https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py

The difference is that xformers seems to assume the inputs to be
(bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim)

"""
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
# NOTE: Almost the same right now, moving parts to Triton is the next step

import math
from typing import List, Optional, Tuple, Dict, Union

import torch
import dataclasses
from transformers.utils import logging

from transformers import PretrainedConfig

is_dacite_available = False
try:
    import dacite
    is_dacite_available = True
except ImportError:
    pass

logger = logging.get_logger(__name__)

@dataclasses.dataclass
class LongRopeConfig(object):
    short_factor: List[float]
    long_factor: List[float]
    original_max_position_embeddings: int
    type: str = "longrope"
    short_mscale: float = -1
    long_mscale: float = -1


    def __post_init__(self):
        assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su"


    @classmethod
    def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig":
        if is_dacite_available:
            # Preferred since we can also type check the input
            return dacite.from_dict(data_class=cls, data=config_dict)
        kwargs = {}
        for field in dataclasses.fields(cls):
            if field.name in config_dict:
                if field.init:
                    kwargs[field.name] = config_dict[field.name]
                else:
                    raise ValueError(f"Field {field.name} is not initiable")
            else:
                if field.default is dataclasses.MISSING:
                    raise ValueError(f"Field {field.name} is required")
        extra_keys = set(config_dict.keys()) - set(kwargs.keys())
        if len(extra_keys) > 0:
            for key in extra_keys:
                logger.error(f"Unrecognized key {key} in config_dict")
            raise ValueError(f"Unrecognized keys in config_dict")
        return cls(**kwargs)

def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)



@torch.jit.script
def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int):
    # NOTE: This could probably be moved to Triton

    if seq_dimension == 0:
        cos = cos[: x.shape[0], None, None, :]
        sin = sin[: x.shape[0], None, None, :]
    elif seq_dimension == 1:
        # Handle a possible sequence length mismatch in between q and k
        cos = cos[None, : x.shape[1], None, :]
        sin = sin[None, : x.shape[1], None, :]
    elif seq_dimension == 2:
        cos = cos[None, None, : x.shape[2], :]
        sin = sin[None, None, : x.shape[2], :]

    return (x * cos) + (rotate_half(x) * sin)



class RotaryEmbedding(torch.nn.Module):
    """
    Adapted from the xformers library

    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.
    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration
    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
    .. warning: Please note that this embedding is not registered on purpose, as it is transformative
        (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis

    # Arguments
    :param dim_mode: head dimention
    :param max_seq_len:
    :param default_seq_dimension: which dim is the sequence length
    :param dtype: cos/sin dtype
    :param use_fused_kernel: if to use customized fused kernel.
        Note: if used, q, k will be modified inplace. Ok for both forward & backward.
    """

    def __init__(
        self,
        dim_model: int,
        *,
        max_seq_len: Optional[int] = None,
        dtype: Optional[torch.dtype] = None,
        base=10000,
        position_scale=1,
        device: Optional[torch.device] = None,
        longrope_config: Optional[LongRopeConfig] = None,
    ):
        super().__init__()
        self.base = base
        self.dim_model = dim_model
        self.max_seq_len = max_seq_len
        self.longrope_config = longrope_config

        if self.is_longrope:
            # Keep the maximum range vector, and slice from it as needed
            self.register_buffer(
                "range_vector",
                torch.arange(max_seq_len, device=device, dtype=torch.float32),
                persistent=False
            )
            self.register_buffer(
                "short_factors",
                torch.tensor(self.longrope_config.short_factor, dtype=torch.float32),
                persistent=False
            )
            self.register_buffer(
                "long_factors",
                torch.tensor(self.longrope_config.long_factor, dtype=torch.float32),
                persistent=False
            )
        else:
            # Generate and save the inverse frequency buffer (non trainable)
            inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model))
            self.register_buffer("inv_freq", inv_freq)

        self.position_scale = position_scale
        
        if not self.is_longrope:
            dtype = dtype or torch.get_default_dtype()
            self._set_cos_sin_cache(
                seq_len=max_seq_len,
                device=self.inv_freq.device,
                dtype=dtype,
            )
    @property
    def is_longrope(self):
        return self.longrope_config is not None

    @property
    def original_max_seq_len(self):
        if self.longrope_config is not None:
            return self.longrope_config.original_max_position_embeddings
        logger.warning_once(
            (
                "``original_max_seq_len'' is being accessed, but longrope_config has not been set. "
                "Please only do this if you are sure about the context."
            )
        )
        return self.max_seq_len

    def get_range_vector(self, seq_len: int, device: torch.device):
        if self.is_longrope:
            assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}"
            if self.range_vector.device != device:
                self.range_vector = self.range_vector.to(device)
            return self.range_vector[:seq_len]
        return torch.arange(seq_len, device=device, dtype=torch.float32)


    def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor:
        if scale <= 1.0:
            return 1.0
        return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len))

    def _set_cos_sin_cache(
        self,
        seq_len: int,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        dtype = dtype or torch.get_default_dtype()
        self.max_seq_len_cached = seq_len
        t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq)
        device_type = device.type if device is not None else "cpu"
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            # shape: (seq_len, dim_model // 2)
            freqs = torch.outer(t, self.inv_freq)
            # shape: (seq_len, dim_model)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        self.register_buffer("cos_cached", cos.to(dtype), persistent=False)
        self.register_buffer("sin_cached", sin.to(dtype), persistent=False)

    def forward(
        self, q: torch.Tensor,
        k: torch.Tensor,
        seq_dimension: int = 1,
        seqlen_offset: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """q, k does not include `seqlen_offset`
        q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
        k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
        """
        if seq_dimension < 0:
            seq_dimension = k.ndim + seq_dimension
        assert seq_dimension in (0, 1, 2)
        seq_len = k.shape[seq_dimension] + seqlen_offset

        if self.is_longrope:
            if seq_len > self.original_max_seq_len:
                t = self.get_range_vector(seq_len, device=q.device)
                rescale_factors = self.long_factors.to(q.device)
                long_mscale = self.longrope_config.long_mscale
                mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len)
            else:
                t = self.get_range_vector(self.original_max_seq_len, device=q.device)
                rescale_factors = self.short_factors.to(q.device)
                short_mscale = self.longrope_config.short_mscale
                mscale = short_mscale if short_mscale > 0 else 1.0
            assert rescale_factors.shape == (self.dim_model // 2, ), (
                f"misaligned shape for LongRoPE rescale factors:\n"
                f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}."
            )
            inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model)))
            device_type = q.device.type if q.device is not None else "cpu"
            device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                freqs = torch.outer(t, inv_freq)
                emb = torch.cat((freqs, freqs), dim=-1)
                cos = emb.cos() * mscale
                sin = emb.sin() * mscale
            cos_cached = cos.to(q.dtype)
            sin_cached = sin.to(q.dtype)
        else:
            if seq_len > self.max_seq_len_cached:
                self._set_cos_sin_cache(
                    seq_len=seq_len,
                    device=k.device,
                    dtype=k.dtype,
                )
            cos_cached = self.cos_cached
            sin_cached = self.sin_cached
        return (
            apply_rotary_pos_emb(
                q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
            ).to(q.dtype),
            apply_rotary_pos_emb(
                k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
            ).to(k.dtype),
        )

    @classmethod
    def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding":
        kwargs = dict(
            dim_model=config.hidden_size // config.num_attention_heads,
            max_seq_len=config.max_position_embeddings,
            base=config.rope_embedding_base,
            position_scale=config.rope_position_scale,
        )
        if config.rope_scaling is not None:
            kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
        return cls(**kwargs)