S3TVR-Demo / xtts_fine_tune /xtts_v2_model_utils.py
yalsaffar's picture
init
aa7cb02
raw
history blame contribute delete
No virus
6.6 kB
import os
import numpy as np
import traceback
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_fine_tuned_xtts_v2(config_path, checkpoint_path, reference_audio_path):
"""
Load the fine-tuned XTTS v2 model and compute speaker latents.
Args:
config_path (str): Path to the configuration file.
Example: "path/to/config.json"
checkpoint_path (str): Path to the checkpoint directory.
Example: "path/to/checkpoint/"
reference_audio_path (str): Path to the reference audio file.
Example: "path/to/reference.wav"
Returns:
tuple: A tuple containing the model, gpt_cond_latent, and speaker_embedding.
Example: (model, gpt_cond_latent, speaker_embedding)
"""
print("Loading model...")
config = XttsConfig()
config.load_json(config_path)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=checkpoint_path, use_deepspeed=True)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[reference_audio_path])
return model, gpt_cond_latent, speaker_embedding
def Inference(model, gpt_cond_latent, speaker_embedding,path_to_save,text, temperature=0.7):
"""
Perform inference using the fine-tuned XTTS v2 model.
Args:
model (Xtts): The XTTS v2 model.
Example: model, gpt_cond_latent, speaker_embedding = load_fine_tuned_xtts_v2(config_path, checkpoint_path, reference_audio_path)
gpt_cond_latent (torch.Tensor): GPT conditioning latent vectors.
speaker_embedding (torch.Tensor): Speaker embedding vectors.
path_to_save (str): Path to save the generated audio.
Example: "path/to/output.wav"
text (str): The input text for synthesis.
Example: "Hello, world!"
temperature (float, optional): Sampling temperature. Default is 0.7.
Example: 0.7
Returns:
None
"""
print("Inference...")
out = model.inference(
text,
gpt_cond_latent,
speaker_embedding,
temperature, # Add custom parameters here # 3
)
torchaudio.save(path_to_save, torch.tensor(out["wav"]).unsqueeze(0), 24000)
#model, gpt_cond_latent, speaker_embedding = load_fine_tuned_xtts_v2("C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/config.json", "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/best_model_72.pth", "old_man_segments/wavs/segment_10.wav")
class xtts_v2_Model():
"""
A class to handle training of the XTTS v2 model.
Args:
train_csv_path (str): Path to the training CSV file.
Example: "path/to/train.csv"
eval_csv_path (str): Path to the evaluation CSV file.
Example: "path/to/eval.csv"
num_epochs (int): Number of training epochs.
Example: 10
batch_size (int): Size of each training batch.
Example: 4
grad_acumm (int): Gradient accumulation steps.
Example: 1
output_path (str): Path to save the trained model outputs.
Example: "path/to/output/"
max_audio_length (int): Maximum allowed length of audio for training in seconds.
Example: 10
language (str, optional): Language of the audio files, either 'en' for English or 'es' for Spanish. Default is "en".
Example: "en"
"""
def __init__(self, train_csv_path, eval_csv_path, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language="en"):
self.train_csv_path = train_csv_path
self.eval_csv_path = eval_csv_path
self.num_epochs = num_epochs
self.batch_size = batch_size
self.grad_acumm = grad_acumm
self.output_path = output_path
self.max_audio_length = max_audio_length
self.language = language
self.config_path = None
self.original_xtts_checkpoint = None
self.vocab_file = None
self.exp_path = None
self.speaker_wav = None
def train_model(self):
"""
Train the XTTS v2 model.
Returns:
tuple: A tuple containing a status message, config_path, vocab_file, fine-tuned XTTS checkpoint, and speaker wav file.
Example: ("Model training done!", "path/to/config.json", "path/to/vocab.json", "path/to/best_model.pth", "path/to/speaker.wav")
"""
#clear_gpu_cache()
if not self.train_csv_path or not self.eval_csv_path:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
try:
# convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050)
self.config_path, self.original_xtts_checkpoint, self.vocab_file, self.exp_path, self.speaker_wav = train_gpt(self.language, self.num_epochs, self.batch_size, self.grad_acumm, self.train_csv_path, self.eval_csv_path, output_path=self.output_path, max_audio_length=max_audio_length)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
# copy original files to avoid parameters changes issues
os.system(f"cp {self.config_path} {self.exp_path}")
os.system(f"cp {self.vocab_file} {self.exp_path}")
ft_xtts_checkpoint = os.path.join(self.exp_path, "best_model.pth")
print("Model training done!")
#clear_gpu_cache()
return "Model training done!", self.config_path, self.vocab_file, ft_xtts_checkpoint, self.speaker_wav
# example
#train_meta = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/train.csv"
#eval_meta = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000/eval.csv"
#num_epochs = 10
#batch_size = 4
#grad_acumm = 1
#out_path = "C:/tmp/xtts_ft/run/training/GPT_XTTS_FT-April-02-2024_05+08PM-0000000"
#max_audio_length = 10
#lang = "en"
#xtts_v2 = xtts_v2_Model(train_meta, eval_meta, num_epochs, batch_size, grad_acumm, out_path, max_audio_length, lang)
#_, config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = xtts_v2_Model.train_model()