File size: 4,528 Bytes
4e1467d
 
 
 
 
0b6a10a
4e1467d
0b6a10a
4e1467d
 
 
0b6a10a
4e1467d
0b6a10a
4e1467d
0b6a10a
 
405f5b1
 
0b6a10a
 
4e1467d
 
0b6a10a
 
 
405f5b1
 
4e1467d
405f5b1
 
 
 
 
 
4e1467d
405f5b1
0b6a10a
 
4e1467d
 
405f5b1
0b6a10a
405f5b1
 
 
2896dec
405f5b1
 
2896dec
4e1467d
 
405f5b1
 
 
 
4e1467d
 
0b6a10a
 
4e1467d
0b6a10a
 
 
 
 
 
 
4e1467d
0b6a10a
 
4e1467d
0b6a10a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405f5b1
0b6a10a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch as t
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import wandb
import fancy_einsum as einsum
from einops import rearrange, repeat, reduce
from utils import OsSoluConfig


class OsSoluModel(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__()
        normalised_shape = None             # TODO: normalised_shape should be defined properly
        self.config = config
        self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
        self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
        self.unembed = nn

    def forward(self, x: t.Tensor) -> t.Tensor:
        positional_embeddings = self.embed_positions(t.arange(x.size(1)))
        token_embeddings = self.embed_tokens(x)
        embeddings = positional_embeddings + token_embeddings
        out = self.dropout(embeddings)
        out = self.transformer_blocks(out)

class SoLU(nn.Module):
    def __init__(self):
        pass

    def forward(self, x: t.Tensor) -> t.Tensor:
        return x * x.softmax(dim=-1)

class GPT2Block(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__() 
        self.config = config

        self.layer_norm1 = nn.LayerNorm(normalized_shape, config.ln_eps)
        self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
        self.MLP = nn.Sequential(
            nn.LayerNorm(normalized_shape, config.ln_eps),
            nn.Linear(config.d_model, 4*config.d_model),
            SoLU(),
            nn.Linear(4*config.d_model, config.d_model),
            nn.Dropout(config.dropout)
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        x = x + self.attention(self.layer_norm1(x))
        x = x + self.MLP(x)
        return x
        


class UnidirectionalAttention(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__()
        self.num_heads = config.num_heads
        self.d_model = config.d_model
        self.project_q = nn.Linear(config.num_embeddings, config.d_model)
        self.project_k = nn.Linear(config.num_embeddings, config.d_model)
        self.project_v = nn.Linear(config.num_embeddings, config.d_model)
        self.project_out = nn.Linear(config.d_model, config.d_model)
        self.LARGE_NEGATIVE_VALUE = -1e5

    def hidden_to_heads(self, tensor: t.Tensor) -> t.Tensor:
        return rearrange(tensor, "b s (nh hs) -> b nh s hs", nh=self.num_heads)

    def compute_pre_softmax_attn_pattern(self, x: t.Tensor) -> t.Tensor:
        Q = self.project_q(x)
        K = self.project_k(x)

        Q = self.hidden_to_heads(Q)
        K = self.hidden_to_heads(K)
        attention_pattern = einsum("batch num_heads seqlen_q head_size, batch num_heads seqlen_k head_size -> batch num_heads seqlen_q seqlen_k")

        return attention_pattern

    def forward(self, x: t.Tensor) -> t.Tensor:
        batch, seqlen, hidden_size = x.shape
        attention_pattern = self.compute_pre_softmax_attn_pattern(x)
        V = self.project_v(x)
        
        # Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
        if seqlen > 1:
            fst_range = t.arange(seqlen, device=self.device).unsqueeze(0).T
            snd_range = t.arange(seqlen, device=self.device).unsqueeze(0)
            bool_array = fst_range < snd_range
            attention_score[..., bool_array] = self.LARGE_NEGATIVE_VALUE
        
        
        attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
        attention_score = attention_pattern.softmax(dim=-1)
        
        V = self.hidden_to_heads(V)
        out = einsum("batch num_heads seqlen_q seqlen_k, batch num_heads seqlen_k head_size -> batch num_heads seqlen_q head_size", attention_score, V)
        out = rearrange("b nh s hs -> b s (nh hs)")
        out = self.project_out(out)
        

        return out

class RotaryAttention(nn.Module):
    def __init__(self, config: OsSoluConfig) -> None:
        super().__init__()
        self.config = config
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        # TODO: implement rotary self-attention
        pass