Spaces:
Runtime error
Runtime error
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
"""Conv2d Module with Valid Padding""" | |
import torch.nn.functional as F | |
from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional | |
class Conv2dValid(_ConvNd): | |
""" | |
Conv2d operator for VALID mode padding. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t = 1, | |
padding: Union[str, _size_2_t] = 0, | |
dilation: _size_2_t = 1, | |
groups: int = 1, | |
bias: bool = True, | |
padding_mode: str = "zeros", # TODO: refine this type | |
device=None, | |
dtype=None, | |
valid_trigx: bool = False, | |
valid_trigy: bool = False, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
kernel_size_ = _pair(kernel_size) | |
stride_ = _pair(stride) | |
padding_ = padding if isinstance(padding, str) else _pair(padding) | |
dilation_ = _pair(dilation) | |
super(Conv2dValid, self).__init__( | |
in_channels, | |
out_channels, | |
kernel_size_, | |
stride_, | |
padding_, | |
dilation_, | |
False, | |
_pair(0), | |
groups, | |
bias, | |
padding_mode, | |
**factory_kwargs, | |
) | |
self.valid_trigx = valid_trigx | |
self.valid_trigy = valid_trigy | |
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): | |
validx, validy = 0, 0 | |
if self.valid_trigx: | |
validx = ( | |
input.size(-2) * (self.stride[-2] - 1) - 1 + self.kernel_size[-2] | |
) // 2 | |
if self.valid_trigy: | |
validy = ( | |
input.size(-1) * (self.stride[-1] - 1) - 1 + self.kernel_size[-1] | |
) // 2 | |
return F.conv2d( | |
input, | |
weight, | |
bias, | |
self.stride, | |
(validx, validy), | |
self.dilation, | |
self.groups, | |
) | |
def forward(self, input: Tensor) -> Tensor: | |
return self._conv_forward(input, self.weight, self.bias) | |