S3TVR-Demo / xtts_fine_tune /xtts_v2_data_formattor.py
yalsaffar's picture
init
aa7cb02
raw
history blame contribute delete
No virus
10.9 kB
import pandas as pd
import numpy as np
import os
import sys
from models.parakeet import parakeet_ctc_model, parakeet_ctc_process
from models.es_fastconformer import stt_es_model, stt_es_process
from pydub import AudioSegment
import json
class Data_Pipeline():
"""
A class to handle data processing and transcription using various models.
Args:
wav_dir (str): Directory containing the .wav files.
Example: "path/to/wav_files/"
lang (str): Language of the audio files, either 'en' for English or 'es' for Spanish.
Example: "en"
"""
def __init__(self, wav_dir, lang):
self.wav_dir = wav_dir
self.lang = lang
def get_wav_files(self):
"""
Retrieve all .wav files from the specified directory.
Returns:
list: A list of paths to the .wav files.
Example: ["path/to/file1.wav", "path/to/file2.wav"]
"""
self.wav_files = []
for root, dirs, files in os.walk(self.wav_dir):
for file in files:
if file.endswith('.wav'):
# make the path into the full path to the file
file = os.path.join(self.wav_dir, file)
self.wav_files.append(file)
return self.wav_files
def get_combined_wav_lengths(self):
"""
Calculate the total length of all .wav files in the directory.
Returns:
float: The total length of all .wav files in seconds.
Example: 123.45
"""
# returns a float number of the total length of all the wav files in the directory
total_length = 0
for file in self.wav_files:
if file.endswith('.wav'):
audio = AudioSegment.from_file(file)
total_length += len(audio)
return total_length/1000
def estimated_transcription_time(self):
"""
Estimate the transcription time based on the real-time factor.
Returns:
float: The estimated transcription time in seconds.
Example: 185.175
str: Error message if the language is not supported.
Example: "Error: Language not supported."
"""
# returns a float number of the estimated time based on the real time factor
# real time factor is 1.5
if self.lang == "en":
return self.get_combined_wav_lengths() * 1.5
elif self.lang == "es":
return self.get_combined_wav_lengths() * 1.5
return "Error: Language not supported."
def load_models(self):
"""
Load the appropriate ASR model based on the language.
Returns:
object: The loaded ASR model.
Example: <nemo.collections.asr.models.EncDecCTCModelBPE object>
str: Error message if the language is not supported.
Example: "Error: Language not supported."
"""
if self.lang == "en":
self.parakeet_model = parakeet_ctc_model()
return self.parakeet_model
elif self.lang == "es":
self.es_model = stt_es_model()
return self.es_model
return "Error: Language not supported."
def en_transcribe(self, audio_file):
"""
Transcribe an English audio file using the Parakeet CTC model.
Args:
audio_file (str): Path to the audio file.
Example: "path/to/audio_file.wav"
Returns:
list: A list containing the transcribed text.
Example: ["transcribed text"]
"""
text = parakeet_ctc_process(self.parakeet_model, audio_file)
return text
def es_transcribe(self, audio_file):
"""
Transcribe a Spanish audio file using the FastConformer model.
Args:
audio_file (str): Path to the audio file.
Example: "path/to/audio_file.wav"
Returns:
list: A list containing the transcribed text.
Example: ["transcribed text"]
"""
text = stt_es_process(self.es_model, audio_file)
return text
def read_transcriptions(self, json_path):
"""
Read transcriptions from a JSON file.
Args:
json_path (str): Path to the JSON file.
Example: "path/to/data.json"
Returns:
dict: The data read from the JSON file.
Example: {"text": ["text1", "text2"], "original_path": ["path1", "path2"]}
"""
# read the json file
with open(json_path) as f:
self.data = json.load(f)
return self.data
def get_transcription(self, file_path):
"""
Get the transcription for a specific file from the JSON data.
Args:
file_path (str): Path to the original audio file.
Example: "path/to/audio_file.wav"
Returns:
str: The transcription for the specified file.
Example: "This is the transcription."
str: Error message if no transcription is found.
Example: "Error: No transcription found."
"""
# the json file has the following keys: text, original_path, path_to_save, language, order, original_text
# get the "original_text" of the element that has the "original_path" equal to the file_path
for i in range(len(self.data['original_path'])):
if self.data['original_path'][i] == file_path:
return self.data['original_text'][i]
return "Error: No transcription found."
def data_formatter_with_models(self):
"""
Format data by transcribing audio files using the appropriate models.
Returns:
pd.DataFrame: A DataFrame containing the transcriptions.
Example: pd.DataFrame({'wav_file': ["file1", "file2"], 'transcription': ["text1", "text2"], 'transcription2': ["text1", "text2"], 'speaker_name': ["user0", "user0"]})
str: Error message if the language is not supported.
Example: "Error: Language not supported."
"""
self.transcriptions_df = pd.DataFrame(columns = ['wav_file', 'transcription','transcription2' ])
if self.lang == "en":
self.load_models()
self.get_wav_files()
for file in self.wav_files:
if file.endswith('.wav'):
transcription = parakeet_ctc_process(self.parakeet_model, file)
# append transcriptions_df with the wav_file and transcription
self.transcriptions_df = self.transcriptions_df.append({'wav_file': file, 'transcription': transcription[0], 'transcription2': transcription[0],'speaker_name': "user0"}, ignore_index=True)
return self.transcriptions_df
elif self.lang == "es":
self.load_models()
self.get_wav_files()
for file in self.wav_files:
if file.endswith('.wav'):
self.transcriptions_df = stt_es_process(self.es_model, file)
# make the path into the full path to the file
file = os.path.join(self.wav_dir, file)
# append transcriptions_df with the wav_file and transcription
self.transcriptions_df = self.transcriptions_df.append({'wav_file': file, 'transcription': transcription[0], 'transcription2': transcription[0], 'speaker_name':"user0"}, ignore_index=True)
return self.transcriptions_df
return "Error: Language not supported."
def data_formatter_without_models(self):
"""
Format data by retrieving transcriptions from a JSON file and transcribing any missing data.
Returns:
pd.DataFrame: A DataFrame containing the transcriptions.
Example: pd.DataFrame({'wav_file': ["file1", "file2"], 'transcription': ["text1", "text2"], 'transcription2': ["text1", "text2"], 'speaker_name': ["user0", "user0"]})
"""
self.transcriptions_df = pd.DataFrame(columns = ['wav_file', 'transcription','transcription2' ])
self.get_wav_files()
for file in self.wav_files:
if file.endswith('.wav'):
transcription = self.get_transcription(file)
if transcription == "Error: No transcription found." and self.lang == "en":
transcription = parakeet_ctc_process(self.parakeet_model, file)
elif transcription == "Error: No transcription found." and self.lang == "es":
transcription = stt_es_process(self.es_model, file)
# make the path into the full path to the file
#file = os.path.join(self.wav_dir, file)
# append transcriptions_df with the wav_file and transcription
self.transcriptions_df = self.transcriptions_df.append({'wav_file': file, 'transcription': transcription[0], 'transcription2': transcription[0], 'speaker_name': "user0"}, ignore_index=True)
return self.transcriptions_df
def save_transcriptions(self, output_file):
"""
Save the transcriptions to CSV files, splitting into training and evaluation datasets.
Args:
output_file (str): Base path for the output CSV files.
Example: "path/to/output"
Returns:
tuple: A tuple containing a success message and paths to the training and evaluation CSV files.
Example: ("Data saved successfully.", "path/to/output_train.csv", "path/to/output_eval.csv")
"""
# split the data into two data, train and eval data
from sklearn.model_selection import train_test_split
train_data, eval_data = train_test_split(self.transcriptions_df, test_size=0.2, random_state=42)
# save the data into csv files
self.path_to_train_data = output_file + "_train.csv"
self.path_to_eval_data = output_file + "_eval.csv"
train_data.to_csv(output_file + "_train.csv" , index=False, sep='|')
eval_data.to_csv(output_file + "_eval.csv", index=False, sep='|')
return "Data saved successfully.", self.path_to_train_data, self.path_to_eval_data
def get_paths(self):
"""
Retrieve the paths to the training and evaluation CSV files and the .wav files directory.
Returns:
tuple: A tuple containing paths to the training data, evaluation data, and .wav files directory.
Example: ("path/to/train.csv", "path/to/eval.csv", "path/to/wav_files/")
"""
# csv files, wav files directory
return self.path_to_train_data, self.path_to_eval_data, self.wav_dir