import torch def init_t_xy(end_x: int, end_y: int): t = torch.arange(end_x * end_y, dtype=torch.float32) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() return t_x, t_y def compute_axial_cis( dim: int, end_x: int, end_y: int, theta: float = 100.0, norm_coeff: int = 1 ): freqs_x = ( 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) * norm_coeff ) freqs_y = ( 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) * norm_coeff ) t_x, t_y = init_t_xy(end_x, end_y) freqs_x = torch.outer(t_x, freqs_x) freqs_y = torch.outer(t_y, freqs_y) freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim freqs_cis = freqs_cis[:, x.shape[1], ...] if freqs_cis.shape == (x.shape[-2], x.shape[-1]): shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]): shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor): with torch.cuda.amp.autocast(enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) # freqs_cis = reshape_for_broadcast(freqs_cis, x).to(x_in.device) freqs_cis = freqs_cis[None, :, : x.shape[2], ...].to(x_in.device) x_out = torch.view_as_real(x * freqs_cis).flatten(3) return x_out.type_as(x_in)