File size: 5,152 Bytes
4e1467d
 
 
 
 
d97c361
4e1467d
0b6a10a
4e1467d
 
c13ef0b
4e1467d
c13ef0b
 
0b6a10a
4e1467d
 
0b6a10a
 
405f5b1
 
c13ef0b
4e1467d
 
d97c361
0b6a10a
 
405f5b1
c13ef0b
 
 
 
 
 
4e1467d
405f5b1
c13ef0b
405f5b1
c13ef0b
405f5b1
 
 
4e1467d
405f5b1
0b6a10a
 
4e1467d
 
c13ef0b
0b6a10a
c13ef0b
405f5b1
c13ef0b
405f5b1
c13ef0b
405f5b1
 
2896dec
4e1467d
 
405f5b1
 
 
 
4e1467d
 
0b6a10a
 
4e1467d
0b6a10a
 
d97c361
 
 
0b6a10a
 
4e1467d
0b6a10a
 
4e1467d
0b6a10a
 
 
 
 
 
d97c361
 
 
 
 
0b6a10a
 
 
 
 
 
 
 
 
 
d97c361
 
0b6a10a
d97c361
0b6a10a
 
 
 
 
 
d97c361
 
 
 
 
 
 
0b6a10a
 
 
 
 
 
 
 
 
 
 
405f5b1
0b6a10a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch as t
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import wandb
from fancy_einsum import einsum
from einops import rearrange, repeat, reduce
from utils import OsSoluConfig



class OsSoluModel(nn.Module):
    """An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
    where the nonlinearity in the MLP block is replaced with SoLU(x) = x * softmax(x)."""
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__()
        self.config = config
        self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
        self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)

    def forward(self, x: t.Tensor) -> t.Tensor:
        positional_embeddings = self.embed_positions(t.arange(x.size(1), device=x.device))
        token_embeddings = self.embed_tokens(x)
        embeddings = positional_embeddings + token_embeddings
        out = self.dropout(embeddings)
        for block in self.transformer_blocks:
            out = block(out)

        # Unembedding is not separate, so we just einsum with token embedding weights.
        out = einsum("vocab hidden, batch seq hidden -> batch seq vocab", self.embed_tokens.weight, out)
        return out

class SoLU(nn.Module):
    """A simple wrapper around the SoLU function such that it can be used as a layer in a model."""
    def __init__(self):
        super().__init__()

    def forward(self, x: t.Tensor) -> t.Tensor:
        return x * x.softmax(dim=-1)

class GPT2Block(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__() 
        self.config = config

        self.layer_norm1 = nn.LayerNorm(config.d_model, config.ln_eps)
        self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
        nonlinearity = SoLU() if config.nonlinearity == "solu" else nn.ReLU()
        self.MLP = nn.Sequential(
            nn.LayerNorm(config.d_model, config.ln_eps),
            nn.Linear(config.d_model, 4*config.d_model),
            nonlinearity,
            nn.Linear(4*config.d_model, config.d_model),
            nn.Dropout(config.dropout)
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        x = x + self.attention(self.layer_norm1(x))
        x = x + self.MLP(x)
        return x
        


class UnidirectionalAttention(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__()
        self.num_heads = config.num_heads
        self.d_model = config.d_model
        self.project_q = nn.Linear(config.d_model, config.d_model)
        self.project_k = nn.Linear(config.d_model, config.d_model)
        self.project_v = nn.Linear(config.d_model, config.d_model)
        self.project_out = nn.Linear(config.d_model, config.d_model)
        self.LARGE_NEGATIVE_VALUE = -1e5

    def hidden_to_heads(self, tensor: t.Tensor) -> t.Tensor:
        return rearrange(tensor, "b s (nh hs) -> b nh s hs", nh=self.num_heads)

    def compute_pre_softmax_attn_pattern(self, x: t.Tensor) -> t.Tensor:
        Q = self.project_q(x)
        K = self.project_k(x)

        Q = self.hidden_to_heads(Q)
        K = self.hidden_to_heads(K)
        attention_pattern = einsum(
            "batch num_heads seqlen_q head_size, "
            "batch num_heads seqlen_k head_size ->"
            "batch num_heads seqlen_q seqlen_k",
            Q, K)

        return attention_pattern

    def forward(self, x: t.Tensor) -> t.Tensor:
        batch, seqlen, hidden_size = x.shape
        attention_pattern = self.compute_pre_softmax_attn_pattern(x)
        V = self.project_v(x)
        
        # Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
        if seqlen > 1:
            fst_range = t.arange(seqlen, device=x.device).unsqueeze(0).T
            snd_range = t.arange(seqlen, device=x.device).unsqueeze(0)
            bool_array = fst_range < snd_range
            attention_pattern[..., bool_array] = self.LARGE_NEGATIVE_VALUE
        
        
        attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
        attention_score = attention_pattern.softmax(dim=-1)
        
        V = self.hidden_to_heads(V)
        out = einsum(
            "batch num_heads seqlen_q seqlen_k,"
            "batch num_heads seqlen_k head_size ->"
            "batch num_heads seqlen_q head_size", 
            attention_score, V)

        out = rearrange(out, "b nh s hs -> b s (nh hs)")
        out = self.project_out(out)
        

        return out

class RotaryAttention(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__()
        self.config = config
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        # TODO: implement rotary self-attention
        pass