PhyscalX's picture
Add code
3d2142b
raw
history blame
8.22 kB
# ------------------------------------------------------------------------
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Text decoder."""
try:
from flash_attn import flash_attn_func
from flash_attn import flash_attn_with_kvcache
from flash_attn.layers.rotary import apply_rotary_emb
except ImportError:
flash_attn_func = None
flash_attn_with_kvcache = None
apply_rotary_emb = None
import torch
from torch import nn
class TransformerCache(nn.Module):
"""Transformer cache module."""
def __init__(self, device=None, dtype=None):
super(TransformerCache, self).__init__()
self.device = device
self.dtype = dtype
self.start_pos = 0
self.cache_dict = {}
def init_seq(self, max_batch_size):
seq_lens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
self.cache_dict["seq_lens"] = seq_lens
def init_rotary(self, seq_len, dim, theta=10000.0):
grid = torch.arange(seq_len, dtype=torch.float32).unsqueeze_(-1)
freq = torch.pow(theta, torch.arange(0, dim, 2)[: dim // 2].float().div_(dim))
broadcast_freq = grid.mul(freq.reciprocal_().unsqueeze_(0))
cache_cos = broadcast_freq.cos().view((-1, dim // 2))
cache_sin = broadcast_freq.sin().view((-1, dim // 2))
self.cache_dict["cos"] = cache_cos.to(self.device, self.dtype)
self.cache_dict["sin"] = cache_sin.to(self.device, self.dtype)
def init_kv(self, mixer, kv_size):
cache_k = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
cache_v = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
self.cache_dict[f"{id(mixer)}_k"] = cache_k
self.cache_dict[f"{id(mixer)}_v"] = cache_v
def set_seq(self, start_pos=0, end_pos=None):
self.start_pos = start_pos
if "seq_lens" in self.cache_dict:
self.cache_dict["seq_lens"].fill_(start_pos)
if "cos" in self.cache_dict and end_pos is not None:
self.cache_dict["seq_cos"] = self.cache_dict["cos"][self.start_pos : end_pos]
self.cache_dict["seq_sin"] = self.cache_dict["sin"][self.start_pos : end_pos]
def forward_rotary(self, q, k, inplace=False):
cos = self.cache_dict.get("seq_cos", self.cache_dict.get("cos", None))
sin = self.cache_dict.get("seq_sin", self.cache_dict.get("sin", None))
if cos is None or sin is None:
return q, k
q = apply_rotary_emb(q, cos, sin, interleaved=True, inplace=inplace)
k = apply_rotary_emb(k, cos, sin, interleaved=True, inplace=inplace)
return q, k
def forward_flash(self, mixer, q, k, v):
cache_k = self.cache_dict.get(f"{id(mixer)}_k", None)
cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
flash_args = {"softmax_scale": mixer.scale, "causal": True}
if cache_k is None or cache_v is None:
return flash_attn_func(q, k, v, **flash_args)
flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
class Attention(nn.Module):
"""Self-Attention layer."""
def __init__(self, dim, num_heads, bias=True):
super(Attention, self).__init__()
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
self.proj = nn.Linear(dim, dim, bias=bias)
self.head_dim = dim // num_heads
self.num_heads = num_heads
self.scale = self.head_dim**-0.5
self.cache = nn.Module()
def forward(self, x):
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
q, k, v = self.qkv(x).view(qkv_shape).unbind(dim=2)
q, k = self.cache.forward_rotary(q, k, inplace=True)
o = self.cache.forward_flash(self, q, k, v)
return self.proj(o.flatten(2))
class MLP(nn.Module):
"""Two layers MLP."""
def __init__(self, dim, mlp_dim, bias=True):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, mlp_dim, bias=bias)
self.fc2 = nn.Linear(mlp_dim, dim, bias=bias)
self.activation = nn.GELU()
def forward(self, x):
return self.fc2(self.activation(self.fc1(x)))
class Block(nn.Module):
"""Transformer block."""
def __init__(self, dim, num_heads, mlp_dim, bias=True):
super(Block, self).__init__()
self.attn = Attention(dim, num_heads, bias=bias)
self.mlp = MLP(dim, mlp_dim, bias=bias)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
x = self.attn(self.norm1(x)).add_(x)
return self.mlp(self.norm2(x)).add_(x)
class Transformer(nn.Module):
"""Causal transformer decoder."""
def __init__(self, depth, dim, num_heads, mlp_dim, vocab_size):
super(Transformer, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.vocab_size = vocab_size
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.blocks = nn.ModuleList(Block(dim, num_heads, mlp_dim) for _ in range(depth))
self.norm = nn.LayerNorm(dim)
self.text_proj = nn.Linear(dim, vocab_size, bias=False)
def forward(self, prompts, tokens, start_pos=0):
prompt_len = prompts.size(1)
start_pos = start_pos + (prompt_len if start_pos > 0 else 0)
end_pos = start_pos + tokens.size(1) + (0 if start_pos > 0 else prompt_len)
self.cache.set_seq(start_pos, end_pos)
x = self.tok_embeddings(tokens)
x = x if start_pos > 0 else torch.cat([prompts, x], dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x[:, 0 if start_pos > 0 else prompt_len :])
return self.text_proj(x).float()
class TextDecoder(nn.Module):
"""Module to decode texts."""
def __init__(
self,
depth,
embed_dim,
num_heads,
mlp_ratio,
prompt_embed_dim,
max_seq_len,
vocab_size,
):
super(TextDecoder, self).__init__()
self.max_seq_len = max_seq_len
self.max_text_len = self.max_seq_len - 1
self.encoder = nn.Linear(prompt_embed_dim, embed_dim, bias=False)
self.transformer = Transformer(
depth=depth,
dim=embed_dim,
mlp_dim=embed_dim * mlp_ratio,
num_heads=num_heads,
vocab_size=vocab_size,
)
def reset_cache(self, max_batch_size=1, max_seq_len=None):
device, dtype = self.encoder.weight.device, self.encoder.weight.dtype
max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
num_heads, head_dim = self.transformer.num_heads, self.transformer.head_dim
self.transformer.cache = TransformerCache(device=device, dtype=dtype)
self.transformer.cache.init_seq(max_batch_size)
self.transformer.cache.init_rotary(max_seq_len, head_dim, theta=10000.0)
kv_cache_size = (max_batch_size, max_seq_len, num_heads, head_dim)
for blk in self.transformer.blocks:
blk.attn.__dict__["cache"] = self.transformer.cache
self.transformer.cache.init_kv(blk.attn, kv_cache_size) if not self.training else None
def get_prompts(self, prompt_tokens):
return self.encoder(prompt_tokens)
def get_outputs(self, inputs, start_pos=0):
return {"text_pred": self.transformer(inputs["prompts"], inputs["tokens"], start_pos)}
def forward(self, inputs, start_pos=0):
return self.get_outputs(inputs, start_pos)