Simonlob's picture
Update app.py
ddbbef2 verified
from pathlib import Path
import argparse
import soundfile as sf
import torch
import io
import argparse
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.env import AttrDict
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import intersperse
import gradio as gr
import requests
import sentry_sdk
from sentry_sdk import capture_message, capture_exception
from datetime import datetime
sentry_sdk.init(
dsn="https://26932c88424672dea1c43f8536dd6a0d@o4508269800849408.ingest.us.sentry.io/4508269808254976",
)
def log_user_request(user_input: str, speaking_rate:float, device:str, execution_time:str):
"""
Logs the user's request and the current date in Sentry.
Args:
user_input (str): User entered text.
speaking_rate (float): Rate of speech.
device (str): Available device.
execution_time (str): Generation time in seconds
"""
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with sentry_sdk.push_scope() as scope:
scope.set_tag("event_type", "user_request")
scope.set_extra("user_input", user_input)
scope.set_extra("speaking_rate", str(speaking_rate))
scope.set_extra("device", device)
scope.set_extra("timestamp", current_datetime)
scope.set_extra("execution_time", f'{execution_time} sec')
sentry_sdk.capture_message("User request")
def download_file(url, save_path):
response = requests.get(url)
print(f'---Loading from URL: {url} ---')
with open(save_path, 'wb') as file:
file.write(response.content)
url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Akyl-AI-TTS-v2/checkpoint_epoch.669.ckpt' #'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt'
save_checkpoint_path = './checkpoints/checkpoint.ckpt'
url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1'
save_generator_path = './checkpoints/generator'
download_file(url_checkpoint, save_checkpoint_path)
download_file(url_generator, save_generator_path)
def log_event(input_text, log_file="usage_log.json"):
event_data = {'timestamp': datetime.now().isoformat(),
'text': input_text}
with open(log_file, "a") as file:
file.write(json.dumps(event_data) + "\n")
def load_matcha( checkpoint_path, device):
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
_ = model.eval()
return model
def load_hifigan(checkpoint_path, device):
h = AttrDict(v1)
hifigan = HiFiGAN(h).to(device)
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def load_vocoder(checkpoint_path, device):
vocoder = None
vocoder = load_hifigan(checkpoint_path, device)
denoiser = Denoiser(vocoder, mode="zeros")
return vocoder, denoiser
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
# log_event(text)
x = torch.tensor(
intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0),
dtype=torch.long,
device=device,
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
x_phones = sequence_to_text(x.squeeze(0).tolist())
print(f"[{i}] - Phonetised text: {x_phones}")
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones.replace('_q_ˌ_o_l_o_n_q_ˈ_ɑ_', '_q_ˌ_o_l_ˈ_o_n_q_ɑ_')}
def to_waveform(mel, vocoder, denoiser=None):
audio = vocoder(mel).clamp(-1, 1)
if denoiser is not None:
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
@torch.inference_mode()
def process_text_gradio(text):
output = process_text(1, text, device)
return output["x_phones"][1::2], output["x"], output["x_lengths"]
@torch.inference_mode()
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1):
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
output = model.synthesise(
text,
text_length,
n_timesteps=n_timesteps,
temperature=temperature,
spks=spk,
length_scale=length_scale,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
return output["waveform"].numpy()
def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1):
phones, text, text_lengths = process_text_gradio(text)
print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)))
return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = './checkpoints/checkpoint.ckpt'
vocoder_path = './checkpoints/generator'
model = load_matcha(model_path, device)
vocoder, denoiser = load_vocoder(vocoder_path, device)
def gen_tts(text, speaking_rate):
try:
start_time = datetime.now()
output = 22050, get_inference(text = text, length_scale = speaking_rate)
end_time = datetime.now()
execution_time = str((end_time - start_time).total_seconds())
log_user_request(text, speaking_rate, device, execution_time)
return output
except Exception as e:
sentry_sdk.capture_exception(e)
default_text = "Баарыңарга салам, менин атым Акылай."
css = """
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
flex: unset !important;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor: pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.25rem !important;
padding-bottom: 0.25rem !important;
right:0;
}
#share-btn * {
all: unset !important;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
}
img {
display: block;
margin: 0 auto;
width: 132px !important;
height: 132px !important;
}
"""
with gr.Blocks(css=css) as block:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Akyl-AI TTS
</h1>
</div>
</div>
"""
)
with gr.Row():
image_path = "./photo_2024-04-07_15-59-52.png"
gr.Image(image_path, label=None, width=132, height=132, show_label=False)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate")
run_button = gr.Button("Generate Audio", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="AkylAi-TTS", type="numpy", elem_id="audio_out")
inputs = [input_text, speaking_rate]
outputs = [audio_out]
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
block.queue()
block.launch(share=True)