Update models.py
Browse files
models.py
CHANGED
@@ -664,7 +664,7 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
664 |
|
665 |
def _load_model(model_config, model_path):
|
666 |
model = ASRCNN(**model_config)
|
667 |
-
params = torch.load(model_path, map_location='cpu')
|
668 |
model.load_state_dict(params)
|
669 |
return model
|
670 |
|
|
|
664 |
|
665 |
def _load_model(model_config, model_path):
|
666 |
model = ASRCNN(**model_config)
|
667 |
+
params = torch.load(model_path, map_location='cpu')['model']
|
668 |
model.load_state_dict(params)
|
669 |
return model
|
670 |
|