Spaces:
Sleeping
Sleeping
from xtts_fine_tune.xtts_v2_data_formattor import Data_Pipeline | |
from xtts_fine_tune.xtts_v2_model_utils import xtts_v2_Model | |
import sys | |
import time | |
#!/usr/bin/env python | |
def Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language): | |
""" | |
Train the XTTS V2 model with the given parameters. | |
This function initializes the data pipeline, checks the audio length, formats the data, and trains the XTTS V2 model. | |
Args: | |
audio_directory (str): Path to the directory containing audio files. | |
Example: "path/to/audio_files/" | |
num_epochs (int): Number of training epochs. | |
Example: 50 | |
batch_size (int): Size of each training batch. | |
Example: 16 | |
grad_acumm (int): Gradient accumulation steps. | |
Example: 4 | |
output_path (str): Path to save the trained model outputs. | |
Example: "path/to/output/" | |
max_audio_length (int): Maximum allowed length of audio for training in seconds. | |
Example: 3600 | |
language (str): Language of the audio files, either 'en' for English or 'es' for Spanish. | |
Example: "en" | |
Returns: | |
tuple: A tuple containing paths to the configuration file, vocabulary file, fine-tuned XTTS checkpoint, and speaker wav file. | |
Example: ("config_path.json", "vocab_path.json", "checkpoint.pth", "speaker.wav") | |
Example usage: | |
config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = Train_XTTS_V2( | |
"path/to/audio_files/", 50, 16, 4, "path/to/output/", 3600, "en" | |
) | |
""" | |
Data_class = Data_Pipeline(audio_directory, language) | |
length_audio = Data_class.get_combined_wav_lengths() | |
if length_audio > max_audio_length: | |
print("The audio is not long enough to be fine tuned. Waiting....") | |
time.sleep(20) | |
Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language) | |
# get the directory before the current one | |
audio_directory_parent = audio_directory.split("/") | |
audio_directory_parent = audio_directory_parent[:-1] | |
audio_directory_parent = "/".join(audio_directory_parent) | |
_, train_meta, eval_meta = Data_class.data_formatter(audio_directory_parent) | |
xtts_v2 = xtts_v2_Model(train_meta, eval_meta, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language) | |
_, config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = xtts_v2.train_model() | |
return config_path, vocab_path, ft_xtts_checkpoint, speaker_wav | |
if __name__ == "__main__": | |
audio_directory = sys.argv[1] | |
num_epochs = int(sys.argv[2]) | |
batch_size = int(sys.argv[3]) | |
grad_acumm = int(sys.argv[4]) | |
output_path = sys.argv[5] | |
max_audio_length = int(sys.argv[6]) | |
language = sys.argv[7] | |
config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language) | |
# Do something with the returned values | |
print(config_path, vocab_path, ft_xtts_checkpoint, speaker_wav) | |