pt-sk commited on
Commit
581227b
1 Parent(s): 911d9a2

Upload 2 files

Browse files
Files changed (2) hide show
  1. Files/inference.py +168 -0
  2. Files/model.py +275 -0
Files/inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import time
4
+ from pathlib import Path
5
+ import json
6
+ from sentencepiece import SentencePieceProcessor
7
+ from tqdm import tqdm
8
+
9
+ from model import ModelArgs, Transformer
10
+
11
+ class LLaMA:
12
+
13
+ def __init__(self, model: Transformer, tokenizer: SentencePieceProcessor, model_args: ModelArgs):
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+ self.args = model_args
17
+
18
+ @staticmethod
19
+ def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str):
20
+ prev_time = time.time()
21
+ if load_model:
22
+ checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
23
+ assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}"
24
+ ckpt_path = checkpoints[0]
25
+ print(f'Loading checkpoint "{ckpt_path}"')
26
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
27
+ print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")
28
+ prev_time = time.time()
29
+ with open(Path(checkpoints_dir) / "params.json", "r") as f:
30
+ params = json.loads(f.read())
31
+
32
+ model_args: ModelArgs = ModelArgs(
33
+ max_seq_len=max_seq_len,
34
+ max_batch_size=max_batch_size,
35
+ device=device,
36
+ **params
37
+ )
38
+
39
+ tokenizer = SentencePieceProcessor()
40
+ tokenizer.load(tokenizer_path)
41
+ model_args.vocab_size = tokenizer.vocab_size()
42
+
43
+ if device == "cuda":
44
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
45
+ else:
46
+ torch.set_default_tensor_type(torch.BFloat16Tensor)
47
+
48
+ model = Transformer(model_args).to(device)
49
+
50
+ if load_model:
51
+ # The only unmatched key in the checkpoint is rope.freqs. Remove it
52
+ del checkpoint['rope.freqs']
53
+ model.load_state_dict(checkpoint, strict=True)
54
+ print(f"Loaded state dict in {time.time() - prev_time:.2f}s")
55
+
56
+ return LLaMA(model, tokenizer, model_args)
57
+
58
+ def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None):
59
+ if max_gen_len is None:
60
+ max_gen_len = self.args.max_seq_len - 1
61
+ # Convert each prompt into tokens
62
+ prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]
63
+ # Make sure the batch size is not too large
64
+ batch_size = len(prompt_tokens)
65
+ assert batch_size <= self.args.max_batch_size, f"batch size must be less than or equal to {self.args.max_batch_size}"
66
+ max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
67
+ # Make sure the prompt length is not larger than the maximum sequence length
68
+ assert max_prompt_len <= self.args.max_seq_len, f"prompt length must be less than or equal to {self.args.max_seq_len}"
69
+ total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)
70
+
71
+ # Create the list that will contain the generated tokens, along with the initial prompt tokens
72
+ pad_id = self.tokenizer.pad_id()
73
+ tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
74
+ for k, t in enumerate(prompt_tokens):
75
+ # Populate the initial tokens with the prompt tokens
76
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
77
+
78
+ eos_reached = torch.tensor([False] * batch_size, device=device)
79
+ prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise
80
+ cur_iterator = tqdm(range(1, total_len), desc="Generating tokens")
81
+ for cur_pos in cur_iterator:
82
+ with torch.no_grad():
83
+ logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos)
84
+ if temperature > 0:
85
+ # The temperature is applied before the softmax
86
+ probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
87
+ next_token = self._sample_top_p(probs, top_p)
88
+ else:
89
+ # Greedily select the token with the max probability
90
+ next_token = torch.argmax(logits[:, -1], dim=-1)
91
+
92
+ next_token = next_token.reshape(-1)
93
+ # Only replace token if it is a padding token
94
+ next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token)
95
+ tokens[:, cur_pos] = next_token
96
+ # EOS is reached only if we found an EOS token for a padding position
97
+ eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id)
98
+ if all(eos_reached):
99
+ break
100
+
101
+ out_tokens = []
102
+ out_text = []
103
+ for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
104
+ # Cut to the EOS token, if present
105
+ if self.tokenizer.eos_id in current_prompt_tokens:
106
+ eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)
107
+ current_prompt_tokens = current_prompt_tokens[:eos_idx]
108
+ out_tokens.append(current_prompt_tokens)
109
+ out_text.append(self.tokenizer.decode(current_prompt_tokens))
110
+ return (out_tokens, out_text)
111
+
112
+ def _sample_top_p(self, probs, p):
113
+ # (B, vocab_size)
114
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
115
+ # (B, vocab_size)
116
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
117
+ # (B, vocab_size)
118
+ # (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
119
+ mask = probs_sum - probs_sort > p
120
+ # Zero out all the probabilities of tokens that are not selected by the Top P
121
+ probs_sort[mask] = 0.0
122
+ # Redistribute the probabilities so that they sum up to 1.
123
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
124
+ # Sample a token (its index) from the top p distribution
125
+ next_token = torch.multinomial(probs_sort, num_samples=1)
126
+ # Get the token position in the vocabulary corresponding to the sampled index
127
+ next_token = torch.gather(probs_idx, -1, next_token)
128
+ return next_token
129
+
130
+
131
+
132
+ if __name__ == '__main__':
133
+ torch.manual_seed(0)
134
+
135
+ allow_cuda = False
136
+ device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu'
137
+
138
+ prompts = [
139
+ "Simply put, the theory of relativity states that ",
140
+ "If Google was an Italian company founded in Milan, it would",
141
+ # Few shot promt
142
+ """Translate English to French:
143
+
144
+ sea otter => loutre de mer
145
+ peppermint => menthe poivrée
146
+ plush girafe => girafe peluche
147
+ cheese =>""",
148
+ # Zero shot prompt
149
+ """Tell me if the following person is actually Doraemon disguised as human:
150
+ Name: Umar Jamil
151
+ Decision:
152
+ """
153
+ ]
154
+
155
+ model = LLaMA.build(
156
+ checkpoints_dir='llama-2-7b/',
157
+ tokenizer_path='tokenizer.model',
158
+ load_model=True,
159
+ max_seq_len=1024,
160
+ max_batch_size=len(prompts),
161
+ device=device
162
+ )
163
+
164
+ out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64))
165
+ assert len(out_texts) == len(prompts)
166
+ for i in range(len(out_texts)):
167
+ print(f'{out_texts[i]}')
168
+ print('-' * 50)
Files/model.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ @dataclass
10
+ class ModelArgs:
11
+ dim: int = 4096
12
+ n_layers: int = 32
13
+ n_heads: int = 32
14
+ n_kv_heads: Optional[int] = None
15
+ vocab_size: int = -1 # Later set in the build method
16
+ multiple_of: int = 256
17
+ ffn_dim_multiplier: Optional[float] = None
18
+ norm_eps: float = 1e-5
19
+
20
+ # Needed for KV cache
21
+ max_batch_size: int = 32
22
+ max_seq_len: int = 2048
23
+
24
+ device: str = None
25
+
26
+
27
+ class RMSNorm(nn.Module):
28
+ def __init__(self, dim: int, eps: float = 1e-6):
29
+ super().__init__()
30
+ self.eps = eps
31
+ # The gamma parameter
32
+ self.weight = nn.Parameter(torch.ones(dim))
33
+
34
+ def _norm(self, x: torch.Tensor):
35
+ # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
36
+ # rsqrt: 1 / sqrt(x)
37
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
38
+
39
+ def forward(self, x: torch.Tensor):
40
+ # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
41
+ return self.weight * self._norm(x.float()).type_as(x)
42
+
43
+
44
+ def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
45
+ # As written in the paragraph 3.2.2 of the paper
46
+ # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
47
+ assert head_dim % 2 == 0, "Dimension must be divisible by 2"
48
+ # Build the theta parameter
49
+ # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
50
+ # Shape: (Head_Dim / 2)
51
+ theta_numerator = torch.arange(0, head_dim, 2).float()
52
+ # Shape: (Head_Dim / 2)
53
+ theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
54
+ # Construct the positions (the "m" parameter)
55
+ # Shape: (Seq_Len)
56
+ m = torch.arange(seq_len, device=device)
57
+ # Multiply each theta by each position using the outer product.
58
+ # Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
59
+ freqs = torch.outer(m, theta).float()
60
+ # We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
61
+ # (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
62
+ freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
63
+ return freqs_complex
64
+
65
+ def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
66
+ # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
67
+ # Two consecutive values will become a single complex number
68
+ # (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
69
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
70
+ # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
71
+ # (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
72
+ freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
73
+ # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
74
+ # Which results in the rotation of the complex number as shown in the Figure 1 of the paper
75
+ # (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
76
+ x_rotated = x_complex * freqs_complex
77
+ # Convert the complex number back to the real number
78
+ # (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
79
+ x_out = torch.view_as_real(x_rotated)
80
+ # (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
81
+ x_out = x_out.reshape(*x.shape)
82
+ return x_out.type_as(x).to(device)
83
+
84
+
85
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
86
+ batch_size, seq_len, n_kv_heads, head_dim = x.shape
87
+ if n_rep == 1:
88
+ return x
89
+ return (
90
+ # (B, Seq_Len, N_KV_Heads, 1, Head_Dim)
91
+ x[:, :, :, None, :]
92
+ # (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim)
93
+ .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
94
+ # (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim)
95
+ .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
96
+ )
97
+
98
+
99
+ class SelfAttention(nn.Module):
100
+ def __init__(self, args: ModelArgs):
101
+ super().__init__()
102
+
103
+ # Indicates the number of heads for the Keys and Values
104
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
105
+ # Indicates the number of heads for the Queries
106
+ self.n_heads_q = args.n_heads
107
+ # Indicates how many times the Keys and Values should be repeated
108
+ self.n_rep = self.n_heads_q // self.n_kv_heads
109
+ # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
110
+ self.head_dim = args.dim // args.n_heads
111
+
112
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
113
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
114
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
115
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
116
+
117
+ self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
118
+ self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ start_pos: int,
124
+ freqs_complex: torch.Tensor
125
+ ):
126
+ batch_size, seq_len, _ = x.shape # (B, 1, Dim)
127
+
128
+ # (B, 1, Dim) -> (B, 1, H_Q * Head_Dim)
129
+ xq = self.wq(x)
130
+ # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
131
+ xk = self.wk(x)
132
+ # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
133
+ xv = self.wv(x)
134
+
135
+ # (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim)
136
+ xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
137
+ # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
138
+ xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
139
+ # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
140
+ xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
141
+
142
+ # (B, 1, H_Q, Head_Dim) --> (B, 1, H_Q, Head_Dim)
143
+ xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
144
+ # (B, 1, H_KV, Head_Dim) --> (B, 1, H_KV, Head_Dim)
145
+ xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)
146
+
147
+ # Replace the entry in the cache
148
+ self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
149
+ self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
150
+
151
+ # (B, Seq_Len_KV, H_KV, Head_Dim)
152
+ keys = self.cache_k[:batch_size, : start_pos + seq_len]
153
+ # (B, Seq_Len_KV, H_KV, Head_Dim)
154
+ values = self.cache_v[:batch_size, : start_pos + seq_len]
155
+
156
+ # Since every group of Q shares the same K and V heads, just repeat the K and V heads for every Q in the same group.
157
+
158
+ # (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim)
159
+ keys = repeat_kv(keys, self.n_rep)
160
+ # (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim)
161
+ values = repeat_kv(values, self.n_rep)
162
+
163
+ # (B, 1, H_Q, Head_Dim) -> (B, H_Q, 1, Head_Dim)
164
+ xq = xq.transpose(1, 2)
165
+ # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
166
+ keys = keys.transpose(1, 2)
167
+ # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim)
168
+ values = values.transpose(1, 2)
169
+
170
+ # (B, H_Q, 1, Head_Dim) @ (B, H_Q, Head_Dim, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV)
171
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
172
+ # (B, H_Q, 1, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV)
173
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
174
+
175
+ # (B, H_Q, 1, Seq_Len) @ (B, H_Q, Seq_Len_KV, Head_Dim) -> (B, H_Q, 1, Head_Dim)
176
+ output = torch.matmul(scores, values)
177
+ # (B, H_Q, 1, Head_Dim) -> (B, 1, H_Q, Head_Dim) -> (B, 1, Dim)
178
+ output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
179
+ return self.wo(output) # (B, 1, Dim) -> (B, 1, Dim)
180
+
181
+
182
+ class FeedForward(nn.Module):
183
+ def __init__(
184
+ self,
185
+ args: ModelArgs
186
+ ):
187
+ super().__init__()
188
+
189
+ hidden_dim = 4 * args.dim
190
+ hidden_dim = int(2 * hidden_dim / 3)
191
+ if args.ffn_dim_multiplier is not None:
192
+ hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
193
+ # Round the hidden_dim to the nearest multiple of the multiple_of parameter
194
+ hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
195
+
196
+ self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
197
+ self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
198
+ self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
199
+
200
+ def forward(self, x: torch.Tensor):
201
+ # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
202
+ swish = F.silu(self.w1(x))
203
+ # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
204
+ x_V = self.w3(x)
205
+ # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
206
+ x = swish * x_V
207
+ # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
208
+ x = self.w2(x)
209
+ return x
210
+
211
+
212
+ class EncoderBlock(nn.Module):
213
+
214
+ def __init__(self, args: ModelArgs):
215
+ super().__init__()
216
+
217
+ self.n_heads = args.n_heads
218
+ self.dim = args.dim
219
+ self.head_dim = args.dim // args.n_heads
220
+
221
+ self.attention = SelfAttention(args)
222
+ self.feed_forward = FeedForward(args)
223
+
224
+ # Normalization BEFORE the attention block
225
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
226
+ # Normalization BEFORE the feed forward block
227
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
228
+
229
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
230
+ # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
231
+ h = x + self.attention.forward(
232
+ self.attention_norm(x), start_pos, freqs_complex
233
+ )
234
+ # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim)
235
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
236
+ return out
237
+
238
+ class Transformer(nn.Module):
239
+
240
+ def __init__(self, args: ModelArgs):
241
+ super().__init__()
242
+
243
+ assert args.vocab_size != -1, "Vocab size must be set"
244
+
245
+ self.args = args
246
+ self.vocab_size = args.vocab_size
247
+ self.n_layers = args.n_layers
248
+ self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
249
+
250
+ self.layers = nn.ModuleList()
251
+ for layer_id in range(args.n_layers):
252
+ self.layers.append(EncoderBlock(args))
253
+
254
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
255
+ self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
256
+
257
+ self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)
258
+
259
+ def forward(self, tokens: torch.Tensor, start_pos: int):
260
+ # (B, Seq_Len)
261
+ batch_size, seq_len = tokens.shape
262
+ assert seq_len == 1, "Only one token at a time can be processed"
263
+
264
+ # (B, Seq_Len) -> (B, Seq_Len, Dim)
265
+ h = self.tok_embeddings(tokens)
266
+
267
+ # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
268
+ freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
269
+
270
+ # Consecutively apply all the encoder layers
271
+ for layer in self.layers:
272
+ h = layer(h, start_pos, freqs_complex)
273
+ h = self.norm(h)
274
+ output = self.output(h).float()
275
+ return output