ED-small / extractors.py
Lakoc's picture
Upload JointCTCAttentionEncoderDecoder
8b3008e verified
import torch
from torch import nn
from transformers.activations import ACT2FN
class Conv2dFeatureExtractor(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = torch.nn.Sequential(
*[
nn.Sequential(
nn.Conv2d(
conv_in,
out_channels=conv_out,
kernel_size=(conv_kernel, conv_kernel),
stride=(conv_stride, conv_stride),
),
ACT2FN[config.feat_extract_activation],
)
for conv_in, conv_out, conv_kernel, conv_stride in zip(
[1, *config.conv_dim], config.conv_dim, config.conv_kernel, config.conv_stride
)
],
)
linear_in_dim = config.conv_dim[-1] * (((config.second_dim_input_size - 1) // 2 - 1) // 2)
self.out = torch.nn.Linear(linear_in_dim, config.hidden_size, bias=True)
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv(input_values[:, None, ...])
hidden_states = self.out(hidden_states.transpose(1, 2).flatten(2, 3))
return hidden_states.transpose(1, 2)