Hecheng0625 commited on
Commit
47464b7
1 Parent(s): ab8b9b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -223,12 +223,12 @@ def load_models():
223
  semantic_code_ckpt = hf_hub_download(
224
  "amphion/MaskGCT", filename="semantic_codec/model.safetensors"
225
  )
226
- codec_encoder_ckpt = hf_hub_download(
227
- "amphion/MaskGCT", filename="acoustic_codec/model.safetensors"
228
- )
229
- codec_decoder_ckpt = hf_hub_download(
230
- "amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors"
231
- )
232
  t2s_model_ckpt = hf_hub_download(
233
  "amphion/MaskGCT", filename="t2s_model/model.safetensors"
234
  )
@@ -240,8 +240,10 @@ def load_models():
240
  )
241
 
242
  safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
243
- safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)
244
- safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)
 
 
245
  safetensors.torch.load_model(t2s_model, t2s_model_ckpt)
246
  safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)
247
  safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)
 
223
  semantic_code_ckpt = hf_hub_download(
224
  "amphion/MaskGCT", filename="semantic_codec/model.safetensors"
225
  )
226
+ # codec_encoder_ckpt = hf_hub_download(
227
+ # "amphion/MaskGCT", filename="acoustic_codec/model.safetensors"
228
+ # )
229
+ # codec_decoder_ckpt = hf_hub_download(
230
+ # "amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors"
231
+ # )
232
  t2s_model_ckpt = hf_hub_download(
233
  "amphion/MaskGCT", filename="t2s_model/model.safetensors"
234
  )
 
240
  )
241
 
242
  safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
243
+ # safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)
244
+ # safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)
245
+ accelerate.load_checkpoint_and_dispatch(codec_encoder, "./acoustic_codec/model.safetensors")
246
+ accelerate.load_checkpoint_and_dispatch(codec_decoder, "./acoustic_codec/model_1.safetensors")
247
  safetensors.torch.load_model(t2s_model, t2s_model_ckpt)
248
  safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)
249
  safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)