lengyue233's picture
Init hf space integration
0a3525d verified
raw
history blame
No virus
14.2 kB
import itertools
import math
from typing import Any, Callable
import lightning as L
import torch
import torch.nn.functional as F
import wandb
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from matplotlib import pyplot as plt
from torch import nn
from fish_speech.models.vqgan.modules.discriminator import Discriminator
from fish_speech.models.vqgan.modules.wavenet import WaveNet
from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
class VQGAN(L.LightningModule):
def __init__(
self,
optimizer: Callable,
lr_scheduler: Callable,
encoder: WaveNet,
quantizer: nn.Module,
decoder: WaveNet,
discriminator: Discriminator,
vocoder: nn.Module,
encode_mel_transform: nn.Module,
gt_mel_transform: nn.Module,
weight_adv: float = 1.0,
weight_vq: float = 1.0,
weight_mel: float = 1.0,
sampling_rate: int = 44100,
freeze_encoder: bool = False,
):
super().__init__()
# Model parameters
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
# Modules
self.encoder = encoder
self.quantizer = quantizer
self.decoder = decoder
self.vocoder = vocoder
self.discriminator = discriminator
self.encode_mel_transform = encode_mel_transform
self.gt_mel_transform = gt_mel_transform
# A simple linear layer to project quality to condition channels
self.quality_projection = nn.Linear(1, 768)
# Freeze vocoder
for param in self.vocoder.parameters():
param.requires_grad = False
# Loss weights
self.weight_adv = weight_adv
self.weight_vq = weight_vq
self.weight_mel = weight_mel
# Other parameters
self.sampling_rate = sampling_rate
# Disable strict loading
self.strict_loading = False
# If encoder is frozen
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
for param in self.quantizer.parameters():
param.requires_grad = False
self.automatic_optimization = False
def on_save_checkpoint(self, checkpoint):
# Do not save vocoder
state_dict = checkpoint["state_dict"]
for name in list(state_dict.keys()):
if "vocoder" in name:
state_dict.pop(name)
def configure_optimizers(self):
optimizer_generator = self.optimizer_builder(
itertools.chain(
self.encoder.parameters(),
self.quantizer.parameters(),
self.decoder.parameters(),
self.quality_projection.parameters(),
)
)
optimizer_discriminator = self.optimizer_builder(
self.discriminator.parameters()
)
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
return (
{
"optimizer": optimizer_generator,
"lr_scheduler": {
"scheduler": lr_scheduler_generator,
"interval": "step",
"name": "optimizer/generator",
},
},
{
"optimizer": optimizer_discriminator,
"lr_scheduler": {
"scheduler": lr_scheduler_discriminator,
"interval": "step",
"name": "optimizer/discriminator",
},
},
)
def training_step(self, batch, batch_idx):
optim_g, optim_d = self.optimizers()
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
audios = audios.float()
audios = audios[:, None, :]
with torch.no_grad():
encoded_mels = self.encode_mel_transform(audios)
gt_mels = self.gt_mel_transform(audios)
quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
quality = quality.unsqueeze(-1)
mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
mel_masks_float_conv = mel_masks[:, None, :].float()
gt_mels = gt_mels * mel_masks_float_conv
encoded_mels = encoded_mels * mel_masks_float_conv
# Encode
encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
# Quantize
vq_result = self.quantizer(encoded_features)
loss_vq = getattr("vq_result", "loss", 0.0)
vq_recon_features = vq_result.z * mel_masks_float_conv
vq_recon_features = (
vq_recon_features + self.quality_projection(quality)[:, :, None]
)
# VQ Decode
gen_mel = (
self.decoder(
torch.randn_like(vq_recon_features) * mel_masks_float_conv,
condition=vq_recon_features,
)
* mel_masks_float_conv
)
# Discriminator
real_logits = self.discriminator(gt_mels)
fake_logits = self.discriminator(gen_mel.detach())
d_mask = F.interpolate(
mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
)
loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
loss_fake = avg_with_mask(fake_logits**2, d_mask)
loss_d = loss_real + loss_fake
self.log(
"train/discriminator/loss",
loss_d,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
# Discriminator backward
optim_d.zero_grad()
self.manual_backward(loss_d)
self.clip_gradients(
optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
)
optim_d.step()
# Mel Loss, applying l1, using a weighted sum
mel_distance = (
gen_mel - gt_mels
).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
loss_mel_mid_freq = avg_with_mask(
mel_distance[:, 40:70, :], mel_masks_float_conv
)
loss_mel_high_freq = avg_with_mask(
mel_distance[:, 70:, :], mel_masks_float_conv
)
loss_mel = (
loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
)
# Adversarial Loss
fake_logits = self.discriminator(gen_mel)
loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
# Total loss
loss = (
self.weight_vq * loss_vq
+ self.weight_mel * loss_mel
+ self.weight_adv * loss_adv
)
# Log losses
self.log(
"train/generator/loss",
loss,
on_step=True,
on_epoch=False,
prog_bar=True,
logger=True,
)
self.log(
"train/generator/loss_vq",
loss_vq,
on_step=True,
on_epoch=False,
prog_bar=False,
logger=True,
)
self.log(
"train/generator/loss_mel",
loss_mel,
on_step=True,
on_epoch=False,
prog_bar=False,
logger=True,
)
self.log(
"train/generator/loss_adv",
loss_adv,
on_step=True,
on_epoch=False,
prog_bar=False,
logger=True,
)
# Generator backward
optim_g.zero_grad()
self.manual_backward(loss)
self.clip_gradients(
optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
)
optim_g.step()
scheduler_g, scheduler_d = self.lr_schedulers()
scheduler_g.step()
scheduler_d.step()
def validation_step(self, batch: Any, batch_idx: int):
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
audios = audios.float()
audios = audios[:, None, :]
encoded_mels = self.encode_mel_transform(audios)
gt_mels = self.gt_mel_transform(audios)
mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
mel_masks_float_conv = mel_masks[:, None, :].float()
gt_mels = gt_mels * mel_masks_float_conv
encoded_mels = encoded_mels * mel_masks_float_conv
# Encode
encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
# Quantize
vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
vq_recon_features = (
vq_recon_features
+ self.quality_projection(
torch.ones(
vq_recon_features.shape[0], 1, device=vq_recon_features.device
)
* 2
)[:, :, None]
)
# VQ Decode
gen_aux_mels = (
self.decoder(
torch.randn_like(vq_recon_features) * mel_masks_float_conv,
condition=vq_recon_features,
)
* mel_masks_float_conv
)
loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
self.log(
"val/loss_mel",
loss_mel,
on_step=False,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=True,
)
recon_audios = self.vocoder(gt_mels)
gen_aux_audios = self.vocoder(gen_aux_mels)
# only log the first batch
if batch_idx != 0:
return
for idx, (
gt_mel,
gen_aux_mel,
audio,
gen_aux_audio,
recon_audio,
audio_len,
) in enumerate(
zip(
gt_mels,
gen_aux_mels,
audios.cpu().float(),
gen_aux_audios.cpu().float(),
recon_audios.cpu().float(),
audio_lengths,
)
):
if idx > 4:
break
mel_len = audio_len // self.gt_mel_transform.hop_length
image_mels = plot_mel(
[
gt_mel[:, :mel_len],
gen_aux_mel[:, :mel_len],
],
[
"Ground-Truth",
"Auxiliary",
],
)
if isinstance(self.logger, WandbLogger):
self.logger.experiment.log(
{
"reconstruction_mel": wandb.Image(image_mels, caption="mels"),
"wavs": [
wandb.Audio(
audio[0, :audio_len],
sample_rate=self.sampling_rate,
caption="gt",
),
wandb.Audio(
gen_aux_audio[0, :audio_len],
sample_rate=self.sampling_rate,
caption="aux",
),
wandb.Audio(
recon_audio[0, :audio_len],
sample_rate=self.sampling_rate,
caption="recon",
),
],
},
)
if isinstance(self.logger, TensorBoardLogger):
self.logger.experiment.add_figure(
f"sample-{idx}/mels",
image_mels,
global_step=self.global_step,
)
self.logger.experiment.add_audio(
f"sample-{idx}/wavs/gt",
audio[0, :audio_len],
self.global_step,
sample_rate=self.sampling_rate,
)
self.logger.experiment.add_audio(
f"sample-{idx}/wavs/gen",
gen_aux_audio[0, :audio_len],
self.global_step,
sample_rate=self.sampling_rate,
)
self.logger.experiment.add_audio(
f"sample-{idx}/wavs/recon",
recon_audio[0, :audio_len],
self.global_step,
sample_rate=self.sampling_rate,
)
plt.close(image_mels)
def encode(self, audios, audio_lengths):
audios = audios.float()
mels = self.encode_mel_transform(audios)
mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
mel_masks_float_conv = mel_masks[:, None, :].float()
mels = mels * mel_masks_float_conv
# Encode
encoded_features = self.encoder(mels) * mel_masks_float_conv
feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
return self.quantizer.encode(encoded_features), feature_lengths
def decode(self, indices, feature_lengths, return_audios=False):
factor = math.prod(self.quantizer.downsample_factor)
mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
mel_masks_float_conv = mel_masks[:, None, :].float()
z = self.quantizer.decode(indices) * mel_masks_float_conv
z = (
z
+ self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
:, :, None
]
)
gen_mel = (
self.decoder(
torch.randn_like(z) * mel_masks_float_conv,
condition=z,
)
* mel_masks_float_conv
)
if return_audios:
return self.vocoder(gen_mel)
return gen_mel