|
from pathlib import Path |
|
import argparse |
|
import soundfile as sf |
|
import torch |
|
import io |
|
import argparse |
|
from matcha.hifigan.config import v1 |
|
from matcha.hifigan.denoiser import Denoiser |
|
from matcha.hifigan.env import AttrDict |
|
from matcha.hifigan.models import Generator as HiFiGAN |
|
from matcha.models.matcha_tts import MatchaTTS |
|
from matcha.text import sequence_to_text, text_to_sequence |
|
from matcha.utils.utils import intersperse |
|
import gradio as gr |
|
import requests |
|
import sentry_sdk |
|
from sentry_sdk import capture_message, capture_exception |
|
from datetime import datetime |
|
|
|
sentry_sdk.init( |
|
dsn="https://26932c88424672dea1c43f8536dd6a0d@o4508269800849408.ingest.us.sentry.io/4508269808254976", |
|
) |
|
|
|
def log_user_request(user_input: str, speaking_rate:float, device:str, execution_time:str): |
|
""" |
|
Logs the user's request and the current date in Sentry. |
|
Args: |
|
user_input (str): User entered text. |
|
speaking_rate (float): Rate of speech. |
|
device (str): Available device. |
|
execution_time (str): Generation time in seconds |
|
""" |
|
|
|
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
with sentry_sdk.push_scope() as scope: |
|
scope.set_tag("event_type", "user_request") |
|
scope.set_extra("user_input", user_input) |
|
scope.set_extra("speaking_rate", str(speaking_rate)) |
|
scope.set_extra("device", device) |
|
scope.set_extra("timestamp", current_datetime) |
|
scope.set_extra("execution_time", f'{execution_time} sec') |
|
|
|
sentry_sdk.capture_message("User request") |
|
|
|
def download_file(url, save_path): |
|
response = requests.get(url) |
|
print(f'---Loading from URL: {url} ---') |
|
with open(save_path, 'wb') as file: |
|
file.write(response.content) |
|
|
|
url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Akyl-AI-TTS-v2/checkpoint_epoch.669.ckpt' |
|
save_checkpoint_path = './checkpoints/checkpoint.ckpt' |
|
url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1' |
|
save_generator_path = './checkpoints/generator' |
|
|
|
download_file(url_checkpoint, save_checkpoint_path) |
|
download_file(url_generator, save_generator_path) |
|
|
|
def log_event(input_text, log_file="usage_log.json"): |
|
event_data = {'timestamp': datetime.now().isoformat(), |
|
'text': input_text} |
|
with open(log_file, "a") as file: |
|
file.write(json.dumps(event_data) + "\n") |
|
|
|
def load_matcha( checkpoint_path, device): |
|
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device) |
|
_ = model.eval() |
|
return model |
|
|
|
def load_hifigan(checkpoint_path, device): |
|
h = AttrDict(v1) |
|
hifigan = HiFiGAN(h).to(device) |
|
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"]) |
|
_ = hifigan.eval() |
|
hifigan.remove_weight_norm() |
|
return hifigan |
|
|
|
def load_vocoder(checkpoint_path, device): |
|
vocoder = None |
|
vocoder = load_hifigan(checkpoint_path, device) |
|
denoiser = Denoiser(vocoder, mode="zeros") |
|
return vocoder, denoiser |
|
|
|
def process_text(i: int, text: str, device: torch.device): |
|
print(f"[{i}] - Input text: {text}") |
|
|
|
x = torch.tensor( |
|
intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0), |
|
dtype=torch.long, |
|
device=device, |
|
)[None] |
|
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) |
|
x_phones = sequence_to_text(x.squeeze(0).tolist()) |
|
print(f"[{i}] - Phonetised text: {x_phones}") |
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones.replace('_q_ˌ_o_l_o_n_q_ˈ_ɑ_', '_q_ˌ_o_l_ˈ_o_n_q_ɑ_')} |
|
|
|
def to_waveform(mel, vocoder, denoiser=None): |
|
audio = vocoder(mel).clamp(-1, 1) |
|
if denoiser is not None: |
|
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() |
|
return audio.cpu().squeeze() |
|
|
|
@torch.inference_mode() |
|
def process_text_gradio(text): |
|
output = process_text(1, text, device) |
|
return output["x_phones"][1::2], output["x"], output["x_lengths"] |
|
|
|
@torch.inference_mode() |
|
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1): |
|
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None |
|
output = model.synthesise( |
|
text, |
|
text_length, |
|
n_timesteps=n_timesteps, |
|
temperature=temperature, |
|
spks=spk, |
|
length_scale=length_scale, |
|
) |
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) |
|
return output["waveform"].numpy() |
|
|
|
def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1): |
|
phones, text, text_lengths = process_text_gradio(text) |
|
print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk))) |
|
return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model_path = './checkpoints/checkpoint.ckpt' |
|
vocoder_path = './checkpoints/generator' |
|
model = load_matcha(model_path, device) |
|
vocoder, denoiser = load_vocoder(vocoder_path, device) |
|
|
|
def gen_tts(text, speaking_rate): |
|
try: |
|
start_time = datetime.now() |
|
output = 22050, get_inference(text = text, length_scale = speaking_rate) |
|
end_time = datetime.now() |
|
execution_time = str((end_time - start_time).total_seconds()) |
|
|
|
log_user_request(text, speaking_rate, device, execution_time) |
|
return output |
|
except Exception as e: |
|
sentry_sdk.capture_exception(e) |
|
|
|
default_text = "Баарыңарга салам, менин атым Акылай." |
|
|
|
css = """ |
|
#share-btn-container { |
|
display: flex; |
|
padding-left: 0.5rem !important; |
|
padding-right: 0.5rem !important; |
|
background-color: #000000; |
|
justify-content: center; |
|
align-items: center; |
|
border-radius: 9999px !important; |
|
width: 13rem; |
|
margin-top: 10px; |
|
margin-left: auto; |
|
flex: unset !important; |
|
} |
|
#share-btn { |
|
all: initial; |
|
color: #ffffff; |
|
font-weight: 600; |
|
cursor: pointer; |
|
font-family: 'IBM Plex Sans', sans-serif; |
|
margin-left: 0.5rem !important; |
|
padding-top: 0.25rem !important; |
|
padding-bottom: 0.25rem !important; |
|
right:0; |
|
} |
|
#share-btn * { |
|
all: unset !important; |
|
} |
|
#share-btn-container div:nth-child(-n+2){ |
|
width: auto !important; |
|
min-height: 0px !important; |
|
} |
|
#share-btn-container .wrap { |
|
display: none !important; |
|
} |
|
} |
|
img { |
|
display: block; |
|
margin: 0 auto; |
|
width: 132px !important; |
|
height: 132px !important; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as block: |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center; max-width: 700px; margin: 0 auto;"> |
|
<div |
|
style=" |
|
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem; |
|
" |
|
> |
|
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;"> |
|
Akyl-AI TTS |
|
</h1> |
|
</div> |
|
</div> |
|
""" |
|
|
|
) |
|
with gr.Row(): |
|
image_path = "./photo_2024-04-07_15-59-52.png" |
|
gr.Image(image_path, label=None, width=132, height=132, show_label=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text") |
|
speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate") |
|
|
|
|
|
run_button = gr.Button("Generate Audio", variant="primary") |
|
with gr.Column(): |
|
audio_out = gr.Audio(label="AkylAi-TTS", type="numpy", elem_id="audio_out") |
|
|
|
inputs = [input_text, speaking_rate] |
|
outputs = [audio_out] |
|
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True) |
|
|
|
|
|
block.queue() |
|
block.launch(share=True) |