Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.insert(0, os.getcwd()) | |
import argparse | |
import re | |
import time | |
import pandas | |
import numpy as np | |
from tqdm import tqdm | |
import random | |
import gradio as gr | |
import json | |
from utils import normalize_zh, batch_split, normalize_audio, combine_audio | |
from tts_model import load_chat_tts_model, clear_cuda_cache, generate_audio_for_seed | |
from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, DEFAULT_ORAL, \ | |
DEFAULT_LAUGH, DEFAULT_BK, DEFAULT_SEG_LENGTH | |
import torch | |
import spaces | |
parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX") | |
parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.") | |
parser.add_argument("--local_path", type=str, help="Path to local model if source is 'local'.") | |
parser.add_argument("--share", default=False, action="store_true", help="Share the server publicly.") | |
args = parser.parse_args() | |
# 存放音频种子文件的目录 | |
SAVED_DIR = "saved_seeds" | |
# mkdir | |
if not os.path.exists(SAVED_DIR): | |
os.makedirs(SAVED_DIR) | |
# 文件路径 | |
SAVED_SEEDS_FILE = os.path.join(SAVED_DIR, "saved_seeds.json") | |
# 选中的种子index | |
SELECTED_SEED_INDEX = -1 | |
# 初始化JSON文件 | |
if not os.path.exists(SAVED_SEEDS_FILE): | |
with open(SAVED_SEEDS_FILE, "w") as f: | |
f.write("[]") | |
chat = load_chat_tts_model(source=args.source, local_path=args.local_path) | |
# chat = None | |
# chat = load_chat_tts_model(source="local", local_path=r"models") | |
# 抽卡的最大数量 | |
max_audio_components = 10 | |
# 加载 | |
def load_seeds(): | |
with open(SAVED_SEEDS_FILE, "r") as f: | |
global saved_seeds | |
seeds = json.load(f) | |
# 兼容旧的 JSON 格式,添加 path 字段 | |
for seed in seeds: | |
if 'path' not in seed: | |
seed['path'] = None | |
saved_seeds = seeds | |
return saved_seeds | |
def display_seeds(): | |
seeds = load_seeds() | |
# 转换为 List[List] 的形式 | |
return [[i, s['seed'], s['name'], s['path']] for i, s in enumerate(seeds)] | |
saved_seeds = load_seeds() | |
num_seeds_default = 2 | |
def save_seeds(): | |
global saved_seeds | |
with open(SAVED_SEEDS_FILE, "w") as f: | |
json.dump(saved_seeds, f) | |
saved_seeds = load_seeds() | |
# 添加 seed | |
def add_seed(seed, name, audio_path, save=True): | |
for s in saved_seeds: | |
if s['seed'] == seed: | |
return False | |
saved_seeds.append({ | |
'seed': seed, | |
'name': name, | |
'path': audio_path | |
}) | |
if save: | |
save_seeds() | |
# 修改 seed | |
def modify_seed(seed, name, save=True): | |
for s in saved_seeds: | |
if s['seed'] == seed: | |
s['name'] = name | |
if save: | |
save_seeds() | |
return True | |
return False | |
def delete_seed(seed, save=True): | |
for s in saved_seeds: | |
if s['seed'] == seed: | |
saved_seeds.remove(s) | |
if save: | |
save_seeds() | |
return True | |
return False | |
def generate_seeds(num_seeds, texts, tq): | |
""" | |
生成随机音频种子并保存 | |
:param num_seeds: | |
:param texts: | |
:param tq: | |
:return: | |
""" | |
seeds = [] | |
sample_rate = 24000 | |
# 按行分割文本 并正则化数字和标点字符 | |
texts = [normalize_zh(_) for _ in texts.split('\n') if _.strip()] | |
print(texts) | |
if not tq: | |
tq = tqdm | |
for _ in tq(range(num_seeds), desc=f"随机音色生成中..."): | |
seed = np.random.randint(0, 9999) | |
filename = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]", None, 0.3, 0.7, 20) | |
seeds.append((filename, seed)) | |
clear_cuda_cache() | |
return seeds | |
# 保存选定的音频种子 | |
def do_save_seed(seed, audio_path): | |
print(f"Saving seed {seed} to {audio_path}") | |
seed = seed.replace('保存种子 ', '').strip() | |
if not seed: | |
return | |
add_seed(int(seed), seed, audio_path) | |
gr.Info(f"Seed {seed} has been saved.") | |
def do_save_seeds(seeds): | |
assert isinstance(seeds, pandas.DataFrame) | |
seeds = seeds.drop(columns=['Index']) | |
# 将 DataFrame 转换为字典列表格式,并将键转换为小写 | |
result = [{k.lower(): v for k, v in row.items()} for row in seeds.to_dict(orient='records')] | |
print(result) | |
if result: | |
global saved_seeds | |
saved_seeds = result | |
save_seeds() | |
gr.Info(f"Seeds have been saved.") | |
return result | |
def do_delete_seed(val): | |
# 从 val 匹配 [(\d+)] 获取index | |
index = re.search(r'\[(\d+)\]', val) | |
global saved_seeds | |
if index: | |
index = int(index.group(1)) | |
seed = saved_seeds[index]['seed'] | |
delete_seed(seed) | |
gr.Info(f"Seed {seed} has been deleted.") | |
return display_seeds() | |
# 定义播放音频的函数 | |
def do_play_seed(val): | |
# 从 val 匹配 [(\d+)] 获取index | |
index = re.search(r'\[(\d+)\]', val) | |
if index: | |
index = int(index.group(1)) | |
seed = saved_seeds[index]['seed'] | |
audio_path = saved_seeds[index]['path'] | |
if audio_path: | |
return gr.update(visible=True, value=audio_path) | |
return gr.update(visible=False, value=None) | |
def seed_change_btn(): | |
global SELECTED_SEED_INDEX | |
if SELECTED_SEED_INDEX == -1: | |
return ['删除', '试听'] | |
return [f'删除 idx=[{SELECTED_SEED_INDEX[0]}]', f'试听 idx=[{SELECTED_SEED_INDEX[0]}]'] | |
def audio_interface(num_seeds, texts, progress=gr.Progress()): | |
""" | |
生成音频 | |
:param num_seeds: | |
:param texts: | |
:param progress: | |
:return: | |
""" | |
seeds = generate_seeds(num_seeds, texts, progress.tqdm) | |
wavs = [_[0] for _ in seeds] | |
seeds = [f"保存种子 {_[1]}" for _ in seeds] | |
# 不足的部分 | |
all_wavs = wavs + [None] * (max_audio_components - len(wavs)) | |
all_seeds = seeds + [''] * (max_audio_components - len(seeds)) | |
return [item for pair in zip(all_wavs, all_seeds, all_wavs) for item in pair] | |
# 保存刚刚生成的种子文件路径 | |
audio_paths = [gr.State(value=None) for _ in range(max_audio_components)] | |
def audio_interface_with_paths(num_seeds, texts, progress=gr.Progress()): | |
""" | |
比 audio_interface 多携带音频的 path | |
""" | |
results = audio_interface(num_seeds, texts, progress) | |
wavs = results[::2] # 提取音频文件路径 | |
for i, wav in enumerate(wavs): | |
audio_paths[i].value = wav # 直接为 State 组件赋值 | |
return results | |
def audio_interface_empty(num_seeds, texts, progress=gr.Progress(track_tqdm=True)): | |
return [None, "", None] * max_audio_components | |
def update_audio_components(slider_value): | |
# 根据滑块的值更新 Audio 组件的可见性 | |
k = int(slider_value) | |
audios = [gr.Audio(visible=True)] * k + [gr.Audio(visible=False)] * (max_audio_components - k) | |
tbs = [gr.Textbox(visible=True)] * k + [gr.Textbox(visible=False)] * (max_audio_components - k) | |
stats = [gr.State(value=None)] * max_audio_components | |
print(f'k={k}, audios={len(audios)}') | |
return [item for pair in zip(audios, tbs, stats) for item in pair] | |
def seed_change(evt: gr.SelectData): | |
# print(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
global SELECTED_SEED_INDEX | |
SELECTED_SEED_INDEX = evt.index | |
return evt.index | |
def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P, | |
top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None, progress=gr.Progress()): | |
from tts_model import generate_audio_for_seed | |
from utils import split_text, replace_tokens, restore_tokens | |
if seed in [0, -1, None]: | |
seed = random.randint(1, 9999) | |
content = '' | |
if os.path.isfile(text_file): | |
content = "" | |
elif isinstance(text_file, str): | |
content = text_file | |
# 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原 | |
content = replace_tokens(content) | |
texts = split_text(content, min_length=min_length) | |
for i, text in enumerate(texts): | |
texts[i] = restore_tokens(text) | |
if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7: | |
raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range") | |
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" | |
try: | |
output_files = generate_audio_for_seed( | |
chat=chat, | |
seed=seed, | |
texts=texts, | |
batch_size=batch_size, | |
speed=speed, | |
refine_text_prompt=refine_text_prompt, | |
roleid=roleid, | |
temperature=temperature, | |
top_P=top_P, | |
top_K=top_K, | |
cur_tqdm=progress.tqdm, | |
skip_save=False, | |
skip_refine_text=not refine_text, | |
speaker_type=speaker_type, | |
pt_file=pt_file, | |
) | |
return output_files | |
except Exception as e: | |
raise e | |
def generate_tts_audio_stream(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, | |
top_P, | |
top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None, | |
stream_mode="fake"): | |
from utils import split_text, replace_tokens, restore_tokens | |
from tts_model import deterministic | |
if seed in [0, -1, None]: | |
seed = random.randint(1, 9999) | |
content = '' | |
if os.path.isfile(text_file): | |
content = "" | |
elif isinstance(text_file, str): | |
content = text_file | |
# 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原 | |
content = replace_tokens(content) | |
# texts = [normalize_zh(_) for _ in content.split('\n') if _.strip()] | |
texts = split_text(content, min_length=min_length) | |
for i, text in enumerate(texts): | |
texts[i] = restore_tokens(text) | |
if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7: | |
raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range") | |
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" | |
print(f"speaker_type: {speaker_type}") | |
if speaker_type == "seed": | |
if seed in [None, -1, 0, "", "random"]: | |
seed = np.random.randint(0, 9999) | |
deterministic(seed) | |
rnd_spk_emb = chat.sample_random_speaker() | |
elif speaker_type == "role": | |
# 从 JSON 文件中读取数据 | |
with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file: | |
slct_idx_loaded = json.load(json_file) | |
# 将包含 Tensor 数据的部分转换回 Tensor 对象 | |
for key in slct_idx_loaded: | |
tensor_list = slct_idx_loaded[key]["tensor"] | |
slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list) | |
# 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature | |
rnd_spk_emb = slct_idx_loaded[roleid]["tensor"] | |
# temperature = 0.001 | |
elif speaker_type == "pt": | |
print(pt_file) | |
rnd_spk_emb = torch.load(pt_file) | |
print(rnd_spk_emb.shape) | |
if rnd_spk_emb.shape != (768,): | |
raise ValueError("维度应为 768。") | |
else: | |
raise ValueError(f"Invalid speaker_type: {speaker_type}. ") | |
params_infer_code = { | |
'spk_emb': rnd_spk_emb, | |
'prompt': f'[speed_{speed}]', | |
'top_P': top_P, | |
'top_K': top_K, | |
'temperature': temperature | |
} | |
params_refine_text = { | |
'prompt': refine_text_prompt, | |
'top_P': top_P, | |
'top_K': top_K, | |
'temperature': temperature | |
} | |
if stream_mode == "real": | |
for text in texts: | |
_params_infer_code = {**params_infer_code} | |
wavs_gen = chat.infer(text, params_infer_code=_params_infer_code, params_refine_text=params_refine_text, | |
use_decoder=True, skip_refine_text=True, stream=True) | |
for gen in wavs_gen: | |
wavs = [np.array([[]])] | |
wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) | |
audio = wavs[0][0] | |
yield 24000, normalize_audio(audio) | |
clear_cuda_cache() | |
else: | |
for text in batch_split(texts, batch_size): | |
_params_infer_code = {**params_infer_code} | |
wavs = chat.infer(text, params_infer_code=_params_infer_code, params_refine_text=params_refine_text, | |
use_decoder=True, skip_refine_text=False, stream=False) | |
combined_audio = combine_audio(wavs) | |
yield 24000, combined_audio[0] | |
def generate_refine(text_file, oral, laugh, bk, temperature, top_P, top_K, progress=gr.Progress()): | |
from tts_model import generate_refine_text | |
from utils import split_text, replace_tokens, restore_tokens, replace_space_between_chinese | |
seed = random.randint(1, 9999) | |
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" | |
content = '' | |
if os.path.isfile(text_file): | |
content = "" | |
elif isinstance(text_file, str): | |
content = text_file | |
if re.search(r'\[uv_break\]|\[laugh\]', content) is not None: | |
gr.Info("检测到 [uv_break] [laugh],不能重复 refine ") | |
# print("检测到 [uv_break] [laugh],不能重复 refine ") | |
return content | |
batch_size = 5 | |
content = replace_tokens(content) | |
texts = split_text(content, min_length=120) | |
print(texts) | |
for i, text in enumerate(texts): | |
texts[i] = restore_tokens(text) | |
txts = [] | |
for batch in progress.tqdm(batch_split(texts, batch_size), desc=f"Refine Text Please Wait ..."): | |
txts.extend(generate_refine_text(chat, seed, batch, refine_text_prompt, temperature, top_P, top_K)) | |
return replace_space_between_chinese('\n\n'.join(txts)) | |
def generate_seed(): | |
new_seed = random.randint(1, 9999) | |
return { | |
"__type__": "update", | |
"value": new_seed | |
} | |
def update_label(text): | |
word_count = len(text) | |
return gr.update(label=f"朗读文本({word_count} 字)") | |
def inser_token(text, btn): | |
if btn == "+笑声": | |
return gr.update( | |
value=text + "[laugh]" | |
) | |
elif btn == "+停顿": | |
return gr.update( | |
value=text + "[uv_break]" | |
) | |
with gr.Blocks() as demo: | |
# 项目链接 | |
# gr.Markdown(""" | |
# <div style='text-align: center; font-size: 16px;'> | |
# 🌟 <a href='https://github.com/6drf21e/ChatTTS_colab'>项目地址 欢迎 start</a> 🌟 | |
# </div> | |
# """) | |
gr.Markdown("# Deployed by [chattts.dev](https://chattts.dev?refer=hf-story-telling)") | |
with gr.Tab("角色扮演"): | |
def txt_2_script(text): | |
lines = text.split("\n") | |
data = [] | |
for line in lines: | |
if not line.strip(): | |
continue | |
parts = line.split("::") | |
if len(parts) != 2: | |
continue | |
data.append({ | |
"character": parts[0], | |
"txt": parts[1] | |
}) | |
return data | |
def script_2_txt(data): | |
assert isinstance(data, list) | |
result = [] | |
for item in data: | |
txt = item['txt'].replace('\n', ' ') | |
result.append(f"{item['character']}::{txt}") | |
return "\n".join(result) | |
def get_characters(lines): | |
assert isinstance(lines, list) | |
characters = list([_["character"] for _ in lines]) | |
unique_characters = list(dict.fromkeys(characters)) | |
print([[character, 0] for character in unique_characters]) | |
return [[character, 0, 5, 2, 0, 4] for character in unique_characters] | |
def get_txt_characters(text): | |
return get_characters(txt_2_script(text)) | |
def llm_change(model): | |
llm_setting = { | |
"gpt-3.5-turbo-0125": ["https://api.openai.com/v1"], | |
"gpt-4o": ["https://api.openai.com/v1"], | |
"deepseek-chat": ["https://api.deepseek.com"], | |
"yi-large": ["https://api.lingyiwanwu.com/v1"] | |
} | |
if model in llm_setting: | |
return llm_setting[model][0] | |
else: | |
gr.Error("Model not found.") | |
return None | |
def ai_script_generate(model, api_base, api_key, text, progress=gr.Progress(track_tqdm=True)): | |
from llm_utils import llm_operation | |
from config import LLM_PROMPT | |
scripts = llm_operation(api_base, api_key, model, LLM_PROMPT, text, required_keys=["txt", "character"]) | |
return script_2_txt(scripts) | |
def generate_script_audio(text, models_seeds, progress=gr.Progress()): | |
scripts = txt_2_script(text) # 将文本转换为剧本 | |
characters = get_characters(scripts) # 从剧本中提取角色 | |
# | |
import pandas as pd | |
from collections import defaultdict | |
import itertools | |
from tts_model import generate_audio_for_seed | |
from utils import combine_audio, save_audio, normalize_zh | |
assert isinstance(models_seeds, pd.DataFrame) | |
# 批次处理函数 | |
def batch(iterable, batch_size): | |
it = iter(iterable) | |
while True: | |
batch = list(itertools.islice(it, batch_size)) | |
if not batch: | |
break | |
yield batch | |
print('1') | |
column_mapping = { | |
'角色': 'character', | |
'种子': 'seed', | |
'语速': 'speed', | |
'口语': 'oral', | |
'笑声': 'laugh', | |
'停顿': 'break' | |
} | |
# 使用 rename 方法重命名 DataFrame 的列 | |
models_seeds = models_seeds.rename(columns=column_mapping).to_dict(orient='records') | |
# models_seeds = models_seeds.to_dict(orient='records') | |
print('2') | |
# 检查每个角色是否都有对应的种子 | |
print(models_seeds) | |
seed_lookup = {seed['character']: seed for seed in models_seeds} | |
character_seeds = {} | |
missing_seeds = [] | |
# 遍历所有角色 | |
for character in characters: | |
character_name = character[0] | |
seed_info = seed_lookup.get(character_name) | |
if seed_info: | |
character_seeds[character_name] = seed_info | |
else: | |
missing_seeds.append(character_name) | |
if missing_seeds: | |
missing_characters_str = ', '.join(missing_seeds) | |
gr.Info(f"以下角色没有种子,请先设置种子:{missing_characters_str}") | |
return None | |
print(f'character_seeds:{character_seeds}') | |
# return | |
refine_text_prompt = "[oral_2][laugh_0][break_4]" | |
all_wavs = [] | |
# 按角色分组,加速推理 | |
grouped_lines = defaultdict(list) | |
for line in scripts: | |
grouped_lines[line["character"]].append(line) | |
batch_results = {character: [] for character in grouped_lines} | |
batch_size = 5 # 设置批次大小 | |
# 按角色处理 | |
for character, lines in progress.tqdm(grouped_lines.items(), desc="生成剧本音频"): | |
info = character_seeds[character] | |
seed = info["seed"] | |
speed = info["speed"] | |
orla = info["oral"] | |
laugh = info["laugh"] | |
bk = info["break"] | |
refine_text_prompt = f"[oral_{orla}][laugh_{laugh}][break_{bk}]" | |
print(f'3 lines:{lines}') | |
# 按批次处理 | |
for batch_lines in lines:#batch(lines, batch_size): | |
texts = [normalize_zh(line["txt"]) for line in [batch_lines]] | |
print(f"seed={seed} t={texts} c={character} s={speed} r={refine_text_prompt}") | |
wavs = generate_audio_for_seed(chat, int(seed), texts, DEFAULT_BATCH_SIZE, speed, | |
refine_text_prompt, None, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, | |
DEFAULT_TOP_K, skip_save=True) # 批量处理文本 | |
# wavs = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]", None, 0.3, 0.7, 20) | |
batch_results[character].extend(wavs) | |
print('4') | |
# 转换回原排序 | |
for line in scripts: | |
character = line["character"] | |
all_wavs.append(batch_results[character].pop(0)) | |
print('5') | |
# 合成所有音频 | |
audio = combine_audio(all_wavs) | |
fname = f"script_{int(time.time())}.wav" | |
return save_audio(fname, audio) | |
script_example = { | |
"lines": [{ | |
"txt": "在一个风和日丽的下午,小红帽准备去森林里看望她的奶奶。", | |
"character": "旁白" | |
}, { | |
"txt": "小红帽说", | |
"character": "旁白" | |
}, { | |
"txt": "我要给奶奶带点好吃的。", | |
"character": "年轻女性" | |
}, { | |
"txt": "在森林里,小红帽遇到了狡猾的大灰狼。", | |
"character": "旁白" | |
}, { | |
"txt": "大灰狼说", | |
"character": "旁白" | |
}, { | |
"txt": "小红帽,你的篮子里装的是什么?", | |
"character": "中年男性" | |
}, { | |
"txt": "小红帽回答", | |
"character": "旁白" | |
}, { | |
"txt": "这是给奶奶的蛋糕和果酱。", | |
"character": "年轻女性" | |
}, { | |
"txt": "大灰狼心生一计,决定先到奶奶家等待小红帽。", | |
"character": "旁白" | |
}, { | |
"txt": "当小红帽到达奶奶家时,她发现大灰狼伪装成了奶奶。", | |
"character": "旁白" | |
}, { | |
"txt": "小红帽疑惑的问", | |
"character": "旁白" | |
}, { | |
"txt": "奶奶,你的耳朵怎么这么尖?", | |
"character": "年轻女性" | |
}, { | |
"txt": "大灰狼慌张地回答", | |
"character": "旁白" | |
}, { | |
"txt": "哦,这是为了更好地听你说话。", | |
"character": "中年男性" | |
}, { | |
"txt": "小红帽越发觉得不对劲,最终发现了大灰狼的诡计。", | |
"character": "旁白" | |
}, { | |
"txt": "她大声呼救,森林里的猎人听到后赶来救了她和奶奶。", | |
"character": "旁白" | |
}, { | |
"txt": "从此,小红帽再也没有单独进入森林,而是和家人一起去看望奶奶。", | |
"character": "旁白" | |
}] | |
} | |
ai_text_default = "武侠小说《花木兰大战周树人》 要符合人物背景" | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
gr.Markdown("### AI脚本") | |
gr.Markdown(""" | |
为确保生成效果稳定,仅支持与 GPT-4 相当的模型,推荐使用 4o yi-large deepseek。 | |
如果没有反应,请检查日志中的错误信息。如果提示格式错误,请重试几次。国内模型可能会受到风控影响,建议更换文本内容后再试。 | |
申请渠道(免费额度): | |
- [https://platform.deepseek.com/](https://platform.deepseek.com/) | |
- [https://platform.lingyiwanwu.com/](https://platform.lingyiwanwu.com/) | |
""") | |
# 申请渠道 | |
with gr.Row(equal_height=True): | |
# 选择模型 只有 gpt4o deepseek-chat yi-large 三个选项 | |
model_select = gr.Radio(label="选择模型", choices=["gpt-4o", "deepseek-chat", "yi-large"], | |
value="gpt-4o", interactive=True, ) | |
with gr.Row(equal_height=True): | |
openai_api_base_input = gr.Textbox(label="OpenAI API Base URL", | |
placeholder="请输入API Base URL", | |
value=r"https://api.openai.com/v1") | |
openai_api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="请输入API Key", | |
value="sk-xxxxxxx", type="password") | |
# AI提示词 | |
ai_text_input = gr.Textbox(label="剧情简介或者一段故事", placeholder="请输入文本...", lines=2, | |
value=ai_text_default) | |
# 生成脚本的按钮 | |
ai_script_generate_button = gr.Button("AI脚本生成") | |
with gr.Column(scale=3): | |
gr.Markdown("### 脚本") | |
gr.Markdown( | |
"脚本可以手工编写也可以从左侧的AI脚本生成按钮生成。脚本格式 **角色::文本** 一行为一句” 注意是::") | |
script_text = "\n".join( | |
[f"{_.get('character', '')}::{_.get('txt', '')}" for _ in script_example['lines']]) | |
script_text_input = gr.Textbox(label="脚本格式 “角色::文本 一行为一句” 注意是::", | |
placeholder="请输入文本...", | |
lines=12, value=script_text) | |
script_translate_button = gr.Button("步骤①:提取角色") | |
with gr.Column(scale=1): | |
gr.Markdown("### 角色种子") | |
# DataFrame 来存放转换后的脚本 | |
# 默认数据 [speed_5][oral_2][laugh_0][break_4] | |
default_data = [ | |
["旁白", 2222, 3, 0, 0, 2], | |
["年轻女性", 2, 5, 2, 0, 2], | |
["中年男性", 2424, 5, 2, 0, 2] | |
] | |
script_data = gr.DataFrame( | |
value=default_data, | |
label="角色对应的音色种子,从抽卡那获取", | |
headers=["角色", "种子", "语速", "口语", "笑声", "停顿"], | |
datatype=["str", "number", "number", "number", "number", "number"], | |
interactive=True, | |
col_count=(6, "fixed"), | |
) | |
# 生视频按钮 | |
script_generate_audio = gr.Button("步骤②:生成音频") | |
# 输出的脚本音频 | |
script_audio = gr.Audio(label="AI生成的音频", interactive=False) | |
# 脚本相关事件 | |
# 脚本转换 | |
script_translate_button.click( | |
get_txt_characters, | |
inputs=[script_text_input], | |
outputs=script_data | |
) | |
# 处理模型切换 | |
model_select.change( | |
llm_change, | |
inputs=[model_select], | |
outputs=[openai_api_base_input] | |
) | |
# AI脚本生成 | |
ai_script_generate_button.click( | |
ai_script_generate, | |
inputs=[model_select, openai_api_base_input, openai_api_key_input, ai_text_input], | |
outputs=[script_text_input] | |
) | |
# 音频生成 | |
script_generate_audio.click( | |
generate_script_audio, | |
inputs=[script_text_input, script_data], | |
outputs=[script_audio] | |
) | |
with gr.Tab("音色抽卡"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
texts = [ | |
"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。", | |
"我是一个充满活力的人,喜欢运动,喜欢旅行,喜欢尝试新鲜事物。我喜欢挑战自己,不断突破自己的极限,让自己变得更加强大。", | |
"罗森宣布将于7月24日退市,在华门店超6000家!", | |
] | |
# gr.Markdown("### 随机音色抽卡") | |
# gr.Markdown(""" | |
# 免抽卡,直接找稳定音色👇 | |
# [ModelScope ChatTTS Speaker(国内)](https://modelscope.cn/studios/ttwwwaa/ChatTTS_Speaker) | [HuggingFace ChatTTS Speaker(国外)](https://huggingface.co/spaces/taa/ChatTTS_Speaker) | |
# 在相同的 seed 和 温度等参数下,音色具有一定的一致性。点击下面的“随机音色生成”按钮将生成多个 seed。找到满意的音色后,点击音频下方“保存”按钮。 | |
# **注意:不同机器使用相同种子生成的音频音色可能不同,同一机器使用相同种子多次生成的音频音色也可能变化。** | |
# """) | |
input_text = gr.Textbox(label="测试文本", | |
info="**每行文本**都会生成一段音频,最终输出的音频是将这些音频段合成后的结果。建议使用**多行文本**进行测试,以确保音色稳定性。", | |
lines=4, placeholder="请输入文本...", value='\n'.join(texts)) | |
num_seeds = gr.Slider(minimum=1, maximum=max_audio_components, step=1, label="seed生成数量", | |
value=num_seeds_default) | |
generate_button = gr.Button("随机音色抽卡🎲", variant="primary") | |
# 保存的种子 | |
gr.Markdown("### 种子管理界面") | |
seed_list = gr.DataFrame( | |
label="种子列表", | |
headers=["Index", "Seed", "Name", "Path"], | |
datatype=["number", "number", "str", "str"], | |
interactive=True, | |
col_count=(4, "fixed"), | |
value=display_seeds | |
) | |
with gr.Row(): | |
refresh_button = gr.Button("刷新") | |
save_button = gr.Button("保存") | |
del_button = gr.Button("删除") | |
play_button = gr.Button("试听") | |
with gr.Row(): | |
# 添加已保存的种子音频播放组件 | |
audio_player = gr.Audio(label="播放已保存种子音频", visible=False) | |
# 绑定按钮和函数 | |
refresh_button.click(display_seeds, outputs=seed_list) | |
seed_list.select(seed_change).success(seed_change_btn, outputs=[del_button, play_button]) | |
save_button.click(do_save_seeds, inputs=[seed_list], outputs=None) | |
del_button.click(do_delete_seed, inputs=del_button, outputs=seed_list) | |
play_button.click(do_play_seed, inputs=play_button, outputs=audio_player) | |
with gr.Column(scale=1): | |
audio_components = [] | |
for i in range(max_audio_components): | |
visible = i < num_seeds_default | |
a = gr.Audio(f"Audio {i}", visible=visible) | |
t = gr.Button(f"Seed", visible=visible) | |
s = gr.State(value=None) | |
t.click(do_save_seed, inputs=[t, s], outputs=None).success(display_seeds, outputs=seed_list) | |
audio_components.append(a) | |
audio_components.append(t) | |
audio_components.append(s) | |
num_seeds.change(update_audio_components, inputs=num_seeds, outputs=audio_components) | |
# output = gr.Column() | |
# audio = gr.Audio(label="Output Audio") | |
generate_button.click( | |
audio_interface_empty, | |
inputs=[num_seeds, input_text], | |
outputs=audio_components | |
).success(audio_interface, inputs=[num_seeds, input_text], outputs=audio_components) | |
with gr.Tab("长音频生成"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### 文本") | |
# gr.Markdown("请上传要转换的文本文件(.txt 格式)。") | |
# text_file_input = gr.File(label="文本文件", file_types=[".txt"]) | |
default_text = "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。" | |
text_file_input = gr.Textbox(label=f"朗读文本(字数: {len(default_text)})", lines=4, | |
placeholder="Please Input Text...", value=default_text) | |
# 当文本框内容发生变化时调用 update_label 函数 | |
text_file_input.change(update_label, inputs=text_file_input, outputs=text_file_input) | |
# 加入停顿按钮 | |
with gr.Row(): | |
break_button = gr.Button("+停顿", variant="secondary") | |
laugh_button = gr.Button("+笑声", variant="secondary") | |
refine_button = gr.Button("Refine Text(预处理 加入停顿词、笑声等)", variant="secondary") | |
with gr.Column(): | |
gr.Markdown("### 配置参数") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("音色选择") | |
num_seeds_input = gr.Number(label="生成音频的数量", value=1, precision=0, visible=False) | |
speaker_stat = gr.State(value="seed") | |
tab_seed = gr.Tab(label="种子") | |
with tab_seed: | |
with gr.Row(): | |
seed_input = gr.Number(label="指定种子", info="种子决定音色 0则随机", value=None, | |
precision=0) | |
generate_audio_seed = gr.Button("\U0001F3B2") | |
tab_roleid = gr.Tab(label="内置音色") | |
with tab_roleid: | |
roleid_input = gr.Dropdown(label="内置音色", | |
choices=[("发姐", "1"), | |
("纯情男大学生", "2"), | |
("阳光开朗大男孩", "3"), | |
("知心小姐姐", "4"), | |
("电视台女主持", "5"), | |
("魅力大叔", "6"), | |
("优雅甜美", "7"), | |
("贴心男宝2", "21"), | |
("正式打工人", "8"), | |
("贴心男宝1", "9")], | |
value="1", | |
info="选择音色后会覆盖种子。感谢 @QuantumDriver 提供音色") | |
tab_pt = gr.Tab(label="上传.PT文件") | |
with tab_pt: | |
pt_input = gr.File(label="上传音色文件", file_types=[".pt"], height=100) | |
with gr.Row(): | |
style_select = gr.Radio(label="预设参数", info="语速部分可自行更改", | |
choices=["小说朗读", "对话", "中英混合", "默认"], value="默认", | |
interactive=True, ) | |
with gr.Row(): | |
# refine | |
refine_text_input = gr.Checkbox(label="Refine", | |
info="打开后会自动根据下方参数添加笑声/停顿等。关闭后可自行添加 [uv_break] [laugh] 或者点击下方 Refin按钮先行转换", | |
value=True) | |
speed_input = gr.Slider(label="语速", minimum=1, maximum=10, value=DEFAULT_SPEED, step=1) | |
with gr.Row(): | |
oral_input = gr.Slider(label="口语化", minimum=0, maximum=9, value=DEFAULT_ORAL, step=1) | |
laugh_input = gr.Slider(label="笑声", minimum=0, maximum=2, value=DEFAULT_LAUGH, step=1) | |
bk_input = gr.Slider(label="停顿", minimum=0, maximum=7, value=DEFAULT_BK, step=1) | |
# gr.Markdown("### 文本参数") | |
with gr.Row(): | |
min_length_input = gr.Number(label="文本分段长度", info="大于这个数值进行分段", | |
value=DEFAULT_SEG_LENGTH, precision=0) | |
batch_size_input = gr.Number(label="批大小", info="越高越快 太高爆显存 4G推荐3 其他酌情", | |
value=DEFAULT_BATCH_SIZE, precision=0) | |
with gr.Accordion("其他参数", open=False): | |
with gr.Row(): | |
# 温度 top_P top_K | |
temperature_input = gr.Slider(label="温度", minimum=0.01, maximum=1.0, step=0.01, | |
value=DEFAULT_TEMPERATURE) | |
top_P_input = gr.Slider(label="top_P", minimum=0.1, maximum=0.9, step=0.05, value=DEFAULT_TOP_P) | |
top_K_input = gr.Slider(label="top_K", minimum=1, maximum=20, step=1, value=DEFAULT_TOP_K) | |
# reset 按钮 | |
reset_button = gr.Button("重置") | |
with gr.Row(): | |
with gr.Column(): | |
generate_button = gr.Button("生成音频", variant="primary") | |
with gr.Column(): | |
generate_button_stream = gr.Button("流式生成音频(一边播放一边推理)", variant="primary") | |
stream_select = gr.Radio(label="流输出方式", | |
info="真流式为实验功能,播放效果:卡播卡播卡播(⏳🎵⏳🎵⏳🎵);伪流式为分段推理后输出,播放效果:卡卡卡播播播播(⏳⏳🎵🎵🎵🎵)。伪流式批次建议4以上减少卡顿", | |
choices=[("真", "real"), ("伪", "fake")], value="fake", interactive=True, ) | |
with gr.Row(): | |
output_audio = gr.Audio(label="生成的音频文件") | |
output_audio_stream = gr.Audio(label="流式音频", value=None, | |
streaming=True, | |
autoplay=True, | |
# disable auto play for Windows, due to https://developer.chrome.com/blog/autoplay#webaudio | |
interactive=False, | |
show_label=True) | |
generate_audio_seed.click(generate_seed, | |
inputs=[], | |
outputs=seed_input) | |
def do_tab_change(evt: gr.SelectData): | |
print(evt.selected, evt.index, evt.value, evt.target) | |
kv = { | |
"种子": "seed", | |
"内置音色": "role", | |
"上传.PT文件": "pt" | |
} | |
return kv.get(evt.value, "seed") | |
tab_seed.select(do_tab_change, outputs=speaker_stat) | |
tab_roleid.select(do_tab_change, outputs=speaker_stat) | |
tab_pt.select(do_tab_change, outputs=speaker_stat) | |
def do_style_select(x): | |
if x == "小说朗读": | |
return [4, 0, 0, 2] | |
elif x == "对话": | |
return [5, 5, 1, 4] | |
elif x == "中英混合": | |
return [4, 1, 0, 3] | |
else: | |
return [DEFAULT_SPEED, DEFAULT_ORAL, DEFAULT_LAUGH, DEFAULT_BK] | |
# style_select 选择 | |
style_select.change( | |
do_style_select, | |
inputs=style_select, | |
outputs=[speed_input, oral_input, laugh_input, bk_input] | |
) | |
# refine 按钮 | |
refine_button.click( | |
generate_refine, | |
inputs=[text_file_input, oral_input, laugh_input, bk_input, temperature_input, top_P_input, top_K_input], | |
outputs=text_file_input | |
) | |
# 重置按钮 重置温度等参数 | |
reset_button.click( | |
lambda: [0.3, 0.7, 20], | |
inputs=None, | |
outputs=[temperature_input, top_P_input, top_K_input] | |
) | |
generate_button.click( | |
fn=generate_tts_audio, | |
inputs=[ | |
text_file_input, | |
num_seeds_input, | |
seed_input, | |
speed_input, | |
oral_input, | |
laugh_input, | |
bk_input, | |
min_length_input, | |
batch_size_input, | |
temperature_input, | |
top_P_input, | |
top_K_input, | |
roleid_input, | |
refine_text_input, | |
speaker_stat, | |
pt_input | |
], | |
outputs=[output_audio] | |
) | |
generate_button_stream.click( | |
fn=generate_tts_audio_stream, | |
inputs=[ | |
text_file_input, | |
num_seeds_input, | |
seed_input, | |
speed_input, | |
oral_input, | |
laugh_input, | |
bk_input, | |
min_length_input, | |
batch_size_input, | |
temperature_input, | |
top_P_input, | |
top_K_input, | |
roleid_input, | |
refine_text_input, | |
speaker_stat, | |
pt_input, | |
stream_select | |
], | |
outputs=[output_audio_stream] | |
) | |
break_button.click( | |
inser_token, | |
inputs=[text_file_input, break_button], | |
outputs=text_file_input | |
) | |
laugh_button.click( | |
inser_token, | |
inputs=[text_file_input, laugh_button], | |
outputs=text_file_input | |
) | |
demo.launch(share=args.share, inbrowser=True) | |