ddd commited on
Commit
a3411b4
1 Parent(s): 853fd97

fix hparam

Browse files
Files changed (1) hide show
  1. modules/diffsinger_midi/fs2.py +0 -109
modules/diffsinger_midi/fs2.py CHANGED
@@ -117,112 +117,3 @@ class FastSpeech2MIDI(FastSpeech2):
117
 
118
  return ret
119
 
120
- def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
121
- decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
122
- pitch_padding = mel2ph == 0
123
- if hparams['pitch_ar']:
124
- ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
125
- if f0 is None:
126
- f0 = pitch_pred[:, :, 0]
127
- else:
128
- ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
129
- if f0 is None:
130
- f0 = pitch_pred[:, :, 0]
131
- if hparams['use_uv'] and uv is None:
132
- uv = pitch_pred[:, :, 1] > 0
133
-
134
- # here f0_denorm for pitch prediction
135
- ret['f0_denorm'] = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
136
-
137
- # here f0_denorm for mel prediction
138
- if self.training:
139
- mask = torch.full(uv.shape, hparams.get('mask_uv_prob', 0.)).to(f0.device)
140
- masked_uv = torch.bernoulli(mask).bool().to(f0.device) # prob 的概率吐出一个随机uv.
141
- uv_masked = uv.bool() | masked_uv
142
- # print((uv.float()-uv_masked.float()).mean(dim=1))
143
- f0_denorm = denorm_f0(f0, uv_masked, hparams, pitch_padding=pitch_padding)
144
- else:
145
- f0_denorm = ret['f0_denorm']
146
-
147
- if pitch_padding is not None:
148
- f0[pitch_padding] = 0
149
-
150
- pitch = f0_to_coarse(f0_denorm) # start from 0
151
- pitch_embed = self.pitch_embed(pitch)
152
- return pitch_embed
153
-
154
-
155
- class FastSpeech2MIDIMasked(FastSpeech2MIDI):
156
- def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
157
- ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
158
- spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
159
- ret = {}
160
-
161
- midi_dur_embedding, slur_embedding = 0, 0
162
- if kwargs.get('midi_dur') is not None:
163
- midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
164
- if kwargs.get('is_slur') is not None:
165
- slur_embedding = self.is_slur_embed(kwargs['is_slur'])
166
- encoder_out = self.encoder(txt_tokens, 0, midi_dur_embedding, slur_embedding) # [B, T, C]
167
- src_nonpadding = (txt_tokens > 0).float()[:, :, None]
168
-
169
- # add ref style embed
170
- # Not implemented
171
- # variance encoder
172
- var_embed = 0
173
-
174
- # encoder_out_dur denotes encoder outputs for duration predictor
175
- # in speech adaptation, duration predictor use old speaker embedding
176
- if hparams['use_spk_embed']:
177
- spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
178
- elif hparams['use_spk_id']:
179
- spk_embed_id = spk_embed
180
- if spk_embed_dur_id is None:
181
- spk_embed_dur_id = spk_embed_id
182
- if spk_embed_f0_id is None:
183
- spk_embed_f0_id = spk_embed_id
184
- spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
185
- spk_embed_dur = spk_embed_f0 = spk_embed
186
- if hparams['use_split_spk_id']:
187
- spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
188
- spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
189
- else:
190
- spk_embed_dur = spk_embed_f0 = spk_embed = 0
191
-
192
- # add dur
193
- dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
194
-
195
- mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
196
-
197
- decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
198
-
199
- mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
200
- decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
201
-
202
- # expanded midi
203
- midi_embedding = self.midi_embed(kwargs['pitch_midi'])
204
- midi_embedding = F.pad(midi_embedding, [0, 0, 1, 0])
205
- midi_embedding = torch.gather(midi_embedding, 1, mel2ph_)
206
- print(midi_embedding.shape, decoder_inp.shape)
207
- midi_mask = torch.full(midi_embedding.shape, hparams.get('mask_uv_prob', 0.)).to(midi_embedding.device)
208
- midi_mask = 1 - torch.bernoulli(midi_mask).bool().to(midi_embedding.device) # prob 的概率吐出一个随机uv.
209
-
210
- tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
211
-
212
- decoder_inp += midi_embedding
213
- decoder_inp_origin = decoder_inp
214
- # add pitch and energy embed
215
- pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
216
- if hparams['use_pitch_embed']:
217
- pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
218
- decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
219
- if hparams['use_energy_embed']:
220
- decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
221
-
222
- ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
223
-
224
- if skip_decoder:
225
- return ret
226
- ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
227
-
228
- return ret
 
117
 
118
  return ret
119