S3TVR-Demo / xtts_fine_tune /xtts_v2_fine_tuning.py
yalsaffar's picture
init
aa7cb02
from xtts_fine_tune.xtts_v2_data_formattor import Data_Pipeline
from xtts_fine_tune.xtts_v2_model_utils import xtts_v2_Model
import sys
import time
#!/usr/bin/env python
def Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language):
"""
Train the XTTS V2 model with the given parameters.
This function initializes the data pipeline, checks the audio length, formats the data, and trains the XTTS V2 model.
Args:
audio_directory (str): Path to the directory containing audio files.
Example: "path/to/audio_files/"
num_epochs (int): Number of training epochs.
Example: 50
batch_size (int): Size of each training batch.
Example: 16
grad_acumm (int): Gradient accumulation steps.
Example: 4
output_path (str): Path to save the trained model outputs.
Example: "path/to/output/"
max_audio_length (int): Maximum allowed length of audio for training in seconds.
Example: 3600
language (str): Language of the audio files, either 'en' for English or 'es' for Spanish.
Example: "en"
Returns:
tuple: A tuple containing paths to the configuration file, vocabulary file, fine-tuned XTTS checkpoint, and speaker wav file.
Example: ("config_path.json", "vocab_path.json", "checkpoint.pth", "speaker.wav")
Example usage:
config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = Train_XTTS_V2(
"path/to/audio_files/", 50, 16, 4, "path/to/output/", 3600, "en"
)
"""
Data_class = Data_Pipeline(audio_directory, language)
length_audio = Data_class.get_combined_wav_lengths()
if length_audio > max_audio_length:
print("The audio is not long enough to be fine tuned. Waiting....")
time.sleep(20)
Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language)
# get the directory before the current one
audio_directory_parent = audio_directory.split("/")
audio_directory_parent = audio_directory_parent[:-1]
audio_directory_parent = "/".join(audio_directory_parent)
_, train_meta, eval_meta = Data_class.data_formatter(audio_directory_parent)
xtts_v2 = xtts_v2_Model(train_meta, eval_meta, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language)
_, config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = xtts_v2.train_model()
return config_path, vocab_path, ft_xtts_checkpoint, speaker_wav
if __name__ == "__main__":
audio_directory = sys.argv[1]
num_epochs = int(sys.argv[2])
batch_size = int(sys.argv[3])
grad_acumm = int(sys.argv[4])
output_path = sys.argv[5]
max_audio_length = int(sys.argv[6])
language = sys.argv[7]
config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language)
# Do something with the returned values
print(config_path, vocab_path, ft_xtts_checkpoint, speaker_wav)