yangwang825 commited on
Commit
9b00023
1 Parent(s): c94a9ff

Create modeling_audio_spectrogram_transformer.py

Browse files
modeling_audio_spectrogram_transformer.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Audio Spectrogram Transformer (AST) model."""
16
+
17
+ import math
18
+ from typing import Dict, List, Optional, Set, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
29
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
30
+ from .configuration_audio_spectrogram_transformer import ASTConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ # General docstring
36
+ _CONFIG_FOR_DOC = "ASTConfig"
37
+
38
+ # Base docstring
39
+ _CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
40
+ _EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
41
+
42
+ # Audio classification docstring
43
+ _SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
44
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
45
+ _SEQ_CLASS_EXPECTED_LOSS = 0.17
46
+
47
+
48
+ class ASTEmbeddings(nn.Module):
49
+ """
50
+ Construct the CLS token, position and patch embeddings.
51
+ """
52
+
53
+ def __init__(self, config: ASTConfig) -> None:
54
+ super().__init__()
55
+
56
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
57
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
58
+ self.patch_embeddings = ASTPatchEmbeddings(config)
59
+
60
+ frequency_out_dimension, time_out_dimension = self.get_shape(config)
61
+ num_patches = frequency_out_dimension * time_out_dimension
62
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+ self.config = config
65
+
66
+ def get_shape(self, config):
67
+ # see Karpathy's cs231n blog on how to calculate the output dimensions
68
+ # https://cs231n.github.io/convolutional-networks/#conv
69
+ frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
70
+ time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
71
+
72
+ return frequency_out_dimension, time_out_dimension
73
+
74
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
75
+ batch_size = input_values.shape[0]
76
+ embeddings = self.patch_embeddings(input_values)
77
+
78
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
79
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
80
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
81
+ embeddings = embeddings + self.position_embeddings
82
+ embeddings = self.dropout(embeddings)
83
+
84
+ return embeddings
85
+
86
+
87
+ class ASDeiTEmbeddings(nn.Module):
88
+ """
89
+ Construct the CLS token, position and patch embeddings.
90
+ """
91
+
92
+ def __init__(self, config: ASTConfig) -> None:
93
+ super().__init__()
94
+
95
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
96
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
97
+ self.patch_embeddings = ASTPatchEmbeddings(config)
98
+
99
+ frequency_out_dimension, time_out_dimension = self.get_shape(config)
100
+ num_patches = frequency_out_dimension * time_out_dimension
101
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
102
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
103
+ self.config = config
104
+
105
+ def get_shape(self, config):
106
+ # see Karpathy's cs231n blog on how to calculate the output dimensions
107
+ # https://cs231n.github.io/convolutional-networks/#conv
108
+ frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
109
+ time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
110
+
111
+ return frequency_out_dimension, time_out_dimension
112
+
113
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
114
+ batch_size = input_values.shape[0]
115
+ embeddings = self.patch_embeddings(input_values)
116
+
117
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
118
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
119
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
120
+ embeddings = embeddings + self.position_embeddings
121
+ embeddings = self.dropout(embeddings)
122
+
123
+ return embeddings
124
+
125
+
126
+ class ASViTEmbeddings(nn.Module):
127
+ """
128
+ Construct the CLS token, position and patch embeddings.
129
+ """
130
+
131
+ def __init__(self, config: ASTConfig) -> None:
132
+ super().__init__()
133
+
134
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
135
+ self.patch_embeddings = ASTPatchEmbeddings(config)
136
+
137
+ frequency_out_dimension, time_out_dimension = self.get_shape(config)
138
+ num_patches = frequency_out_dimension * time_out_dimension
139
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
140
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
141
+ self.config = config
142
+
143
+ def get_shape(self, config):
144
+ # see Karpathy's cs231n blog on how to calculate the output dimensions
145
+ # https://cs231n.github.io/convolutional-networks/#conv
146
+ frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
147
+ time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
148
+
149
+ return frequency_out_dimension, time_out_dimension
150
+
151
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
152
+ batch_size = input_values.shape[0]
153
+ embeddings = self.patch_embeddings(input_values)
154
+
155
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
156
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
157
+ embeddings = embeddings + self.position_embeddings
158
+ embeddings = self.dropout(embeddings)
159
+
160
+ return embeddings
161
+
162
+
163
+ class ASTPatchEmbeddings(nn.Module):
164
+ """
165
+ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
166
+ seq_length, hidden_size)` to be consumed by a Transformer.
167
+ """
168
+
169
+ def __init__(self, config):
170
+ super().__init__()
171
+
172
+ patch_size = config.patch_size
173
+ frequency_stride = config.frequency_stride
174
+ time_stride = config.time_stride
175
+
176
+ self.projection = nn.Conv2d(
177
+ 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride)
178
+ )
179
+
180
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
181
+ input_values = input_values.unsqueeze(1)
182
+ input_values = input_values.transpose(2, 3)
183
+ embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
184
+ return embeddings
185
+
186
+
187
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
188
+ class ASTSelfAttention(nn.Module):
189
+ def __init__(self, config: ASTConfig) -> None:
190
+ super().__init__()
191
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
192
+ raise ValueError(
193
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
194
+ f"heads {config.num_attention_heads}."
195
+ )
196
+
197
+ self.num_attention_heads = config.num_attention_heads
198
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
199
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
200
+
201
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
202
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
203
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
204
+
205
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
206
+
207
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
208
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
209
+ x = x.view(new_x_shape)
210
+ return x.permute(0, 2, 1, 3)
211
+
212
+ def forward(
213
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
214
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
215
+ mixed_query_layer = self.query(hidden_states)
216
+
217
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
218
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
219
+ query_layer = self.transpose_for_scores(mixed_query_layer)
220
+
221
+ # Take the dot product between "query" and "key" to get the raw attention scores.
222
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
223
+
224
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
225
+
226
+ # Normalize the attention scores to probabilities.
227
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
228
+
229
+ # This is actually dropping out entire tokens to attend to, which might
230
+ # seem a bit unusual, but is taken from the original Transformer paper.
231
+ attention_probs = self.dropout(attention_probs)
232
+
233
+ # Mask heads if we want to
234
+ if head_mask is not None:
235
+ attention_probs = attention_probs * head_mask
236
+
237
+ context_layer = torch.matmul(attention_probs, value_layer)
238
+
239
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
240
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
241
+ context_layer = context_layer.view(new_context_layer_shape)
242
+
243
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
244
+
245
+ return outputs
246
+
247
+
248
+ # Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
249
+ class ASTSdpaSelfAttention(ASTSelfAttention):
250
+ def __init__(self, config: ASTConfig) -> None:
251
+ super().__init__(config)
252
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
253
+
254
+ def forward(
255
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
256
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
257
+ mixed_query_layer = self.query(hidden_states)
258
+
259
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
260
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
261
+ query_layer = self.transpose_for_scores(mixed_query_layer)
262
+
263
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
264
+ query_layer,
265
+ key_layer,
266
+ value_layer,
267
+ head_mask,
268
+ self.attention_probs_dropout_prob if self.training else 0.0,
269
+ is_causal=False,
270
+ scale=None,
271
+ )
272
+
273
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
274
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
275
+ context_layer = context_layer.view(new_context_layer_shape)
276
+
277
+ return context_layer, None
278
+
279
+
280
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
281
+ class ASTSelfOutput(nn.Module):
282
+ """
283
+ The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
284
+ layernorm applied before each block.
285
+ """
286
+
287
+ def __init__(self, config: ASTConfig) -> None:
288
+ super().__init__()
289
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
+
292
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+
296
+ return hidden_states
297
+
298
+
299
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
300
+ class ASTAttention(nn.Module):
301
+ def __init__(self, config: ASTConfig) -> None:
302
+ super().__init__()
303
+ self.attention = ASTSelfAttention(config)
304
+ self.output = ASTSelfOutput(config)
305
+ self.pruned_heads = set()
306
+
307
+ def prune_heads(self, heads: Set[int]) -> None:
308
+ if len(heads) == 0:
309
+ return
310
+ heads, index = find_pruneable_heads_and_indices(
311
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
312
+ )
313
+
314
+ # Prune linear layers
315
+ self.attention.query = prune_linear_layer(self.attention.query, index)
316
+ self.attention.key = prune_linear_layer(self.attention.key, index)
317
+ self.attention.value = prune_linear_layer(self.attention.value, index)
318
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
319
+
320
+ # Update hyper params and store pruned heads
321
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
322
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
323
+ self.pruned_heads = self.pruned_heads.union(heads)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states: torch.Tensor,
328
+ head_mask: Optional[torch.Tensor] = None,
329
+ output_attentions: bool = False,
330
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
331
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
332
+
333
+ attention_output = self.output(self_outputs[0], hidden_states)
334
+
335
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
336
+ return outputs
337
+
338
+
339
+ # Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
340
+ class ASTSdpaAttention(ASTAttention):
341
+ def __init__(self, config: ASTConfig) -> None:
342
+ super().__init__(config)
343
+ self.attention = ASTSdpaSelfAttention(config)
344
+
345
+
346
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
347
+ class ASTIntermediate(nn.Module):
348
+ def __init__(self, config: ASTConfig) -> None:
349
+ super().__init__()
350
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
351
+ if isinstance(config.hidden_act, str):
352
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
353
+ else:
354
+ self.intermediate_act_fn = config.hidden_act
355
+
356
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
357
+ hidden_states = self.dense(hidden_states)
358
+ hidden_states = self.intermediate_act_fn(hidden_states)
359
+
360
+ return hidden_states
361
+
362
+
363
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
364
+ class ASTOutput(nn.Module):
365
+ def __init__(self, config: ASTConfig) -> None:
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
369
+
370
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
371
+ hidden_states = self.dense(hidden_states)
372
+ hidden_states = self.dropout(hidden_states)
373
+
374
+ hidden_states = hidden_states + input_tensor
375
+
376
+ return hidden_states
377
+
378
+
379
+ AST_ATTENTION_CLASSES = {
380
+ "eager": ASTAttention,
381
+ "sdpa": ASTSdpaAttention,
382
+ }
383
+
384
+
385
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
386
+ class ASTLayer(nn.Module):
387
+ """This corresponds to the Block class in the timm implementation."""
388
+
389
+ def __init__(self, config: ASTConfig) -> None:
390
+ super().__init__()
391
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
392
+ self.seq_len_dim = 1
393
+ self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
394
+ self.intermediate = ASTIntermediate(config)
395
+ self.output = ASTOutput(config)
396
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
397
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ head_mask: Optional[torch.Tensor] = None,
403
+ output_attentions: bool = False,
404
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
405
+ self_attention_outputs = self.attention(
406
+ self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention
407
+ head_mask,
408
+ output_attentions=output_attentions,
409
+ )
410
+ attention_output = self_attention_outputs[0]
411
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
412
+
413
+ # first residual connection
414
+ hidden_states = attention_output + hidden_states
415
+
416
+ # in AST, layernorm is also applied after self-attention
417
+ layer_output = self.layernorm_after(hidden_states)
418
+ layer_output = self.intermediate(layer_output)
419
+
420
+ # second residual connection is done here
421
+ layer_output = self.output(layer_output, hidden_states)
422
+
423
+ outputs = (layer_output,) + outputs
424
+
425
+ return outputs
426
+
427
+
428
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
429
+ class ASTEncoder(nn.Module):
430
+ def __init__(self, config: ASTConfig) -> None:
431
+ super().__init__()
432
+ self.config = config
433
+ self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
434
+ self.gradient_checkpointing = False
435
+
436
+ def forward(
437
+ self,
438
+ hidden_states: torch.Tensor,
439
+ head_mask: Optional[torch.Tensor] = None,
440
+ output_attentions: bool = False,
441
+ output_hidden_states: bool = False,
442
+ return_dict: bool = True,
443
+ ) -> Union[tuple, BaseModelOutput]:
444
+ all_hidden_states = () if output_hidden_states else None
445
+ all_self_attentions = () if output_attentions else None
446
+
447
+ for i, layer_module in enumerate(self.layer):
448
+ if output_hidden_states:
449
+ all_hidden_states = all_hidden_states + (hidden_states,)
450
+
451
+ layer_head_mask = head_mask[i] if head_mask is not None else None
452
+
453
+ if self.gradient_checkpointing and self.training:
454
+ layer_outputs = self._gradient_checkpointing_func(
455
+ layer_module.__call__,
456
+ hidden_states,
457
+ layer_head_mask,
458
+ output_attentions,
459
+ )
460
+ else:
461
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
462
+
463
+ hidden_states = layer_outputs[0]
464
+
465
+ if output_attentions:
466
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
467
+
468
+ if output_hidden_states:
469
+ all_hidden_states = all_hidden_states + (hidden_states,)
470
+
471
+ if not return_dict:
472
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
473
+ return BaseModelOutput(
474
+ last_hidden_state=hidden_states,
475
+ hidden_states=all_hidden_states,
476
+ attentions=all_self_attentions,
477
+ )
478
+
479
+
480
+ class ASTPreTrainedModel(PreTrainedModel):
481
+ """
482
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
483
+ models.
484
+ """
485
+
486
+ config_class = ASTConfig
487
+ base_model_prefix = "audio_spectrogram_transformer"
488
+ main_input_name = "input_values"
489
+ supports_gradient_checkpointing = True
490
+ _supports_sdpa = True
491
+
492
+ # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
493
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
494
+ """Initialize the weights"""
495
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
496
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
497
+ # `trunc_normal_cpu` not implemented in `half` issues
498
+ module.weight.data = nn.init.trunc_normal_(
499
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
500
+ ).to(module.weight.dtype)
501
+ if module.bias is not None:
502
+ module.bias.data.zero_()
503
+ elif isinstance(module, nn.LayerNorm):
504
+ module.bias.data.zero_()
505
+ module.weight.data.fill_(1.0)
506
+
507
+
508
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
509
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
510
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
511
+ behavior.
512
+
513
+ Parameters:
514
+ config ([`ASTConfig`]):
515
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
516
+ load the weights associated with the model, only the configuration. Check out the
517
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
518
+ """
519
+
520
+ AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
521
+ Args:
522
+ input_values (`torch.FloatTensor` of shape `(batch_size, max_length, num_mel_bins)`):
523
+ Float values mel features extracted from the raw audio waveform. Raw audio waveform can be obtained by
524
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
525
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
526
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
527
+ tensor of type `torch.FloatTensor`. See [`~ASTFeatureExtractor.__call__`]
528
+
529
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
530
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
531
+
532
+ - 1 indicates the head is **not masked**,
533
+ - 0 indicates the head is **masked**.
534
+
535
+ output_attentions (`bool`, *optional*):
536
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
537
+ tensors for more detail.
538
+ output_hidden_states (`bool`, *optional*):
539
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
540
+ more detail.
541
+ return_dict (`bool`, *optional*):
542
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
543
+ """
544
+
545
+
546
+ @add_start_docstrings(
547
+ "The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
548
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
549
+ )
550
+ class ASTModel(ASTPreTrainedModel):
551
+ def __init__(self, config: ASTConfig) -> None:
552
+ super().__init__(config)
553
+ self.config = config
554
+
555
+ self.embeddings = ASTEmbeddings(config)
556
+ self.encoder = ASTEncoder(config)
557
+
558
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
559
+
560
+ # Initialize weights and apply final processing
561
+ self.post_init()
562
+
563
+ def get_input_embeddings(self) -> ASTPatchEmbeddings:
564
+ return self.embeddings.patch_embeddings
565
+
566
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
567
+ """
568
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
569
+ class PreTrainedModel
570
+ """
571
+ for layer, heads in heads_to_prune.items():
572
+ self.encoder.layer[layer].attention.prune_heads(heads)
573
+
574
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
575
+ @add_code_sample_docstrings(
576
+ checkpoint=_CHECKPOINT_FOR_DOC,
577
+ output_type=BaseModelOutputWithPooling,
578
+ config_class=_CONFIG_FOR_DOC,
579
+ modality="audio",
580
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
581
+ )
582
+ def forward(
583
+ self,
584
+ input_values: Optional[torch.Tensor] = None,
585
+ head_mask: Optional[torch.Tensor] = None,
586
+ output_attentions: Optional[bool] = None,
587
+ output_hidden_states: Optional[bool] = None,
588
+ return_dict: Optional[bool] = None,
589
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
590
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
591
+ output_hidden_states = (
592
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
593
+ )
594
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
595
+
596
+ if input_values is None:
597
+ raise ValueError("You have to specify input_values")
598
+
599
+ # Prepare head mask if needed
600
+ # 1.0 in head_mask indicate we keep the head
601
+ # attention_probs has shape bsz x n_heads x N x N
602
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
603
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
604
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
605
+
606
+ embedding_output = self.embeddings(input_values)
607
+
608
+ encoder_outputs = self.encoder(
609
+ embedding_output,
610
+ head_mask=head_mask,
611
+ output_attentions=output_attentions,
612
+ output_hidden_states=output_hidden_states,
613
+ return_dict=return_dict,
614
+ )
615
+ sequence_output = encoder_outputs[0]
616
+ sequence_output = self.layernorm(sequence_output)
617
+
618
+ pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
619
+
620
+ if not return_dict:
621
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
622
+
623
+ return BaseModelOutputWithPooling(
624
+ last_hidden_state=sequence_output,
625
+ pooler_output=pooled_output,
626
+ hidden_states=encoder_outputs.hidden_states,
627
+ attentions=encoder_outputs.attentions,
628
+ )
629
+
630
+
631
+ class ASDeiTModel(ASTPreTrainedModel):
632
+
633
+ def __init__(self, config: ASTConfig) -> None:
634
+ super().__init__(config)
635
+ self.config = config
636
+
637
+ self.embeddings = ASDeiTEmbeddings(config)
638
+ self.encoder = ASTEncoder(config)
639
+
640
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
641
+
642
+ # Initialize weights and apply final processing
643
+ self.post_init()
644
+
645
+ def get_input_embeddings(self) -> ASTPatchEmbeddings:
646
+ return self.embeddings.patch_embeddings
647
+
648
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
649
+ """
650
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
651
+ class PreTrainedModel
652
+ """
653
+ for layer, heads in heads_to_prune.items():
654
+ self.encoder.layer[layer].attention.prune_heads(heads)
655
+
656
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
657
+ @add_code_sample_docstrings(
658
+ checkpoint=_CHECKPOINT_FOR_DOC,
659
+ output_type=BaseModelOutputWithPooling,
660
+ config_class=_CONFIG_FOR_DOC,
661
+ modality="audio",
662
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
663
+ )
664
+ def forward(
665
+ self,
666
+ input_values: Optional[torch.Tensor] = None,
667
+ head_mask: Optional[torch.Tensor] = None,
668
+ output_attentions: Optional[bool] = None,
669
+ output_hidden_states: Optional[bool] = None,
670
+ return_dict: Optional[bool] = None,
671
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
672
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
673
+ output_hidden_states = (
674
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
675
+ )
676
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
677
+
678
+ if input_values is None:
679
+ raise ValueError("You have to specify input_values")
680
+
681
+ # Prepare head mask if needed
682
+ # 1.0 in head_mask indicate we keep the head
683
+ # attention_probs has shape bsz x n_heads x N x N
684
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
685
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
686
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
687
+
688
+ embedding_output = self.embeddings(input_values)
689
+
690
+ encoder_outputs = self.encoder(
691
+ embedding_output,
692
+ head_mask=head_mask,
693
+ output_attentions=output_attentions,
694
+ output_hidden_states=output_hidden_states,
695
+ return_dict=return_dict,
696
+ )
697
+ sequence_output = encoder_outputs[0]
698
+ sequence_output = self.layernorm(sequence_output)
699
+
700
+ pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
701
+
702
+ if not return_dict:
703
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
704
+
705
+ return BaseModelOutputWithPooling(
706
+ last_hidden_state=sequence_output,
707
+ pooler_output=pooled_output,
708
+ hidden_states=encoder_outputs.hidden_states,
709
+ attentions=encoder_outputs.attentions,
710
+ )
711
+
712
+
713
+ class ASViTModel(ASTPreTrainedModel):
714
+
715
+ def __init__(self, config: ASTConfig) -> None:
716
+ super().__init__(config)
717
+ self.config = config
718
+
719
+ self.embeddings = ASViTEmbeddings(config)
720
+ self.encoder = ASTEncoder(config)
721
+
722
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
723
+
724
+ # Initialize weights and apply final processing
725
+ self.post_init()
726
+
727
+ def get_input_embeddings(self) -> ASTPatchEmbeddings:
728
+ return self.embeddings.patch_embeddings
729
+
730
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
731
+ """
732
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
733
+ class PreTrainedModel
734
+ """
735
+ for layer, heads in heads_to_prune.items():
736
+ self.encoder.layer[layer].attention.prune_heads(heads)
737
+
738
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
739
+ @add_code_sample_docstrings(
740
+ checkpoint=_CHECKPOINT_FOR_DOC,
741
+ output_type=BaseModelOutputWithPooling,
742
+ config_class=_CONFIG_FOR_DOC,
743
+ modality="audio",
744
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
745
+ )
746
+ def forward(
747
+ self,
748
+ input_values: Optional[torch.Tensor] = None,
749
+ head_mask: Optional[torch.Tensor] = None,
750
+ output_attentions: Optional[bool] = None,
751
+ output_hidden_states: Optional[bool] = None,
752
+ return_dict: Optional[bool] = None,
753
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
754
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
755
+ output_hidden_states = (
756
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
757
+ )
758
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
759
+
760
+ if input_values is None:
761
+ raise ValueError("You have to specify input_values")
762
+
763
+ # Prepare head mask if needed
764
+ # 1.0 in head_mask indicate we keep the head
765
+ # attention_probs has shape bsz x n_heads x N x N
766
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
767
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
768
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
769
+
770
+ embedding_output = self.embeddings(input_values)
771
+
772
+ encoder_outputs = self.encoder(
773
+ embedding_output,
774
+ head_mask=head_mask,
775
+ output_attentions=output_attentions,
776
+ output_hidden_states=output_hidden_states,
777
+ return_dict=return_dict,
778
+ )
779
+ sequence_output = encoder_outputs[0]
780
+ sequence_output = self.layernorm(sequence_output)
781
+
782
+ pooled_output = sequence_output[:, 0]
783
+
784
+ if not return_dict:
785
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
786
+
787
+ return BaseModelOutputWithPooling(
788
+ last_hidden_state=sequence_output,
789
+ pooler_output=pooled_output,
790
+ hidden_states=encoder_outputs.hidden_states,
791
+ attentions=encoder_outputs.attentions,
792
+ )
793
+
794
+
795
+ class ASTMLPHead(nn.Module):
796
+ def __init__(self, config: ASTConfig):
797
+ super().__init__()
798
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
799
+ self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
800
+
801
+ def forward(self, hidden_state):
802
+ hidden_state = self.layernorm(hidden_state)
803
+ hidden_state = self.dense(hidden_state)
804
+ return hidden_state
805
+
806
+
807
+ @add_start_docstrings(
808
+ """
809
+ Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
810
+ output) e.g. for datasets like AudioSet, Speech Commands v2.
811
+ """,
812
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
813
+ )
814
+ class ASTForAudioClassification(ASTPreTrainedModel):
815
+ def __init__(self, config: ASTConfig) -> None:
816
+ super().__init__(config)
817
+
818
+ self.num_labels = config.num_labels
819
+ self.audio_spectrogram_transformer = ASTModel(config)
820
+
821
+ # Classifier head
822
+ self.classifier = ASTMLPHead(config)
823
+
824
+ # Initialize weights and apply final processing
825
+ self.post_init()
826
+
827
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
828
+ @add_code_sample_docstrings(
829
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
830
+ output_type=SequenceClassifierOutput,
831
+ config_class=_CONFIG_FOR_DOC,
832
+ modality="audio",
833
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
834
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
835
+ )
836
+ def forward(
837
+ self,
838
+ input_values: Optional[torch.Tensor] = None,
839
+ head_mask: Optional[torch.Tensor] = None,
840
+ labels: Optional[torch.Tensor] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None,
844
+ ) -> Union[tuple, SequenceClassifierOutput]:
845
+ r"""
846
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
847
+ Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
848
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
849
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
850
+ """
851
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
+
853
+ outputs = self.audio_spectrogram_transformer(
854
+ input_values,
855
+ head_mask=head_mask,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ )
860
+
861
+ pooled_output = outputs[1]
862
+ logits = self.classifier(pooled_output)
863
+
864
+ loss = None
865
+ if labels is not None:
866
+ if self.config.problem_type is None:
867
+ if self.num_labels == 1:
868
+ self.config.problem_type = "regression"
869
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
870
+ self.config.problem_type = "single_label_classification"
871
+ else:
872
+ self.config.problem_type = "multi_label_classification"
873
+
874
+ if self.config.problem_type == "regression":
875
+ loss_fct = MSELoss()
876
+ if self.num_labels == 1:
877
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
878
+ else:
879
+ loss = loss_fct(logits, labels)
880
+ elif self.config.problem_type == "single_label_classification":
881
+ loss_fct = CrossEntropyLoss()
882
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
883
+ elif self.config.problem_type == "multi_label_classification":
884
+ loss_fct = BCEWithLogitsLoss()
885
+ loss = loss_fct(logits, labels)
886
+
887
+ if not return_dict:
888
+ output = (logits,) + outputs[2:]
889
+ return ((loss,) + output) if loss is not None else output
890
+
891
+ return SequenceClassifierOutput(
892
+ loss=loss,
893
+ logits=logits,
894
+ hidden_states=outputs.hidden_states,
895
+ attentions=outputs.attentions,
896
+ )
897
+
898
+
899
+ class ASDeiTForAudioClassification(ASTPreTrainedModel):
900
+
901
+ def __init__(self, config: ASTConfig) -> None:
902
+ super().__init__(config)
903
+
904
+ self.num_labels = config.num_labels
905
+ self.audio_spectrogram_transformer = ASDeiTModel(config)
906
+
907
+ # Classifier head
908
+ self.classifier = ASTMLPHead(config)
909
+
910
+ # Initialize weights and apply final processing
911
+ self.post_init()
912
+
913
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
914
+ @add_code_sample_docstrings(
915
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
916
+ output_type=SequenceClassifierOutput,
917
+ config_class=_CONFIG_FOR_DOC,
918
+ modality="audio",
919
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
920
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
921
+ )
922
+ def forward(
923
+ self,
924
+ input_values: Optional[torch.Tensor] = None,
925
+ head_mask: Optional[torch.Tensor] = None,
926
+ labels: Optional[torch.Tensor] = None,
927
+ output_attentions: Optional[bool] = None,
928
+ output_hidden_states: Optional[bool] = None,
929
+ return_dict: Optional[bool] = None,
930
+ ) -> Union[tuple, SequenceClassifierOutput]:
931
+ r"""
932
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
933
+ Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
934
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
935
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
936
+ """
937
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
938
+
939
+ outputs = self.audio_spectrogram_transformer(
940
+ input_values,
941
+ head_mask=head_mask,
942
+ output_attentions=output_attentions,
943
+ output_hidden_states=output_hidden_states,
944
+ return_dict=return_dict,
945
+ )
946
+
947
+ pooled_output = outputs[1]
948
+ logits = self.classifier(pooled_output)
949
+
950
+ loss = None
951
+ if labels is not None:
952
+ if self.config.problem_type is None:
953
+ if self.num_labels == 1:
954
+ self.config.problem_type = "regression"
955
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
956
+ self.config.problem_type = "single_label_classification"
957
+ else:
958
+ self.config.problem_type = "multi_label_classification"
959
+
960
+ if self.config.problem_type == "regression":
961
+ loss_fct = MSELoss()
962
+ if self.num_labels == 1:
963
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
964
+ else:
965
+ loss = loss_fct(logits, labels)
966
+ elif self.config.problem_type == "single_label_classification":
967
+ loss_fct = CrossEntropyLoss()
968
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
969
+ elif self.config.problem_type == "multi_label_classification":
970
+ loss_fct = BCEWithLogitsLoss()
971
+ loss = loss_fct(logits, labels)
972
+
973
+ if not return_dict:
974
+ output = (logits,) + outputs[2:]
975
+ return ((loss,) + output) if loss is not None else output
976
+
977
+ return SequenceClassifierOutput(
978
+ loss=loss,
979
+ logits=logits,
980
+ hidden_states=outputs.hidden_states,
981
+ attentions=outputs.attentions,
982
+ )
983
+
984
+
985
+ class ASViTForAudioClassification(ASTPreTrainedModel):
986
+
987
+ def __init__(self, config: ASTConfig) -> None:
988
+ super().__init__(config)
989
+
990
+ self.num_labels = config.num_labels
991
+ self.audio_spectrogram_transformer = ASViTModel(config)
992
+
993
+ # Classifier head
994
+ self.classifier = ASTMLPHead(config)
995
+
996
+ # Initialize weights and apply final processing
997
+ self.post_init()
998
+
999
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
1000
+ @add_code_sample_docstrings(
1001
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
1002
+ output_type=SequenceClassifierOutput,
1003
+ config_class=_CONFIG_FOR_DOC,
1004
+ modality="audio",
1005
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1006
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1007
+ )
1008
+ def forward(
1009
+ self,
1010
+ input_values: Optional[torch.Tensor] = None,
1011
+ head_mask: Optional[torch.Tensor] = None,
1012
+ labels: Optional[torch.Tensor] = None,
1013
+ output_attentions: Optional[bool] = None,
1014
+ output_hidden_states: Optional[bool] = None,
1015
+ return_dict: Optional[bool] = None,
1016
+ ) -> Union[tuple, SequenceClassifierOutput]:
1017
+ r"""
1018
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1019
+ Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
1020
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1021
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1022
+ """
1023
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024
+
1025
+ outputs = self.audio_spectrogram_transformer(
1026
+ input_values,
1027
+ head_mask=head_mask,
1028
+ output_attentions=output_attentions,
1029
+ output_hidden_states=output_hidden_states,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ pooled_output = outputs[1]
1034
+ logits = self.classifier(pooled_output)
1035
+
1036
+ loss = None
1037
+ if labels is not None:
1038
+ if self.config.problem_type is None:
1039
+ if self.num_labels == 1:
1040
+ self.config.problem_type = "regression"
1041
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1042
+ self.config.problem_type = "single_label_classification"
1043
+ else:
1044
+ self.config.problem_type = "multi_label_classification"
1045
+
1046
+ if self.config.problem_type == "regression":
1047
+ loss_fct = MSELoss()
1048
+ if self.num_labels == 1:
1049
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1050
+ else:
1051
+ loss = loss_fct(logits, labels)
1052
+ elif self.config.problem_type == "single_label_classification":
1053
+ loss_fct = CrossEntropyLoss()
1054
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1055
+ elif self.config.problem_type == "multi_label_classification":
1056
+ loss_fct = BCEWithLogitsLoss()
1057
+ loss = loss_fct(logits, labels)
1058
+
1059
+ if not return_dict:
1060
+ output = (logits,) + outputs[2:]
1061
+ return ((loss,) + output) if loss is not None else output
1062
+
1063
+ return SequenceClassifierOutput(
1064
+ loss=loss,
1065
+ logits=logits,
1066
+ hidden_states=outputs.hidden_states,
1067
+ attentions=outputs.attentions,
1068
+ )