Set more precise shape to the attention weights and outputs
#1
by
ivanzhouyq
- opened
modeling_backpack_gpt2.py
CHANGED
@@ -101,13 +101,14 @@ class BackpackWeightNetwork(nn.Module):
|
|
101 |
super().__init__()
|
102 |
self.n_embd = embed_dim
|
103 |
self.num_senses = num_senses
|
104 |
-
self.
|
|
|
105 |
self.softmax_scale = None
|
106 |
|
107 |
def forward(self, encoded):
|
108 |
b, s, d = encoded.shape
|
109 |
encoded = self.c_attn(encoded) # (b, s, 2*d)
|
110 |
-
encoded = encoded.reshape(b, s, 2, self.num_senses,
|
111 |
batch_size, seqlen = encoded.shape[0], encoded.shape[1]
|
112 |
|
113 |
# compute scores & mask
|
|
|
101 |
super().__init__()
|
102 |
self.n_embd = embed_dim
|
103 |
self.num_senses = num_senses
|
104 |
+
self.embed_per_sense = embed_dim // num_senses
|
105 |
+
self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense)
|
106 |
self.softmax_scale = None
|
107 |
|
108 |
def forward(self, encoded):
|
109 |
b, s, d = encoded.shape
|
110 |
encoded = self.c_attn(encoded) # (b, s, 2*d)
|
111 |
+
encoded = encoded.reshape(b, s, 2, self.num_senses, self.embed_per_sense) #(b, s, 2, nv, d//nv)
|
112 |
batch_size, seqlen = encoded.shape[0], encoded.shape[1]
|
113 |
|
114 |
# compute scores & mask
|