Spaces:
Running
Running
mrfakename
commited on
Commit
•
96269dd
1
Parent(s):
a8a1e8e
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/trainer.py +1 -1
- model/utils.py +2 -2
- scripts/eval_infer_batch.py +1 -1
- speech_edit.py +2 -1
model/trainer.py
CHANGED
@@ -140,7 +140,7 @@ class Trainer:
|
|
140 |
else:
|
141 |
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
142 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
143 |
-
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
|
144 |
|
145 |
if self.is_main:
|
146 |
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
|
|
140 |
else:
|
141 |
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
142 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
143 |
+
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
144 |
|
145 |
if self.is_main:
|
146 |
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
model/utils.py
CHANGED
@@ -509,7 +509,7 @@ def run_sim(args):
|
|
509 |
device = f"cuda:{rank}"
|
510 |
|
511 |
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
|
512 |
-
state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
|
513 |
model.load_state_dict(state_dict['model'], strict=False)
|
514 |
|
515 |
use_gpu=True if torch.cuda.is_available() else False
|
@@ -565,7 +565,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
|
565 |
from safetensors.torch import load_file
|
566 |
checkpoint = load_file(ckpt_path, device=device)
|
567 |
else:
|
568 |
-
checkpoint = torch.load(ckpt_path, map_location=device)
|
569 |
|
570 |
if use_ema == True:
|
571 |
ema_model = EMA(model, include_online_model = False).to(device)
|
|
|
509 |
device = f"cuda:{rank}"
|
510 |
|
511 |
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
|
512 |
+
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
513 |
model.load_state_dict(state_dict['model'], strict=False)
|
514 |
|
515 |
use_gpu=True if torch.cuda.is_available() else False
|
|
|
565 |
from safetensors.torch import load_file
|
566 |
checkpoint = load_file(ckpt_path, device=device)
|
567 |
else:
|
568 |
+
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
569 |
|
570 |
if use_ema == True:
|
571 |
ema_model = EMA(model, include_online_model = False).to(device)
|
scripts/eval_infer_batch.py
CHANGED
@@ -127,7 +127,7 @@ local = False
|
|
127 |
if local:
|
128 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
129 |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
130 |
-
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
131 |
vocos.load_state_dict(state_dict)
|
132 |
vocos.eval()
|
133 |
else:
|
|
|
127 |
if local:
|
128 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
129 |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
130 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
131 |
vocos.load_state_dict(state_dict)
|
132 |
vocos.eval()
|
133 |
else:
|
speech_edit.py
CHANGED
@@ -85,8 +85,9 @@ local = False
|
|
85 |
if local:
|
86 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
87 |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
88 |
-
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
89 |
vocos.load_state_dict(state_dict)
|
|
|
90 |
vocos.eval()
|
91 |
else:
|
92 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
|
|
85 |
if local:
|
86 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
87 |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
88 |
+
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
89 |
vocos.load_state_dict(state_dict)
|
90 |
+
|
91 |
vocos.eval()
|
92 |
else:
|
93 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|