File size: 5,152 Bytes
4e1467d d97c361 4e1467d 0b6a10a 4e1467d c13ef0b 4e1467d c13ef0b 0b6a10a 4e1467d 0b6a10a 405f5b1 c13ef0b 4e1467d d97c361 0b6a10a 405f5b1 c13ef0b 4e1467d 405f5b1 c13ef0b 405f5b1 c13ef0b 405f5b1 4e1467d 405f5b1 0b6a10a 4e1467d c13ef0b 0b6a10a c13ef0b 405f5b1 c13ef0b 405f5b1 c13ef0b 405f5b1 2896dec 4e1467d 405f5b1 4e1467d 0b6a10a 4e1467d 0b6a10a d97c361 0b6a10a 4e1467d 0b6a10a 4e1467d 0b6a10a d97c361 0b6a10a d97c361 0b6a10a d97c361 0b6a10a d97c361 0b6a10a 405f5b1 0b6a10a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch as t
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import wandb
from fancy_einsum import einsum
from einops import rearrange, repeat, reduce
from utils import OsSoluConfig
class OsSoluModel(nn.Module):
"""An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
where the nonlinearity in the MLP block is replaced with SoLU(x) = x * softmax(x)."""
def __init__(self, config: OsSoluConfig) -> None:
super().__init__()
self.config = config
self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.dropout = nn.Dropout(config.dropout)
self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
def forward(self, x: t.Tensor) -> t.Tensor:
positional_embeddings = self.embed_positions(t.arange(x.size(1), device=x.device))
token_embeddings = self.embed_tokens(x)
embeddings = positional_embeddings + token_embeddings
out = self.dropout(embeddings)
for block in self.transformer_blocks:
out = block(out)
# Unembedding is not separate, so we just einsum with token embedding weights.
out = einsum("vocab hidden, batch seq hidden -> batch seq vocab", self.embed_tokens.weight, out)
return out
class SoLU(nn.Module):
"""A simple wrapper around the SoLU function such that it can be used as a layer in a model."""
def __init__(self):
super().__init__()
def forward(self, x: t.Tensor) -> t.Tensor:
return x * x.softmax(dim=-1)
class GPT2Block(nn.Module):
def __init__(self, config: OsSoluConfig) -> None:
super().__init__()
self.config = config
self.layer_norm1 = nn.LayerNorm(config.d_model, config.ln_eps)
self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
nonlinearity = SoLU() if config.nonlinearity == "solu" else nn.ReLU()
self.MLP = nn.Sequential(
nn.LayerNorm(config.d_model, config.ln_eps),
nn.Linear(config.d_model, 4*config.d_model),
nonlinearity,
nn.Linear(4*config.d_model, config.d_model),
nn.Dropout(config.dropout)
)
def forward(self, x: t.Tensor) -> t.Tensor:
x = x + self.attention(self.layer_norm1(x))
x = x + self.MLP(x)
return x
class UnidirectionalAttention(nn.Module):
def __init__(self, config: OsSoluConfig) -> None:
super().__init__()
self.num_heads = config.num_heads
self.d_model = config.d_model
self.project_q = nn.Linear(config.d_model, config.d_model)
self.project_k = nn.Linear(config.d_model, config.d_model)
self.project_v = nn.Linear(config.d_model, config.d_model)
self.project_out = nn.Linear(config.d_model, config.d_model)
self.LARGE_NEGATIVE_VALUE = -1e5
def hidden_to_heads(self, tensor: t.Tensor) -> t.Tensor:
return rearrange(tensor, "b s (nh hs) -> b nh s hs", nh=self.num_heads)
def compute_pre_softmax_attn_pattern(self, x: t.Tensor) -> t.Tensor:
Q = self.project_q(x)
K = self.project_k(x)
Q = self.hidden_to_heads(Q)
K = self.hidden_to_heads(K)
attention_pattern = einsum(
"batch num_heads seqlen_q head_size, "
"batch num_heads seqlen_k head_size ->"
"batch num_heads seqlen_q seqlen_k",
Q, K)
return attention_pattern
def forward(self, x: t.Tensor) -> t.Tensor:
batch, seqlen, hidden_size = x.shape
attention_pattern = self.compute_pre_softmax_attn_pattern(x)
V = self.project_v(x)
# Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
if seqlen > 1:
fst_range = t.arange(seqlen, device=x.device).unsqueeze(0).T
snd_range = t.arange(seqlen, device=x.device).unsqueeze(0)
bool_array = fst_range < snd_range
attention_pattern[..., bool_array] = self.LARGE_NEGATIVE_VALUE
attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
attention_score = attention_pattern.softmax(dim=-1)
V = self.hidden_to_heads(V)
out = einsum(
"batch num_heads seqlen_q seqlen_k,"
"batch num_heads seqlen_k head_size ->"
"batch num_heads seqlen_q head_size",
attention_score, V)
out = rearrange(out, "b nh s hs -> b s (nh hs)")
out = self.project_out(out)
return out
class RotaryAttention(nn.Module):
def __init__(self, config: OsSoluConfig) -> None:
super().__init__()
self.config = config
def forward(self, x: t.Tensor) -> t.Tensor:
# TODO: implement rotary self-attention
pass |