mrfakename commited on
Commit
626f70a
1 Parent(s): 51628d6

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: E2/F5 TTS
3
  emoji: 🗣️
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
- short_description: 'E2-TTS & F5-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
  sdk_version: 5.1.0
11
  ---
12
 
 
1
  ---
2
+ title: F5-TTS
3
  emoji: 🗣️
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: true
9
+ short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
  sdk_version: 5.1.0
11
  ---
12
 
README_REPO.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
+
3
+ [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
+ [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
6
+ [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
+
8
+ **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
9
+
10
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction.
11
+
12
+ **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
13
+
14
+ ## Installation
15
+
16
+ Clone the repository:
17
+
18
+ ```bash
19
+ git clone https://github.com/SWivid/F5-TTS.git
20
+ cd F5-TTS
21
+ ```
22
+
23
+ Install torch with your CUDA version, e.g. :
24
+
25
+ ```bash
26
+ pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
27
+ pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
28
+ ```
29
+
30
+ Install other packages:
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ ## Prepare Dataset
37
+
38
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
39
+
40
+ ```bash
41
+ # prepare custom dataset up to your need
42
+ # download corresponding dataset first, and fill in the path in scripts
43
+
44
+ # Prepare the Emilia dataset
45
+ python scripts/prepare_emilia.py
46
+
47
+ # Prepare the Wenetspeech4TTS dataset
48
+ python scripts/prepare_wenetspeech4tts.py
49
+ ```
50
+
51
+ ## Training
52
+
53
+ Once your datasets are prepared, you can start the training process.
54
+
55
+ ```bash
56
+ # setup accelerate config, e.g. use multi-gpu ddp, fp16
57
+ # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
58
+ accelerate config
59
+ accelerate launch train.py
60
+ ```
61
+ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
62
+
63
+ ## Inference
64
+
65
+ To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
66
+
67
+ Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
68
+ - To avoid possible inference failures, make sure you have seen through the following instructions.
69
+ - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
70
+ - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
71
+ - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
72
+
73
+ ### CLI Inference
74
+
75
+ Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
76
+
77
+ ```bash
78
+ python inference-cli.py \
79
+ --model "F5-TTS" \
80
+ --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
81
+ --ref_text "Some call me nature, others call me mother nature." \
82
+ --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
83
+
84
+ python inference-cli.py \
85
+ --model "E2-TTS" \
86
+ --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
87
+ --ref_text "对,这就是我,万人敬仰的太乙真人。" \
88
+ --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
89
+ ```
90
+
91
+ ### Gradio App
92
+ Currently supported features:
93
+ - Chunk inference
94
+ - Podcast Generation
95
+ - Multiple Speech-Type Generation
96
+
97
+ You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
98
+
99
+ ```bash
100
+ python gradio_app.py
101
+ ```
102
+
103
+ You can specify the port/host:
104
+
105
+ ```bash
106
+ python gradio_app.py --port 7860 --host 0.0.0.0
107
+ ```
108
+
109
+ Or launch a share link:
110
+
111
+ ```bash
112
+ python gradio_app.py --share
113
+ ```
114
+
115
+ ### Speech Editing
116
+
117
+ To test speech editing capabilities, use the following command.
118
+
119
+ ```bash
120
+ python speech_edit.py
121
+ ```
122
+
123
+ ## Evaluation
124
+
125
+ ### Prepare Test Datasets
126
+
127
+ 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
128
+ 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
129
+ 3. Unzip the downloaded datasets and place them in the data/ directory.
130
+ 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
131
+ 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
132
+
133
+ ### Batch Inference for Test Set
134
+
135
+ To run batch inference for evaluations, execute the following commands:
136
+
137
+ ```bash
138
+ # batch inference for evaluations
139
+ accelerate config # if not set before
140
+ bash scripts/eval_infer_batch.sh
141
+ ```
142
+
143
+ ### Download Evaluation Model Checkpoints
144
+
145
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
146
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
147
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
148
+
149
+ ### Objective Evaluation
150
+
151
+ **Some Notes**
152
+
153
+ For faster-whisper with CUDA 11:
154
+
155
+ ```bash
156
+ pip install --force-reinstall ctranslate2==3.24.0
157
+ ```
158
+
159
+ (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
160
+
161
+ ```bash
162
+ pip install faster-whisper==0.10.1
163
+ ```
164
+
165
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
166
+ ```bash
167
+ # Evaluation for Seed-TTS test set
168
+ python scripts/eval_seedtts_testset.py
169
+
170
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
171
+ python scripts/eval_librispeech_test_clean.py
172
+ ```
173
+
174
+ ## Acknowledgements
175
+
176
+ - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
177
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
178
+ - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
179
+ - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
180
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
181
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
182
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
183
+ - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
184
+
185
+ ## Citation
186
+ ```
187
+ @article{chen-etal-2024-f5tts,
188
+ title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
189
+ author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
190
+ journal={arXiv preprint arXiv:2410.06885},
191
+ year={2024},
192
+ }
193
+ ```
194
+ ## License
195
+
196
+ Our code is released under MIT License.
app.py CHANGED
@@ -6,28 +6,53 @@ import gradio as gr
6
  import numpy as np
7
  import tempfile
8
  from einops import rearrange
9
- from ema_pytorch import EMA
10
  from vocos import Vocos
11
  from pydub import AudioSegment, silence
12
  from model import CFM, UNetT, DiT, MMDiT
13
  from cached_path import cached_path
14
  from model.utils import (
15
- get_tokenizer,
16
- convert_char_to_pinyin,
 
17
  save_spectrogram,
18
  )
19
  from transformers import pipeline
20
- import spaces
21
  import librosa
 
22
  import soundfile as sf
23
- from txtsplit import txtsplit
24
- from detoxify import Detoxify
25
 
 
 
 
 
 
26
 
27
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- model = Detoxify('original', device=device)
 
 
 
 
30
 
 
31
 
32
  pipe = pipeline(
33
  "automatic-speech-recognition",
@@ -35,6 +60,7 @@ pipe = pipeline(
35
  torch_dtype=torch.float16,
36
  device=device,
37
  )
 
38
 
39
  # --------------------- Settings -------------------- #
40
 
@@ -44,20 +70,20 @@ hop_length = 256
44
  target_rms = 0.1
45
  nfe_step = 32 # 16, 32
46
  cfg_strength = 2.0
47
- ode_method = 'euler'
48
  sway_sampling_coef = -1.0
49
  speed = 1.0
50
  # fix_duration = 27 # None or float (duration in seconds)
51
  fix_duration = None
52
 
 
53
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
54
- checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
 
55
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
56
  model = CFM(
57
  transformer=model_cls(
58
- **model_cfg,
59
- text_num_embeds=vocab_size,
60
- mel_dim=n_mel_channels
61
  ),
62
  mel_spec_kwargs=dict(
63
  target_sample_rate=target_sample_rate,
@@ -70,64 +96,130 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
70
  vocab_char_map=vocab_char_map,
71
  ).to(device)
72
 
73
- ema_model = EMA(model, include_online_model=False).to(device)
74
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
75
- ema_model.copy_params_from_ema_to_model()
76
 
77
  return model
78
 
 
79
  # load models
80
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
 
81
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
82
 
83
- F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
84
- E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
 
 
 
 
85
 
86
- @spaces.GPU
87
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
88
- print(gen_text)
89
- if model.predict(gen_text)['toxicity'] > 0.8:
90
- print("Flagged for toxicity:", gen_text)
91
- raise gr.Error("Your text was flagged for toxicity, please try again with a different text.")
92
- gr.Info("Converting audio...")
93
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
94
- aseg = AudioSegment.from_file(ref_audio_orig)
95
- # remove long silence in reference audio
96
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
97
- non_silent_wave = AudioSegment.silent(duration=0)
98
- for non_silent_seg in non_silent_segs:
99
- non_silent_wave += non_silent_seg
100
- aseg = non_silent_wave
101
- # Convert to mono
102
- aseg = aseg.set_channels(1)
103
- audio_duration = len(aseg)
104
- if audio_duration > 15000:
105
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
106
- aseg = aseg[:15000]
107
- aseg.export(f.name, format="wav")
108
- ref_audio = f.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if exp_name == "F5-TTS":
110
  ema_model = F5TTS_ema_model
111
  elif exp_name == "E2-TTS":
112
  ema_model = E2TTS_ema_model
113
-
114
- if not ref_text.strip():
115
- gr.Info("No reference text provided, transcribing reference audio...")
116
- ref_text = outputs = pipe(
117
- ref_audio,
118
- chunk_length_s=30,
119
- batch_size=128,
120
- generate_kwargs={"task": "transcribe"},
121
- return_timestamps=False,
122
- )['text'].strip()
123
- gr.Info("Finished transcription")
124
- else:
125
- gr.Info("Using custom reference text...")
126
- audio, sr = torchaudio.load(ref_audio)
127
- max_chars = int(len(ref_text) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
128
- # Audio
129
  if audio.shape[0] > 1:
130
  audio = torch.mean(audio, dim=0, keepdim=True)
 
131
  rms = torch.sqrt(torch.mean(torch.square(audio)))
132
  if rms < target_rms:
133
  audio = audio * target_rms / rms
@@ -135,28 +227,25 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
135
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
136
  audio = resampler(audio)
137
  audio = audio.to(device)
138
- # Chunk
139
- chunks = txtsplit(gen_text, 0.7*max_chars, 0.9*max_chars)
140
- results = []
141
- generated_mel_specs = []
142
- for chunk in progress.tqdm(chunks):
143
  # Prepare the text
144
- text_list = [ref_text + chunk]
 
 
145
  final_text_list = convert_char_to_pinyin(text_list)
146
-
147
  # Calculate duration
148
  ref_audio_len = audio.shape[-1] // hop_length
149
- # if fix_duration is not None:
150
- # duration = int(fix_duration * target_sample_rate / hop_length)
151
- # else:
152
  zh_pause_punc = r"。,、;:?!"
153
  ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
154
  gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
155
- chunk = len(chunk.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
156
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
157
-
158
  # inference
159
- gr.Info(f"Generating audio using {exp_name}")
160
  with torch.inference_mode():
161
  generated, _ = ema_model.sample(
162
  cond=audio,
@@ -166,29 +255,26 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
166
  cfg_strength=cfg_strength,
167
  sway_sampling_coef=sway_sampling_coef,
168
  )
169
-
170
  generated = generated[:, ref_audio_len:, :]
171
- generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
172
- gr.Info("Running vocoder")
173
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
174
  generated_wave = vocos.decode(generated_mel_spec.cpu())
175
  if rms < target_rms:
176
  generated_wave = generated_wave * rms / target_rms
177
-
178
  # wav -> numpy
179
  generated_wave = generated_wave.squeeze().cpu().numpy()
180
- results.append(generated_wave)
181
- generated_wave = np.concatenate(results)
 
 
 
 
 
 
182
  if remove_silence:
183
- gr.Info("Removing audio silences... This may take a moment")
184
- # non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
185
- # non_silent_wave = np.array([])
186
- # for interval in non_silent_intervals:
187
- # start, end = interval
188
- # non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
189
- # generated_wave = non_silent_wave
190
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
191
- sf.write(f.name, generated_wave, target_sample_rate)
192
  aseg = AudioSegment.from_file(f.name)
193
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
194
  non_silent_wave = AudioSegment.silent(duration=0)
@@ -196,65 +282,543 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
196
  non_silent_wave += non_silent_seg
197
  aseg = non_silent_wave
198
  aseg.export(f.name, format="wav")
199
- generated_wave, _ = torchaudio.load(f.name)
200
- generated_wave = generated_wave.squeeze().cpu().numpy()
201
 
 
 
 
 
 
 
202
 
203
- # spectogram
204
- # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
205
- # spectrogram_path = tmp_spectrogram.name
206
- # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
207
 
208
- return (target_sample_rate, generated_wave)
 
 
 
 
209
 
210
- with gr.Blocks() as app:
211
- gr.Markdown("""
212
- # E2/F5 TTS
213
 
214
- This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
 
 
215
 
216
- * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
217
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
 
 
 
218
 
219
- This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
 
 
 
 
 
220
 
221
- The checkpoints support English and Chinese.
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- The model is licensed under the CC-BY-NC license, this demo cannot be used for commercial purposes.
 
 
226
 
227
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
228
- """)
 
 
 
 
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
231
- gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
232
- model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
 
 
233
  generate_btn = gr.Button("Synthesize", variant="primary")
234
  with gr.Accordion("Advanced Settings", open=False):
235
- ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
236
- remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
237
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  audio_output = gr.Audio(label="Synthesized Audio")
239
- # spectrogram_output = gr.Image(label="Spectrogram")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
242
- gr.Markdown("""
243
- ## Run Locally
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- Run this demo locally on CPU, CUDA, or MPS/Apple Silicon (requires macOS >= 14):
246
 
247
- First, ensure `ffmpeg` is installed.
 
248
 
249
- ```bash
250
- git clone https://huggingface.co/spaces/mrfakename/E2-F5-TTS
251
- cd E2-F5-TTS
252
- python -m pip install -r requirements.txt
253
- python app_local.py
254
- ```
255
 
256
- """)
257
- gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
 
260
- app.queue().launch()
 
 
6
  import numpy as np
7
  import tempfile
8
  from einops import rearrange
 
9
  from vocos import Vocos
10
  from pydub import AudioSegment, silence
11
  from model import CFM, UNetT, DiT, MMDiT
12
  from cached_path import cached_path
13
  from model.utils import (
14
+ load_checkpoint,
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
17
  save_spectrogram,
18
  )
19
  from transformers import pipeline
 
20
  import librosa
21
+ import click
22
  import soundfile as sf
 
 
23
 
24
+ try:
25
+ import spaces
26
+ USING_SPACES = True
27
+ except ImportError:
28
+ USING_SPACES = False
29
 
30
+ def gpu_decorator(func):
31
+ if USING_SPACES:
32
+ return spaces.GPU(func)
33
+ else:
34
+ return func
35
+
36
+
37
+
38
+ SPLIT_WORDS = [
39
+ "but", "however", "nevertheless", "yet", "still",
40
+ "therefore", "thus", "hence", "consequently",
41
+ "moreover", "furthermore", "additionally",
42
+ "meanwhile", "alternatively", "otherwise",
43
+ "namely", "specifically", "for example", "such as",
44
+ "in fact", "indeed", "notably",
45
+ "in contrast", "on the other hand", "conversely",
46
+ "in conclusion", "to summarize", "finally"
47
+ ]
48
 
49
+ device = (
50
+ "cuda"
51
+ if torch.cuda.is_available()
52
+ else "mps" if torch.backends.mps.is_available() else "cpu"
53
+ )
54
 
55
+ print(f"Using {device} device")
56
 
57
  pipe = pipeline(
58
  "automatic-speech-recognition",
 
60
  torch_dtype=torch.float16,
61
  device=device,
62
  )
63
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
 
65
  # --------------------- Settings -------------------- #
66
 
 
70
  target_rms = 0.1
71
  nfe_step = 32 # 16, 32
72
  cfg_strength = 2.0
73
+ ode_method = "euler"
74
  sway_sampling_coef = -1.0
75
  speed = 1.0
76
  # fix_duration = 27 # None or float (duration in seconds)
77
  fix_duration = None
78
 
79
+
80
  def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
  model = CFM(
85
  transformer=model_cls(
86
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
 
 
87
  ),
88
  mel_spec_kwargs=dict(
89
  target_sample_rate=target_sample_rate,
 
96
  vocab_char_map=vocab_char_map,
97
  ).to(device)
98
 
99
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
 
 
100
 
101
  return model
102
 
103
+
104
  # load models
105
+ F5TTS_model_cfg = dict(
106
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
+ )
108
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
 
110
+ F5TTS_ema_model = load_model(
111
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
+ )
113
+ E2TTS_ema_model = load_model(
114
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
+ )
116
 
117
+ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
+ if len(text.encode('utf-8')) <= max_chars:
119
+ return [text]
120
+ if text[-1] not in ['', '.', '!', '!', '?', '?']:
121
+ text += '.'
122
+
123
+ sentences = re.split('([。.!?!?])', text)
124
+ sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
+
126
+ batches = []
127
+ current_batch = ""
128
+
129
+ def split_by_words(text):
130
+ words = text.split()
131
+ current_word_part = ""
132
+ word_batches = []
133
+ for word in words:
134
+ if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
+ current_word_part += word + ' '
136
+ else:
137
+ if current_word_part:
138
+ # Try to find a suitable split word
139
+ for split_word in split_words:
140
+ split_index = current_word_part.rfind(' ' + split_word + ' ')
141
+ if split_index != -1:
142
+ word_batches.append(current_word_part[:split_index].strip())
143
+ current_word_part = current_word_part[split_index:].strip() + ' '
144
+ break
145
+ else:
146
+ # If no suitable split word found, just append the current part
147
+ word_batches.append(current_word_part.strip())
148
+ current_word_part = ""
149
+ current_word_part += word + ' '
150
+ if current_word_part:
151
+ word_batches.append(current_word_part.strip())
152
+ return word_batches
153
+
154
+ for sentence in sentences:
155
+ if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
+ current_batch += sentence
157
+ else:
158
+ # If adding this sentence would exceed the limit
159
+ if current_batch:
160
+ batches.append(current_batch)
161
+ current_batch = ""
162
+
163
+ # If the sentence itself is longer than max_chars, split it
164
+ if len(sentence.encode('utf-8')) > max_chars:
165
+ # First, try to split by colon
166
+ colon_parts = sentence.split(':')
167
+ if len(colon_parts) > 1:
168
+ for part in colon_parts:
169
+ if len(part.encode('utf-8')) <= max_chars:
170
+ batches.append(part)
171
+ else:
172
+ # If colon part is still too long, split by comma
173
+ comma_parts = re.split('[,,]', part)
174
+ if len(comma_parts) > 1:
175
+ current_comma_part = ""
176
+ for comma_part in comma_parts:
177
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
+ current_comma_part += comma_part + ','
179
+ else:
180
+ if current_comma_part:
181
+ batches.append(current_comma_part.rstrip(','))
182
+ current_comma_part = comma_part + ','
183
+ if current_comma_part:
184
+ batches.append(current_comma_part.rstrip(','))
185
+ else:
186
+ # If no comma, split by words
187
+ batches.extend(split_by_words(part))
188
+ else:
189
+ # If no colon, split by comma
190
+ comma_parts = re.split('[,,]', sentence)
191
+ if len(comma_parts) > 1:
192
+ current_comma_part = ""
193
+ for comma_part in comma_parts:
194
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
+ current_comma_part += comma_part + ','
196
+ else:
197
+ if current_comma_part:
198
+ batches.append(current_comma_part.rstrip(','))
199
+ current_comma_part = comma_part + ','
200
+ if current_comma_part:
201
+ batches.append(current_comma_part.rstrip(','))
202
+ else:
203
+ # If no comma, split by words
204
+ batches.extend(split_by_words(sentence))
205
+ else:
206
+ current_batch = sentence
207
+
208
+ if current_batch:
209
+ batches.append(current_batch)
210
+
211
+ return batches
212
+
213
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
  if exp_name == "F5-TTS":
215
  ema_model = F5TTS_ema_model
216
  elif exp_name == "E2-TTS":
217
  ema_model = E2TTS_ema_model
218
+
219
+ audio, sr = ref_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  if audio.shape[0] > 1:
221
  audio = torch.mean(audio, dim=0, keepdim=True)
222
+
223
  rms = torch.sqrt(torch.mean(torch.square(audio)))
224
  if rms < target_rms:
225
  audio = audio * target_rms / rms
 
227
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
  audio = resampler(audio)
229
  audio = audio.to(device)
230
+
231
+ generated_waves = []
232
+ spectrograms = []
233
+
234
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
  # Prepare the text
236
+ if len(ref_text[-1].encode('utf-8')) == 1:
237
+ ref_text = ref_text + " "
238
+ text_list = [ref_text + gen_text]
239
  final_text_list = convert_char_to_pinyin(text_list)
240
+
241
  # Calculate duration
242
  ref_audio_len = audio.shape[-1] // hop_length
 
 
 
243
  zh_pause_punc = r"。,、;:?!"
244
  ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
  gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
 
246
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
+
248
  # inference
 
249
  with torch.inference_mode():
250
  generated, _ = ema_model.sample(
251
  cond=audio,
 
255
  cfg_strength=cfg_strength,
256
  sway_sampling_coef=sway_sampling_coef,
257
  )
258
+
259
  generated = generated[:, ref_audio_len:, :]
260
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
 
 
261
  generated_wave = vocos.decode(generated_mel_spec.cpu())
262
  if rms < target_rms:
263
  generated_wave = generated_wave * rms / target_rms
264
+
265
  # wav -> numpy
266
  generated_wave = generated_wave.squeeze().cpu().numpy()
267
+
268
+ generated_waves.append(generated_wave)
269
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
+
271
+ # Combine all generated waves
272
+ final_wave = np.concatenate(generated_waves)
273
+
274
+ # Remove silence
275
  if remove_silence:
 
 
 
 
 
 
 
276
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
+ sf.write(f.name, final_wave, target_sample_rate)
278
  aseg = AudioSegment.from_file(f.name)
279
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
  non_silent_wave = AudioSegment.silent(duration=0)
 
282
  non_silent_wave += non_silent_seg
283
  aseg = non_silent_wave
284
  aseg.export(f.name, format="wav")
285
+ final_wave, _ = torchaudio.load(f.name)
286
+ final_wave = final_wave.squeeze().cpu().numpy()
287
 
288
+ # Create a combined spectrogram
289
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
+
291
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
+ spectrogram_path = tmp_spectrogram.name
293
+ save_spectrogram(combined_spectrogram, spectrogram_path)
294
 
295
+ return (target_sample_rate, final_wave), spectrogram_path
 
 
 
296
 
297
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
+ if not custom_split_words.strip():
299
+ custom_words = [word.strip() for word in custom_split_words.split(',')]
300
+ global SPLIT_WORDS
301
+ SPLIT_WORDS = custom_words
302
 
303
+ print(gen_text)
 
 
304
 
305
+ gr.Info("Converting audio...")
306
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
+ aseg = AudioSegment.from_file(ref_audio_orig)
308
 
309
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
+ non_silent_wave = AudioSegment.silent(duration=0)
311
+ for non_silent_seg in non_silent_segs:
312
+ non_silent_wave += non_silent_seg
313
+ aseg = non_silent_wave
314
 
315
+ audio_duration = len(aseg)
316
+ if audio_duration > 15000:
317
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
+ aseg = aseg[:15000]
319
+ aseg.export(f.name, format="wav")
320
+ ref_audio = f.name
321
 
322
+ if not ref_text.strip():
323
+ gr.Info("No reference text provided, transcribing reference audio...")
324
+ ref_text = pipe(
325
+ ref_audio,
326
+ chunk_length_s=30,
327
+ batch_size=128,
328
+ generate_kwargs={"task": "transcribe"},
329
+ return_timestamps=False,
330
+ )["text"].strip()
331
+ gr.Info("Finished transcription")
332
+ else:
333
+ gr.Info("Using custom reference text...")
334
 
335
+ # Split the input text into batches
336
+ audio, sr = torchaudio.load(ref_audio)
337
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
+ gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
+ print('ref_text', ref_text)
340
+ for i, gen_text in enumerate(gen_text_batches):
341
+ print(f'gen_text {i}', gen_text)
342
+
343
+ gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
+
346
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
+ # Split the script into speaker blocks
348
+ speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
+ speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
+
351
+ generated_audio_segments = []
352
+
353
+ for i in range(0, len(speaker_blocks), 2):
354
+ speaker = speaker_blocks[i]
355
+ text = speaker_blocks[i+1].strip()
356
+
357
+ # Determine which speaker is talking
358
+ if speaker == speaker1_name:
359
+ ref_audio = ref_audio1
360
+ ref_text = ref_text1
361
+ elif speaker == speaker2_name:
362
+ ref_audio = ref_audio2
363
+ ref_text = ref_text2
364
+ else:
365
+ continue # Skip if the speaker is neither speaker1 nor speaker2
366
+
367
+ # Generate audio for this block
368
+ audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
+
370
+ # Convert the generated audio to a numpy array
371
+ sr, audio_data = audio
372
+
373
+ # Save the audio data as a WAV file
374
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
+ sf.write(temp_file.name, audio_data, sr)
376
+ audio_segment = AudioSegment.from_wav(temp_file.name)
377
+
378
+ generated_audio_segments.append(audio_segment)
379
+
380
+ # Add a short pause between speakers
381
+ pause = AudioSegment.silent(duration=500) # 500ms pause
382
+ generated_audio_segments.append(pause)
383
+
384
+ # Concatenate all audio segments
385
+ final_podcast = sum(generated_audio_segments)
386
+
387
+ # Export the final podcast
388
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
+ podcast_path = temp_file.name
390
+ final_podcast.export(podcast_path, format="wav")
391
+
392
+ return podcast_path
393
 
394
+ def parse_speechtypes_text(gen_text):
395
+ # Pattern to find (Emotion)
396
+ pattern = r'\((.*?)\)'
397
 
398
+ # Split the text by the pattern
399
+ tokens = re.split(pattern, gen_text)
400
+
401
+ segments = []
402
+
403
+ current_emotion = 'Regular'
404
 
405
+ for i in range(len(tokens)):
406
+ if i % 2 == 0:
407
+ # This is text
408
+ text = tokens[i].strip()
409
+ if text:
410
+ segments.append({'emotion': current_emotion, 'text': text})
411
+ else:
412
+ # This is emotion
413
+ emotion = tokens[i].strip()
414
+ current_emotion = emotion
415
+
416
+ return segments
417
+
418
+ def update_speed(new_speed):
419
+ global speed
420
+ speed = new_speed
421
+ return f"Speed set to: {speed}"
422
+
423
+ with gr.Blocks() as app_credits:
424
+ gr.Markdown("""
425
+ # Credits
426
+
427
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
428
+ * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
429
+ """)
430
+ with gr.Blocks() as app_tts:
431
+ gr.Markdown("# Batched TTS")
432
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
433
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
434
+ model_choice = gr.Radio(
435
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
436
+ )
437
  generate_btn = gr.Button("Synthesize", variant="primary")
438
  with gr.Accordion("Advanced Settings", open=False):
439
+ ref_text_input = gr.Textbox(
440
+ label="Reference Text",
441
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
442
+ lines=2,
443
+ )
444
+ remove_silence = gr.Checkbox(
445
+ label="Remove Silences",
446
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
447
+ value=True,
448
+ )
449
+ split_words_input = gr.Textbox(
450
+ label="Custom Split Words",
451
+ info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
452
+ lines=2,
453
+ )
454
+ speed_slider = gr.Slider(
455
+ label="Speed",
456
+ minimum=0.3,
457
+ maximum=2.0,
458
+ value=speed,
459
+ step=0.1,
460
+ info="Adjust the speed of the audio.",
461
+ )
462
+ speed_slider.change(update_speed, inputs=speed_slider)
463
+
464
  audio_output = gr.Audio(label="Synthesized Audio")
465
+ spectrogram_output = gr.Image(label="Spectrogram")
466
+
467
+ generate_btn.click(
468
+ infer,
469
+ inputs=[
470
+ ref_audio_input,
471
+ ref_text_input,
472
+ gen_text_input,
473
+ model_choice,
474
+ remove_silence,
475
+ split_words_input,
476
+ ],
477
+ outputs=[audio_output, spectrogram_output],
478
+ )
479
+
480
+ with gr.Blocks() as app_podcast:
481
+ gr.Markdown("# Podcast Generation")
482
+ speaker1_name = gr.Textbox(label="Speaker 1 Name")
483
+ ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
484
+ ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
485
+
486
+ speaker2_name = gr.Textbox(label="Speaker 2 Name")
487
+ ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
488
+ ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
489
+
490
+ script_input = gr.Textbox(label="Podcast Script", lines=10,
491
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
492
+
493
+ podcast_model_choice = gr.Radio(
494
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
495
+ )
496
+ podcast_remove_silence = gr.Checkbox(
497
+ label="Remove Silences",
498
+ value=True,
499
+ )
500
+ generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
501
+ podcast_output = gr.Audio(label="Generated Podcast")
502
+
503
+ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
504
+ return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
505
+
506
+ generate_podcast_btn.click(
507
+ podcast_generation,
508
+ inputs=[
509
+ script_input,
510
+ speaker1_name,
511
+ ref_audio_input1,
512
+ ref_text_input1,
513
+ speaker2_name,
514
+ ref_audio_input2,
515
+ ref_text_input2,
516
+ podcast_model_choice,
517
+ podcast_remove_silence,
518
+ ],
519
+ outputs=podcast_output,
520
+ )
521
+
522
+ def parse_emotional_text(gen_text):
523
+ # Pattern to find (Emotion)
524
+ pattern = r'\((.*?)\)'
525
+
526
+ # Split the text by the pattern
527
+ tokens = re.split(pattern, gen_text)
528
+
529
+ segments = []
530
+
531
+ current_emotion = 'Regular'
532
+
533
+ for i in range(len(tokens)):
534
+ if i % 2 == 0:
535
+ # This is text
536
+ text = tokens[i].strip()
537
+ if text:
538
+ segments.append({'emotion': current_emotion, 'text': text})
539
+ else:
540
+ # This is emotion
541
+ emotion = tokens[i].strip()
542
+ current_emotion = emotion
543
+
544
+ return segments
545
+
546
+ with gr.Blocks() as app_emotional:
547
+ # New section for emotional generation
548
+ gr.Markdown(
549
+ """
550
+ # Multiple Speech-Type Generation
551
+
552
+ This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
553
+
554
+ **Example Input:**
555
+
556
+ (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
557
+ """
558
+ )
559
+
560
+ gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
561
+
562
+ # Regular speech type (mandatory)
563
+ with gr.Row():
564
+ regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
565
+ regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
566
+ regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
567
+
568
+ # Additional speech types (up to 9 more)
569
+ max_speech_types = 10
570
+ speech_type_names = []
571
+ speech_type_audios = []
572
+ speech_type_ref_texts = []
573
+ speech_type_delete_btns = []
574
+
575
+ for i in range(max_speech_types - 1):
576
+ with gr.Row():
577
+ name_input = gr.Textbox(label='Speech Type Name', visible=False)
578
+ audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
579
+ ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
580
+ delete_btn = gr.Button("Delete", variant="secondary", visible=False)
581
+ speech_type_names.append(name_input)
582
+ speech_type_audios.append(audio_input)
583
+ speech_type_ref_texts.append(ref_text_input)
584
+ speech_type_delete_btns.append(delete_btn)
585
+
586
+ # Button to add speech type
587
+ add_speech_type_btn = gr.Button("Add Speech Type")
588
+
589
+ # Keep track of current number of speech types
590
+ speech_type_count = gr.State(value=0)
591
+
592
+ # Function to add a speech type
593
+ def add_speech_type_fn(speech_type_count):
594
+ if speech_type_count < max_speech_types - 1:
595
+ speech_type_count += 1
596
+ # Prepare updates for the components
597
+ name_updates = []
598
+ audio_updates = []
599
+ ref_text_updates = []
600
+ delete_btn_updates = []
601
+ for i in range(max_speech_types - 1):
602
+ if i < speech_type_count:
603
+ name_updates.append(gr.update(visible=True))
604
+ audio_updates.append(gr.update(visible=True))
605
+ ref_text_updates.append(gr.update(visible=True))
606
+ delete_btn_updates.append(gr.update(visible=True))
607
+ else:
608
+ name_updates.append(gr.update())
609
+ audio_updates.append(gr.update())
610
+ ref_text_updates.append(gr.update())
611
+ delete_btn_updates.append(gr.update())
612
+ else:
613
+ # Optionally, show a warning
614
+ # gr.Warning("Maximum number of speech types reached.")
615
+ name_updates = [gr.update() for _ in range(max_speech_types - 1)]
616
+ audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
617
+ ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
618
+ delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
619
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
620
+
621
+ add_speech_type_btn.click(
622
+ add_speech_type_fn,
623
+ inputs=speech_type_count,
624
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
625
+ )
626
+
627
+ # Function to delete a speech type
628
+ def make_delete_speech_type_fn(index):
629
+ def delete_speech_type_fn(speech_type_count):
630
+ # Prepare updates
631
+ name_updates = []
632
+ audio_updates = []
633
+ ref_text_updates = []
634
+ delete_btn_updates = []
635
+
636
+ for i in range(max_speech_types - 1):
637
+ if i == index:
638
+ name_updates.append(gr.update(visible=False, value=''))
639
+ audio_updates.append(gr.update(visible=False, value=None))
640
+ ref_text_updates.append(gr.update(visible=False, value=''))
641
+ delete_btn_updates.append(gr.update(visible=False))
642
+ else:
643
+ name_updates.append(gr.update())
644
+ audio_updates.append(gr.update())
645
+ ref_text_updates.append(gr.update())
646
+ delete_btn_updates.append(gr.update())
647
+
648
+ speech_type_count = max(0, speech_type_count - 1)
649
+
650
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
651
+
652
+ return delete_speech_type_fn
653
+
654
+ for i, delete_btn in enumerate(speech_type_delete_btns):
655
+ delete_fn = make_delete_speech_type_fn(i)
656
+ delete_btn.click(
657
+ delete_fn,
658
+ inputs=speech_type_count,
659
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
660
+ )
661
+
662
+ # Text input for the prompt
663
+ gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
664
+
665
+ # Model choice
666
+ model_choice_emotional = gr.Radio(
667
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
668
+ )
669
 
670
+ with gr.Accordion("Advanced Settings", open=False):
671
+ remove_silence_emotional = gr.Checkbox(
672
+ label="Remove Silences",
673
+ value=True,
674
+ )
675
+
676
+ # Generate button
677
+ generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
678
+
679
+ # Output audio
680
+ audio_output_emotional = gr.Audio(label="Synthesized Audio")
681
+
682
+ def generate_emotional_speech(
683
+ regular_audio,
684
+ regular_ref_text,
685
+ gen_text,
686
+ *args,
687
+ ):
688
+ num_additional_speech_types = max_speech_types - 1
689
+ speech_type_names_list = args[:num_additional_speech_types]
690
+ speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
691
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
692
+ model_choice = args[3 * num_additional_speech_types]
693
+ remove_silence = args[3 * num_additional_speech_types + 1]
694
+
695
+ # Collect the speech types and their audios into a dict
696
+ speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
697
+
698
+ for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
699
+ if name_input and audio_input:
700
+ speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
701
+
702
+ # Parse the gen_text into segments
703
+ segments = parse_speechtypes_text(gen_text)
704
+
705
+ # For each segment, generate speech
706
+ generated_audio_segments = []
707
+ current_emotion = 'Regular'
708
+
709
+ for segment in segments:
710
+ emotion = segment['emotion']
711
+ text = segment['text']
712
+
713
+ if emotion in speech_types:
714
+ current_emotion = emotion
715
+ else:
716
+ # If emotion not available, default to Regular
717
+ current_emotion = 'Regular'
718
+
719
+ ref_audio = speech_types[current_emotion]['audio']
720
+ ref_text = speech_types[current_emotion].get('ref_text', '')
721
+
722
+ # Generate speech for this segment
723
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
724
+ sr, audio_data = audio
725
+
726
+ generated_audio_segments.append(audio_data)
727
+
728
+ # Concatenate all audio segments
729
+ if generated_audio_segments:
730
+ final_audio_data = np.concatenate(generated_audio_segments)
731
+ return (sr, final_audio_data)
732
+ else:
733
+ gr.Warning("No audio generated.")
734
+ return None
735
+
736
+ generate_emotional_btn.click(
737
+ generate_emotional_speech,
738
+ inputs=[
739
+ regular_audio,
740
+ regular_ref_text,
741
+ gen_text_input_emotional,
742
+ ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
743
+ model_choice_emotional,
744
+ remove_silence_emotional,
745
+ ],
746
+ outputs=audio_output_emotional,
747
+ )
748
+
749
+ # Validation function to disable Generate button if speech types are missing
750
+ def validate_speech_types(
751
+ gen_text,
752
+ regular_name,
753
+ *args
754
+ ):
755
+ num_additional_speech_types = max_speech_types - 1
756
+ speech_type_names_list = args[:num_additional_speech_types]
757
+
758
+ # Collect the speech types names
759
+ speech_types_available = set()
760
+ if regular_name:
761
+ speech_types_available.add(regular_name)
762
+ for name_input in speech_type_names_list:
763
+ if name_input:
764
+ speech_types_available.add(name_input)
765
+
766
+ # Parse the gen_text to get the speech types used
767
+ segments = parse_emotional_text(gen_text)
768
+ speech_types_in_text = set(segment['emotion'] for segment in segments)
769
+
770
+ # Check if all speech types in text are available
771
+ missing_speech_types = speech_types_in_text - speech_types_available
772
+
773
+ if missing_speech_types:
774
+ # Disable the generate button
775
+ return gr.update(interactive=False)
776
+ else:
777
+ # Enable the generate button
778
+ return gr.update(interactive=True)
779
+
780
+ gen_text_input_emotional.change(
781
+ validate_speech_types,
782
+ inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
783
+ outputs=generate_emotional_btn
784
+ )
785
+ with gr.Blocks() as app:
786
+ gr.Markdown(
787
+ """
788
+ # E2/F5 TTS
789
 
790
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
791
 
792
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
793
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
794
 
795
+ The checkpoints support English and Chinese.
 
 
 
 
 
796
 
797
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
798
+
799
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
800
+ """
801
+ )
802
+ gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
803
+
804
+ @click.command()
805
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
806
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
807
+ @click.option(
808
+ "--share",
809
+ "-s",
810
+ default=False,
811
+ is_flag=True,
812
+ help="Share the app via Gradio share link",
813
+ )
814
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
815
+ def main(port, host, share, api):
816
+ global app
817
+ print(f"Starting app...")
818
+ app.queue(api_open=api).launch(
819
+ server_name=host, server_port=port, share=share, show_api=api
820
+ )
821
 
822
 
823
+ if __name__ == "__main__":
824
+ main()
inference-cli.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ import tempfile
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+ from pydub import AudioSegment, silence
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from cached_path import cached_path
11
+ from model.utils import (
12
+ load_checkpoint,
13
+ get_tokenizer,
14
+ convert_char_to_pinyin,
15
+ save_spectrogram,
16
+ )
17
+ from transformers import pipeline
18
+ import soundfile as sf
19
+ import tomli
20
+ import argparse
21
+ import tqdm
22
+ from pathlib import Path
23
+
24
+ parser = argparse.ArgumentParser(
25
+ prog="python3 inference-cli.py",
26
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
27
+ epilog="Specify options above to override one or more settings from config.",
28
+ )
29
+ parser.add_argument(
30
+ "-c",
31
+ "--config",
32
+ help="Configuration file. Default=cli-config.toml",
33
+ default="inference-cli.toml",
34
+ )
35
+ parser.add_argument(
36
+ "-m",
37
+ "--model",
38
+ help="F5-TTS | E2-TTS",
39
+ )
40
+ parser.add_argument(
41
+ "-r",
42
+ "--ref_audio",
43
+ type=str,
44
+ help="Reference audio file < 15 seconds."
45
+ )
46
+ parser.add_argument(
47
+ "-s",
48
+ "--ref_text",
49
+ type=str,
50
+ default="666",
51
+ help="Subtitle for the reference audio."
52
+ )
53
+ parser.add_argument(
54
+ "-t",
55
+ "--gen_text",
56
+ type=str,
57
+ help="Text to generate.",
58
+ )
59
+ parser.add_argument(
60
+ "-o",
61
+ "--output_dir",
62
+ type=str,
63
+ help="Path to output folder..",
64
+ )
65
+ parser.add_argument(
66
+ "--remove_silence",
67
+ help="Remove silence.",
68
+ )
69
+ args = parser.parse_args()
70
+
71
+ config = tomli.load(open(args.config, "rb"))
72
+
73
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
74
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
75
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
76
+ output_dir = args.output_dir if args.output_dir else config["output_dir"]
77
+ model = args.model if args.model else config["model"]
78
+ remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
79
+ wave_path = Path(output_dir)/"out.wav"
80
+ spectrogram_path = Path(output_dir)/"out.png"
81
+
82
+ SPLIT_WORDS = [
83
+ "but", "however", "nevertheless", "yet", "still",
84
+ "therefore", "thus", "hence", "consequently",
85
+ "moreover", "furthermore", "additionally",
86
+ "meanwhile", "alternatively", "otherwise",
87
+ "namely", "specifically", "for example", "such as",
88
+ "in fact", "indeed", "notably",
89
+ "in contrast", "on the other hand", "conversely",
90
+ "in conclusion", "to summarize", "finally"
91
+ ]
92
+
93
+ device = (
94
+ "cuda"
95
+ if torch.cuda.is_available()
96
+ else "mps" if torch.backends.mps.is_available() else "cpu"
97
+ )
98
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
99
+
100
+ print(f"Using {device} device")
101
+
102
+ # --------------------- Settings -------------------- #
103
+
104
+ target_sample_rate = 24000
105
+ n_mel_channels = 100
106
+ hop_length = 256
107
+ target_rms = 0.1
108
+ nfe_step = 32 # 16, 32
109
+ cfg_strength = 2.0
110
+ ode_method = "euler"
111
+ sway_sampling_coef = -1.0
112
+ speed = 1.0
113
+ # fix_duration = 27 # None or float (duration in seconds)
114
+ fix_duration = None
115
+
116
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
117
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
118
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
119
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
120
+ model = CFM(
121
+ transformer=model_cls(
122
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
123
+ ),
124
+ mel_spec_kwargs=dict(
125
+ target_sample_rate=target_sample_rate,
126
+ n_mel_channels=n_mel_channels,
127
+ hop_length=hop_length,
128
+ ),
129
+ odeint_kwargs=dict(
130
+ method=ode_method,
131
+ ),
132
+ vocab_char_map=vocab_char_map,
133
+ ).to(device)
134
+
135
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
136
+
137
+ return model
138
+
139
+
140
+ # load models
141
+ F5TTS_model_cfg = dict(
142
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
143
+ )
144
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
145
+
146
+ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
147
+ if len(text.encode('utf-8')) <= max_chars:
148
+ return [text]
149
+ if text[-1] not in ['。', '.', '!', '!', '?', '?']:
150
+ text += '.'
151
+
152
+ sentences = re.split('([。.!?!?])', text)
153
+ sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
154
+
155
+ batches = []
156
+ current_batch = ""
157
+
158
+ def split_by_words(text):
159
+ words = text.split()
160
+ current_word_part = ""
161
+ word_batches = []
162
+ for word in words:
163
+ if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
164
+ current_word_part += word + ' '
165
+ else:
166
+ if current_word_part:
167
+ # Try to find a suitable split word
168
+ for split_word in split_words:
169
+ split_index = current_word_part.rfind(' ' + split_word + ' ')
170
+ if split_index != -1:
171
+ word_batches.append(current_word_part[:split_index].strip())
172
+ current_word_part = current_word_part[split_index:].strip() + ' '
173
+ break
174
+ else:
175
+ # If no suitable split word found, just append the current part
176
+ word_batches.append(current_word_part.strip())
177
+ current_word_part = ""
178
+ current_word_part += word + ' '
179
+ if current_word_part:
180
+ word_batches.append(current_word_part.strip())
181
+ return word_batches
182
+
183
+ for sentence in sentences:
184
+ if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
185
+ current_batch += sentence
186
+ else:
187
+ # If adding this sentence would exceed the limit
188
+ if current_batch:
189
+ batches.append(current_batch)
190
+ current_batch = ""
191
+
192
+ # If the sentence itself is longer than max_chars, split it
193
+ if len(sentence.encode('utf-8')) > max_chars:
194
+ # First, try to split by colon
195
+ colon_parts = sentence.split(':')
196
+ if len(colon_parts) > 1:
197
+ for part in colon_parts:
198
+ if len(part.encode('utf-8')) <= max_chars:
199
+ batches.append(part)
200
+ else:
201
+ # If colon part is still too long, split by comma
202
+ comma_parts = re.split('[,,]', part)
203
+ if len(comma_parts) > 1:
204
+ current_comma_part = ""
205
+ for comma_part in comma_parts:
206
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
207
+ current_comma_part += comma_part + ','
208
+ else:
209
+ if current_comma_part:
210
+ batches.append(current_comma_part.rstrip(','))
211
+ current_comma_part = comma_part + ','
212
+ if current_comma_part:
213
+ batches.append(current_comma_part.rstrip(','))
214
+ else:
215
+ # If no comma, split by words
216
+ batches.extend(split_by_words(part))
217
+ else:
218
+ # If no colon, split by comma
219
+ comma_parts = re.split('[,,]', sentence)
220
+ if len(comma_parts) > 1:
221
+ current_comma_part = ""
222
+ for comma_part in comma_parts:
223
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
224
+ current_comma_part += comma_part + ','
225
+ else:
226
+ if current_comma_part:
227
+ batches.append(current_comma_part.rstrip(','))
228
+ current_comma_part = comma_part + ','
229
+ if current_comma_part:
230
+ batches.append(current_comma_part.rstrip(','))
231
+ else:
232
+ # If no comma, split by words
233
+ batches.extend(split_by_words(sentence))
234
+ else:
235
+ current_batch = sentence
236
+
237
+ if current_batch:
238
+ batches.append(current_batch)
239
+
240
+ return batches
241
+
242
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
243
+ if model == "F5-TTS":
244
+ ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
245
+ elif model == "E2-TTS":
246
+ ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
247
+
248
+ audio, sr = ref_audio
249
+ if audio.shape[0] > 1:
250
+ audio = torch.mean(audio, dim=0, keepdim=True)
251
+
252
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
253
+ if rms < target_rms:
254
+ audio = audio * target_rms / rms
255
+ if sr != target_sample_rate:
256
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
257
+ audio = resampler(audio)
258
+ audio = audio.to(device)
259
+
260
+ generated_waves = []
261
+ spectrograms = []
262
+
263
+ for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
264
+ # Prepare the text
265
+ if len(ref_text[-1].encode('utf-8')) == 1:
266
+ ref_text = ref_text + " "
267
+ text_list = [ref_text + gen_text]
268
+ final_text_list = convert_char_to_pinyin(text_list)
269
+
270
+ # Calculate duration
271
+ ref_audio_len = audio.shape[-1] // hop_length
272
+ zh_pause_punc = r"。,、;:?!"
273
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
274
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
275
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
276
+
277
+ # inference
278
+ with torch.inference_mode():
279
+ generated, _ = ema_model.sample(
280
+ cond=audio,
281
+ text=final_text_list,
282
+ duration=duration,
283
+ steps=nfe_step,
284
+ cfg_strength=cfg_strength,
285
+ sway_sampling_coef=sway_sampling_coef,
286
+ )
287
+
288
+ generated = generated[:, ref_audio_len:, :]
289
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
290
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
291
+ if rms < target_rms:
292
+ generated_wave = generated_wave * rms / target_rms
293
+
294
+ # wav -> numpy
295
+ generated_wave = generated_wave.squeeze().cpu().numpy()
296
+
297
+ generated_waves.append(generated_wave)
298
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
299
+
300
+ # Combine all generated waves
301
+ final_wave = np.concatenate(generated_waves)
302
+
303
+ with open(wave_path, "wb") as f:
304
+ sf.write(f.name, final_wave, target_sample_rate)
305
+ # Remove silence
306
+ if remove_silence:
307
+ aseg = AudioSegment.from_file(f.name)
308
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
309
+ non_silent_wave = AudioSegment.silent(duration=0)
310
+ for non_silent_seg in non_silent_segs:
311
+ non_silent_wave += non_silent_seg
312
+ aseg = non_silent_wave
313
+ aseg.export(f.name, format="wav")
314
+ print(f.name)
315
+
316
+ # Create a combined spectrogram
317
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
318
+ save_spectrogram(combined_spectrogram, spectrogram_path)
319
+ print(spectrogram_path)
320
+
321
+
322
+ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
323
+ if not custom_split_words.strip():
324
+ custom_words = [word.strip() for word in custom_split_words.split(',')]
325
+ global SPLIT_WORDS
326
+ SPLIT_WORDS = custom_words
327
+
328
+ print(gen_text)
329
+
330
+ print("Converting audio...")
331
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
332
+ aseg = AudioSegment.from_file(ref_audio_orig)
333
+
334
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
335
+ non_silent_wave = AudioSegment.silent(duration=0)
336
+ for non_silent_seg in non_silent_segs:
337
+ non_silent_wave += non_silent_seg
338
+ aseg = non_silent_wave
339
+
340
+ audio_duration = len(aseg)
341
+ if audio_duration > 15000:
342
+ print("Audio is over 15s, clipping to only first 15s.")
343
+ aseg = aseg[:15000]
344
+ aseg.export(f.name, format="wav")
345
+ ref_audio = f.name
346
+
347
+ if not ref_text.strip():
348
+ print("No reference text provided, transcribing reference audio...")
349
+ pipe = pipeline(
350
+ "automatic-speech-recognition",
351
+ model="openai/whisper-large-v3-turbo",
352
+ torch_dtype=torch.float16,
353
+ device=device,
354
+ )
355
+ ref_text = pipe(
356
+ ref_audio,
357
+ chunk_length_s=30,
358
+ batch_size=128,
359
+ generate_kwargs={"task": "transcribe"},
360
+ return_timestamps=False,
361
+ )["text"].strip()
362
+ print("Finished transcription")
363
+ else:
364
+ print("Using custom reference text...")
365
+
366
+ # Split the input text into batches
367
+ audio, sr = torchaudio.load(ref_audio)
368
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
369
+ gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
370
+ print('ref_text', ref_text)
371
+ for i, gen_text in enumerate(gen_text_batches):
372
+ print(f'gen_text {i}', gen_text)
373
+
374
+ print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
375
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
376
+
377
+
378
+ infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
inference-cli.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
+ # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = "Some call me nature, others call me mother nature."
6
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
+ remove_silence = true
8
+ output_dir = "tests"
model/cfm.py CHANGED
@@ -95,6 +95,7 @@ class CFM(nn.Module):
95
  no_ref_audio = False,
96
  duplicate_test = False,
97
  t_inter = 0.1,
 
98
  ):
99
  self.eval()
100
 
@@ -125,6 +126,8 @@ class CFM(nn.Module):
125
  # duration
126
 
127
  cond_mask = lens_to_mask(lens)
 
 
128
 
129
  if isinstance(duration, int):
130
  duration = torch.full((batch,), duration, device = device, dtype = torch.long)
@@ -142,7 +145,10 @@ class CFM(nn.Module):
142
  cond_mask = rearrange(cond_mask, '... -> ... 1')
143
  step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
144
 
145
- mask = lens_to_mask(duration)
 
 
 
146
 
147
  # test for no ref audio
148
  if no_ref_audio:
 
95
  no_ref_audio = False,
96
  duplicate_test = False,
97
  t_inter = 0.1,
98
+ edit_mask = None,
99
  ):
100
  self.eval()
101
 
 
126
  # duration
127
 
128
  cond_mask = lens_to_mask(lens)
129
+ if edit_mask is not None:
130
+ cond_mask = cond_mask & edit_mask
131
 
132
  if isinstance(duration, int):
133
  duration = torch.full((batch,), duration, device = device, dtype = torch.long)
 
145
  cond_mask = rearrange(cond_mask, '... -> ... 1')
146
  step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
147
 
148
+ if batch > 1:
149
+ mask = lens_to_mask(duration)
150
+ else: # save memory and speed up, as single inference need no mask currently
151
+ mask = None
152
 
153
  # test for no ref audio
154
  if no_ref_audio:
model/dataset.py CHANGED
@@ -188,7 +188,7 @@ def load_dataset(
188
  dataset_type: str = "CustomDataset",
189
  audio_type: str = "raw",
190
  mel_spec_kwargs: dict = dict()
191
- ) -> CustomDataset | HFDataset:
192
 
193
  print("Loading dataset ...")
194
 
 
188
  dataset_type: str = "CustomDataset",
189
  audio_type: str = "raw",
190
  mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset:
192
 
193
  print("Loading dataset ...")
194
 
model/trainer.py CHANGED
@@ -138,19 +138,24 @@ class Trainer:
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
- latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
145
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
146
 
147
  if self.is_main:
148
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
149
 
150
- if self.scheduler:
151
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
152
-
153
- step = checkpoint['step']
 
 
 
 
 
 
 
154
  del checkpoint; gc.collect()
155
  return step
156
 
@@ -163,16 +168,16 @@ class Trainer:
163
  generator = None
164
 
165
  if self.batch_size_type == "sample":
166
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
167
  batch_size=self.batch_size, shuffle=True, generator=generator)
168
  elif self.batch_size_type == "frame":
169
  self.accelerator.even_batches = False
170
  sampler = SequentialSampler(train_dataset)
171
  batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
172
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
173
  batch_sampler=batch_sampler)
174
  else:
175
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}")
176
 
177
  # accelerator.prepare() dispatches batches to devices;
178
  # which means the length of dataloader calculated before, should consider the number of devices
 
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
 
 
144
 
145
  if self.is_main:
146
  self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
 
148
+ if 'step' in checkpoint:
149
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
+ if self.scheduler:
152
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
+ step = checkpoint['step']
154
+ else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
+ step = 0
158
+
159
  del checkpoint; gc.collect()
160
  return step
161
 
 
168
  generator = None
169
 
170
  if self.batch_size_type == "sample":
171
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
  batch_size=self.batch_size, shuffle=True, generator=generator)
173
  elif self.batch_size_type == "frame":
174
  self.accelerator.even_batches = False
175
  sampler = SequentialSampler(train_dataset)
176
  batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
  batch_sampler=batch_sampler)
179
  else:
180
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
 
182
  # accelerator.prepare() dispatches batches to devices;
183
  # which means the length of dataloader calculated before, should consider the number of devices
model/utils.py CHANGED
@@ -134,7 +134,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
134
  - if use "byte", set to 256 (unicode byte range)
135
  '''
136
  if tokenizer in ["pinyin", "char"]:
137
- with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f:
138
  vocab_char_map = {}
139
  for i, char in enumerate(f):
140
  vocab_char_map[char[:-1]] = i
@@ -153,9 +153,11 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
153
  def convert_char_to_pinyin(text_list, polyphone = True):
154
  final_text_list = []
155
  god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
 
156
  for text in text_list:
157
  char_list = []
158
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
 
159
  for seg in jieba.cut(text):
160
  seg_byte_len = len(bytes(seg, 'UTF-8'))
161
  if seg_byte_len == len(seg): # if pure alphabets and symbols
@@ -273,6 +275,8 @@ def get_inference_prompt(
273
  ref_audio = resampler(ref_audio)
274
 
275
  # Text
 
 
276
  text = [prompt_text + gt_text]
277
  if tokenizer == "pinyin":
278
  text_list = convert_char_to_pinyin(text, polyphone = polyphone)
@@ -292,8 +296,8 @@ def get_inference_prompt(
292
  # ref_audio = gt_audio
293
  else:
294
  zh_pause_punc = r"。,、;:?!"
295
- ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
296
- gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
297
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
298
 
299
  # to mel spectrogram
@@ -543,3 +547,28 @@ def repetition_found(text, length = 2, tolerance = 10):
543
  if count > tolerance:
544
  return True
545
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  - if use "byte", set to 256 (unicode byte range)
135
  '''
136
  if tokenizer in ["pinyin", "char"]:
137
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
138
  vocab_char_map = {}
139
  for i, char in enumerate(f):
140
  vocab_char_map[char[:-1]] = i
 
153
  def convert_char_to_pinyin(text_list, polyphone = True):
154
  final_text_list = []
155
  god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
156
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
157
  for text in text_list:
158
  char_list = []
159
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
160
+ text = text.translate(custom_trans)
161
  for seg in jieba.cut(text):
162
  seg_byte_len = len(bytes(seg, 'UTF-8'))
163
  if seg_byte_len == len(seg): # if pure alphabets and symbols
 
275
  ref_audio = resampler(ref_audio)
276
 
277
  # Text
278
+ if len(prompt_text[-1].encode('utf-8')) == 1:
279
+ prompt_text = prompt_text + " "
280
  text = [prompt_text + gt_text]
281
  if tokenizer == "pinyin":
282
  text_list = convert_char_to_pinyin(text, polyphone = polyphone)
 
296
  # ref_audio = gt_audio
297
  else:
298
  zh_pause_punc = r"。,、;:?!"
299
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
300
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
301
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
302
 
303
  # to mel spectrogram
 
547
  if count > tolerance:
548
  return True
549
  return False
550
+
551
+
552
+ # load model checkpoint for inference
553
+
554
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
555
+ from ema_pytorch import EMA
556
+
557
+ ckpt_type = ckpt_path.split(".")[-1]
558
+ if ckpt_type == "safetensors":
559
+ from safetensors.torch import load_file
560
+ checkpoint = load_file(ckpt_path, device=device)
561
+ else:
562
+ checkpoint = torch.load(ckpt_path, map_location=device)
563
+
564
+ if use_ema == True:
565
+ ema_model = EMA(model, include_online_model = False).to(device)
566
+ if ckpt_type == "safetensors":
567
+ ema_model.load_state_dict(checkpoint)
568
+ else:
569
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
570
+ ema_model.copy_params_from_ema_to_model()
571
+ else:
572
+ model.load_state_dict(checkpoint['model_state_dict'])
573
+
574
+ return model
requirements.txt CHANGED
@@ -1,17 +1,24 @@
1
  accelerate>=0.33.0
 
 
2
  datasets
3
  einops>=0.8.0
4
  einx>=0.3.0
5
  ema_pytorch>=0.5.2
6
  faster_whisper
7
  funasr
 
8
  jieba
9
  jiwer
10
  librosa
11
  matplotlib
 
 
12
  pypinyin
13
- torch>=2.0
14
- torchaudio>=2.3.0
 
 
15
  torchdiffeq
16
  tqdm>=4.65.0
17
  transformers
@@ -19,9 +26,4 @@ vocos
19
  wandb
20
  x_transformers>=1.31.14
21
  zhconv
22
- zhon
23
- cached_path
24
- pydub
25
- txtsplit
26
- detoxify
27
- soundfile
 
1
  accelerate>=0.33.0
2
+ cached_path
3
+ click
4
  datasets
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
8
  faster_whisper
9
  funasr
10
+ gradio
11
  jieba
12
  jiwer
13
  librosa
14
  matplotlib
15
+ numpy==1.23.5
16
+ pydub
17
  pypinyin
18
+ safetensors
19
+ soundfile
20
+ # torch>=2.0
21
+ # torchaudio>=2.3.0
22
  torchdiffeq
23
  tqdm>=4.65.0
24
  transformers
 
26
  wandb
27
  x_transformers>=1.31.14
28
  zhconv
29
+ zhon
 
 
 
 
 
scripts/eval_infer_batch.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ import time
5
+ import random
6
+ from tqdm import tqdm
7
+ import argparse
8
+
9
+ import torch
10
+ import torchaudio
11
+ from accelerate import Accelerator
12
+ from einops import rearrange
13
+ from vocos import Vocos
14
+
15
+ from model import CFM, UNetT, DiT
16
+ from model.utils import (
17
+ load_checkpoint,
18
+ get_tokenizer,
19
+ get_seedtts_testset_metainfo,
20
+ get_librispeech_test_clean_metainfo,
21
+ get_inference_prompt,
22
+ )
23
+
24
+ accelerator = Accelerator()
25
+ device = f"cuda:{accelerator.process_index}"
26
+
27
+
28
+ # --------------------- Dataset Settings -------------------- #
29
+
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ target_rms = 0.1
34
+
35
+ tokenizer = "pinyin"
36
+
37
+
38
+ # ---------------------- infer setting ---------------------- #
39
+
40
+ parser = argparse.ArgumentParser(description="batch inference")
41
+
42
+ parser.add_argument('-s', '--seed', default=None, type=int)
43
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
44
+ parser.add_argument('-n', '--expname', required=True)
45
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
46
+
47
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
48
+ parser.add_argument('-o', '--odemethod', default="euler")
49
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
50
+
51
+ parser.add_argument('-t', '--testset', required=True)
52
+
53
+ args = parser.parse_args()
54
+
55
+
56
+ seed = args.seed
57
+ dataset_name = args.dataset
58
+ exp_name = args.expname
59
+ ckpt_step = args.ckptstep
60
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
61
+
62
+ nfe_step = args.nfestep
63
+ ode_method = args.odemethod
64
+ sway_sampling_coef = args.swaysampling
65
+
66
+ testset = args.testset
67
+
68
+
69
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
70
+ cfg_strength = 2.
71
+ speed = 1.
72
+ use_truth_duration = False
73
+ no_ref_audio = False
74
+
75
+
76
+ if exp_name == "F5TTS_Base":
77
+ model_cls = DiT
78
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
79
+
80
+ elif exp_name == "E2TTS_Base":
81
+ model_cls = UNetT
82
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
83
+
84
+
85
+ if testset == "ls_pc_test_clean":
86
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
87
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
88
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
89
+
90
+ elif testset == "seedtts_test_zh":
91
+ metalst = "data/seedtts_testset/zh/meta.lst"
92
+ metainfo = get_seedtts_testset_metainfo(metalst)
93
+
94
+ elif testset == "seedtts_test_en":
95
+ metalst = "data/seedtts_testset/en/meta.lst"
96
+ metainfo = get_seedtts_testset_metainfo(metalst)
97
+
98
+
99
+ # path to save genereted wavs
100
+ if seed is None: seed = random.randint(-10000, 10000)
101
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
102
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
103
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
104
+ f"_cfg{cfg_strength}_speed{speed}" \
105
+ f"{'_gt-dur' if use_truth_duration else ''}" \
106
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
107
+
108
+
109
+ # -------------------------------------------------#
110
+
111
+ use_ema = True
112
+
113
+ prompts_all = get_inference_prompt(
114
+ metainfo,
115
+ speed = speed,
116
+ tokenizer = tokenizer,
117
+ target_sample_rate = target_sample_rate,
118
+ n_mel_channels = n_mel_channels,
119
+ hop_length = hop_length,
120
+ target_rms = target_rms,
121
+ use_truth_duration = use_truth_duration,
122
+ infer_batch_size = infer_batch_size,
123
+ )
124
+
125
+ # Vocoder model
126
+ local = False
127
+ if local:
128
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
131
+ vocos.load_state_dict(state_dict)
132
+ vocos.eval()
133
+ else:
134
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
135
+
136
+ # Tokenizer
137
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
138
+
139
+ # Model
140
+ model = CFM(
141
+ transformer = model_cls(
142
+ **model_cfg,
143
+ text_num_embeds = vocab_size,
144
+ mel_dim = n_mel_channels
145
+ ),
146
+ mel_spec_kwargs = dict(
147
+ target_sample_rate = target_sample_rate,
148
+ n_mel_channels = n_mel_channels,
149
+ hop_length = hop_length,
150
+ ),
151
+ odeint_kwargs = dict(
152
+ method = ode_method,
153
+ ),
154
+ vocab_char_map = vocab_char_map,
155
+ ).to(device)
156
+
157
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
158
+
159
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
160
+ os.makedirs(output_dir)
161
+
162
+ # start batch inference
163
+ accelerator.wait_for_everyone()
164
+ start = time.time()
165
+
166
+ with accelerator.split_between_processes(prompts_all) as prompts:
167
+
168
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
+ ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
173
+
174
+ # Inference
175
+ with torch.inference_mode():
176
+ generated, _ = model.sample(
177
+ cond = ref_mels,
178
+ text = final_text_list,
179
+ duration = total_mel_lens,
180
+ lens = ref_mel_lens,
181
+ steps = nfe_step,
182
+ cfg_strength = cfg_strength,
183
+ sway_sampling_coef = sway_sampling_coef,
184
+ no_ref_audio = no_ref_audio,
185
+ seed = seed,
186
+ )
187
+ # Final result
188
+ for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
191
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
192
+ if ref_rms_list[i] < target_rms:
193
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
194
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
195
+
196
+ accelerator.wait_for_everyone()
197
+ if accelerator.is_main_process:
198
+ timediff = time.time() - start
199
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
scripts/eval_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
speech_edit.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from model.utils import (
11
+ load_checkpoint,
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+
44
+ if exp_name == "F5TTS_Base":
45
+ model_cls = DiT
46
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
47
+
48
+ elif exp_name == "E2TTS_Base":
49
+ model_cls = UNetT
50
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
+
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
+ output_dir = "tests"
54
+
55
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
56
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
57
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
58
+ # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
59
+ # [result will be saved at same path of audio file]
60
+ # [--language "zho" for Chinese, "eng" for English]
61
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
62
+
63
+ audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
64
+ origin_text = "Some call me nature, others call me mother nature."
65
+ target_text = "Some call me optimist, others call me realist."
66
+ parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
68
+
69
+ # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
70
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
71
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
72
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
73
+ # fix_duration = None # use origin text duration
74
+
75
+
76
+ # -------------------------------------------------#
77
+
78
+ use_ema = True
79
+
80
+ if not os.path.exists(output_dir):
81
+ os.makedirs(output_dir)
82
+
83
+ # Vocoder model
84
+ local = False
85
+ if local:
86
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
+ vocos.load_state_dict(state_dict)
90
+ vocos.eval()
91
+ else:
92
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
93
+
94
+ # Tokenizer
95
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
+
97
+ # Model
98
+ model = CFM(
99
+ transformer = model_cls(
100
+ **model_cfg,
101
+ text_num_embeds = vocab_size,
102
+ mel_dim = n_mel_channels
103
+ ),
104
+ mel_spec_kwargs = dict(
105
+ target_sample_rate = target_sample_rate,
106
+ n_mel_channels = n_mel_channels,
107
+ hop_length = hop_length,
108
+ ),
109
+ odeint_kwargs = dict(
110
+ method = ode_method,
111
+ ),
112
+ vocab_char_map = vocab_char_map,
113
+ ).to(device)
114
+
115
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
116
+
117
+ # Audio
118
+ audio, sr = torchaudio.load(audio_to_edit)
119
+ if audio.shape[0] > 1:
120
+ audio = torch.mean(audio, dim=0, keepdim=True)
121
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
122
+ if rms < target_rms:
123
+ audio = audio * target_rms / rms
124
+ if sr != target_sample_rate:
125
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126
+ audio = resampler(audio)
127
+ offset = 0
128
+ audio_ = torch.zeros(1, 0)
129
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
130
+ for part in parts_to_edit:
131
+ start, end = part
132
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133
+ part_dur = part_dur * target_sample_rate
134
+ start = start * target_sample_rate
135
+ audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
136
+ edit_mask = torch.cat((edit_mask,
137
+ torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
138
+ torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
139
+ ), dim = -1)
140
+ offset = end * target_sample_rate
141
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
142
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
143
+ audio = audio.to(device)
144
+ edit_mask = edit_mask.to(device)
145
+
146
+ # Text
147
+ text_list = [target_text]
148
+ if tokenizer == "pinyin":
149
+ final_text_list = convert_char_to_pinyin(text_list)
150
+ else:
151
+ final_text_list = [text_list]
152
+ print(f"text : {text_list}")
153
+ print(f"pinyin: {final_text_list}")
154
+
155
+ # Duration
156
+ ref_audio_len = 0
157
+ duration = audio.shape[-1] // hop_length
158
+
159
+ # Inference
160
+ with torch.inference_mode():
161
+ generated, trajectory = model.sample(
162
+ cond = audio,
163
+ text = final_text_list,
164
+ duration = duration,
165
+ steps = nfe_step,
166
+ cfg_strength = cfg_strength,
167
+ sway_sampling_coef = sway_sampling_coef,
168
+ seed = seed,
169
+ edit_mask = edit_mask,
170
+ )
171
+ print(f"Generated mel: {generated.shape}")
172
+
173
+ # Final result
174
+ generated = generated[:, ref_audio_len:, :]
175
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
176
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
177
+ if rms < target_rms:
178
+ generated_wave = generated_wave * rms / target_rms
179
+
180
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
181
+ torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
182
+ print(f"Generated wav: {generated_wave.shape}")
train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
2
+ from model.utils import get_tokenizer
3
+ from model.dataset import load_dataset
4
+
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin"
13
+ dataset_name = "Emilia_ZH_EN"
14
+
15
+
16
+ # -------------------------- Training Settings -------------------------- #
17
+
18
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
+
20
+ learning_rate = 7.5e-5
21
+
22
+ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
+ batch_size_type = "frame" # "frame" or "sample"
24
+ max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
+ grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.
27
+
28
+ epochs = 11 # use linear decay, thus epochs control the slope
29
+ num_warmup_updates = 20000 # warmup steps
30
+ save_per_updates = 50000 # save checkpoint per steps
31
+ last_per_steps = 5000 # save last checkpoint per steps
32
+
33
+ # model params
34
+ if exp_name == "F5TTS_Base":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
+ elif exp_name == "E2TTS_Base":
39
+ wandb_resume_id = None
40
+ model_cls = UNetT
41
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
+
43
+
44
+ # ----------------------------------------------------------------------- #
45
+
46
+ def main():
47
+
48
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
49
+
50
+ mel_spec_kwargs = dict(
51
+ target_sample_rate = target_sample_rate,
52
+ n_mel_channels = n_mel_channels,
53
+ hop_length = hop_length,
54
+ )
55
+
56
+ e2tts = CFM(
57
+ transformer = model_cls(
58
+ **model_cfg,
59
+ text_num_embeds = vocab_size,
60
+ mel_dim = n_mel_channels
61
+ ),
62
+ mel_spec_kwargs = mel_spec_kwargs,
63
+ vocab_char_map = vocab_char_map,
64
+ )
65
+
66
+ trainer = Trainer(
67
+ e2tts,
68
+ epochs,
69
+ learning_rate,
70
+ num_warmup_updates = num_warmup_updates,
71
+ save_per_updates = save_per_updates,
72
+ checkpoint_path = f'ckpts/{exp_name}',
73
+ batch_size = batch_size_per_gpu,
74
+ batch_size_type = batch_size_type,
75
+ max_samples = max_samples,
76
+ grad_accumulation_steps = grad_accumulation_steps,
77
+ max_grad_norm = max_grad_norm,
78
+ wandb_project = "CFM-TTS",
79
+ wandb_run_name = exp_name,
80
+ wandb_resume_id = wandb_resume_id,
81
+ last_per_steps = last_per_steps,
82
+ )
83
+
84
+ train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(train_dataset,
86
+ resumable_with_seed = 666 # seed for shuffling dataset
87
+ )
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()