Upload 2 files
Browse files- Files/inference.py +168 -0
- Files/model.py +275 -0
Files/inference.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
import json
|
6 |
+
from sentencepiece import SentencePieceProcessor
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from model import ModelArgs, Transformer
|
10 |
+
|
11 |
+
class LLaMA:
|
12 |
+
|
13 |
+
def __init__(self, model: Transformer, tokenizer: SentencePieceProcessor, model_args: ModelArgs):
|
14 |
+
self.model = model
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.args = model_args
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str):
|
20 |
+
prev_time = time.time()
|
21 |
+
if load_model:
|
22 |
+
checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
|
23 |
+
assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}"
|
24 |
+
ckpt_path = checkpoints[0]
|
25 |
+
print(f'Loading checkpoint "{ckpt_path}"')
|
26 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
27 |
+
print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")
|
28 |
+
prev_time = time.time()
|
29 |
+
with open(Path(checkpoints_dir) / "params.json", "r") as f:
|
30 |
+
params = json.loads(f.read())
|
31 |
+
|
32 |
+
model_args: ModelArgs = ModelArgs(
|
33 |
+
max_seq_len=max_seq_len,
|
34 |
+
max_batch_size=max_batch_size,
|
35 |
+
device=device,
|
36 |
+
**params
|
37 |
+
)
|
38 |
+
|
39 |
+
tokenizer = SentencePieceProcessor()
|
40 |
+
tokenizer.load(tokenizer_path)
|
41 |
+
model_args.vocab_size = tokenizer.vocab_size()
|
42 |
+
|
43 |
+
if device == "cuda":
|
44 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
45 |
+
else:
|
46 |
+
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
47 |
+
|
48 |
+
model = Transformer(model_args).to(device)
|
49 |
+
|
50 |
+
if load_model:
|
51 |
+
# The only unmatched key in the checkpoint is rope.freqs. Remove it
|
52 |
+
del checkpoint['rope.freqs']
|
53 |
+
model.load_state_dict(checkpoint, strict=True)
|
54 |
+
print(f"Loaded state dict in {time.time() - prev_time:.2f}s")
|
55 |
+
|
56 |
+
return LLaMA(model, tokenizer, model_args)
|
57 |
+
|
58 |
+
def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None):
|
59 |
+
if max_gen_len is None:
|
60 |
+
max_gen_len = self.args.max_seq_len - 1
|
61 |
+
# Convert each prompt into tokens
|
62 |
+
prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]
|
63 |
+
# Make sure the batch size is not too large
|
64 |
+
batch_size = len(prompt_tokens)
|
65 |
+
assert batch_size <= self.args.max_batch_size, f"batch size must be less than or equal to {self.args.max_batch_size}"
|
66 |
+
max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
|
67 |
+
# Make sure the prompt length is not larger than the maximum sequence length
|
68 |
+
assert max_prompt_len <= self.args.max_seq_len, f"prompt length must be less than or equal to {self.args.max_seq_len}"
|
69 |
+
total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)
|
70 |
+
|
71 |
+
# Create the list that will contain the generated tokens, along with the initial prompt tokens
|
72 |
+
pad_id = self.tokenizer.pad_id()
|
73 |
+
tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
|
74 |
+
for k, t in enumerate(prompt_tokens):
|
75 |
+
# Populate the initial tokens with the prompt tokens
|
76 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
|
77 |
+
|
78 |
+
eos_reached = torch.tensor([False] * batch_size, device=device)
|
79 |
+
prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise
|
80 |
+
cur_iterator = tqdm(range(1, total_len), desc="Generating tokens")
|
81 |
+
for cur_pos in cur_iterator:
|
82 |
+
with torch.no_grad():
|
83 |
+
logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos)
|
84 |
+
if temperature > 0:
|
85 |
+
# The temperature is applied before the softmax
|
86 |
+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
87 |
+
next_token = self._sample_top_p(probs, top_p)
|
88 |
+
else:
|
89 |
+
# Greedily select the token with the max probability
|
90 |
+
next_token = torch.argmax(logits[:, -1], dim=-1)
|
91 |
+
|
92 |
+
next_token = next_token.reshape(-1)
|
93 |
+
# Only replace token if it is a padding token
|
94 |
+
next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
95 |
+
tokens[:, cur_pos] = next_token
|
96 |
+
# EOS is reached only if we found an EOS token for a padding position
|
97 |
+
eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id)
|
98 |
+
if all(eos_reached):
|
99 |
+
break
|
100 |
+
|
101 |
+
out_tokens = []
|
102 |
+
out_text = []
|
103 |
+
for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
|
104 |
+
# Cut to the EOS token, if present
|
105 |
+
if self.tokenizer.eos_id in current_prompt_tokens:
|
106 |
+
eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)
|
107 |
+
current_prompt_tokens = current_prompt_tokens[:eos_idx]
|
108 |
+
out_tokens.append(current_prompt_tokens)
|
109 |
+
out_text.append(self.tokenizer.decode(current_prompt_tokens))
|
110 |
+
return (out_tokens, out_text)
|
111 |
+
|
112 |
+
def _sample_top_p(self, probs, p):
|
113 |
+
# (B, vocab_size)
|
114 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
115 |
+
# (B, vocab_size)
|
116 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
117 |
+
# (B, vocab_size)
|
118 |
+
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
|
119 |
+
mask = probs_sum - probs_sort > p
|
120 |
+
# Zero out all the probabilities of tokens that are not selected by the Top P
|
121 |
+
probs_sort[mask] = 0.0
|
122 |
+
# Redistribute the probabilities so that they sum up to 1.
|
123 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
124 |
+
# Sample a token (its index) from the top p distribution
|
125 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
126 |
+
# Get the token position in the vocabulary corresponding to the sampled index
|
127 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
128 |
+
return next_token
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
torch.manual_seed(0)
|
134 |
+
|
135 |
+
allow_cuda = False
|
136 |
+
device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu'
|
137 |
+
|
138 |
+
prompts = [
|
139 |
+
"Simply put, the theory of relativity states that ",
|
140 |
+
"If Google was an Italian company founded in Milan, it would",
|
141 |
+
# Few shot promt
|
142 |
+
"""Translate English to French:
|
143 |
+
|
144 |
+
sea otter => loutre de mer
|
145 |
+
peppermint => menthe poivrée
|
146 |
+
plush girafe => girafe peluche
|
147 |
+
cheese =>""",
|
148 |
+
# Zero shot prompt
|
149 |
+
"""Tell me if the following person is actually Doraemon disguised as human:
|
150 |
+
Name: Umar Jamil
|
151 |
+
Decision:
|
152 |
+
"""
|
153 |
+
]
|
154 |
+
|
155 |
+
model = LLaMA.build(
|
156 |
+
checkpoints_dir='llama-2-7b/',
|
157 |
+
tokenizer_path='tokenizer.model',
|
158 |
+
load_model=True,
|
159 |
+
max_seq_len=1024,
|
160 |
+
max_batch_size=len(prompts),
|
161 |
+
device=device
|
162 |
+
)
|
163 |
+
|
164 |
+
out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64))
|
165 |
+
assert len(out_texts) == len(prompts)
|
166 |
+
for i in range(len(out_texts)):
|
167 |
+
print(f'{out_texts[i]}')
|
168 |
+
print('-' * 50)
|
Files/model.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class ModelArgs:
|
11 |
+
dim: int = 4096
|
12 |
+
n_layers: int = 32
|
13 |
+
n_heads: int = 32
|
14 |
+
n_kv_heads: Optional[int] = None
|
15 |
+
vocab_size: int = -1 # Later set in the build method
|
16 |
+
multiple_of: int = 256
|
17 |
+
ffn_dim_multiplier: Optional[float] = None
|
18 |
+
norm_eps: float = 1e-5
|
19 |
+
|
20 |
+
# Needed for KV cache
|
21 |
+
max_batch_size: int = 32
|
22 |
+
max_seq_len: int = 2048
|
23 |
+
|
24 |
+
device: str = None
|
25 |
+
|
26 |
+
|
27 |
+
class RMSNorm(nn.Module):
|
28 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
29 |
+
super().__init__()
|
30 |
+
self.eps = eps
|
31 |
+
# The gamma parameter
|
32 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
33 |
+
|
34 |
+
def _norm(self, x: torch.Tensor):
|
35 |
+
# (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
|
36 |
+
# rsqrt: 1 / sqrt(x)
|
37 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor):
|
40 |
+
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
|
41 |
+
return self.weight * self._norm(x.float()).type_as(x)
|
42 |
+
|
43 |
+
|
44 |
+
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
|
45 |
+
# As written in the paragraph 3.2.2 of the paper
|
46 |
+
# >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
|
47 |
+
assert head_dim % 2 == 0, "Dimension must be divisible by 2"
|
48 |
+
# Build the theta parameter
|
49 |
+
# According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
|
50 |
+
# Shape: (Head_Dim / 2)
|
51 |
+
theta_numerator = torch.arange(0, head_dim, 2).float()
|
52 |
+
# Shape: (Head_Dim / 2)
|
53 |
+
theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
|
54 |
+
# Construct the positions (the "m" parameter)
|
55 |
+
# Shape: (Seq_Len)
|
56 |
+
m = torch.arange(seq_len, device=device)
|
57 |
+
# Multiply each theta by each position using the outer product.
|
58 |
+
# Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
|
59 |
+
freqs = torch.outer(m, theta).float()
|
60 |
+
# We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
|
61 |
+
# (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
|
62 |
+
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
|
63 |
+
return freqs_complex
|
64 |
+
|
65 |
+
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
|
66 |
+
# Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
|
67 |
+
# Two consecutive values will become a single complex number
|
68 |
+
# (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
|
69 |
+
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
70 |
+
# Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
|
71 |
+
# (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
|
72 |
+
freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
|
73 |
+
# Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
|
74 |
+
# Which results in the rotation of the complex number as shown in the Figure 1 of the paper
|
75 |
+
# (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
|
76 |
+
x_rotated = x_complex * freqs_complex
|
77 |
+
# Convert the complex number back to the real number
|
78 |
+
# (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
|
79 |
+
x_out = torch.view_as_real(x_rotated)
|
80 |
+
# (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
|
81 |
+
x_out = x_out.reshape(*x.shape)
|
82 |
+
return x_out.type_as(x).to(device)
|
83 |
+
|
84 |
+
|
85 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
86 |
+
batch_size, seq_len, n_kv_heads, head_dim = x.shape
|
87 |
+
if n_rep == 1:
|
88 |
+
return x
|
89 |
+
return (
|
90 |
+
# (B, Seq_Len, N_KV_Heads, 1, Head_Dim)
|
91 |
+
x[:, :, :, None, :]
|
92 |
+
# (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim)
|
93 |
+
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
|
94 |
+
# (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim)
|
95 |
+
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
class SelfAttention(nn.Module):
|
100 |
+
def __init__(self, args: ModelArgs):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
# Indicates the number of heads for the Keys and Values
|
104 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
105 |
+
# Indicates the number of heads for the Queries
|
106 |
+
self.n_heads_q = args.n_heads
|
107 |
+
# Indicates how many times the Keys and Values should be repeated
|
108 |
+
self.n_rep = self.n_heads_q // self.n_kv_heads
|
109 |
+
# Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
|
110 |
+
self.head_dim = args.dim // args.n_heads
|
111 |
+
|
112 |
+
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
113 |
+
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
114 |
+
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
115 |
+
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
116 |
+
|
117 |
+
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
|
118 |
+
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self,
|
122 |
+
x: torch.Tensor,
|
123 |
+
start_pos: int,
|
124 |
+
freqs_complex: torch.Tensor
|
125 |
+
):
|
126 |
+
batch_size, seq_len, _ = x.shape # (B, 1, Dim)
|
127 |
+
|
128 |
+
# (B, 1, Dim) -> (B, 1, H_Q * Head_Dim)
|
129 |
+
xq = self.wq(x)
|
130 |
+
# (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
|
131 |
+
xk = self.wk(x)
|
132 |
+
# (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
|
133 |
+
xv = self.wv(x)
|
134 |
+
|
135 |
+
# (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim)
|
136 |
+
xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
|
137 |
+
# (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
|
138 |
+
xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
|
139 |
+
# (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
|
140 |
+
xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
|
141 |
+
|
142 |
+
# (B, 1, H_Q, Head_Dim) --> (B, 1, H_Q, Head_Dim)
|
143 |
+
xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
|
144 |
+
# (B, 1, H_KV, Head_Dim) --> (B, 1, H_KV, Head_Dim)
|
145 |
+
xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)
|
146 |
+
|
147 |
+
# Replace the entry in the cache
|
148 |
+
self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
|
149 |
+
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
|
150 |
+
|
151 |
+
# (B, Seq_Len_KV, H_KV, Head_Dim)
|
152 |
+
keys = self.cache_k[:batch_size, : start_pos + seq_len]
|
153 |
+
# (B, Seq_Len_KV, H_KV, Head_Dim)
|
154 |
+
values = self.cache_v[:batch_size, : start_pos + seq_len]
|
155 |
+
|
156 |
+
# Since every group of Q shares the same K and V heads, just repeat the K and V heads for every Q in the same group.
|
157 |
+
|
158 |
+
# (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim)
|
159 |
+
keys = repeat_kv(keys, self.n_rep)
|
160 |
+
# (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim)
|
161 |
+
values = repeat_kv(values, self.n_rep)
|
162 |
+
|
163 |
+
# (B, 1, H_Q, Head_Dim) -> (B, H_Q, 1, Head_Dim)
|
164 |
+
xq = xq.transpose(1, 2)
|
165 |
+
# (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
|
166 |
+
keys = keys.transpose(1, 2)
|
167 |
+
# (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
|
168 |
+
values = values.transpose(1, 2)
|
169 |
+
|
170 |
+
# (B, H_Q, 1, Head_Dim) @ (B, H_Q, Head_Dim, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV)
|
171 |
+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
172 |
+
# (B, H_Q, 1, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV)
|
173 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
174 |
+
|
175 |
+
# (B, H_Q, 1, Seq_Len) @ (B, H_Q, Seq_Len_KV, Head_Dim) -> (B, H_Q, 1, Head_Dim)
|
176 |
+
output = torch.matmul(scores, values)
|
177 |
+
# (B, H_Q, 1, Head_Dim) -> (B, 1, H_Q, Head_Dim) -> (B, 1, Dim)
|
178 |
+
output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
|
179 |
+
return self.wo(output) # (B, 1, Dim) -> (B, 1, Dim)
|
180 |
+
|
181 |
+
|
182 |
+
class FeedForward(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
args: ModelArgs
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
hidden_dim = 4 * args.dim
|
190 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
191 |
+
if args.ffn_dim_multiplier is not None:
|
192 |
+
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
|
193 |
+
# Round the hidden_dim to the nearest multiple of the multiple_of parameter
|
194 |
+
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
|
195 |
+
|
196 |
+
self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
|
197 |
+
self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
|
198 |
+
self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
|
199 |
+
|
200 |
+
def forward(self, x: torch.Tensor):
|
201 |
+
# (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
|
202 |
+
swish = F.silu(self.w1(x))
|
203 |
+
# (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
|
204 |
+
x_V = self.w3(x)
|
205 |
+
# (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
|
206 |
+
x = swish * x_V
|
207 |
+
# (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
|
208 |
+
x = self.w2(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class EncoderBlock(nn.Module):
|
213 |
+
|
214 |
+
def __init__(self, args: ModelArgs):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
self.n_heads = args.n_heads
|
218 |
+
self.dim = args.dim
|
219 |
+
self.head_dim = args.dim // args.n_heads
|
220 |
+
|
221 |
+
self.attention = SelfAttention(args)
|
222 |
+
self.feed_forward = FeedForward(args)
|
223 |
+
|
224 |
+
# Normalization BEFORE the attention block
|
225 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
226 |
+
# Normalization BEFORE the feed forward block
|
227 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
228 |
+
|
229 |
+
def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
|
230 |
+
# (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
|
231 |
+
h = x + self.attention.forward(
|
232 |
+
self.attention_norm(x), start_pos, freqs_complex
|
233 |
+
)
|
234 |
+
# (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
|
235 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
236 |
+
return out
|
237 |
+
|
238 |
+
class Transformer(nn.Module):
|
239 |
+
|
240 |
+
def __init__(self, args: ModelArgs):
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
assert args.vocab_size != -1, "Vocab size must be set"
|
244 |
+
|
245 |
+
self.args = args
|
246 |
+
self.vocab_size = args.vocab_size
|
247 |
+
self.n_layers = args.n_layers
|
248 |
+
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
249 |
+
|
250 |
+
self.layers = nn.ModuleList()
|
251 |
+
for layer_id in range(args.n_layers):
|
252 |
+
self.layers.append(EncoderBlock(args))
|
253 |
+
|
254 |
+
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
255 |
+
self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
|
256 |
+
|
257 |
+
self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)
|
258 |
+
|
259 |
+
def forward(self, tokens: torch.Tensor, start_pos: int):
|
260 |
+
# (B, Seq_Len)
|
261 |
+
batch_size, seq_len = tokens.shape
|
262 |
+
assert seq_len == 1, "Only one token at a time can be processed"
|
263 |
+
|
264 |
+
# (B, Seq_Len) -> (B, Seq_Len, Dim)
|
265 |
+
h = self.tok_embeddings(tokens)
|
266 |
+
|
267 |
+
# Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
|
268 |
+
freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
|
269 |
+
|
270 |
+
# Consecutively apply all the encoder layers
|
271 |
+
for layer in self.layers:
|
272 |
+
h = layer(h, start_pos, freqs_complex)
|
273 |
+
h = self.norm(h)
|
274 |
+
output = self.output(h).float()
|
275 |
+
return output
|