|
import gradio as gr |
|
import torch |
|
import os |
|
from io import BytesIO |
|
import base64 |
|
import numpy as np |
|
from pydub import AudioSegment |
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
repo_id = "parler-tts/parler_tts_mini_v0.1" |
|
|
|
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id) |
|
|
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
SAMPLE_RATE = feature_extractor.sampling_rate |
|
SEED = 42 |
|
|
|
def gen_tts(secret_token, text, description): |
|
if secret_token != SECRET_TOKEN: |
|
raise gr.Error( |
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
inputs = tokenizer(description, return_tensors="pt").to(device) |
|
prompt = tokenizer(text, return_tensors="pt").to(device) |
|
|
|
set_seed(SEED) |
|
generation = model.generate( |
|
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, do_sample=True, temperature=1.0 |
|
) |
|
audio_arr = generation.cpu().numpy().squeeze() |
|
|
|
|
|
samples = np.array(audio_arr * (2**15 - 1), dtype=np.int16) |
|
sound = AudioSegment( |
|
samples.tobytes(), |
|
frame_rate=SAMPLE_RATE, |
|
sample_width=samples.dtype.itemsize, |
|
channels=1 |
|
) |
|
|
|
|
|
buff_mp3 = BytesIO() |
|
sound.export(buff_mp3, format="mp3") |
|
buff_mp3.seek(0) |
|
|
|
|
|
audio_base64 = base64.b64encode(buff_mp3.read()).decode('utf-8') |
|
data_uri = 'data:audio/mp3;base64,' + audio_base64 |
|
|
|
return data_uri |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.HTML(""" |
|
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;"> |
|
<div style="text-align: center; color: black;"> |
|
<p style="color: black;">This space is a headless component of the cloud rendering engine used by AiTube.</p> |
|
<p style="color: black;">It is not available for public use, but you can use the <a href="https://huggingface.co/spaces/ByteDance/AnimateDiff-Lightning" target="_blank">original space</a>.</p> |
|
</div> |
|
</div>""") |
|
secret_token = gr.Textbox(label="Secret token") |
|
input_text = gr.Textbox(label="Input Text") |
|
description = gr.Textbox(label="Description") |
|
run_button = gr.Button("Generate Audio") |
|
audio_out = gr.Textbox() |
|
|
|
inputs = [secret_token, input_text, description] |
|
outputs = [audio_out] |
|
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True) |
|
|
|
app.queue() |
|
app.launch() |
|
|