Spaces:
Build error
Build error
import torch | |
# from inference.tts.fs import FastSpeechInfer | |
# from modules.tts.fs2_orig import FastSpeech2Orig | |
from inference.svs.base_svs_infer import BaseSVSInfer | |
from utils import load_ckpt | |
from utils.hparams import hparams | |
from modules.diff.shallow_diffusion_tts import GaussianDiffusion | |
from tasks.svs.diffsinger_task import DIFF_DECODERS | |
from modules.fastspeech.pe import PitchExtractor | |
import utils | |
class DiffSingerE2EInfer(BaseSVSInfer): | |
def build_model(self): | |
model = GaussianDiffusion( | |
phone_encoder=self.ph_encoder, | |
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), | |
timesteps=hparams['timesteps'], | |
K_step=hparams['K_step'], | |
loss_type=hparams['diff_loss_type'], | |
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], | |
) | |
model.eval() | |
load_ckpt(model, hparams['work_dir'], 'model') | |
if hparams.get('pe_enable') is not None and hparams['pe_enable']: | |
self.pe = PitchExtractor().to(self.device) | |
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) | |
self.pe.eval() | |
return model | |
def forward_model(self, inp): | |
sample = self.input_to_batch(inp) | |
txt_tokens = sample['txt_tokens'] # [B, T_t] | |
spk_id = sample.get('spk_ids') | |
with torch.no_grad(): | |
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True, | |
pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'], | |
is_slur=sample['is_slur']) | |
mel_out = output['mel_out'] # [B, T,80] | |
if hparams.get('pe_enable') is not None and hparams['pe_enable']: | |
f0_pred = self.pe(mel_out)['f0_denorm_pred'] # pe predict from Pred mel | |
else: | |
f0_pred = output['f0_denorm'] | |
wav_out = self.run_vocoder(mel_out, f0=f0_pred) | |
wav_out = wav_out.cpu().numpy() | |
return wav_out[0] | |
if __name__ == '__main__': | |
inp = { | |
'text': '小酒窝长睫毛AP是你最美的记号', | |
'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4', | |
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340', | |
'input_type': 'word' | |
} # user input: Chinese characters | |
inp = { | |
'text': '小酒窝长睫毛AP是你最美的记号', | |
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao', | |
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4', | |
'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340', | |
'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0', | |
'input_type': 'phoneme' | |
} # input like Opencpop dataset. | |
DiffSingerE2EInfer.example_run(inp) | |
# CUDA_VISIBLE_DEVICES=3 python inference/svs/ds_e2e.py --config egs/egs_bases/svs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel |