Upload 8 files
Browse files- .gitattributes +3 -0
- lightning_logs/predictor/checkpoints/predictor.ckpt +3 -0
- lightning_logs/predictor/hparams.yaml +6 -0
- lightning_logs/version_0/checkpoints/frn-epoch=65-val_loss=0.2290.ckpt +3 -0
- lightning_logs/version_0/checkpoints/frn.onnx +3 -0
- lightning_logs/version_0/hparams.yaml +6 -0
- models/__init__.py +0 -0
- models/blocks.py +142 -0
- models/frn.py +220 -0
.gitattributes
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
lightning_logs/predictor/checkpoints/predictor.ckpt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
lightning_logs/version_0/checkpoints/frn-epoch=65-val_loss=0.2290.ckpt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
lightning_logs/version_0/checkpoints/frn.onnx filter=lfs diff=lfs merge=lfs -text
|
lightning_logs/predictor/checkpoints/predictor.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f3679c9431666575eb7899e556d040073aa74956c48f122b16b30b9efa2e93b
|
3 |
+
size 14985163
|
lightning_logs/predictor/hparams.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size: 90
|
2 |
+
input: mag
|
3 |
+
lstm_dim: 512
|
4 |
+
lstm_layers: 1
|
5 |
+
output: mag
|
6 |
+
window_size: 960
|
lightning_logs/version_0/checkpoints/frn-epoch=65-val_loss=0.2290.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4061bb0f6e669315e00878009440dab749f60f823d5bf863bfa4b8172d96d073
|
3 |
+
size 109184745
|
lightning_logs/version_0/checkpoints/frn.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fdf07d992ff655e5ab32074d4d7b747986cd79fed16b499ed11b120c7042a666
|
3 |
+
size 36527867
|
lightning_logs/version_0/hparams.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size: 90
|
2 |
+
cnn_dim: 64
|
3 |
+
cnn_layers: 5
|
4 |
+
lstm_dim: 512
|
5 |
+
lstm_layers: 1
|
6 |
+
window_size: 960
|
models/__init__.py
ADDED
File without changes
|
models/blocks.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import torch
|
4 |
+
from einops.layers.torch import Rearrange
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class Aff(nn.Module):
|
9 |
+
def __init__(self, dim):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
|
13 |
+
self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
x = x * self.alpha + self.beta
|
17 |
+
return x
|
18 |
+
|
19 |
+
|
20 |
+
class FeedForward(nn.Module):
|
21 |
+
def __init__(self, dim, hidden_dim, dropout=0.):
|
22 |
+
super().__init__()
|
23 |
+
self.net = nn.Sequential(
|
24 |
+
nn.Linear(dim, hidden_dim),
|
25 |
+
nn.GELU(),
|
26 |
+
nn.Dropout(dropout),
|
27 |
+
nn.Linear(hidden_dim, dim),
|
28 |
+
nn.Dropout(dropout)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return self.net(x)
|
33 |
+
|
34 |
+
|
35 |
+
class MLPBlock(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, dim, mlp_dim, dropout=0., init_values=1e-4):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.pre_affine = Aff(dim)
|
41 |
+
self.inter = nn.LSTM(input_size=dim, hidden_size=dim, num_layers=1,
|
42 |
+
bidirectional=False, batch_first=True)
|
43 |
+
self.ff = nn.Sequential(
|
44 |
+
FeedForward(dim, mlp_dim, dropout),
|
45 |
+
)
|
46 |
+
self.post_affine = Aff(dim)
|
47 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
|
48 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
|
49 |
+
|
50 |
+
def forward(self, x, state=None):
|
51 |
+
x = self.pre_affine(x)
|
52 |
+
if state is None:
|
53 |
+
inter, _ = self.inter(x)
|
54 |
+
else:
|
55 |
+
inter, state = self.inter(x, (state[0], state[1]))
|
56 |
+
x = x + self.gamma_1 * inter
|
57 |
+
x = self.post_affine(x)
|
58 |
+
x = x + self.gamma_2 * self.ff(x)
|
59 |
+
if state is None:
|
60 |
+
return x
|
61 |
+
state = torch.stack(state, 0)
|
62 |
+
return x, state
|
63 |
+
|
64 |
+
|
65 |
+
class Encoder(nn.Module):
|
66 |
+
|
67 |
+
def __init__(self, in_dim, dim, depth, mlp_dim):
|
68 |
+
super().__init__()
|
69 |
+
self.in_dim = in_dim
|
70 |
+
self.dim = dim
|
71 |
+
self.depth = depth
|
72 |
+
self.mlp_dim = mlp_dim
|
73 |
+
self.to_patch_embedding = nn.Sequential(
|
74 |
+
Rearrange('b c f t -> b t (c f)'),
|
75 |
+
nn.Linear(in_dim, dim),
|
76 |
+
nn.GELU()
|
77 |
+
)
|
78 |
+
|
79 |
+
self.mlp_blocks = nn.ModuleList([])
|
80 |
+
|
81 |
+
for _ in range(depth):
|
82 |
+
self.mlp_blocks.append(MLPBlock(self.dim, mlp_dim, dropout=0.15))
|
83 |
+
|
84 |
+
self.affine = nn.Sequential(
|
85 |
+
Aff(self.dim),
|
86 |
+
nn.Linear(dim, in_dim),
|
87 |
+
Rearrange('b t (c f) -> b c f t', c=2),
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x_in, states=None):
|
91 |
+
x = self.to_patch_embedding(x_in)
|
92 |
+
if states is not None:
|
93 |
+
out_states = []
|
94 |
+
for i, mlp_block in enumerate(self.mlp_blocks):
|
95 |
+
if states is None:
|
96 |
+
x = mlp_block(x)
|
97 |
+
else:
|
98 |
+
x, state = mlp_block(x, states[i])
|
99 |
+
out_states.append(state)
|
100 |
+
x = self.affine(x)
|
101 |
+
x = x + x_in
|
102 |
+
if states is None:
|
103 |
+
return x
|
104 |
+
else:
|
105 |
+
return x, torch.stack(out_states, 0)
|
106 |
+
|
107 |
+
|
108 |
+
class Predictor(pl.LightningModule): # mel
|
109 |
+
def __init__(self, window_size=1536, sr=48000, lstm_dim=256, lstm_layers=3, n_mels=64):
|
110 |
+
super(Predictor, self).__init__()
|
111 |
+
self.window_size = window_size
|
112 |
+
self.hop_size = window_size // 2
|
113 |
+
self.lstm_dim = lstm_dim
|
114 |
+
self.n_mels = n_mels
|
115 |
+
self.lstm_layers = lstm_layers
|
116 |
+
|
117 |
+
fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
|
118 |
+
self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
|
119 |
+
self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
|
120 |
+
num_layers=self.lstm_layers, batch_first=True)
|
121 |
+
self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
|
122 |
+
self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
|
123 |
+
|
124 |
+
def forward(self, x, state=None): # B, 2, F, T
|
125 |
+
|
126 |
+
self.fb = self.fb.to(x.device)
|
127 |
+
x = torch.log(torch.matmul(self.fb, x) + 1e-8)
|
128 |
+
B, C, F, T = x.shape
|
129 |
+
x = x.reshape(B, F * C, T)
|
130 |
+
x = x.permute(0, 2, 1)
|
131 |
+
if state is None:
|
132 |
+
x, _ = self.lstm(x)
|
133 |
+
else:
|
134 |
+
x, state = self.lstm(x, (state[0], state[1]))
|
135 |
+
x = self.expand_dim(x)
|
136 |
+
x = torch.abs(self.inv_mel(torch.exp(x)))
|
137 |
+
x = x.permute(0, 2, 1)
|
138 |
+
x = x.reshape(B, C, -1, T)
|
139 |
+
if state is None:
|
140 |
+
return x
|
141 |
+
else:
|
142 |
+
return x, torch.stack(state, 0)
|
models/frn.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import soundfile as sf
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ
|
10 |
+
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility as STOI
|
11 |
+
|
12 |
+
from PLCMOS.plc_mos import PLCMOSEstimator
|
13 |
+
from config import CONFIG
|
14 |
+
from loss import Loss
|
15 |
+
from models.blocks import Encoder, Predictor
|
16 |
+
from utils.utils import visualize, LSD
|
17 |
+
|
18 |
+
plcmos = PLCMOSEstimator()
|
19 |
+
|
20 |
+
|
21 |
+
class PLCModel(pl.LightningModule):
|
22 |
+
def __init__(self, train_dataset=None, val_dataset=None, window_size=960, enc_layers=4, enc_in_dim=384, enc_dim=768,
|
23 |
+
pred_dim=512, pred_layers=1, pred_ckpt_path='lightning_logs/predictor/checkpoints/predictor.ckpt'):
|
24 |
+
super(PLCModel, self).__init__()
|
25 |
+
self.window_size = window_size
|
26 |
+
self.hop_size = window_size // 2
|
27 |
+
self.learning_rate = CONFIG.TRAIN.lr
|
28 |
+
self.hparams.batch_size = CONFIG.TRAIN.batch_size
|
29 |
+
|
30 |
+
self.enc_layers = enc_layers
|
31 |
+
self.enc_in_dim = enc_in_dim
|
32 |
+
self.enc_dim = enc_dim
|
33 |
+
self.pred_dim = pred_dim
|
34 |
+
self.pred_layers = pred_layers
|
35 |
+
self.train_dataset = train_dataset
|
36 |
+
self.val_dataset = val_dataset
|
37 |
+
self.stoi = STOI(48000)
|
38 |
+
self.pesq = PESQ(16000, 'wb')
|
39 |
+
|
40 |
+
if pred_ckpt_path is not None:
|
41 |
+
self.predictor = Predictor.load_from_checkpoint(pred_ckpt_path)
|
42 |
+
else:
|
43 |
+
self.predictor = Predictor(window_size=self.window_size, lstm_dim=self.pred_dim,
|
44 |
+
lstm_layers=self.pred_layers)
|
45 |
+
self.joiner = nn.Sequential(
|
46 |
+
nn.Conv2d(3, 48, kernel_size=(9, 1), stride=1, padding=(4, 0), padding_mode='reflect',
|
47 |
+
groups=3),
|
48 |
+
nn.LeakyReLU(0.2),
|
49 |
+
nn.Conv2d(48, 2, kernel_size=1, stride=1, padding=0, groups=2),
|
50 |
+
)
|
51 |
+
|
52 |
+
self.encoder = Encoder(in_dim=self.window_size, dim=self.enc_in_dim, depth=self.enc_layers,
|
53 |
+
mlp_dim=self.enc_dim)
|
54 |
+
|
55 |
+
self.loss = Loss()
|
56 |
+
self.window = torch.sqrt(torch.hann_window(self.window_size))
|
57 |
+
self.save_hyperparameters('window_size', 'enc_layers', 'enc_in_dim', 'enc_dim', 'pred_dim', 'pred_layers')
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
"""
|
61 |
+
Input: real-imaginary; shape (B, F, T, 2); F = hop_size + 1
|
62 |
+
Output: real-imaginary
|
63 |
+
"""
|
64 |
+
|
65 |
+
B, C, F, T = x.shape
|
66 |
+
|
67 |
+
x = x.permute(3, 0, 1, 2).unsqueeze(-1)
|
68 |
+
prev_mag = torch.zeros((B, 1, F, 1), device=x.device)
|
69 |
+
predictor_state = torch.zeros((2, self.predictor.lstm_layers, B, self.predictor.lstm_dim), device=x.device)
|
70 |
+
mlp_state = torch.zeros((self.encoder.depth, 2, 1, B, self.encoder.dim), device=x.device)
|
71 |
+
result = []
|
72 |
+
for step in x:
|
73 |
+
feat, mlp_state = self.encoder(step, mlp_state)
|
74 |
+
prev_mag, predictor_state = self.predictor(prev_mag, predictor_state)
|
75 |
+
feat = torch.cat((feat, prev_mag), 1)
|
76 |
+
feat = self.joiner(feat)
|
77 |
+
feat = feat + step
|
78 |
+
result.append(feat)
|
79 |
+
prev_mag = torch.linalg.norm(feat, dim=1, ord=1, keepdims=True) # compute magnitude
|
80 |
+
output = torch.cat(result, -1)
|
81 |
+
return output
|
82 |
+
|
83 |
+
def forward_onnx(self, x, prev_mag, predictor_state=None, mlp_state=None):
|
84 |
+
prev_mag, predictor_state = self.predictor(prev_mag, predictor_state)
|
85 |
+
feat, mlp_state = self.encoder(x, mlp_state)
|
86 |
+
|
87 |
+
feat = torch.cat((feat, prev_mag), 1)
|
88 |
+
feat = self.joiner(feat)
|
89 |
+
prev_mag = torch.linalg.norm(feat, dim=1, ord=1, keepdims=True)
|
90 |
+
feat = feat + x
|
91 |
+
return feat, prev_mag, predictor_state, mlp_state
|
92 |
+
|
93 |
+
def train_dataloader(self):
|
94 |
+
return DataLoader(self.train_dataset, shuffle=False, batch_size=self.hparams.batch_size,
|
95 |
+
num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
|
96 |
+
|
97 |
+
def val_dataloader(self):
|
98 |
+
return DataLoader(self.val_dataset, shuffle=False, batch_size=self.hparams.batch_size,
|
99 |
+
num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
|
100 |
+
|
101 |
+
def training_step(self, batch, batch_idx):
|
102 |
+
x_in, y = batch
|
103 |
+
f_0 = x_in[:, :, 0:1, :]
|
104 |
+
x = x_in[:, :, 1:, :]
|
105 |
+
|
106 |
+
x = self(x)
|
107 |
+
x = torch.cat([f_0, x], dim=2)
|
108 |
+
|
109 |
+
loss = self.loss(x, y)
|
110 |
+
self.log('train_loss', loss, logger=True)
|
111 |
+
return loss
|
112 |
+
|
113 |
+
def validation_step(self, val_batch, batch_idx):
|
114 |
+
x, y = val_batch
|
115 |
+
f_0 = x[:, :, 0:1, :]
|
116 |
+
x_in = x[:, :, 1:, :]
|
117 |
+
|
118 |
+
pred = self(x_in)
|
119 |
+
pred = torch.cat([f_0, pred], dim=2)
|
120 |
+
|
121 |
+
loss = self.loss(pred, y)
|
122 |
+
self.window = self.window.to(pred.device)
|
123 |
+
pred = torch.view_as_complex(pred.permute(0, 2, 3, 1).contiguous())
|
124 |
+
pred = torch.istft(pred, self.window_size, self.hop_size, window=self.window)
|
125 |
+
y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous())
|
126 |
+
y = torch.istft(y, self.window_size, self.hop_size, window=self.window)
|
127 |
+
|
128 |
+
self.log('val_loss', loss, on_step=False, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
|
129 |
+
|
130 |
+
if batch_idx == 0:
|
131 |
+
i = torch.randint(0, x.shape[0], (1,)).item()
|
132 |
+
x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous())
|
133 |
+
x = torch.istft(x[i], self.window_size, self.hop_size, window=self.window)
|
134 |
+
|
135 |
+
self.trainer.logger.log_spectrogram(y[i], x, pred[i], self.current_epoch)
|
136 |
+
self.trainer.logger.log_audio(y[i], x, pred[i], self.current_epoch)
|
137 |
+
|
138 |
+
def test_step(self, test_batch, batch_idx):
|
139 |
+
inp, tar, inp_wav, tar_wav = test_batch
|
140 |
+
inp_wav = inp_wav.squeeze()
|
141 |
+
tar_wav = tar_wav.squeeze()
|
142 |
+
f_0 = inp[:, :, 0:1, :]
|
143 |
+
x = inp[:, :, 1:, :]
|
144 |
+
pred = self(x)
|
145 |
+
pred = torch.cat([f_0, pred], dim=2)
|
146 |
+
pred = torch.istft(pred.squeeze(0).permute(1, 2, 0), self.window_size, self.hop_size,
|
147 |
+
window=self.window.to(pred.device))
|
148 |
+
stoi = self.stoi(pred, tar_wav)
|
149 |
+
|
150 |
+
tar_wav = tar_wav.cpu().numpy()
|
151 |
+
inp_wav = inp_wav.cpu().numpy()
|
152 |
+
pred = pred.detach().cpu().numpy()
|
153 |
+
lsd, _ = LSD(tar_wav, pred)
|
154 |
+
|
155 |
+
if batch_idx in [5, 7, 9]:
|
156 |
+
sample_path = os.path.join(CONFIG.LOG.sample_path)
|
157 |
+
path = os.path.join(sample_path, 'sample_' + str(batch_idx))
|
158 |
+
visualize(tar_wav, inp_wav, pred, path)
|
159 |
+
sf.write(os.path.join(path, 'enhanced_output.wav'), pred, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
|
160 |
+
sf.write(os.path.join(path, 'lossy_input.wav'), inp_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
|
161 |
+
sf.write(os.path.join(path, 'target.wav'), tar_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
|
162 |
+
if CONFIG.DATA.sr != 16000:
|
163 |
+
pred = librosa.resample(pred, orig_sr=48000, target_sr=16000)
|
164 |
+
tar_wav = librosa.resample(tar_wav, orig_sr=48000, target_sr=16000, res_type='kaiser_fast')
|
165 |
+
ret = plcmos.run(pred, tar_wav)
|
166 |
+
pesq = self.pesq(torch.tensor(pred), torch.tensor(tar_wav))
|
167 |
+
metrics = {
|
168 |
+
"Intrusive": ret[0],
|
169 |
+
"Non-intrusive": ret[1],
|
170 |
+
'LSD': lsd,
|
171 |
+
'STOI': stoi,
|
172 |
+
'PESQ': pesq,
|
173 |
+
}
|
174 |
+
self.log_dict(metrics)
|
175 |
+
return metrics
|
176 |
+
|
177 |
+
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
178 |
+
f_0 = batch[:, :, 0:1, :]
|
179 |
+
x = batch[:, :, 1:, :]
|
180 |
+
pred = self(x)
|
181 |
+
pred = torch.cat([f_0, pred], dim=2)
|
182 |
+
pred = torch.istft(pred.squeeze(0).permute(1, 2, 0), self.window_size, self.hop_size,
|
183 |
+
window=self.window.to(pred.device))
|
184 |
+
return pred
|
185 |
+
|
186 |
+
def configure_optimizers(self):
|
187 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
188 |
+
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=CONFIG.TRAIN.patience,
|
189 |
+
factor=CONFIG.TRAIN.factor, verbose=True)
|
190 |
+
|
191 |
+
scheduler = {
|
192 |
+
'scheduler': lr_scheduler,
|
193 |
+
'reduce_on_plateau': True,
|
194 |
+
'monitor': 'val_loss'
|
195 |
+
}
|
196 |
+
return [optimizer], [scheduler]
|
197 |
+
|
198 |
+
|
199 |
+
class OnnxWrapper(pl.LightningModule):
|
200 |
+
def __init__(self, model, *args, **kwargs):
|
201 |
+
super().__init__(*args, **kwargs)
|
202 |
+
self.model = model
|
203 |
+
batch_size = 1
|
204 |
+
pred_states = torch.zeros((2, 1, batch_size, model.predictor.lstm_dim))
|
205 |
+
mlp_states = torch.zeros((model.encoder.depth, 2, 1, batch_size, model.encoder.dim))
|
206 |
+
mag = torch.zeros((batch_size, 1, model.hop_size, 1))
|
207 |
+
x = torch.randn(batch_size, model.hop_size + 1, 2)
|
208 |
+
self.sample = (x, mag, pred_states, mlp_states)
|
209 |
+
self.input_names = ['input', 'mag_in_cached_', 'pred_state_in_cached_', 'mlp_state_in_cached_']
|
210 |
+
self.output_names = ['output', 'mag_out_cached_', 'pred_state_out_cached_', 'mlp_state_out_cached_']
|
211 |
+
|
212 |
+
def forward(self, x, prev_mag, predictor_state=None, mlp_state=None):
|
213 |
+
x = x.permute(0, 2, 1).unsqueeze(-1)
|
214 |
+
f_0 = x[:, :, 0:1, :]
|
215 |
+
x = x[:, :, 1:, :]
|
216 |
+
|
217 |
+
output, prev_mag, predictor_state, mlp_state = self.model.forward_onnx(x, prev_mag, predictor_state, mlp_state)
|
218 |
+
output = torch.cat([f_0, output], dim=2)
|
219 |
+
output = output.squeeze(-1).permute(0, 2, 1)
|
220 |
+
return output, prev_mag, predictor_state, mlp_state
|