vixtts-demo / app.py
Uhhy's picture
Update app.py
c5feb16 verified
import csv
import datetime
import os
import re
import time
import uuid
from io import StringIO
import gradio as gr
import nltk
import numpy as np
import pyrubberband
import spaces
import torch
import torchaudio
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from nltk.sentiment import SentimentIntensityAnalyzer
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from vinorm import TTSnorm
nltk.download('vader_lexicon')
os.system("python -m unidic download")
os.system('nvidia-smi')
HF_TOKEN = None
api = HfApi(token=HF_TOKEN)
checkpoint_dir = "model/"
repo_id = "capleaf/viXTTS"
use_deepspeed = False
os.makedirs(checkpoint_dir, exist_ok=True)
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
files_in_dir = os.listdir(checkpoint_dir)
if not all(file in files_in_dir for file in required_files):
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=checkpoint_dir,
)
hf_hub_download(
repo_id="coqui/XTTS-v2",
filename="speakers_xtts.pth",
local_dir=checkpoint_dir,
)
xtts_config = os.path.join(checkpoint_dir, "config.json")
config = XttsConfig()
config.load_json(xtts_config)
MODEL = Xtts.init_from_config(config)
MODEL.load_checkpoint(
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
MODEL.to(device)
supported_languages = config.languages
if not "vi" in supported_languages:
supported_languages.append("vi")
if not "es-AR" in supported_languages:
supported_languages.append("es-AR")
def normalize_vietnamese_text(text):
text = (
TTSnorm(text, unknown=False, lower=False, rule=True)
.replace("..", ".")
.replace("!.", "!")
.replace("?.", "?")
.replace(" .", ".")
.replace(" ,", ",")
.replace('"', "")
.replace("'", "")
.replace("AI", "Ây Ai")
.replace("A.I", "Ây Ai")
)
return text
def analyze_sentiment(text):
sia = SentimentIntensityAnalyzer()
scores = sia.polarity_scores(text)
return scores['compound']
def change_pitch(audio_data, sampling_rate, sentiment):
semitones = sentiment * 2
return pyrubberband.pitch_shift(audio_data, sampling_rate, semitones)
def apply_distortion(audio_data, sentiment):
distortion_factor = abs(sentiment) * 0.5
return audio_data * (1 + distortion_factor * np.random.randn(len(audio_data)))
@spaces.GPU(duration=0)
def predict(
prompt,
language,
audio_file_pth,
normalize_text=True,
):
if language not in supported_languages:
metrics_text = gr.Warning(
f"El idioma seleccionado ({language}) no está disponible. Por favor, elige uno de la lista."
)
return (None, metrics_text)
speaker_wav = audio_file_pth
if len(prompt) < 2:
metrics_text = gr.Warning("Por favor, introduce un texto más largo.")
return (None, metrics_text)
try:
metrics_text = ""
t_latent = time.time()
try:
(
gpt_cond_latent,
speaker_embedding,
) = MODEL.get_conditioning_latents(
audio_path=speaker_wav,
gpt_cond_len=30,
gpt_cond_chunk_len=4,
max_ref_length=60,
)
except Exception as e:
print("Speaker encoding error", str(e))
metrics_text = gr.Warning(
"¿Has activado el micrófono? Parece que hay un problema con la referencia de audio."
)
return (None, metrics_text)
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
if normalize_text and language == "vi":
prompt = normalize_vietnamese_text(prompt)
sentiment = analyze_sentiment(prompt)
temperature = 0.75 + sentiment * 0.2
temperature = max(0.5, min(temperature, 1.0))
t0 = time.time()
out = MODEL.inference(
prompt,
language,
gpt_cond_latent,
speaker_embedding,
repetition_penalty=5.0,
temperature=temperature,
enable_text_splitting=True,
)
inference_time = time.time() - t0
metrics_text += (
f"Tiempo de generación de audio: {round(inference_time*1000)} milisegundos\n"
)
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
metrics_text += f"Factor de tiempo real (RTF): {real_time_factor:.2f}\n"
audio_data = np.array(out["wav"])
modified_audio = change_pitch(audio_data, 24000, sentiment)
modified_audio = apply_distortion(modified_audio, sentiment)
torchaudio.save("output.wav", torch.tensor(modified_audio).unsqueeze(0), 24000)
except RuntimeError as e:
if "device-side assert" in str(e):
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
error_data = [
error_time,
prompt,
language,
audio_file_pth,
]
error_data = [str(e) if type(e) != str else e for e in error_data]
write_io = StringIO()
csv.writer(write_io).writerows([error_data])
csv_upload = write_io.getvalue().encode()
filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
error_api = HfApi()
error_api.upload_file(
path_or_fileobj=csv_upload,
path_in_repo=filename,
repo_id="coqui/xtts-flagged-dataset",
repo_type="dataset",
)
speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
error_api = HfApi()
error_api.upload_file(
path_or_fileobj=speaker_wav,
path_in_repo=speaker_filename,
repo_id="coqui/xtts-flagged-dataset",
repo_type="dataset",
)
space = api.get_space_runtime(repo_id=repo_id)
if space.stage != "BUILDING":
api.restart_space(repo_id=repo_id)
else:
if "Failed to decode" in str(e):
metrics_text = gr.Warning(
metrics_text="Parece que hay un problema con la referencia de audio. ¿Has activado el micrófono?"
)
else:
metrics_text = gr.Warning(
"Se ha producido un error inesperado. Por favor, inténtalo de nuevo."
)
return (None, metrics_text)
return ("output.wav", metrics_text)
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
# viXTTS Demo ✨
"""
)
with gr.Column():
pass
with gr.Row():
with gr.Column():
input_text_gr = gr.Textbox(
label="Texto a convertir a voz",
value="Hola, soy un modelo de texto a voz.",
)
language_gr = gr.Dropdown(
label="Idioma",
choices=[
"es-AR",
"vi",
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh-cn",
"ja",
"ko",
"hu",
"hi",
],
max_choices=1,
value="es-AR",
)
normalize_text = gr.Checkbox(
label="Normalizar texto en vietnamita",
info="Solo aplicable al idioma vietnamita",
value=True,
)
ref_gr = gr.Audio(
label="Audio de referencia (opcional)",
type="filepath",
value="model/samples/nu-luu-loat.wav",
)
tts_button = gr.Button(
"Generar voz 🗣️🔥",
elem_id="send-btn",
visible=True,
variant="primary",
)
with gr.Column():
audio_gr = gr.Audio(label="Audio generado", autoplay=True)
out_text_gr = gr.Text(label="Métricas")
tts_button.click(
predict,
[
input_text_gr,
language_gr,
ref_gr,
normalize_text,
],
outputs=[audio_gr, out_text_gr],
api_name="predict",
)
demo.queue()
demo.launch(debug=True, show_api=True, share=True)