Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import traceback | |
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt | |
import torch | |
import torchaudio | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_fine_tuned_xtts_v2(config_path, checkpoint_path, reference_audio_path): | |
""" | |
Load the fine-tuned XTTS v2 model and compute speaker latents. | |
Args: | |
config_path (str): Path to the configuration file. | |
Example: "path/to/config.json" | |
checkpoint_path (str): Path to the checkpoint directory. | |
Example: "path/to/checkpoint/" | |
reference_audio_path (str): Path to the reference audio file. | |
Example: "path/to/reference.wav" | |
Returns: | |
tuple: A tuple containing the model, gpt_cond_latent, and speaker_embedding. | |
Example: (model, gpt_cond_latent, speaker_embedding) | |
""" | |
print("Loading model...") | |
config = XttsConfig() | |
config.load_json(config_path) | |
model = Xtts.init_from_config(config) | |
model.load_checkpoint(config, checkpoint_dir=checkpoint_path, use_deepspeed=True) | |
model.cuda() | |
print("Computing speaker latents...") | |
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[reference_audio_path]) | |
return model, gpt_cond_latent, speaker_embedding | |
def Inference(model, gpt_cond_latent, speaker_embedding,path_to_save,text, temperature=0.7): | |
""" | |
Perform inference using the fine-tuned XTTS v2 model. | |
Args: | |
model (Xtts): The XTTS v2 model. | |
Example: model, gpt_cond_latent, speaker_embedding = load_fine_tuned_xtts_v2(config_path, checkpoint_path, reference_audio_path) | |
gpt_cond_latent (torch.Tensor): GPT conditioning latent vectors. | |
speaker_embedding (torch.Tensor): Speaker embedding vectors. | |
path_to_save (str): Path to save the generated audio. | |
Example: "path/to/output.wav" | |
text (str): The input text for synthesis. | |
Example: "Hello, world!" | |
temperature (float, optional): Sampling temperature. Default is 0.7. | |
Example: 0.7 | |
Returns: | |
None | |
""" | |
print("Inference...") | |
out = model.inference( | |
text, | |
gpt_cond_latent, | |
speaker_embedding, | |
temperature, # Add custom parameters here # 3 | |
) | |
torchaudio.save(path_to_save, torch.tensor(out["wav"]).unsqueeze(0), 24000) | |
#model, gpt_cond_latent, speaker_embedding = load_fine_tuned_xtts_v2("C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/config.json", "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/best_model_72.pth", "old_man_segments/wavs/segment_10.wav") | |
class xtts_v2_Model(): | |
""" | |
A class to handle training of the XTTS v2 model. | |
Args: | |
train_csv_path (str): Path to the training CSV file. | |
Example: "path/to/train.csv" | |
eval_csv_path (str): Path to the evaluation CSV file. | |
Example: "path/to/eval.csv" | |
num_epochs (int): Number of training epochs. | |
Example: 10 | |
batch_size (int): Size of each training batch. | |
Example: 4 | |
grad_acumm (int): Gradient accumulation steps. | |
Example: 1 | |
output_path (str): Path to save the trained model outputs. | |
Example: "path/to/output/" | |
max_audio_length (int): Maximum allowed length of audio for training in seconds. | |
Example: 10 | |
language (str, optional): Language of the audio files, either 'en' for English or 'es' for Spanish. Default is "en". | |
Example: "en" | |
""" | |
def __init__(self, train_csv_path, eval_csv_path, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language="en"): | |
self.train_csv_path = train_csv_path | |
self.eval_csv_path = eval_csv_path | |
self.num_epochs = num_epochs | |
self.batch_size = batch_size | |
self.grad_acumm = grad_acumm | |
self.output_path = output_path | |
self.max_audio_length = max_audio_length | |
self.language = language | |
self.config_path = None | |
self.original_xtts_checkpoint = None | |
self.vocab_file = None | |
self.exp_path = None | |
self.speaker_wav = None | |
def train_model(self): | |
""" | |
Train the XTTS v2 model. | |
Returns: | |
tuple: A tuple containing a status message, config_path, vocab_file, fine-tuned XTTS checkpoint, and speaker wav file. | |
Example: ("Model training done!", "path/to/config.json", "path/to/vocab.json", "path/to/best_model.pth", "path/to/speaker.wav") | |
""" | |
#clear_gpu_cache() | |
if not self.train_csv_path or not self.eval_csv_path: | |
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" | |
try: | |
# convert seconds to waveform frames | |
max_audio_length = int(max_audio_length * 22050) | |
self.config_path, self.original_xtts_checkpoint, self.vocab_file, self.exp_path, self.speaker_wav = train_gpt(self.language, self.num_epochs, self.batch_size, self.grad_acumm, self.train_csv_path, self.eval_csv_path, output_path=self.output_path, max_audio_length=max_audio_length) | |
except: | |
traceback.print_exc() | |
error = traceback.format_exc() | |
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" | |
# copy original files to avoid parameters changes issues | |
os.system(f"cp {self.config_path} {self.exp_path}") | |
os.system(f"cp {self.vocab_file} {self.exp_path}") | |
ft_xtts_checkpoint = os.path.join(self.exp_path, "best_model.pth") | |
print("Model training done!") | |
#clear_gpu_cache() | |
return "Model training done!", self.config_path, self.vocab_file, ft_xtts_checkpoint, self.speaker_wav | |
# example | |
#train_meta = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/train.csv" | |
#eval_meta = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/eval.csv" | |
#num_epochs = 10 | |
#batch_size = 4 | |
#grad_acumm = 1 | |
#out_path = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000" | |
#max_audio_length = 10 | |
#lang = "en" | |
#xtts_v2 = xtts_v2_Model(train_meta, eval_meta, num_epochs, batch_size, grad_acumm, out_path, max_audio_length, lang) | |
#_, config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = xtts_v2_Model.train_model() |