fraserlove commited on
Commit
86d4301
1 Parent(s): 1def9d3

Delete gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +0 -183
gpt.py DELETED
@@ -1,183 +0,0 @@
1
- """
2
- Full implementation of a Generative Pre-trained Transformer (GPT) model.
3
-
4
- References
5
- 1) GPT-2 Paper:
6
- https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
7
- 2) GPT-3 Paper:
8
- https://arxiv.org/abs/2005.14165
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from dataclasses import dataclass
15
- from huggingface_hub import PyTorchModelHubMixin
16
-
17
- @dataclass
18
- class GPTConfig:
19
- block_size: int = 1024 # Maximum context length
20
- vocab_size: int = 50257 # Number of unique tokens
21
- n_layer: int = 12 # Number of transformer blocks
22
- n_head: int = 12 # Number of self-attention heads
23
- n_embd: int = 768 # Embedding dimensionality
24
-
25
- class CausalSelfAttention(nn.Module):
26
- """Multi-head causal self-attention."""
27
-
28
- def __init__(self, config: GPTConfig):
29
- super().__init__()
30
- assert config.n_embd % config.n_head == 0, 'Embedding dimensionality must be divisible by number of heads'
31
- # Transformations for queries, keys, and values for all heads
32
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
33
- # Output projection
34
- self.c_proj = nn.Linear(config.n_embd, config.n_embd)
35
- self.n_head = config.n_head
36
- self.n_embd = config.n_embd
37
- # Autoregressive mask - not needed due as using PyTorch's flash-attention implementation
38
- # self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size))
39
- # .view(1, 1, config.block_size, config.block_size))
40
-
41
-
42
- def forward(self, x: torch.Tensor) -> torch.Tensor:
43
- B, T, C = x.shape # batch_size, block_size, n_embd
44
- # Calculate queries, keys, and values for all heads in a single pass
45
- # H is the number of heads and C/H is the head size, C = H * C/H
46
- qkv = self.c_attn(x)
47
- q, k, v = qkv.split(self.n_embd, dim=2)
48
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, H, T, C/H)
49
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, H, T, C/H)
50
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, H, T, C/H)
51
- # Compute attention scores ('affinities')
52
- # W = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5) # (B, H, T, C/H) @ (B, H, C/H, T) -> (B, H, T, T)
53
- # W = W.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) # Autoregressive mask
54
- # W = F.softmax(W, dim=-1)
55
- # Perform the attention-weighted sum
56
- # y = W @ v # (B, H, T, T) @ (B, H, T, C/H) -> (B, H, T, C/H)
57
- y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Flash-attention - https://arxiv.org/abs/2205.14135
58
- y = y.transpose(1, 2).contiguous().view(B, T, C) # Re-assemble all head outputs side by side
59
- y = self.c_proj(y)
60
- return y
61
-
62
- class MLP(nn.Module):
63
- """Single non-linear feed-forward layer."""
64
-
65
- def __init__(self, config: GPTConfig):
66
- super().__init__()
67
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
68
- self.gelu = nn.GELU(approximate='tanh')
69
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
70
-
71
- def forward(self, x: torch.Tensor) -> torch.Tensor:
72
- x = self.c_fc(x)
73
- x = self.gelu(x)
74
- x = self.c_proj(x)
75
- return x
76
-
77
- class Block(nn.Module):
78
- """Transformer block with a causal self-attention layer and a feed-forward layer."""
79
-
80
- def __init__(self, config: GPTConfig):
81
- super().__init__()
82
- self.ln_1 = nn.LayerNorm(config.n_embd)
83
- self.attn = CausalSelfAttention(config)
84
- self.ln_2 = nn.LayerNorm(config.n_embd)
85
- self.mlp = MLP(config)
86
-
87
- def forward(self, x: torch.Tensor) -> torch.Tensor:
88
- x = x + self.attn(self.ln_1(x))
89
- x = x + self.mlp(self.ln_2(x))
90
- return x
91
-
92
- class GPT(nn.Module, PyTorchModelHubMixin):
93
- """A GPT model."""
94
-
95
- def __init__(self, config: GPTConfig):
96
- super().__init__()
97
- self.config = config
98
-
99
- self.transformer = nn.ModuleDict(dict(
100
- wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
101
- wpe = nn.Embedding(config.block_size, config.n_embd), # Positional embeddings
102
- h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # Transformer blocks
103
- ln_f = nn.LayerNorm(config.n_embd), # Final layer norm
104
- ))
105
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
106
-
107
- # Weight sharing between embedding and output layers - https://arxiv.org/abs/1608.05859
108
- self.transformer.wte.weight = self.lm_head.weight
109
-
110
- # Initialise weights as per GPT-2
111
- self.apply(self._init_weights)
112
-
113
- def _init_weights(self, module):
114
- if isinstance(module, nn.Linear):
115
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
116
- if module.bias is not None:
117
- nn.init.zeros_(module.bias)
118
- elif isinstance(module, nn.Embedding):
119
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
120
- # Scale init of residual layers as std grows with depth in residual streams
121
- for name, param in self.named_parameters():
122
- if name.endswith('c_proj.weight'):
123
- nn.init.normal_(param, mean=0.0, std=0.02 * (2 * self.config.n_layer) ** -0.5)
124
-
125
- def forward(self, x: torch.Tensor, y: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
126
- B, T = x.shape # batch_size, block_size
127
- assert T <= self.config.block_size, f'Sequence of length {T} exceeds block size {self.config.block_size}'
128
- pos = torch.arange(T, dtype=torch.long, device=x.device)
129
- pos_embd = self.transformer.wpe(pos) # (T) -> (T, C)
130
- tok_embd = self.transformer.wte(x) # (B, T) -> (B, T, C)
131
- z = tok_embd + pos_embd
132
- for block in self.transformer.h:
133
- z = block(z)
134
- z = self.transformer.ln_f(z)
135
- logits = self.lm_head(z) # (B, T, C) -> (B, T, V) where V is vocab_size
136
- loss = None
137
- if y is not None:
138
- # Flatten batch and sequence dimensions to (B*T, C) and (B*T) respectively, for cross-entropy loss
139
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
140
- return logits, loss
141
-
142
- def configure_optimisers(self, weight_decay: float, lr: float) -> torch.optim.Optimizer:
143
- """Configure AdamW optimiser with weight decay and learning rate."""
144
- params = {name: param for name, param in self.named_parameters() if param.requires_grad}
145
- # Any parameter that is at least 2D has weight decay applied - i.e. all weight tensors
146
- # in matmuls + embeddings decay, all bias tensors don't.
147
- decay_params = [param for _, param in params.items() if param.dim() >= 2]
148
- no_decay_params = [param for _, param in params.items() if param.dim() < 2]
149
- optim_groups = [
150
- {'params': decay_params, 'weight_decay': weight_decay},
151
- {'params': no_decay_params, 'weight_decay': 0.0}
152
- ]
153
- # Use fused optimiser for faster training on GPU
154
- optimiser = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
155
- return optimiser
156
-
157
- @torch.no_grad()
158
- def generate(self, x: torch.Tensor, max_tokens: int = 64, n_samples: int = 1, temp: float = 1.0, top_k: int = 50, seed: int = None) -> torch.Tensor:
159
- """Generate sequences of tokens given an initial context."""
160
- rng = torch.Generator(device=x.device)
161
- if seed is not None:
162
- rng.manual_seed(seed)
163
- # Repeat the input context for each sample
164
- x = x.unsqueeze(0).repeat(n_samples, 1)
165
- """Generate a sequence of tokens given an initial context."""
166
- for _ in range(max_tokens):
167
- # Crop the sequence context to the last block_size tokens
168
- x = x[:, -self.config.block_size:]
169
- # Forward pass
170
- logits, _ = self(x)
171
- # Scale the logits by the temperature and keep only the last token prediction
172
- logits = logits[:, -1, :] / temp
173
- # Softmax for probabilities
174
- probs = F.softmax(logits, dim=1)
175
- # Top-k sampling
176
- topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
177
- # Sample from the top-k probabilities
178
- ix = torch.multinomial(topk_probs, 1, generator=rng)
179
- # Gather sampled token indices
180
- x_next = torch.gather(topk_indices, -1, ix)
181
- # Concatenate sampled token to the sequence
182
- x = torch.cat((x, x_next), dim=1)
183
- return x