File size: 656 Bytes
45916af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import torch.nn as nn
class STFTMag(nn.Module):
def __init__(self,
nfft=1024,
hop=256):
super().__init__()
self.nfft = nfft
self.hop = hop
self.register_buffer('window', torch.hann_window(nfft), False)
# x: [B,T] or [T]
@torch.no_grad()
def forward(self, x):
stft = torch.stft(x.cpu(),
self.nfft,
self.hop,
window=self.window,
) # return_complex=False) #[B, F, TT,2]
mag = torch.norm(stft, p=2, dim=-1) # [B, F, TT]
return mag
|