|
import torch |
|
from torch import nn |
|
from transformers.activations import ACT2FN |
|
|
|
|
|
class Conv2dFeatureExtractor(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.conv = torch.nn.Sequential( |
|
*[ |
|
nn.Sequential( |
|
nn.Conv2d( |
|
conv_in, |
|
out_channels=conv_out, |
|
kernel_size=(conv_kernel, conv_kernel), |
|
stride=(conv_stride, conv_stride), |
|
), |
|
ACT2FN[config.feat_extract_activation], |
|
) |
|
for conv_in, conv_out, conv_kernel, conv_stride in zip( |
|
[1, *config.conv_dim], config.conv_dim, config.conv_kernel, config.conv_stride |
|
) |
|
], |
|
) |
|
|
|
linear_in_dim = config.conv_dim[-1] * (((config.second_dim_input_size - 1) // 2 - 1) // 2) |
|
self.out = torch.nn.Linear(linear_in_dim, config.hidden_size, bias=True) |
|
|
|
def forward(self, input_values: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.conv(input_values[:, None, ...]) |
|
hidden_states = self.out(hidden_states.transpose(1, 2).flatten(2, 3)) |
|
return hidden_states.transpose(1, 2) |
|
|