inwaves commited on
Commit
0b6a10a
1 Parent(s): 2896dec

Implemented unidirectional attention, moving on

Browse files
Files changed (2) hide show
  1. model.py +67 -16
  2. utils.py +10 -8
model.py CHANGED
@@ -3,46 +3,97 @@ import torch.nn as nn
3
  import torch.functional as F
4
  import torch.optim as optim
5
  import wandb
6
- import fancy_einsum
7
  from einops import rearrange, repeat, reduce
 
8
 
9
 
10
  class OsSoluModel(nn.Module):
11
- def __init__(self, config) -> None:
12
  super().__init__()
 
13
  self.config = config
 
 
14
  self.transformer_block = TransformerBlock(config)
 
 
15
 
16
  def forward(self, x: t.Tensor) -> t.Tensor:
17
- pass
 
 
18
 
19
 
20
  class TransformerBlock(nn.Module):
21
- def __init__(self, config) -> None:
22
- super().__init__()
23
  self.config = config
24
 
25
- self.embed = nn.Embedding(config.num_embeddings, config.d_model)
26
  self.linear = nn.Sequential(
27
  nn.Linear(config.d_model, config.d_model),
28
  SoLU(),
29
  )
30
- self.layer_norm = nn.LayerNorm(normalized_shape)
31
  self.unembed = nn.Embedding(config.num_embeddings, config.d_model)
32
 
33
  def forward(self, x: t.Tensor) -> t.Tensor:
34
  pass
35
 
36
 
37
- class RotaryAttention(nn.Module):
38
- def __init__(self, config) -> None:
39
  super().__init__()
 
 
 
 
 
 
 
40
 
41
- def forward(self, x: t.Tensor, attention_mask: t.Tensor) -> t.Tensor:
42
- # Compute pre-softmax attention scores
43
- # Apply attention mask
44
- # Compute softmax
45
- # Apply final einsum
46
- # Return attention output
47
 
48
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch.functional as F
4
  import torch.optim as optim
5
  import wandb
6
+ import fancy_einsum as einsum
7
  from einops import rearrange, repeat, reduce
8
+ from utils import OsSoluConfig
9
 
10
 
11
  class OsSoluModel(nn.Module):
12
+ def __init__(self, config: OsSoluConfig) -> None:
13
  super().__init__()
14
+ normalised_shape = None # TODO: normalised_shape should be defined properly
15
  self.config = config
16
+ self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
17
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
18
  self.transformer_block = TransformerBlock(config)
19
+ self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
20
+ self.unembed = nn
21
 
22
  def forward(self, x: t.Tensor) -> t.Tensor:
23
+ positional_embeddings = self.embed_positions(t.arange(x.size(1)))
24
+ token_embeddings = self.embed_tokens(x)
25
+ embeddings = positional_embeddings + token_embeddings
26
 
27
 
28
  class TransformerBlock(nn.Module):
29
+ def __init__(self, config: OsSoluConfig) -> None:
30
+ super().__init__()
31
  self.config = config
32
 
33
+ self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
34
  self.linear = nn.Sequential(
35
  nn.Linear(config.d_model, config.d_model),
36
  SoLU(),
37
  )
38
+ self.layer_norm = nn.LayerNorm(normalized_shape, config.ln_eps)
39
  self.unembed = nn.Embedding(config.num_embeddings, config.d_model)
40
 
41
  def forward(self, x: t.Tensor) -> t.Tensor:
42
  pass
43
 
44
 
45
+ class UnidirectionalAttention(nn.Module):
46
+ def __init__(self, config: OsSoluConfig) -> None:
47
  super().__init__()
48
+ self.num_heads = config.num_heads
49
+ self.d_model = config.d_model
50
+ self.project_q = nn.Linear(config.num_embeddings, config.d_model)
51
+ self.project_k = nn.Linear(config.num_embeddings, config.d_model)
52
+ self.project_v = nn.Linear(config.num_embeddings, config.d_model)
53
+ self.project_out = nn.Linear(config.d_model, config.d_model)
54
+ self.LARGE_NEGATIVE_VALUE = -1e5
55
 
56
+ def hidden_to_heads(self, tensor: t.Tensor) -> t.Tensor:
57
+ return rearrange(tensor, "b s (nh hs) -> b nh s hs", nh=self.num_heads)
 
 
 
 
58
 
59
+ def compute_pre_softmax_attn_pattern(self, x: t.Tensor) -> t.Tensor:
60
+ Q = self.project_q(x)
61
+ K = self.project_k(x)
62
+
63
+ Q = self.hidden_to_heads(Q)
64
+ K = self.hidden_to_heads(K)
65
+ attention_pattern = einsum("batch num_heads seqlen_q head_size, batch num_heads seqlen_k head_size -> batch num_heads seqlen_q seqlen_k")
66
+
67
+ return attention_pattern
68
+
69
+ def forward(self, x: t.Tensor) -> t.Tensor:
70
+ batch, seqlen, hidden_size = x.shape
71
+ attention_pattern = self.compute_pre_softmax_attn_pattern(x)
72
+ V = self.project_v(x)
73
+
74
+ # Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
75
+ if seqlen > 1:
76
+ fst_range = t.arange(seqlen, device=self.device).unsqueeze(0).T
77
+ snd_range = t.arange(seqlen, device=self.device).unsqueeze(0)
78
+ bool_array = fst_range < snd_range
79
+ attention_score[..., bool_array] = self.LARGE_NEGATIVE_VALUE
80
+
81
+
82
+ attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
83
+ attention_score = attention_pattern.softmax(dim=-1)
84
+
85
+ V = self.hidden_to_heads(V)
86
+ out = einsum("batch num_heads seqlen_q seqlen_k, batch num_heads seqlen_k head_size -> batch num_heads seqlen_q head_size", attention_score, V)
87
+ out = rearrange("b nh s hs -> b s (nh hs)")
88
+ out = self.project_out(out)
89
+
90
+
91
+ return out
92
+
93
+ class RotaryAttention(nn.Module):
94
+ def __init__(self, config: OsSoluConfig) -> None:
95
+ super().__init__()
96
+ self.config = config
97
+
98
+ def forward(self, x: t.Tensor) -> t.Tensor:
99
+ pass
utils.py CHANGED
@@ -1,10 +1,12 @@
1
  @dataclass
2
  class OsSoluConfig:
3
- d_model: int = 512 # Hidden size of the model.
4
- vocab_size: int = 65536 # Vocabulary size of the input sequence. Unsure about this.
5
- learning_rate: float = 1e-3 # Learning rate for the optimiser.
6
- num_embeddings: int = 1024 # Number of embeddings. Unsure about this.
7
- num_blocks: int = 1 # Number of transformer blocks.
8
- dropout: float = 0.1 # Probability of dropout.
9
- ln_eps: float = 1e-3 # Layer norm epsilon.
10
- num_heads: int = 4 # Number of attention heads in each attention layer.
 
 
 
1
  @dataclass
2
  class OsSoluConfig:
3
+ d_model: int = 512 # Hidden size of the model.
4
+ vocab_size: int = 65536 # Vocabulary size of the input sequence. Unsure about this.
5
+ learning_rate: float = 1e-3 # Learning rate for the optimiser.
6
+ num_embeddings: int = 1024 # Number of embeddings. Unsure about this.
7
+ num_blocks: int = 1 # Number of transformer blocks.
8
+ dropout: float = 0.1 # Probability of dropout.
9
+ ln_eps: float = 1e-3 # Layer norm epsilon.
10
+ num_heads: int = 4 # Number of attention heads in each attention layer.
11
+ self_attention_type: str = "unidirectional" # What type of attention to use: rotary or unidirectional.
12
+ max_positional_embeddings: int = 1024 # Maximum number of positional embeddings.