kartynnik commited on
Commit
4ea1916
1 Parent(s): 947703f

Load checkpoints on CPU

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -188,13 +188,13 @@ def main():
188
  net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
189
  hps.train.segment_size // hps.data.hop_length,
190
  **hps.model)
191
- net_g.load_state_dict(torch.load(a.ckpt))
192
  _ = net_g.eval()
193
 
194
  text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
195
  hps.train.segment_size // hps.data.hop_length,
196
  **hps_t2w2v.model)
197
- text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v))
198
  text2w2v.eval()
199
 
200
  speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
 
188
  net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
189
  hps.train.segment_size // hps.data.hop_length,
190
  **hps.model)
191
+ net_g.load_state_dict(load_checkpoint(a.ckpt, device))
192
  _ = net_g.eval()
193
 
194
  text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
195
  hps.train.segment_size // hps.data.hop_length,
196
  **hps_t2w2v.model)
197
+ text2w2v.load_state_dict(load_checkpoint(a.ckpt_text2w2v, device))
198
  text2w2v.eval()
199
 
200
  speechsr = SpeechSR48(h_sr48.data.n_mel_channels,