# ------------------------------------------------------------------------ # Copyright (c) 2023-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------ """Text decoder.""" try: from flash_attn import flash_attn_func from flash_attn import flash_attn_with_kvcache from flash_attn.layers.rotary import apply_rotary_emb except ImportError: flash_attn_func = None flash_attn_with_kvcache = None apply_rotary_emb = None import torch from torch import nn class TransformerCache(nn.Module): """Transformer cache module.""" def __init__(self, device=None, dtype=None): super(TransformerCache, self).__init__() self.device = device self.dtype = dtype self.start_pos = 0 self.cache_dict = {} def init_seq(self, max_batch_size): seq_lens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device) self.cache_dict["seq_lens"] = seq_lens def init_rotary(self, seq_len, dim, theta=10000.0): grid = torch.arange(seq_len, dtype=torch.float32).unsqueeze_(-1) freq = torch.pow(theta, torch.arange(0, dim, 2)[: dim // 2].float().div_(dim)) broadcast_freq = grid.mul(freq.reciprocal_().unsqueeze_(0)) cache_cos = broadcast_freq.cos().view((-1, dim // 2)) cache_sin = broadcast_freq.sin().view((-1, dim // 2)) self.cache_dict["cos"] = cache_cos.to(self.device, self.dtype) self.cache_dict["sin"] = cache_sin.to(self.device, self.dtype) def init_kv(self, mixer, kv_size): cache_k = torch.zeros(*kv_size, dtype=self.dtype, device=self.device) cache_v = torch.zeros(*kv_size, dtype=self.dtype, device=self.device) self.cache_dict[f"{id(mixer)}_k"] = cache_k self.cache_dict[f"{id(mixer)}_v"] = cache_v def set_seq(self, start_pos=0, end_pos=None): self.start_pos = start_pos if "seq_lens" in self.cache_dict: self.cache_dict["seq_lens"].fill_(start_pos) if "cos" in self.cache_dict and end_pos is not None: self.cache_dict["seq_cos"] = self.cache_dict["cos"][self.start_pos : end_pos] self.cache_dict["seq_sin"] = self.cache_dict["sin"][self.start_pos : end_pos] def forward_rotary(self, q, k, inplace=False): cos = self.cache_dict.get("seq_cos", self.cache_dict.get("cos", None)) sin = self.cache_dict.get("seq_sin", self.cache_dict.get("sin", None)) if cos is None or sin is None: return q, k q = apply_rotary_emb(q, cos, sin, interleaved=True, inplace=inplace) k = apply_rotary_emb(k, cos, sin, interleaved=True, inplace=inplace) return q, k def forward_flash(self, mixer, q, k, v): cache_k = self.cache_dict.get(f"{id(mixer)}_k", None) cache_v = self.cache_dict.get(f"{id(mixer)}_v", None) flash_args = {"softmax_scale": mixer.scale, "causal": True} if cache_k is None or cache_v is None: return flash_attn_func(q, k, v, **flash_args) flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]] return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args) class Attention(nn.Module): """Self-Attention layer.""" def __init__(self, dim, num_heads, bias=True): super(Attention, self).__init__() self.qkv = nn.Linear(dim, dim * 3, bias=bias) self.proj = nn.Linear(dim, dim, bias=bias) self.head_dim = dim // num_heads self.num_heads = num_heads self.scale = self.head_dim**-0.5 self.cache = nn.Module() def forward(self, x): qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim) q, k, v = self.qkv(x).view(qkv_shape).unbind(dim=2) q, k = self.cache.forward_rotary(q, k, inplace=True) o = self.cache.forward_flash(self, q, k, v) return self.proj(o.flatten(2)) class MLP(nn.Module): """Two layers MLP.""" def __init__(self, dim, mlp_dim, bias=True): super(MLP, self).__init__() self.fc1 = nn.Linear(dim, mlp_dim, bias=bias) self.fc2 = nn.Linear(mlp_dim, dim, bias=bias) self.activation = nn.GELU() def forward(self, x): return self.fc2(self.activation(self.fc1(x))) class Block(nn.Module): """Transformer block.""" def __init__(self, dim, num_heads, mlp_dim, bias=True): super(Block, self).__init__() self.attn = Attention(dim, num_heads, bias=bias) self.mlp = MLP(dim, mlp_dim, bias=bias) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) def forward(self, x): x = self.attn(self.norm1(x)).add_(x) return self.mlp(self.norm2(x)).add_(x) class Transformer(nn.Module): """Causal transformer decoder.""" def __init__(self, depth, dim, num_heads, mlp_dim, vocab_size): super(Transformer, self).__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.vocab_size = vocab_size self.tok_embeddings = nn.Embedding(vocab_size, dim) self.blocks = nn.ModuleList(Block(dim, num_heads, mlp_dim) for _ in range(depth)) self.norm = nn.LayerNorm(dim) self.text_proj = nn.Linear(dim, vocab_size, bias=False) def forward(self, prompts, tokens, start_pos=0): prompt_len = prompts.size(1) start_pos = start_pos + (prompt_len if start_pos > 0 else 0) end_pos = start_pos + tokens.size(1) + (0 if start_pos > 0 else prompt_len) self.cache.set_seq(start_pos, end_pos) x = self.tok_embeddings(tokens) x = x if start_pos > 0 else torch.cat([prompts, x], dim=1) for blk in self.blocks: x = blk(x) x = self.norm(x[:, 0 if start_pos > 0 else prompt_len :]) return self.text_proj(x).float() class TextDecoder(nn.Module): """Module to decode texts.""" def __init__( self, depth, embed_dim, num_heads, mlp_ratio, prompt_embed_dim, max_seq_len, vocab_size, ): super(TextDecoder, self).__init__() self.max_seq_len = max_seq_len self.max_text_len = self.max_seq_len - 1 self.encoder = nn.Linear(prompt_embed_dim, embed_dim, bias=False) self.transformer = Transformer( depth=depth, dim=embed_dim, mlp_dim=embed_dim * mlp_ratio, num_heads=num_heads, vocab_size=vocab_size, ) def reset_cache(self, max_batch_size=1, max_seq_len=None): device, dtype = self.encoder.weight.device, self.encoder.weight.dtype max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len num_heads, head_dim = self.transformer.num_heads, self.transformer.head_dim self.transformer.cache = TransformerCache(device=device, dtype=dtype) self.transformer.cache.init_seq(max_batch_size) self.transformer.cache.init_rotary(max_seq_len, head_dim, theta=10000.0) kv_cache_size = (max_batch_size, max_seq_len, num_heads, head_dim) for blk in self.transformer.blocks: blk.attn.__dict__["cache"] = self.transformer.cache self.transformer.cache.init_kv(blk.attn, kv_cache_size) if not self.training else None def get_prompts(self, prompt_tokens): return self.encoder(prompt_tokens) def get_outputs(self, inputs, start_pos=0): return {"text_pred": self.transformer(inputs["prompts"], inputs["tokens"], start_pos)} def forward(self, inputs, start_pos=0): return self.get_outputs(inputs, start_pos)