fraserlove
commited on
Commit
•
86d4301
1
Parent(s):
1def9d3
Delete gpt.py
Browse files
gpt.py
DELETED
@@ -1,183 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Full implementation of a Generative Pre-trained Transformer (GPT) model.
|
3 |
-
|
4 |
-
References
|
5 |
-
1) GPT-2 Paper:
|
6 |
-
https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
|
7 |
-
2) GPT-3 Paper:
|
8 |
-
https://arxiv.org/abs/2005.14165
|
9 |
-
"""
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torch.nn as nn
|
13 |
-
import torch.nn.functional as F
|
14 |
-
from dataclasses import dataclass
|
15 |
-
from huggingface_hub import PyTorchModelHubMixin
|
16 |
-
|
17 |
-
@dataclass
|
18 |
-
class GPTConfig:
|
19 |
-
block_size: int = 1024 # Maximum context length
|
20 |
-
vocab_size: int = 50257 # Number of unique tokens
|
21 |
-
n_layer: int = 12 # Number of transformer blocks
|
22 |
-
n_head: int = 12 # Number of self-attention heads
|
23 |
-
n_embd: int = 768 # Embedding dimensionality
|
24 |
-
|
25 |
-
class CausalSelfAttention(nn.Module):
|
26 |
-
"""Multi-head causal self-attention."""
|
27 |
-
|
28 |
-
def __init__(self, config: GPTConfig):
|
29 |
-
super().__init__()
|
30 |
-
assert config.n_embd % config.n_head == 0, 'Embedding dimensionality must be divisible by number of heads'
|
31 |
-
# Transformations for queries, keys, and values for all heads
|
32 |
-
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
33 |
-
# Output projection
|
34 |
-
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
35 |
-
self.n_head = config.n_head
|
36 |
-
self.n_embd = config.n_embd
|
37 |
-
# Autoregressive mask - not needed due as using PyTorch's flash-attention implementation
|
38 |
-
# self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size))
|
39 |
-
# .view(1, 1, config.block_size, config.block_size))
|
40 |
-
|
41 |
-
|
42 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
43 |
-
B, T, C = x.shape # batch_size, block_size, n_embd
|
44 |
-
# Calculate queries, keys, and values for all heads in a single pass
|
45 |
-
# H is the number of heads and C/H is the head size, C = H * C/H
|
46 |
-
qkv = self.c_attn(x)
|
47 |
-
q, k, v = qkv.split(self.n_embd, dim=2)
|
48 |
-
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, H, T, C/H)
|
49 |
-
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, H, T, C/H)
|
50 |
-
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, H, T, C/H)
|
51 |
-
# Compute attention scores ('affinities')
|
52 |
-
# W = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5) # (B, H, T, C/H) @ (B, H, C/H, T) -> (B, H, T, T)
|
53 |
-
# W = W.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) # Autoregressive mask
|
54 |
-
# W = F.softmax(W, dim=-1)
|
55 |
-
# Perform the attention-weighted sum
|
56 |
-
# y = W @ v # (B, H, T, T) @ (B, H, T, C/H) -> (B, H, T, C/H)
|
57 |
-
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Flash-attention - https://arxiv.org/abs/2205.14135
|
58 |
-
y = y.transpose(1, 2).contiguous().view(B, T, C) # Re-assemble all head outputs side by side
|
59 |
-
y = self.c_proj(y)
|
60 |
-
return y
|
61 |
-
|
62 |
-
class MLP(nn.Module):
|
63 |
-
"""Single non-linear feed-forward layer."""
|
64 |
-
|
65 |
-
def __init__(self, config: GPTConfig):
|
66 |
-
super().__init__()
|
67 |
-
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
68 |
-
self.gelu = nn.GELU(approximate='tanh')
|
69 |
-
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
70 |
-
|
71 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
72 |
-
x = self.c_fc(x)
|
73 |
-
x = self.gelu(x)
|
74 |
-
x = self.c_proj(x)
|
75 |
-
return x
|
76 |
-
|
77 |
-
class Block(nn.Module):
|
78 |
-
"""Transformer block with a causal self-attention layer and a feed-forward layer."""
|
79 |
-
|
80 |
-
def __init__(self, config: GPTConfig):
|
81 |
-
super().__init__()
|
82 |
-
self.ln_1 = nn.LayerNorm(config.n_embd)
|
83 |
-
self.attn = CausalSelfAttention(config)
|
84 |
-
self.ln_2 = nn.LayerNorm(config.n_embd)
|
85 |
-
self.mlp = MLP(config)
|
86 |
-
|
87 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
88 |
-
x = x + self.attn(self.ln_1(x))
|
89 |
-
x = x + self.mlp(self.ln_2(x))
|
90 |
-
return x
|
91 |
-
|
92 |
-
class GPT(nn.Module, PyTorchModelHubMixin):
|
93 |
-
"""A GPT model."""
|
94 |
-
|
95 |
-
def __init__(self, config: GPTConfig):
|
96 |
-
super().__init__()
|
97 |
-
self.config = config
|
98 |
-
|
99 |
-
self.transformer = nn.ModuleDict(dict(
|
100 |
-
wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
|
101 |
-
wpe = nn.Embedding(config.block_size, config.n_embd), # Positional embeddings
|
102 |
-
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # Transformer blocks
|
103 |
-
ln_f = nn.LayerNorm(config.n_embd), # Final layer norm
|
104 |
-
))
|
105 |
-
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
106 |
-
|
107 |
-
# Weight sharing between embedding and output layers - https://arxiv.org/abs/1608.05859
|
108 |
-
self.transformer.wte.weight = self.lm_head.weight
|
109 |
-
|
110 |
-
# Initialise weights as per GPT-2
|
111 |
-
self.apply(self._init_weights)
|
112 |
-
|
113 |
-
def _init_weights(self, module):
|
114 |
-
if isinstance(module, nn.Linear):
|
115 |
-
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
116 |
-
if module.bias is not None:
|
117 |
-
nn.init.zeros_(module.bias)
|
118 |
-
elif isinstance(module, nn.Embedding):
|
119 |
-
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
120 |
-
# Scale init of residual layers as std grows with depth in residual streams
|
121 |
-
for name, param in self.named_parameters():
|
122 |
-
if name.endswith('c_proj.weight'):
|
123 |
-
nn.init.normal_(param, mean=0.0, std=0.02 * (2 * self.config.n_layer) ** -0.5)
|
124 |
-
|
125 |
-
def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
126 |
-
B, T = x.shape # batch_size, block_size
|
127 |
-
assert T <= self.config.block_size, f'Sequence of length {T} exceeds block size {self.config.block_size}'
|
128 |
-
pos = torch.arange(T, dtype=torch.long, device=x.device)
|
129 |
-
pos_embd = self.transformer.wpe(pos) # (T) -> (T, C)
|
130 |
-
tok_embd = self.transformer.wte(x) # (B, T) -> (B, T, C)
|
131 |
-
z = tok_embd + pos_embd
|
132 |
-
for block in self.transformer.h:
|
133 |
-
z = block(z)
|
134 |
-
z = self.transformer.ln_f(z)
|
135 |
-
logits = self.lm_head(z) # (B, T, C) -> (B, T, V) where V is vocab_size
|
136 |
-
loss = None
|
137 |
-
if y is not None:
|
138 |
-
# Flatten batch and sequence dimensions to (B*T, C) and (B*T) respectively, for cross-entropy loss
|
139 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
140 |
-
return logits, loss
|
141 |
-
|
142 |
-
def configure_optimisers(self, weight_decay: float, lr: float) -> torch.optim.Optimizer:
|
143 |
-
"""Configure AdamW optimiser with weight decay and learning rate."""
|
144 |
-
params = {name: param for name, param in self.named_parameters() if param.requires_grad}
|
145 |
-
# Any parameter that is at least 2D has weight decay applied - i.e. all weight tensors
|
146 |
-
# in matmuls + embeddings decay, all bias tensors don't.
|
147 |
-
decay_params = [param for _, param in params.items() if param.dim() >= 2]
|
148 |
-
no_decay_params = [param for _, param in params.items() if param.dim() < 2]
|
149 |
-
optim_groups = [
|
150 |
-
{'params': decay_params, 'weight_decay': weight_decay},
|
151 |
-
{'params': no_decay_params, 'weight_decay': 0.0}
|
152 |
-
]
|
153 |
-
# Use fused optimiser for faster training on GPU
|
154 |
-
optimiser = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
|
155 |
-
return optimiser
|
156 |
-
|
157 |
-
@torch.no_grad()
|
158 |
-
def generate(self, x: torch.Tensor, max_tokens: int = 64, n_samples: int = 1, temp: float = 1.0, top_k: int = 50, seed: int = None) -> torch.Tensor:
|
159 |
-
"""Generate sequences of tokens given an initial context."""
|
160 |
-
rng = torch.Generator(device=x.device)
|
161 |
-
if seed is not None:
|
162 |
-
rng.manual_seed(seed)
|
163 |
-
# Repeat the input context for each sample
|
164 |
-
x = x.unsqueeze(0).repeat(n_samples, 1)
|
165 |
-
"""Generate a sequence of tokens given an initial context."""
|
166 |
-
for _ in range(max_tokens):
|
167 |
-
# Crop the sequence context to the last block_size tokens
|
168 |
-
x = x[:, -self.config.block_size:]
|
169 |
-
# Forward pass
|
170 |
-
logits, _ = self(x)
|
171 |
-
# Scale the logits by the temperature and keep only the last token prediction
|
172 |
-
logits = logits[:, -1, :] / temp
|
173 |
-
# Softmax for probabilities
|
174 |
-
probs = F.softmax(logits, dim=1)
|
175 |
-
# Top-k sampling
|
176 |
-
topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
|
177 |
-
# Sample from the top-k probabilities
|
178 |
-
ix = torch.multinomial(topk_probs, 1, generator=rng)
|
179 |
-
# Gather sampled token indices
|
180 |
-
x_next = torch.gather(topk_indices, -1, ix)
|
181 |
-
# Concatenate sampled token to the sequence
|
182 |
-
x = torch.cat((x, x_next), dim=1)
|
183 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|