ammarnasr commited on
Commit
65880ee
1 Parent(s): 41aaa4e

Upload model

Browse files
Files changed (1) hide show
  1. modeling_t5mimo.py +64 -74
modeling_t5mimo.py CHANGED
@@ -125,6 +125,69 @@ class T5LayerFF(nn.Module):
125
  return hidden_states
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  class T5Attention(nn.Module):
129
  def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
130
  super().__init__()
@@ -1265,7 +1328,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1265
  self.decoder = T5Stack(decoder_config, self.shared)
1266
 
1267
 
1268
- self.conv_block = MultivariateConvBlock(config.num_seqs, config.d_model, num_filters=config.num_filters)
1269
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1270
 
1271
  # Initialize weights and apply final processing
@@ -1676,76 +1739,3 @@ class T5MIMOEncoderModel(T5PreTrainedModel):
1676
 
1677
 
1678
 
1679
-
1680
- class MultivariateConvBlock(nn.Module):
1681
- def __init__(self, num_seqs, model_dim, kernel_size=3, num_filters=64, stride=1, padding=1):
1682
- """
1683
- Multivariate convolutional block to capture cross-sequence interactions and temporal patterns.
1684
-
1685
- Args:
1686
- num_seqs (int): Number of sequences (multivariate time series).
1687
- model_dim (int): Dimension of each feature vector (typically 256).
1688
- kernel_size (int): Size of the convolutional kernel. Default is 3.
1689
- num_filters (int): Number of convolutional filters (output channels). Default is 64.
1690
- stride (int): Stride of the convolutional kernel. Default is 1.
1691
- padding (int): Padding for the convolutional kernel. Default is 1 (to preserve sequence length).
1692
- """
1693
- super(MultivariateConvBlock, self).__init__()
1694
-
1695
-
1696
- # 2D Convolution across sequences and time
1697
- self.conv1 = nn.Conv2d(
1698
- in_channels=num_seqs,
1699
- out_channels=num_filters,
1700
- kernel_size=kernel_size, # Kernel spans across time and all features
1701
- stride=1, # Stride across time, no stride across features
1702
- padding=1 # Padding to preserve sequence length, no padding across features
1703
- )
1704
-
1705
- # Batch normalization for stabilization and faster convergence
1706
- self.bn1 = nn.BatchNorm2d(num_filters)
1707
-
1708
- # Second convolution layer to further model interactions and temporal patterns
1709
- self.conv2 = nn.Conv2d(
1710
- in_channels=num_filters,
1711
- out_channels=num_filters,
1712
- kernel_size=(kernel_size, 1), # Focus only on temporal patterns
1713
- stride=(stride, 1),
1714
- padding=(padding, 0)
1715
- )
1716
-
1717
- # Batch normalization after second convolution
1718
- self.bn2 = nn.BatchNorm2d(num_filters)
1719
-
1720
- # 1x1 Convolution to reduce the channel dimension back to num_seqs
1721
- self.conv3 = nn.Conv2d(
1722
- in_channels=num_filters,
1723
- out_channels=num_seqs, # Back to the original number of sequences (channels)
1724
- kernel_size=(1, 1)
1725
- )
1726
-
1727
- def forward(self, x):
1728
- """
1729
- Forward pass of the multivariate convolutional block.
1730
-
1731
- Args:
1732
- x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
1733
-
1734
- Returns:
1735
- torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
1736
- """
1737
- # Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
1738
- x = x.permute(0, 1, 3, 2)
1739
-
1740
- # Apply first convolution and activation
1741
- x = nn.functional.relu(self.bn1(self.conv1(x)))
1742
- # Apply second convolution and activation
1743
- x = nn.functional.relu(self.bn2(self.conv2(x)))
1744
-
1745
- # Reduce channel dimension back to num_seqs
1746
- x = self.conv3(x)
1747
-
1748
- # Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
1749
- x = x.permute(0, 1, 3, 2)
1750
-
1751
- return x
 
125
  return hidden_states
126
 
127
 
128
+
129
+ class MultivariateConvBlock(nn.Module):
130
+ def __init__(self, config: T5MIMOConfig, kernel_size=3, stride=1, padding=1):
131
+ super().__init__()
132
+ # 2D Convolution across sequences and time
133
+ self.conv1 = nn.Conv2d(
134
+ in_channels=config.num_seqs,
135
+ out_channels=config.num_filters,
136
+ kernel_size=kernel_size, # Kernel spans across time and all features
137
+ stride=1, # Stride across time, no stride across features
138
+ padding=1 # Padding to preserve sequence length, no padding across features
139
+ )
140
+
141
+ # Batch normalization for stabilization and faster convergence
142
+ self.bn1 = nn.BatchNorm2d(config.num_filters)
143
+
144
+ # Second convolution layer to further model interactions and temporal patterns
145
+ self.conv2 = nn.Conv2d(
146
+ in_channels=config.num_filters,
147
+ out_channels=config.num_filters,
148
+ kernel_size=(kernel_size, 1), # Focus only on temporal patterns
149
+ stride=(stride, 1),
150
+ padding=(padding, 0)
151
+ )
152
+
153
+ # Batch normalization after second convolution
154
+ self.bn2 = nn.BatchNorm2d(config.num_filters)
155
+
156
+ # 1x1 Convolution to reduce the channel dimension back to num_seqs
157
+ self.conv3 = nn.Conv2d(
158
+ in_channels=config.num_filters,
159
+ out_channels=config.num_seqs, # Back to the original number of sequences (channels)
160
+ kernel_size=(1, 1)
161
+ )
162
+
163
+ def forward(self, x):
164
+ """
165
+ Forward pass of the multivariate convolutional block.
166
+
167
+ Args:
168
+ x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
169
+
170
+ Returns:
171
+ torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
172
+ """
173
+ # Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
174
+ x = x.permute(0, 1, 3, 2)
175
+
176
+ # Apply first convolution and activation
177
+ x = nn.functional.relu(self.bn1(self.conv1(x)))
178
+ # Apply second convolution and activation
179
+ x = nn.functional.relu(self.bn2(self.conv2(x)))
180
+
181
+ # Reduce channel dimension back to num_seqs
182
+ x = self.conv3(x)
183
+
184
+ # Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
185
+ x = x.permute(0, 1, 3, 2)
186
+
187
+ return x
188
+
189
+
190
+
191
  class T5Attention(nn.Module):
192
  def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
193
  super().__init__()
 
1328
  self.decoder = T5Stack(decoder_config, self.shared)
1329
 
1330
 
1331
+ self.conv_block = MultivariateConvBlock(config)
1332
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1333
 
1334
  # Initialize weights and apply final processing
 
1739
 
1740
 
1741