from typing import Optional, List, Union import torch from torch import nn from .encoders import TextEncoder, PosteriorEncoder from .generators import Generator from .nsf import NSFGenerator from .residuals import ResidualCouplingBlock from .utils import ( slice_on_last_dim, rand_slice_segments_on_last_dim, ) class SynthesizerTrnMsNSFsid(nn.Module): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Optional[Union[str, int]], encoder_dim: int, use_f0: bool, ): super().__init__() if isinstance(sr, str): sr = { "32k": 32000, "40k": 40000, "48k": 48000, }[sr] self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = float(p_dropout) self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels self.spk_embed_dim = spk_embed_dim self.enc_p = TextEncoder( encoder_dim, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), f0=use_f0, ) if use_f0: self.dec = NSFGenerator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, ) else: self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels ) self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) def remove_weight_norm(self): self.dec.remove_weight_norm() self.flow.remove_weight_norm() if hasattr(self, "enc_q"): self.enc_q.remove_weight_norm() def __prepare_scriptable__(self): for hook in self.dec._forward_pre_hooks.values(): # The hook we want to remove is an instance of WeightNorm class, so # normally we would do `if isinstance(...)` but this class is not accessible # because of shadowing, so we check the module name directly. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.dec) for hook in self.flow._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.flow) if hasattr(self, "enc_q"): for hook in self.enc_q._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.enc_q) return self @torch.jit.ignore def forward( self, phone: torch.Tensor, phone_lengths: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor, ds: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, ): # 这里ds是id,[bs,1] # print(1,pitch.shape)#[bs,t] g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) z_slice, ids_slice = rand_slice_segments_on_last_dim( z, y_lengths, self.segment_size ) if pitchf is not None: pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size) o = self.dec(z_slice, pitchf, g=g) else: o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, sid: torch.Tensor, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, # nsff0 skip_head: Optional[int] = None, return_length: Optional[int] = None, return_length2: Optional[int] = None, ): g = self.emb_g(sid).unsqueeze(-1) if skip_head is not None and return_length is not None: head = int(skip_head) length = int(return_length) flow_head = head - 24 if flow_head < 0: flow_head = 0 dec_head = head - flow_head m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) z = z[:, :, dec_head : dec_head + length] x_mask = x_mask[:, :, dec_head : dec_head + length] if pitchf is not None: pitchf = pitchf[:, head : head + length] else: m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask z = self.flow(z_p, x_mask, g=g, reverse=True) del z_p, m_p, logs_p if pitchf is not None: o = self.dec( z * x_mask, pitchf, g=g, n_res=return_length2, ) else: o = self.dec(z * x_mask, g=g, n_res=return_length2) del x_mask, z return o # , x_mask, (z, z_p, m_p, logs_p) class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 256, True, ) class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr: Union[str, int], ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, 768, True, ) class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr=None, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, 256, False, ) class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], spk_embed_dim: int, gin_channels: int, sr=None, ): super().__init__( spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, 768, False, )