Spaces:
No application file
No application file
from argparse import ArgumentParser | |
import matplotlib.pyplot as plt | |
import pytorch_lightning as pl | |
import torch | |
import wandb | |
from loguru import logger | |
from mmengine import Config | |
from mmengine.optim import OPTIMIZERS | |
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger | |
from torch.utils.data import DataLoader | |
from fish_diffusion.archs.diffsinger import DiffSinger | |
from fish_diffusion.datasets import DATASETS | |
from fish_diffusion.datasets.repeat import RepeatDataset | |
from fish_diffusion.utils.scheduler import LR_SCHEUDLERS | |
from fish_diffusion.utils.viz import viz_synth_sample | |
from fish_diffusion.vocoders import VOCODERS | |
class FishDiffusion(pl.LightningModule): | |
def __init__(self, config): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = DiffSinger(config.model) | |
self.config = config | |
# 音频编码器, 将梅尔谱转换为音频 | |
self.vocoder = VOCODERS.build(config.model.vocoder) | |
self.vocoder.freeze() | |
def configure_optimizers(self): | |
self.config.optimizer.params = self.parameters() | |
optimizer = OPTIMIZERS.build(self.config.optimizer) | |
self.config.scheduler.optimizer = optimizer | |
scheduler = LR_SCHEUDLERS.build(self.config.scheduler) | |
return [optimizer], dict(scheduler=scheduler, interval="step") | |
def _step(self, batch, batch_idx, mode): | |
assert batch["pitches"].shape[1] == batch["mels"].shape[1] | |
pitches = batch["pitches"].clone() | |
batch_size = batch["speakers"].shape[0] | |
output = self.model( | |
speakers=batch["speakers"], | |
contents=batch["contents"], | |
src_lens=batch["content_lens"], | |
max_src_len=batch["max_content_len"], | |
mels=batch["mels"], | |
mel_lens=batch["mel_lens"], | |
max_mel_len=batch["max_mel_len"], | |
pitches=batch["pitches"], | |
) | |
self.log(f"{mode}_loss", output["loss"], batch_size=batch_size, sync_dist=True) | |
if mode != "valid": | |
return output["loss"] | |
x = self.model.diffusion(output["features"]) | |
for idx, (gt_mel, gt_pitch, predict_mel, predict_mel_len) in enumerate( | |
zip(batch["mels"], pitches, x, batch["mel_lens"]) | |
): | |
image_mels, wav_reconstruction, wav_prediction = viz_synth_sample( | |
gt_mel=gt_mel, | |
gt_pitch=gt_pitch, | |
predict_mel=predict_mel, | |
predict_mel_len=predict_mel_len, | |
vocoder=self.vocoder, | |
return_image=False, | |
) | |
wav_reconstruction = wav_reconstruction.to(torch.float32).cpu().numpy() | |
wav_prediction = wav_prediction.to(torch.float32).cpu().numpy() | |
# WanDB logger | |
if isinstance(self.logger, WandbLogger): | |
self.logger.experiment.log( | |
{ | |
f"reconstruction_mel": wandb.Image(image_mels, caption="mels"), | |
f"wavs": [ | |
wandb.Audio( | |
wav_reconstruction, | |
sample_rate=44100, | |
caption=f"reconstruction (gt)", | |
), | |
wandb.Audio( | |
wav_prediction, | |
sample_rate=44100, | |
caption=f"prediction", | |
), | |
], | |
}, | |
) | |
# TensorBoard logger | |
if isinstance(self.logger, TensorBoardLogger): | |
self.logger.experiment.add_figure( | |
f"sample-{idx}/mels", | |
image_mels, | |
global_step=self.global_step, | |
) | |
self.logger.experiment.add_audio( | |
f"sample-{idx}/wavs/gt", | |
wav_reconstruction, | |
self.global_step, | |
sample_rate=44100, | |
) | |
self.logger.experiment.add_audio( | |
f"sample-{idx}/wavs/prediction", | |
wav_prediction, | |
self.global_step, | |
sample_rate=44100, | |
) | |
if isinstance(image_mels, plt.Figure): | |
plt.close(image_mels) | |
return output["loss"] | |
def training_step(self, batch, batch_idx): | |
return self._step(batch, batch_idx, mode="train") | |
def validation_step(self, batch, batch_idx): | |
return self._step(batch, batch_idx, mode="valid") | |
if __name__ == "__main__": | |
pl.seed_everything(42, workers=True) | |
parser = ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--resume", type=str, default=None) | |
parser.add_argument( | |
"--tensorboard", | |
action="store_true", | |
default=False, | |
help="Use tensorboard logger, default is wandb.", | |
) | |
parser.add_argument("--resume-id", type=str, default=None, help="Wandb run id.") | |
parser.add_argument("--entity", type=str, default=None, help="Wandb entity.") | |
parser.add_argument("--name", type=str, default=None, help="Wandb run name.") | |
parser.add_argument( | |
"--pretrained", type=str, default=None, help="Pretrained model." | |
) | |
parser.add_argument( | |
"--only-train-speaker-embeddings", | |
action="store_true", | |
default=False, | |
help="Only train speaker embeddings.", | |
) | |
args = parser.parse_args() | |
cfg = Config.fromfile(args.config) | |
model = FishDiffusion(cfg) | |
# We only load the state_dict of the model, not the optimizer. | |
if args.pretrained: | |
state_dict = torch.load(args.pretrained, map_location="cpu") | |
if "state_dict" in state_dict: | |
state_dict = state_dict["state_dict"] | |
result = model.load_state_dict(state_dict, strict=False) | |
missing_keys = set(result.missing_keys) | |
unexpected_keys = set(result.unexpected_keys) | |
# Make sure incorrect keys are just noise predictor keys. | |
unexpected_keys = unexpected_keys - set( | |
i.replace(".naive_noise_predictor.", ".") for i in missing_keys | |
) | |
assert len(unexpected_keys) == 0 | |
if args.only_train_speaker_embeddings: | |
for name, param in model.named_parameters(): | |
if "speaker_encoder" not in name: | |
param.requires_grad = False | |
logger.info( | |
"Only train speaker embeddings, all other parameters are frozen." | |
) | |
logger = ( | |
TensorBoardLogger("logs", name=cfg.model.type) | |
if args.tensorboard | |
else WandbLogger( | |
project=cfg.model.type, | |
save_dir="logs", | |
log_model=True, | |
name=args.name, | |
entity=args.entity, | |
resume="must" if args.resume_id else False, | |
id=args.resume_id, | |
) | |
) | |
trainer = pl.Trainer( | |
logger=logger, | |
**cfg.trainer, | |
) | |
train_dataset = DATASETS.build(cfg.dataset.train) | |
train_loader = DataLoader( | |
train_dataset, | |
collate_fn=train_dataset.collate_fn, | |
**cfg.dataloader.train, | |
) | |
valid_dataset = DATASETS.build(cfg.dataset.valid) | |
valid_dataset = RepeatDataset( | |
valid_dataset, repeat=trainer.num_devices, collate_fn=valid_dataset.collate_fn | |
) | |
valid_loader = DataLoader( | |
valid_dataset, | |
collate_fn=valid_dataset.collate_fn, | |
**cfg.dataloader.valid, | |
) | |
trainer.fit(model, train_loader, valid_loader, ckpt_path=args.resume) | |