Spaces:
Build error
Build error
import csv | |
import datetime | |
import os | |
import re | |
import time | |
import uuid | |
from io import StringIO | |
import gradio as gr | |
import nltk | |
import numpy as np | |
import pyrubberband | |
import spaces | |
import torch | |
import torchaudio | |
from huggingface_hub import HfApi, hf_hub_download, snapshot_download | |
from nltk.sentiment import SentimentIntensityAnalyzer | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from vinorm import TTSnorm | |
nltk.download('vader_lexicon') | |
os.system("python -m unidic download") | |
os.system('nvidia-smi') | |
HF_TOKEN = None | |
api = HfApi(token=HF_TOKEN) | |
checkpoint_dir = "model/" | |
repo_id = "capleaf/viXTTS" | |
use_deepspeed = False | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] | |
files_in_dir = os.listdir(checkpoint_dir) | |
if not all(file in files_in_dir for file in required_files): | |
snapshot_download( | |
repo_id=repo_id, | |
repo_type="model", | |
local_dir=checkpoint_dir, | |
) | |
hf_hub_download( | |
repo_id="coqui/XTTS-v2", | |
filename="speakers_xtts.pth", | |
local_dir=checkpoint_dir, | |
) | |
xtts_config = os.path.join(checkpoint_dir, "config.json") | |
config = XttsConfig() | |
config.load_json(xtts_config) | |
MODEL = Xtts.init_from_config(config) | |
MODEL.load_checkpoint( | |
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed | |
) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
MODEL.to(device) | |
supported_languages = config.languages | |
if not "vi" in supported_languages: | |
supported_languages.append("vi") | |
if not "es-AR" in supported_languages: | |
supported_languages.append("es-AR") | |
def normalize_vietnamese_text(text): | |
text = ( | |
TTSnorm(text, unknown=False, lower=False, rule=True) | |
.replace("..", ".") | |
.replace("!.", "!") | |
.replace("?.", "?") | |
.replace(" .", ".") | |
.replace(" ,", ",") | |
.replace('"', "") | |
.replace("'", "") | |
.replace("AI", "Ây Ai") | |
.replace("A.I", "Ây Ai") | |
) | |
return text | |
def analyze_sentiment(text): | |
sia = SentimentIntensityAnalyzer() | |
scores = sia.polarity_scores(text) | |
return scores['compound'] | |
def change_pitch(audio_data, sampling_rate, sentiment): | |
semitones = sentiment * 2 | |
return pyrubberband.pitch_shift(audio_data, sampling_rate, semitones) | |
def apply_distortion(audio_data, sentiment): | |
distortion_factor = abs(sentiment) * 0.5 | |
return audio_data * (1 + distortion_factor * np.random.randn(len(audio_data))) | |
def predict( | |
prompt, | |
language, | |
audio_file_pth, | |
normalize_text=True, | |
): | |
if language not in supported_languages: | |
metrics_text = gr.Warning( | |
f"El idioma seleccionado ({language}) no está disponible. Por favor, elige uno de la lista." | |
) | |
return (None, metrics_text) | |
speaker_wav = audio_file_pth | |
if len(prompt) < 2: | |
metrics_text = gr.Warning("Por favor, introduce un texto más largo.") | |
return (None, metrics_text) | |
try: | |
metrics_text = "" | |
t_latent = time.time() | |
try: | |
( | |
gpt_cond_latent, | |
speaker_embedding, | |
) = MODEL.get_conditioning_latents( | |
audio_path=speaker_wav, | |
gpt_cond_len=30, | |
gpt_cond_chunk_len=4, | |
max_ref_length=60, | |
) | |
except Exception as e: | |
print("Speaker encoding error", str(e)) | |
metrics_text = gr.Warning( | |
"¿Has activado el micrófono? Parece que hay un problema con la referencia de audio." | |
) | |
return (None, metrics_text) | |
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt) | |
if normalize_text and language == "vi": | |
prompt = normalize_vietnamese_text(prompt) | |
sentiment = analyze_sentiment(prompt) | |
temperature = 0.75 + sentiment * 0.2 | |
temperature = max(0.5, min(temperature, 1.0)) | |
t0 = time.time() | |
out = MODEL.inference( | |
prompt, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
repetition_penalty=5.0, | |
temperature=temperature, | |
enable_text_splitting=True, | |
) | |
inference_time = time.time() - t0 | |
metrics_text += ( | |
f"Tiempo de generación de audio: {round(inference_time*1000)} milisegundos\n" | |
) | |
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 | |
metrics_text += f"Factor de tiempo real (RTF): {real_time_factor:.2f}\n" | |
audio_data = np.array(out["wav"]) | |
modified_audio = change_pitch(audio_data, 24000, sentiment) | |
modified_audio = apply_distortion(modified_audio, sentiment) | |
torchaudio.save("output.wav", torch.tensor(modified_audio).unsqueeze(0), 24000) | |
except RuntimeError as e: | |
if "device-side assert" in str(e): | |
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S") | |
error_data = [ | |
error_time, | |
prompt, | |
language, | |
audio_file_pth, | |
] | |
error_data = [str(e) if type(e) != str else e for e in error_data] | |
write_io = StringIO() | |
csv.writer(write_io).writerows([error_data]) | |
csv_upload = write_io.getvalue().encode() | |
filename = error_time + "_" + str(uuid.uuid4()) + ".csv" | |
error_api = HfApi() | |
error_api.upload_file( | |
path_or_fileobj=csv_upload, | |
path_in_repo=filename, | |
repo_id="coqui/xtts-flagged-dataset", | |
repo_type="dataset", | |
) | |
speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav" | |
error_api = HfApi() | |
error_api.upload_file( | |
path_or_fileobj=speaker_wav, | |
path_in_repo=speaker_filename, | |
repo_id="coqui/xtts-flagged-dataset", | |
repo_type="dataset", | |
) | |
space = api.get_space_runtime(repo_id=repo_id) | |
if space.stage != "BUILDING": | |
api.restart_space(repo_id=repo_id) | |
else: | |
if "Failed to decode" in str(e): | |
metrics_text = gr.Warning( | |
metrics_text="Parece que hay un problema con la referencia de audio. ¿Has activado el micrófono?" | |
) | |
else: | |
metrics_text = gr.Warning( | |
"Se ha producido un error inesperado. Por favor, inténtalo de nuevo." | |
) | |
return (None, metrics_text) | |
return ("output.wav", metrics_text) | |
with gr.Blocks(analytics_enabled=False) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
# viXTTS Demo ✨ | |
""" | |
) | |
with gr.Column(): | |
pass | |
with gr.Row(): | |
with gr.Column(): | |
input_text_gr = gr.Textbox( | |
label="Texto a convertir a voz", | |
value="Hola, soy un modelo de texto a voz.", | |
) | |
language_gr = gr.Dropdown( | |
label="Idioma", | |
choices=[ | |
"es-AR", | |
"vi", | |
"en", | |
"es", | |
"fr", | |
"de", | |
"it", | |
"pt", | |
"pl", | |
"tr", | |
"ru", | |
"nl", | |
"cs", | |
"ar", | |
"zh-cn", | |
"ja", | |
"ko", | |
"hu", | |
"hi", | |
], | |
max_choices=1, | |
value="es-AR", | |
) | |
normalize_text = gr.Checkbox( | |
label="Normalizar texto en vietnamita", | |
info="Solo aplicable al idioma vietnamita", | |
value=True, | |
) | |
ref_gr = gr.Audio( | |
label="Audio de referencia (opcional)", | |
type="filepath", | |
value="model/samples/nu-luu-loat.wav", | |
) | |
tts_button = gr.Button( | |
"Generar voz 🗣️🔥", | |
elem_id="send-btn", | |
visible=True, | |
variant="primary", | |
) | |
with gr.Column(): | |
audio_gr = gr.Audio(label="Audio generado", autoplay=True) | |
out_text_gr = gr.Text(label="Métricas") | |
tts_button.click( | |
predict, | |
[ | |
input_text_gr, | |
language_gr, | |
ref_gr, | |
normalize_text, | |
], | |
outputs=[audio_gr, out_text_gr], | |
api_name="predict", | |
) | |
demo.queue() | |
demo.launch(debug=True, show_api=True, share=True) |