train-tts / app.py
ovieyra21's picture
Update app.py
4bbc150 verified
import os
import torch
import gradio as gr
import yt_dlp as youtube_dl
import numpy as np
from datasets import Dataset, Audio
from scipy.io import wavfile
from huggingface_hub import whoami
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
import time
import demucs.api
def hello(profile: gr.OAuthProfile | None) -> str:
if profile is None:
return "I don't know you."
return f"Hello {profile.name}"
def list_organizations(oauth_token: gr.OAuthToken | None) -> str:
if oauth_token is None:
return "Please log in to list organizations."
org_names = [org["name"] for org in whoami(oauth_token.token)["orgs"]]
return f"You belong to {', '.join(org_names)}."
MODEL_NAME = "openai/whisper-large-v3"
DEMUCS_MODEL_NAME = "htdemucs_ft"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
device = 0 if torch.cuda.is_available() else "cpu"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
separator = demucs.api.Separator(model = DEMUCS_MODEL_NAME, )
def separate_vocal(path):
origin, separated = separator.separate_audio_file(path)
demucs.api.save_audio(separated["vocals"], path, samplerate=separator.samplerate)
return path
def transcribe(inputs_path, task, use_demucs, dataset_name, oauth_token: gr.OAuthToken | None, progress=gr.Progress()):
if inputs_path is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
if dataset_name is None:
raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
if oauth_token is None:
gr.Warning("Make sure to click and login before using this demo.")
return [["transcripts will appear here"]], ""
total_step = 4
current_step = 0
current_step += 1
progress((current_step, total_step), desc="Transcribe using Whisper.")
sampling_rate, inputs = wavfile.read(inputs_path)
out = pipe(inputs_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
text = out["text"]
current_step += 1
progress((current_step, total_step), desc="Merge chunks.")
chunks = naive_postprocess_whisper_chunks(out["chunks"], inputs, sampling_rate)
current_step += 1
progress((current_step, total_step), desc="Create dataset.")
transcripts = []
audios = []
with tempfile.TemporaryDirectory() as tmpdirname:
for i,chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for)")):
# TODO: make sure 1D or 2D?
arr = chunk["audio"]
path = os.path.join(tmpdirname, f"{i}.wav")
wavfile.write(path, sampling_rate, arr)
if use_demucs == "separate-audio":
# use demucs tp separate vocals
print(f"Separating vocals #{i}")
path = separate_vocal(path)
audios.append(path)
transcripts.append(chunk["text"])
dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
current_step += 1
progress((current_step, total_step), desc="Push dataset.")
dataset.push_to_hub(dataset_name, token=oauth_token.token if oauth_token else oauth_token)
return [[transcript] for transcript in transcripts], text
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
" </center>"
)
return HTML_str
def download_yt_audio(yt_url, filename):
info_loader = youtube_dl.YoutubeDL()
try:
info = info_loader.extract_info(yt_url, download=False)
except youtube_dl.utils.DownloadError as err:
raise gr.Error(str(err))
file_length = info["duration_string"]
file_h_m_s = file_length.split(":")
file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
if len(file_h_m_s) == 1:
file_h_m_s.insert(0, 0)
if len(file_h_m_s) == 2:
file_h_m_s.insert(0, 0)
file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
if file_length_s > YT_LENGTH_LIMIT_S:
yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
try:
ydl.download([yt_url])
except youtube_dl.utils.ExtractorError as err:
raise gr.Error(str(err))
def yt_transcribe(yt_url, task, use_demucs, dataset_name, oauth_token: gr.OAuthToken | None, max_filesize=75.0, dataset_sampling_rate = 24000,
progress=gr.Progress()):
if yt_url is None:
raise gr.Error("No youtube link submitted! Please put a working link.")
if dataset_name is None:
raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
total_step = 5
current_step = 0
html_embed_str = _return_yt_html_embed(yt_url)
if oauth_token is None:
gr.Warning("Make sure to click and login before using this demo.")
return html_embed_str, [["transcripts will appear here"]], ""
current_step += 1
progress((current_step, total_step), desc="Load video.")
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "video.mp4")
download_yt_audio(yt_url, filepath)
with open(filepath, "rb") as f:
inputs_path = f.read()
inputs = ffmpeg_read(inputs_path, pipe.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
current_step += 1
progress((current_step, total_step), desc="Transcribe using Whisper.")
out = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
text = out["text"]
inputs = ffmpeg_read(inputs_path, dataset_sampling_rate)
current_step += 1
progress((current_step, total_step), desc="Merge chunks.")
chunks = naive_postprocess_whisper_chunks(out["chunks"], inputs, dataset_sampling_rate)
current_step += 1
progress((current_step, total_step), desc="Create dataset.")
transcripts = []
audios = []
with tempfile.TemporaryDirectory() as tmpdirname:
for i,chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for).")):
# TODO: make sure 1D or 2D?
arr = chunk["audio"]
path = os.path.join(tmpdirname, f"{i}.wav")
wavfile.write(path, dataset_sampling_rate, arr)
if use_demucs == "separate-audio":
# use demucs tp separate vocals
print(f"Separating vocals #{i}")
path = separate_vocal(path)
audios.append(path)
transcripts.append(chunk["text"])
dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
current_step += 1
progress((current_step, total_step), desc="Push dataset.")
dataset.push_to_hub(dataset_name, token=oauth_token.token if oauth_token else oauth_token)
return html_embed_str, [[transcript] for transcript in transcripts], text
def naive_postprocess_whisper_chunks(chunks, audio_array, sampling_rate, stop_chars = ".!:;?", min_duration = 5):
# merge chunks as long as merged audio duration is lower than min_duration and that a stop character is not met
# return list of dictionnaries (text, audio)
# min duration is in seconds
min_duration = int(min_duration * sampling_rate)
new_chunks = []
while chunks:
current_chunk = chunks.pop(0)
begin, end = current_chunk["timestamp"]
begin, end = int(begin*sampling_rate), int(end*sampling_rate)
current_dur = end-begin
text = current_chunk["text"]
chunk_to_concat = [audio_array[begin:end]]
while chunks and (text[-1] not in stop_chars or (current_dur<min_duration)):
ch = chunks.pop(0)
begin, end = ch["timestamp"]
begin, end = int(begin*sampling_rate), int(end*sampling_rate)
current_dur += end-begin
text = "".join([text, ch["text"]])
# TODO: add silence ?
chunk_to_concat.append(audio_array[begin:end])
new_chunks.append({
"text": text.strip(),
"audio": np.concatenate(chunk_to_concat),
})
print(f"LENGTH CHUNK #{len(new_chunks)}: {current_dur/sampling_rate}s")
return new_chunks
css = """
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
gr.LoginButton()
gr.LogoutButton()
with gr.Tab("YouTube"):
gr.Markdown("Create your own TTS dataset using Youtube", elem_id="intro")
gr.Markdown(
"This demo allows use to create a text-to-speech dataset from an input audio snippet and push it to hub to keep track of it."
f"Demo uses the checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to automatically transcribe audio files"
" of arbitrary length. It then merge chunks of audio and push it to the hub."
)
with gr.Row():
with gr.Column():
audio_youtube = gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
task_youtube = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
cleaning_youtube = gr.Radio(["no-post-processing", "separate-audio"], label="Audio separation and cleaning (takes longer - use it if your samples are not cleaned (background noise and music))", value="separate-audio")
textbox_youtube = gr.Textbox(lines=1, placeholder="Place your new dataset name here. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.", label="Dataset name")
with gr.Row():
clear_youtube = gr.ClearButton([audio_youtube, task_youtube, cleaning_youtube, textbox_youtube])
submit_youtube = gr.Button("Submit")
with gr.Column():
html_youtube = gr.HTML()
dataset_youtube = gr.Dataset(label="Transcribed samples.",components=["text"], headers=["Transcripts"], samples=[["transcripts will appear here"]])
transcript_youtube = gr.Textbox(label="Transcription")
with gr.Tab("Microphone or Audio file"):
gr.Markdown("Create your own TTS dataset using your own recordings", elem_id="intro")
gr.Markdown(
"This demo allows use to create a text-to-speech dataset from an input audio snippet and push it to hub to keep track of it."
f"Demo uses the checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to automatically transcribe audio files"
" of arbitrary length. It then merge chunks of audio and push it to the hub."
)
with gr.Row():
with gr.Column():
audio_file = gr.Audio(type="filepath")
task_file = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
cleaning_file = gr.Radio(["no-post-processing", "separate-audio"], label="Audio separation and cleaning (takes longer - use it if your samples are not cleaned (background noise and music))", value="separate-audio")
textbox_file = gr.Textbox(lines=1, placeholder="Place your new dataset name here. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.", label="Dataset name")
with gr.Row():
clear_file = gr.ClearButton([audio_file, task_file, cleaning_file, textbox_file])
submit_file = gr.Button("Submit")
with gr.Column():
dataset_file = gr.Dataset(label="Transcribed samples.", components=["text"], headers=["Transcripts"], samples=[["transcripts will appear here"]])
transcript_file = gr.Textbox(label="Transcription")
submit_file.click(transcribe, inputs=[audio_file, task_file, cleaning_file, textbox_file], outputs=[dataset_file, transcript_file])
submit_youtube.click(yt_transcribe, inputs=[audio_youtube, task_youtube, cleaning_youtube, textbox_youtube], outputs=[html_youtube, dataset_youtube, transcript_youtube])
demo.launch(debug=True)