yizhilll commited on
Commit
d73369c
1 Parent(s): ed76739

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_MERT.py +131 -0
  2. modeling_MERT.py +409 -0
configuration_MERT.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MERT model configuration
3
+ """
4
+
5
+ import functools
6
+ import operator
7
+
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+
15
+ class MERTConfig(PretrainedConfig):
16
+ r"""
17
+ """
18
+ model_type = "mert_model"
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_size=32,
23
+ hidden_size=768,
24
+ num_hidden_layers=12,
25
+ num_attention_heads=12,
26
+ intermediate_size=3072,
27
+ hidden_act="gelu",
28
+ hidden_dropout=0.1,
29
+ activation_dropout=0.1,
30
+ attention_dropout=0.1,
31
+ feat_proj_layer_norm=True,
32
+ feat_proj_dropout=0.0,
33
+ final_dropout=0.1,
34
+ layerdrop=0.1,
35
+ initializer_range=0.02,
36
+ layer_norm_eps=1e-5,
37
+ feat_extract_norm="group",
38
+ feat_extract_activation="gelu",
39
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
40
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
41
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
42
+ conv_bias=False,
43
+ num_conv_pos_embeddings=128,
44
+ num_conv_pos_embedding_groups=16,
45
+ do_stable_layer_norm=False,
46
+ apply_spec_augment=True,
47
+ mask_time_prob=0.05,
48
+ mask_time_length=10,
49
+ mask_time_min_masks=2,
50
+ mask_feature_prob=0.0,
51
+ mask_feature_length=10,
52
+ mask_feature_min_masks=0,
53
+ ctc_loss_reduction="sum",
54
+ ctc_zero_infinity=False,
55
+ use_weighted_layer_sum=False,
56
+ classifier_proj_size=256,
57
+ pad_token_id=0,
58
+ bos_token_id=1,
59
+ eos_token_id=2,
60
+ feature_extractor_cqt=False,
61
+ feature_extractor_cqt_bins=336,
62
+ deepnorm=False,
63
+ attention_relax=-1.0,
64
+ **kwargs
65
+ ):
66
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
67
+ self.hidden_size = hidden_size
68
+ self.feat_extract_norm = feat_extract_norm
69
+ self.feat_extract_activation = feat_extract_activation
70
+ self.conv_dim = list(conv_dim)
71
+ self.conv_stride = list(conv_stride)
72
+ self.conv_kernel = list(conv_kernel)
73
+ self.conv_bias = conv_bias
74
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
75
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
76
+ self.num_feat_extract_layers = len(self.conv_dim)
77
+ self.num_hidden_layers = num_hidden_layers
78
+ self.intermediate_size = intermediate_size
79
+ self.hidden_act = hidden_act
80
+ self.num_attention_heads = num_attention_heads
81
+ self.hidden_dropout = hidden_dropout
82
+ self.attention_dropout = attention_dropout
83
+ self.activation_dropout = activation_dropout
84
+ self.feat_proj_layer_norm = feat_proj_layer_norm
85
+ self.feat_proj_dropout = feat_proj_dropout
86
+ self.final_dropout = final_dropout
87
+ self.layerdrop = layerdrop
88
+ self.layer_norm_eps = layer_norm_eps
89
+ self.initializer_range = initializer_range
90
+ self.vocab_size = vocab_size
91
+ self.do_stable_layer_norm = do_stable_layer_norm
92
+ self.use_weighted_layer_sum = use_weighted_layer_sum
93
+ self.classifier_proj_size = classifier_proj_size
94
+
95
+ if (
96
+ (len(self.conv_stride) != self.num_feat_extract_layers)
97
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
98
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
99
+ ):
100
+ raise ValueError(
101
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
102
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
103
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
104
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
105
+ )
106
+
107
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
108
+ self.apply_spec_augment = apply_spec_augment
109
+ self.mask_time_prob = mask_time_prob
110
+ self.mask_time_length = mask_time_length
111
+ self.mask_time_min_masks = mask_time_min_masks
112
+ self.mask_feature_prob = mask_feature_prob
113
+ self.mask_feature_length = mask_feature_length
114
+ self.mask_feature_min_masks = mask_feature_min_masks
115
+
116
+ # ctc loss
117
+ self.ctc_loss_reduction = ctc_loss_reduction
118
+ self.ctc_zero_infinity = ctc_zero_infinity
119
+
120
+ # cqt feature extractor
121
+ self.feature_extractor_cqt = feature_extractor_cqt
122
+ self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
123
+
124
+ # deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
125
+ self.deepnorm = deepnorm
126
+
127
+ self.attention_relax = attention_relax
128
+
129
+ @property
130
+ def inputs_to_logits_ratio(self):
131
+ return functools.reduce(operator.mul, self.conv_stride, 1)
modeling_MERT.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MERT model definition.
3
+ We largely adapt codes from:
4
+ 1. https://github.com/huggingface/transformers/blob/main/src/transformers/models/hubert/modeling_hubert.py
5
+ 2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
6
+ """
7
+
8
+ from typing import Optional, Tuple, Union
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+ import torch
11
+ from torch import nn
12
+
13
+ from transformers.models.hubert.modeling_hubert import (
14
+ HubertFeatureEncoder,
15
+ HubertModel,
16
+ HubertEncoderStableLayerNorm,
17
+ HubertEncoder,
18
+ HubertEncoderLayer,
19
+ HubertPositionalConvEmbedding,
20
+ HubertAttention,
21
+ HubertFeedForward,
22
+ )
23
+
24
+ try:
25
+ from nnAudio import features as nnAudioFeatures
26
+ NNAUDIO_INSTALLED=True
27
+ except:
28
+ print("WARNING: feature_extractor_cqt requires the libray 'nnAudio'")
29
+ NNAUDIO_INSTALLED=False
30
+
31
+ from .configuration_MERT import MERTConfig
32
+
33
+ class MERTFeatureProjection(nn.Module):
34
+ def __init__(self, config):
35
+ super().__init__()
36
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
37
+ self.feature_extractor_cqt = config.feature_extractor_cqt
38
+
39
+ if self.feature_extractor_cqt:
40
+ # v3 concat features
41
+ self.feature_dimension = config.conv_dim[-1] + config.feature_extractor_cqt_bins
42
+ print(f"feature dimention: {self.feature_dimension}")
43
+ else:
44
+ self.feature_dimension = config.conv_dim[-1]
45
+ if self.feat_proj_layer_norm:
46
+ self.layer_norm = nn.LayerNorm(self.feature_dimension, eps=config.layer_norm_eps)
47
+ self.projection = nn.Linear(self.feature_dimension, config.hidden_size)
48
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
49
+
50
+ def forward(self, hidden_states):
51
+ # non-projected hidden states are needed for quantization
52
+ if self.feat_proj_layer_norm:
53
+ hidden_states = self.layer_norm(hidden_states)
54
+ hidden_states = self.projection(hidden_states)
55
+ hidden_states = self.dropout(hidden_states)
56
+ return hidden_states
57
+
58
+ class MERTModel(HubertModel):
59
+ # overwrite config class
60
+ config_class = MERTConfig
61
+ base_model_prefix = "mert_model"
62
+ def __init__(
63
+ self,
64
+ config: MERTConfig,
65
+ ) -> None:
66
+ """
67
+ initialize the with the grandparent method HubertPreTrainedModel.__init__()
68
+ and modify the HuBERTModel.__init__()
69
+ """
70
+ super(HubertModel, self).__init__(config)
71
+
72
+ self.config = config
73
+
74
+ self.feature_extractor = HubertFeatureEncoder(config)
75
+ self.feature_projection = MERTFeatureProjection(config) # replace Feature Projection for introcuing new feature
76
+
77
+ if self.config.feature_extractor_cqt:
78
+ assert NNAUDIO_INSTALLED, "ERROR: feature_extractor_cqt requires the libray 'nnAudio', try after `pip install nnAudio` "
79
+ print('initializing cqt extractor for MERT')
80
+ self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=self.config.sample_rate, hop_length=self.config.sample_rate//50, fmin=32.7,
81
+ fmax=None, n_bins=self.config.feature_extractor_cqt_bins, bins_per_octave=self.config.feature_extractor_cqt_bins//7,
82
+ filter_scale=1, norm=1, window='hann', center=True,
83
+ pad_mode='constant', trainable=False,
84
+ output_format='Magnitude', verbose=True)
85
+
86
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
87
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
88
+
89
+
90
+ if config.do_stable_layer_norm:
91
+ assert not config.deepnorm, "must use post-layer_norm with deepnorm"
92
+ self.encoder = HubertEncoderStableLayerNorm(config)
93
+ else:
94
+ if config.deepnorm:
95
+ self.encoder = HubertEncoder_extend(config)
96
+ else:
97
+ self.encoder = HubertEncoder(config)
98
+
99
+ # Initialize weights and apply final processing
100
+ self.post_init()
101
+
102
+ def forward(self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutput]:
103
+
104
+ # return super().forward(input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
105
+
106
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
107
+ output_hidden_states = (
108
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
109
+ )
110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
111
+
112
+ extract_features = self.feature_extractor(input_values)
113
+ extract_features = extract_features.transpose(1, 2)
114
+
115
+ # add additional cqt features for transformer input
116
+ if self.config.feature_extractor_cqt:
117
+ features_cqt = self.feature_extractor_cqt(input_values).transpose(1, 2)
118
+ features_cqt = features_cqt[:,:extract_features.shape[1],:] # align shape
119
+ # # v2
120
+ # features_cqt = self.post_cqt_feature_proj(features_cqt)
121
+ # extract_features = self.feature_projection.layer_norm(extract_features) + self.feature_projection.layer_norm(features_cqt) #v2
122
+ # v3
123
+ extract_features = torch.cat([extract_features,features_cqt], 2)
124
+
125
+ if attention_mask is not None:
126
+ # compute reduced attention_mask corresponding to feature vectors
127
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
128
+
129
+ hidden_states = self.feature_projection(extract_features)
130
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
131
+
132
+ encoder_outputs = self.encoder(
133
+ hidden_states,
134
+ attention_mask=attention_mask,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ hidden_states = encoder_outputs[0] # take last_hidden from encoder output
141
+
142
+ if not return_dict:
143
+ return (hidden_states,) + encoder_outputs[1:]
144
+
145
+ return BaseModelOutput(
146
+ last_hidden_state=hidden_states,
147
+ hidden_states=encoder_outputs.hidden_states,
148
+ attentions=encoder_outputs.attentions,
149
+ )
150
+
151
+
152
+ class HubertEncoder_extend(HubertEncoder):
153
+ def __init__(self, config):
154
+ # super().__init__()
155
+ # call nn module initialization
156
+ nn.Module.__init__(self)
157
+ # super(HubertEncoder_extend, self).__init__()
158
+
159
+ self.config = config
160
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
161
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
162
+ self.dropout = nn.Dropout(config.hidden_dropout)
163
+
164
+
165
+ self.layers = nn.ModuleList([HubertEncoderLayerExtend(config) for _ in range(config.num_hidden_layers)])
166
+
167
+ self.gradient_checkpointing = False
168
+
169
+ if config.deepnorm:
170
+ import math
171
+ init_scale = math.pow(8.0 * config.num_hidden_layers, 0.25)
172
+ for name, p in self.named_parameters():
173
+ if (
174
+ "feed_forward.intermediate_dense" in name
175
+ or "feed_forward.output_dense" in name
176
+ or "out_proj" in name
177
+ or "v_proj" in name
178
+ ):
179
+ p.data.div_(init_scale)
180
+
181
+ class HubertEncoderLayerExtend(HubertEncoderLayer):
182
+ def __init__(self, config):
183
+ nn.Module.__init__(self)
184
+ # super(HubertEncoderLayerExtend, self).__init__()
185
+ if config.attention_relax > 0 :
186
+ self.attention = HubertAttention_extend(
187
+ embed_dim=config.hidden_size,
188
+ num_heads=config.num_attention_heads,
189
+ dropout=config.attention_dropout,
190
+ is_decoder=False,
191
+ attention_relax=config.attention_relax,
192
+ )
193
+ else:
194
+ self.attention = HubertAttention(
195
+ embed_dim=config.hidden_size,
196
+ num_heads=config.num_attention_heads,
197
+ dropout=config.attention_dropout,
198
+ is_decoder=False,
199
+ )
200
+ self.dropout = nn.Dropout(config.hidden_dropout)
201
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
202
+ self.feed_forward = HubertFeedForward(config)
203
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
204
+
205
+ if config.deepnorm:
206
+ import math
207
+ self.residual_alpha = math.pow(2.0 * config.num_hidden_layers, 0.25)
208
+ else:
209
+ self.residual_alpha = 1.0
210
+
211
+ def residual_connection(self, x, residual):
212
+ '''
213
+ residual: input before f()
214
+ x: output of f(residual)
215
+ '''
216
+ return residual * self.residual_alpha + x
217
+
218
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
219
+ attn_residual = hidden_states
220
+ hidden_states, attn_weights, _ = self.attention(
221
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
222
+ )
223
+ hidden_states = self.dropout(hidden_states)
224
+
225
+ # hidden_states = attn_residual + hidden_states
226
+ hidden_states = self.residual_connection(hidden_states, attn_residual)
227
+
228
+ hidden_states = self.layer_norm(hidden_states)
229
+
230
+ # hidden_states = hidden_states + self.feed_forward(hidden_states)
231
+ ffn_residual = hidden_states
232
+ hidden_states = self.feed_forward(hidden_states)
233
+ hidden_states = self.residual_connection(hidden_states, ffn_residual)
234
+
235
+ hidden_states = self.final_layer_norm(hidden_states)
236
+
237
+ outputs = (hidden_states,)
238
+
239
+ if output_attentions:
240
+ outputs += (attn_weights,)
241
+
242
+ return outputs
243
+
244
+
245
+ class HubertAttention_extend(nn.Module):
246
+ def __init__(
247
+ self,
248
+ embed_dim: int,
249
+ num_heads: int,
250
+ dropout: float = 0.0,
251
+ is_decoder: bool = False,
252
+ bias: bool = True,
253
+ attention_relax: float = -1.0,
254
+ ):
255
+ super().__init__()
256
+ # nn.Module.__init__(self)
257
+ self.embed_dim = embed_dim
258
+ self.num_heads = num_heads
259
+ self.dropout = dropout
260
+ self.head_dim = embed_dim // num_heads
261
+
262
+ if (self.head_dim * num_heads) != self.embed_dim:
263
+ raise ValueError(
264
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
265
+ f" and `num_heads`: {num_heads})."
266
+ )
267
+ self.scaling = self.head_dim**-0.5
268
+ self.is_decoder = is_decoder
269
+
270
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
271
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
272
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
273
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
274
+
275
+ if attention_relax > 0:
276
+ self.attention_relax = attention_relax
277
+
278
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
279
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states: torch.Tensor,
284
+ key_value_states: Optional[torch.Tensor] = None,
285
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ layer_head_mask: Optional[torch.Tensor] = None,
288
+ output_attentions: bool = False,
289
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
+ """Input shape: Batch x Time x Channel"""
291
+
292
+ # if key_value_states are provided this layer is used as a cross-attention layer
293
+ # for the decoder
294
+ is_cross_attention = key_value_states is not None
295
+
296
+ bsz, tgt_len, _ = hidden_states.size()
297
+
298
+ # get query proj
299
+ query_states = self.q_proj(hidden_states) * self.scaling
300
+ # get key, value proj
301
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
302
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
303
+ # the provided `key_value_states` to support prefix tuning
304
+ if (
305
+ is_cross_attention
306
+ and past_key_value is not None
307
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
308
+ ):
309
+ # reuse k,v, cross_attentions
310
+ key_states = past_key_value[0]
311
+ value_states = past_key_value[1]
312
+ elif is_cross_attention:
313
+ # cross_attentions
314
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
315
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
316
+ elif past_key_value is not None:
317
+ # reuse k, v, self_attention
318
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
319
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
320
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
321
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
322
+ else:
323
+ # self_attention
324
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
325
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
326
+
327
+ if self.is_decoder:
328
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
329
+ # Further calls to cross_attention layer can then reuse all cross-attention
330
+ # key/value_states (first "if" case)
331
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
332
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
333
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
334
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
335
+ past_key_value = (key_states, value_states)
336
+
337
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
338
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
339
+ key_states = key_states.view(*proj_shape)
340
+ value_states = value_states.view(*proj_shape)
341
+
342
+ src_len = key_states.size(1)
343
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
344
+
345
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
346
+ raise ValueError(
347
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
348
+ f" {attn_weights.size()}"
349
+ )
350
+
351
+ if attention_mask is not None:
352
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
353
+ raise ValueError(
354
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
355
+ )
356
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
357
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
358
+
359
+ if self.attention_relax > 0:
360
+ # => (bsz, self.num_heads, tgt_len, src_len)
361
+ # attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax
362
+ # => (bsz*self.num_heads, tgt_len, src_len)
363
+ attn_weights_relax = attn_weights / self.attention_relax
364
+
365
+ # => (bsz* self.num_heads, tgt_len, 1)
366
+ attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
367
+ attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
368
+
369
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
370
+
371
+ if layer_head_mask is not None:
372
+ if layer_head_mask.size() != (self.num_heads,):
373
+ raise ValueError(
374
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
375
+ f" {layer_head_mask.size()}"
376
+ )
377
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
378
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
379
+
380
+ if output_attentions:
381
+ # this operation is a bit awkward, but it's required to
382
+ # make sure that attn_weights keeps its gradient.
383
+ # In order to do so, attn_weights have to be reshaped
384
+ # twice and have to be reused in the following
385
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
386
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
387
+ else:
388
+ attn_weights_reshaped = None
389
+
390
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
391
+
392
+ attn_output = torch.bmm(attn_probs, value_states)
393
+
394
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
395
+ raise ValueError(
396
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
397
+ f" {attn_output.size()}"
398
+ )
399
+
400
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
401
+ attn_output = attn_output.transpose(1, 2)
402
+
403
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
404
+ # partitioned aross GPUs when using tensor-parallelism.
405
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
406
+
407
+ attn_output = self.out_proj(attn_output)
408
+
409
+ return attn_output, attn_weights_reshaped, past_key_value