PoTaTo721 commited on
Commit
644fd7b
1 Parent(s): 469209d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -130
app.py CHANGED
@@ -5,7 +5,7 @@ import hydra
5
 
6
  # Download if not exists
7
  os.makedirs("checkpoints", exist_ok=True)
8
- snapshot_download(repo_id="fishaudio/fish-speech-1", local_dir="./checkpoints/fish-speech-1")
9
 
10
  print("All checkpoints downloaded")
11
 
@@ -30,8 +30,8 @@ os.environ["EINX_FILTER_TRACEBACK"] = "false"
30
 
31
  HEADER_MD = """# Fish Speech
32
 
33
- ## The demo in this space is version 1.0, Please check [Fish Audio](https://fish.audio) for the best model.
34
- ## 该 Demo 为 Fish Speech 1.0 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
35
 
36
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
37
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
@@ -39,14 +39,14 @@ A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https
39
  You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
40
  你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
41
 
42
- Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.
43
- 相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.
44
 
45
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
46
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
47
 
48
- The model running in this WebUI is Fish Speech V1 Medium SFT 4K.
49
- 在此 WebUI 中运行的模型是 Fish Speech V1 Medium SFT 4K.
50
  """
51
 
52
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -85,36 +85,27 @@ def inference(
85
  top_p,
86
  repetition_penalty,
87
  temperature,
88
- speaker,
89
  ):
90
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
91
- return None, f"Text is too long, please keep it under {args.max_gradio_length} characters."
92
-
93
- # Parse reference audio aka prompt
94
- prompt_tokens = None
95
- if enable_reference_audio and reference_audio is not None:
96
- # reference_audio_sr, reference_audio_content = reference_audio
97
- reference_audio_content, _ = librosa.load(
98
- reference_audio, sr=vqgan_model.sampling_rate, mono=True
99
- )
100
- audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
101
- None, None, :
102
- ]
103
-
104
- logger.info(
105
- f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
106
  )
107
 
108
- # VQ Encoder
109
- audio_lengths = torch.tensor(
110
- [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
111
- )
112
- prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
 
113
 
114
  # LLAMA Inference
115
  request = dict(
116
- tokenizer=llama_tokenizer,
117
- device=vqgan_model.device,
118
  max_new_tokens=max_new_tokens,
119
  text=text,
120
  top_p=top_p,
@@ -123,43 +114,246 @@ def inference(
123
  compile=args.compile,
124
  iterative_prompt=chunk_length > 0,
125
  chunk_length=chunk_length,
126
- max_length=args.max_length,
127
- speaker=speaker if speaker else None,
128
  prompt_tokens=prompt_tokens if enable_reference_audio else None,
129
  prompt_text=reference_text if enable_reference_audio else None,
130
  )
131
 
132
- payload = dict(
133
- response_queue=queue.Queue(),
134
- request=request,
 
 
 
135
  )
136
- llama_queue.put(payload)
137
 
138
- codes = []
 
 
 
 
139
  while True:
140
- result = payload["response_queue"].get()
141
- if result == "next":
142
- # TODO: handle next sentence
143
- continue
144
-
145
- if result == "done":
146
- if payload["success"] is False:
147
- return None, build_html_error_message(payload["response"])
148
  break
149
 
150
- codes.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- codes = torch.cat(codes, dim=1)
153
 
154
- # VQGAN Inference
155
- feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
156
- fake_audios = vqgan_model.decode(
157
- indices=codes[None], feature_lengths=feature_lengths, return_audios=True
158
- )[0, 0]
159
 
160
- fake_audios = fake_audios.float().cpu().numpy()
 
 
 
 
161
 
162
- return (vqgan_model.sampling_rate, fake_audios), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  def build_app():
@@ -170,95 +364,182 @@ def build_app():
170
  app.load(
171
  None,
172
  None,
173
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
 
174
  )
175
 
176
  # Inference
177
  with gr.Row():
178
  with gr.Column(scale=3):
179
  text = gr.Textbox(
180
- label="Input Text / 输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
 
 
 
 
 
 
 
 
181
  )
182
 
183
  with gr.Row():
184
- with gr.Tab(label="Advanced Config / 高级参数"):
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  chunk_length = gr.Slider(
186
- label="Iterative Prompt Length, 0 means off / 迭代提示长度,0 表示关闭",
187
  minimum=0,
188
- maximum=100,
189
- value=30,
190
  step=8,
191
  )
192
 
193
  max_new_tokens = gr.Slider(
194
- label="Maximum tokens per batch, 0 means no limit / 每批最大令牌数,0 表示无限制",
195
- minimum=128,
196
- maximum=512,
197
- value=512, # 0 means no limit
198
  step=8,
199
  )
200
 
201
  top_p = gr.Slider(
202
- label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
 
 
 
 
203
  )
204
 
205
  repetition_penalty = gr.Slider(
206
- label="Repetition Penalty",
207
- minimum=0,
208
- maximum=2,
209
- value=1.5,
210
  step=0.01,
211
  )
212
 
213
  temperature = gr.Slider(
214
  label="Temperature",
215
- minimum=0,
216
- maximum=2,
217
  value=0.7,
218
  step=0.01,
219
  )
220
 
221
- speaker = gr.Textbox(
222
- label="Speaker / 说话人",
223
- placeholder="Type name of the speaker / 输入说话人的名称",
224
- lines=1,
225
- )
226
-
227
- with gr.Tab(label="Reference Audio / 参考音频"):
228
  gr.Markdown(
229
- "5 to 10 seconds of reference audio, useful for specifying speaker. \n5 到 10 秒的参考音频,适用于指定音色。"
 
 
230
  )
231
 
232
  enable_reference_audio = gr.Checkbox(
233
- label="Enable Reference Audio / 启用参考音频",
234
  )
235
  reference_audio = gr.Audio(
236
- label="Reference Audio / 参考音频",
237
  type="filepath",
238
  )
239
- reference_text = gr.Textbox(
240
- label="Reference Text / 参考文本",
241
- placeholder="参考文本",
242
- lines=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  )
244
 
245
  with gr.Column(scale=3):
246
- with gr.Row():
247
- error = gr.HTML(label="Error Message / 错误信息")
248
- with gr.Row():
249
- audio = gr.Audio(label="Generated Audio / 音频", type="numpy")
 
 
 
 
 
 
 
 
 
 
 
250
 
 
 
 
 
 
 
 
 
251
  with gr.Row():
252
  with gr.Column(scale=3):
253
  generate = gr.Button(
254
- value="\U0001F3A7 Generate / 合成", variant="primary"
 
 
 
 
255
  )
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # # Submit
258
  generate.click(
259
- inference,
260
  [
261
- text,
262
  enable_reference_audio,
263
  reference_audio,
264
  reference_text,
@@ -267,12 +548,29 @@ def build_app():
267
  top_p,
268
  repetition_penalty,
269
  temperature,
270
- speaker,
 
271
  ],
272
- [audio, error],
273
  concurrency_limit=1,
274
  )
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  return app
277
 
278
 
@@ -281,74 +579,60 @@ def parse_args():
281
  parser.add_argument(
282
  "--llama-checkpoint-path",
283
  type=Path,
284
- default="checkpoints/text2semantic-sft-large-v1-4k.pth",
285
  )
286
  parser.add_argument(
287
- "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
288
- )
289
- parser.add_argument(
290
- "--vqgan-checkpoint-path",
291
  type=Path,
292
- default="checkpoints/vq-gan-group-fsq-2x1024.pth",
293
  )
294
- parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
295
- parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
296
  parser.add_argument("--device", type=str, default="cuda")
297
  parser.add_argument("--half", action="store_true")
298
- parser.add_argument("--max-length", type=int, default=2048)
299
  parser.add_argument("--compile", action="store_true")
300
  parser.add_argument("--max-gradio-length", type=int, default=0)
 
301
 
302
  return parser.parse_args()
303
 
304
 
305
  if __name__ == "__main__":
306
  args = parse_args()
307
-
308
  args.precision = torch.half if args.half else torch.bfloat16
309
- args.compile = True
310
- args.max_gradio_length = 1024
311
- args.tokenizer = "./checkpoints/fish-speech-1"
312
- args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-medium-v1-4k.pth"
313
- args.llama_config_name = "dual_ar_2_codebook_medium"
314
- args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
315
- args.vqgan_config_name = "vqgan_pretrain"
316
 
317
  logger.info("Loading Llama model...")
318
  llama_queue = launch_thread_safe_queue(
319
- config_name=args.llama_config_name,
320
  checkpoint_path=args.llama_checkpoint_path,
321
  device=args.device,
322
  precision=args.precision,
323
- max_length=args.max_length,
324
  compile=args.compile,
325
  )
326
- llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
327
  logger.info("Llama model loaded, loading VQ-GAN model...")
328
 
329
- vqgan_model = load_vqgan_model(
330
- config_name=args.vqgan_config_name,
331
- checkpoint_path=args.vqgan_checkpoint_path,
332
  device=args.device,
333
  )
334
 
335
- logger.info("VQ-GAN model loaded, warming up...")
336
 
337
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
338
- inference(
339
- text="Hello, world!",
340
- enable_reference_audio=False,
341
- reference_audio=None,
342
- reference_text="",
343
- max_new_tokens=0,
344
- chunk_length=0,
345
- top_p=0.7,
346
- repetition_penalty=1.5,
347
- temperature=0.7,
348
- speaker=None,
 
349
  )
350
 
351
  logger.info("Warming up done, launching the web UI...")
352
 
353
  app = build_app()
354
- app.launch(show_api=False)
 
5
 
6
  # Download if not exists
7
  os.makedirs("checkpoints", exist_ok=True)
8
+ snapshot_download(repo_id="fishaudio/fish-speech-1.2-sft", local_dir="./checkpoints/fish-speech-1.2")
9
 
10
  print("All checkpoints downloaded")
11
 
 
30
 
31
  HEADER_MD = """# Fish Speech
32
 
33
+ ## The demo in this space is version 1.2, Please check [Fish Audio](https://fish.audio) for the best model.
34
+ ## 该 Demo 为 Fish Speech 1.2 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
35
 
36
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
37
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
 
39
  You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
40
  你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
41
 
42
+ Related code and weights are released under CC BY-NC-SA 4.0 License.
43
+ 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
44
 
45
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
46
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
47
 
48
+ The model running in this WebUI is Fish Speech V1.2 Medium SFT
49
+ 在此 WebUI 中运行的模型是 Fish Speech V1.2 Medium SFT
50
  """
51
 
52
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
85
  top_p,
86
  repetition_penalty,
87
  temperature,
88
+ streaming=False,
89
  ):
90
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
91
+ return (
92
+ None,
93
+ None,
94
+ i18n("Text is too long, please keep it under {} characters.").format(
95
+ args.max_gradio_length
96
+ ),
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
+ # Parse reference audio aka prompt
100
+ prompt_tokens = encode_reference(
101
+ decoder_model=decoder_model,
102
+ reference_audio=reference_audio,
103
+ enable_reference_audio=enable_reference_audio,
104
+ )
105
 
106
  # LLAMA Inference
107
  request = dict(
108
+ device=decoder_model.device,
 
109
  max_new_tokens=max_new_tokens,
110
  text=text,
111
  top_p=top_p,
 
114
  compile=args.compile,
115
  iterative_prompt=chunk_length > 0,
116
  chunk_length=chunk_length,
117
+ max_length=2048,
 
118
  prompt_tokens=prompt_tokens if enable_reference_audio else None,
119
  prompt_text=reference_text if enable_reference_audio else None,
120
  )
121
 
122
+ response_queue = queue.Queue()
123
+ llama_queue.put(
124
+ GenerateRequest(
125
+ request=request,
126
+ response_queue=response_queue,
127
+ )
128
  )
 
129
 
130
+ if streaming:
131
+ yield wav_chunk_header(), None, None
132
+
133
+ segments = []
134
+
135
  while True:
136
+ result: WrappedGenerateResponse = response_queue.get()
137
+ if result.status == "error":
138
+ yield None, None, build_html_error_message(result.response)
 
 
 
 
 
139
  break
140
 
141
+ result: GenerateResponse = result.response
142
+ if result.action == "next":
143
+ break
144
+
145
+ with torch.autocast(
146
+ device_type=(
147
+ "cpu"
148
+ if decoder_model.device.type == "mps"
149
+ else decoder_model.device.type
150
+ ),
151
+ dtype=args.precision,
152
+ ):
153
+ fake_audios = decode_vq_tokens(
154
+ decoder_model=decoder_model,
155
+ codes=result.codes,
156
+ )
157
+
158
+ fake_audios = fake_audios.float().cpu().numpy()
159
+ segments.append(fake_audios)
160
+
161
+ if streaming:
162
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
163
+
164
+ if len(segments) == 0:
165
+ return (
166
+ None,
167
+ None,
168
+ build_html_error_message(
169
+ i18n("No audio generated, please check the input text.")
170
+ ),
171
+ )
172
+
173
+ # No matter streaming or not, we need to return the final audio
174
+ audio = np.concatenate(segments, axis=0)
175
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
176
+
177
+ if torch.cuda.is_available():
178
+ torch.cuda.empty_cache()
179
+ gc.collect()
180
+
181
+
182
+ def inference_with_auto_rerank(
183
+ text,
184
+ enable_reference_audio,
185
+ reference_audio,
186
+ reference_text,
187
+ max_new_tokens,
188
+ chunk_length,
189
+ top_p,
190
+ repetition_penalty,
191
+ temperature,
192
+ use_auto_rerank,
193
+ streaming=False,
194
+ ):
195
+
196
+ max_attempts = 2 if use_auto_rerank else 1
197
+ best_wer = float("inf")
198
+ best_audio = None
199
+ best_sample_rate = None
200
+
201
+ for attempt in range(max_attempts):
202
+ audio_generator = inference(
203
+ text,
204
+ enable_reference_audio,
205
+ reference_audio,
206
+ reference_text,
207
+ max_new_tokens,
208
+ chunk_length,
209
+ top_p,
210
+ repetition_penalty,
211
+ temperature,
212
+ streaming=False,
213
+ )
214
+
215
+ # 获取音频数据
216
+ for _ in audio_generator:
217
+ pass
218
+ _, (sample_rate, audio), message = _
219
+
220
+ if audio is None:
221
+ return None, None, message
222
+
223
+ if not use_auto_rerank:
224
+ return None, (sample_rate, audio), None
225
+
226
+ asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
227
+ wer = calculate_wer(text, asr_result["text"])
228
+ if wer <= 0.3 and not asr_result["huge_gap"]:
229
+ return None, (sample_rate, audio), None
230
+
231
+ if wer < best_wer:
232
+ best_wer = wer
233
+ best_audio = audio
234
+ best_sample_rate = sample_rate
235
+
236
+ if attempt == max_attempts - 1:
237
+ break
238
+
239
+ return None, (best_sample_rate, best_audio), None
240
+
241
+
242
+ inference_stream = partial(inference, streaming=True)
243
+
244
+ n_audios = 4
245
+
246
+ global_audio_list = []
247
+ global_error_list = []
248
+
249
+
250
+ def inference_wrapper(
251
+ text,
252
+ enable_reference_audio,
253
+ reference_audio,
254
+ reference_text,
255
+ max_new_tokens,
256
+ chunk_length,
257
+ top_p,
258
+ repetition_penalty,
259
+ temperature,
260
+ batch_infer_num,
261
+ if_load_asr_model,
262
+ ):
263
+ audios = []
264
+ errors = []
265
+
266
+ for _ in range(batch_infer_num):
267
+ result = inference_with_auto_rerank(
268
+ text,
269
+ enable_reference_audio,
270
+ reference_audio,
271
+ reference_text,
272
+ max_new_tokens,
273
+ chunk_length,
274
+ top_p,
275
+ repetition_penalty,
276
+ temperature,
277
+ if_load_asr_model,
278
+ )
279
+
280
+ _, audio_data, error_message = result
281
+
282
+ audios.append(
283
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
284
+ )
285
+ errors.append(
286
+ gr.HTML(value=error_message if error_message else None, visible=True),
287
+ )
288
+
289
+ for _ in range(batch_infer_num, n_audios):
290
+ audios.append(
291
+ gr.Audio(value=None, visible=False),
292
+ )
293
+ errors.append(
294
+ gr.HTML(value=None, visible=False),
295
+ )
296
+
297
+ return None, *audios, *errors
298
+
299
+
300
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
301
+ buffer = io.BytesIO()
302
+
303
+ with wave.open(buffer, "wb") as wav_file:
304
+ wav_file.setnchannels(channels)
305
+ wav_file.setsampwidth(bit_depth // 8)
306
+ wav_file.setframerate(sample_rate)
307
+
308
+ wav_header_bytes = buffer.getvalue()
309
+ buffer.close()
310
+ return wav_header_bytes
311
+
312
+
313
+ def normalize_text(user_input, use_normalization):
314
+ if use_normalization:
315
+ return ChnNormedText(raw_text=user_input).normalize()
316
+ else:
317
+ return user_input
318
+
319
+
320
+ asr_model = None
321
 
 
322
 
323
+ def change_if_load_asr_model(if_load):
324
+ global asr_model
 
 
 
325
 
326
+ if if_load:
327
+ gr.Warning("Loading faster whisper model...")
328
+ if asr_model is None:
329
+ asr_model = load_model()
330
+ return gr.Checkbox(label="Unload faster whisper model", value=if_load)
331
 
332
+ if if_load is False:
333
+ gr.Warning("Unloading faster whisper model...")
334
+ del asr_model
335
+ asr_model = None
336
+ if torch.cuda.is_available():
337
+ torch.cuda.empty_cache()
338
+ gc.collect()
339
+ return gr.Checkbox(label="Load faster whisper model", value=if_load)
340
+
341
+
342
+ def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
343
+ if if_load and asr_model is not None:
344
+ if (
345
+ if_auto_label
346
+ and enable_ref
347
+ and ref_audio is not None
348
+ and ref_text.strip() == ""
349
+ ):
350
+ data, sample_rate = librosa.load(ref_audio)
351
+ res = batch_asr(asr_model, [data], sample_rate)[0]
352
+ ref_text = res["text"]
353
+ else:
354
+ gr.Warning("Whisper model not loaded!")
355
+
356
+ return gr.Textbox(value=ref_text)
357
 
358
 
359
  def build_app():
 
364
  app.load(
365
  None,
366
  None,
367
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
368
+ % args.theme,
369
  )
370
 
371
  # Inference
372
  with gr.Row():
373
  with gr.Column(scale=3):
374
  text = gr.Textbox(
375
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
376
+ )
377
+ refined_text = gr.Textbox(
378
+ label=i18n("Realtime Transform Text"),
379
+ placeholder=i18n(
380
+ "Normalization Result Preview (Currently Only Chinese)"
381
+ ),
382
+ lines=5,
383
+ interactive=False,
384
  )
385
 
386
  with gr.Row():
387
+ if_refine_text = gr.Checkbox(
388
+ label=i18n("Text Normalization"),
389
+ value=True,
390
+ scale=1,
391
+ )
392
+
393
+ if_load_asr_model = gr.Checkbox(
394
+ label=i18n("Load / Unload ASR model for auto-reranking"),
395
+ value=False,
396
+ scale=3,
397
+ )
398
+
399
+ with gr.Row():
400
+ with gr.Tab(label=i18n("Advanced Config")):
401
  chunk_length = gr.Slider(
402
+ label=i18n("Iterative Prompt Length, 0 means off"),
403
  minimum=0,
404
+ maximum=500,
405
+ value=100,
406
  step=8,
407
  )
408
 
409
  max_new_tokens = gr.Slider(
410
+ label=i18n("Maximum tokens per batch, 0 means no limit"),
411
+ minimum=0,
412
+ maximum=2048,
413
+ value=1024, # 0 means no limit
414
  step=8,
415
  )
416
 
417
  top_p = gr.Slider(
418
+ label="Top-P",
419
+ minimum=0.6,
420
+ maximum=0.9,
421
+ value=0.7,
422
+ step=0.01,
423
  )
424
 
425
  repetition_penalty = gr.Slider(
426
+ label=i18n("Repetition Penalty"),
427
+ minimum=1,
428
+ maximum=1.5,
429
+ value=1.2,
430
  step=0.01,
431
  )
432
 
433
  temperature = gr.Slider(
434
  label="Temperature",
435
+ minimum=0.6,
436
+ maximum=0.9,
437
  value=0.7,
438
  step=0.01,
439
  )
440
 
441
+ with gr.Tab(label=i18n("Reference Audio")):
 
 
 
 
 
 
442
  gr.Markdown(
443
+ i18n(
444
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
445
+ )
446
  )
447
 
448
  enable_reference_audio = gr.Checkbox(
449
+ label=i18n("Enable Reference Audio"),
450
  )
451
  reference_audio = gr.Audio(
452
+ label=i18n("Reference Audio"),
453
  type="filepath",
454
  )
455
+ with gr.Row():
456
+ if_auto_label = gr.Checkbox(
457
+ label=i18n("Auto Labeling"),
458
+ min_width=100,
459
+ scale=0,
460
+ value=False,
461
+ )
462
+ reference_text = gr.Textbox(
463
+ label=i18n("Reference Text"),
464
+ lines=1,
465
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
466
+ value="",
467
+ )
468
+ with gr.Tab(label=i18n("Batch Inference")):
469
+ batch_infer_num = gr.Slider(
470
+ label="Batch infer nums",
471
+ minimum=1,
472
+ maximum=n_audios,
473
+ step=1,
474
+ value=1,
475
  )
476
 
477
  with gr.Column(scale=3):
478
+ for _ in range(n_audios):
479
+ with gr.Row():
480
+ error = gr.HTML(
481
+ label=i18n("Error Message"),
482
+ visible=True if _ == 0 else False,
483
+ )
484
+ global_error_list.append(error)
485
+ with gr.Row():
486
+ audio = gr.Audio(
487
+ label=i18n("Generated Audio"),
488
+ type="numpy",
489
+ interactive=False,
490
+ visible=True if _ == 0 else False,
491
+ )
492
+ global_audio_list.append(audio)
493
 
494
+ with gr.Row():
495
+ stream_audio = gr.Audio(
496
+ label=i18n("Streaming Audio"),
497
+ streaming=True,
498
+ autoplay=True,
499
+ interactive=False,
500
+ show_download_button=True,
501
+ )
502
  with gr.Row():
503
  with gr.Column(scale=3):
504
  generate = gr.Button(
505
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
506
+ )
507
+ generate_stream = gr.Button(
508
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
509
+ variant="primary",
510
  )
511
 
512
+ text.input(
513
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
514
+ )
515
+
516
+ if_load_asr_model.change(
517
+ fn=change_if_load_asr_model,
518
+ inputs=[if_load_asr_model],
519
+ outputs=[if_load_asr_model],
520
+ )
521
+
522
+ if_auto_label.change(
523
+ fn=lambda: gr.Textbox(value=""),
524
+ inputs=[],
525
+ outputs=[reference_text],
526
+ ).then(
527
+ fn=change_if_auto_label,
528
+ inputs=[
529
+ if_load_asr_model,
530
+ if_auto_label,
531
+ enable_reference_audio,
532
+ reference_audio,
533
+ reference_text,
534
+ ],
535
+ outputs=[reference_text],
536
+ )
537
+
538
  # # Submit
539
  generate.click(
540
+ inference_wrapper,
541
  [
542
+ refined_text,
543
  enable_reference_audio,
544
  reference_audio,
545
  reference_text,
 
548
  top_p,
549
  repetition_penalty,
550
  temperature,
551
+ batch_infer_num,
552
+ if_load_asr_model,
553
  ],
554
+ [stream_audio, *global_audio_list, *global_error_list],
555
  concurrency_limit=1,
556
  )
557
 
558
+ generate_stream.click(
559
+ inference_stream,
560
+ [
561
+ refined_text,
562
+ enable_reference_audio,
563
+ reference_audio,
564
+ reference_text,
565
+ max_new_tokens,
566
+ chunk_length,
567
+ top_p,
568
+ repetition_penalty,
569
+ temperature,
570
+ ],
571
+ [stream_audio, global_audio_list[0], global_error_list[0]],
572
+ concurrency_limit=10,
573
+ )
574
  return app
575
 
576
 
 
579
  parser.add_argument(
580
  "--llama-checkpoint-path",
581
  type=Path,
582
+ default="checkpoints/fish-speech-1.2-sft",
583
  )
584
  parser.add_argument(
585
+ "--decoder-checkpoint-path",
 
 
 
586
  type=Path,
587
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
588
  )
589
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
 
590
  parser.add_argument("--device", type=str, default="cuda")
591
  parser.add_argument("--half", action="store_true")
 
592
  parser.add_argument("--compile", action="store_true")
593
  parser.add_argument("--max-gradio-length", type=int, default=0)
594
+ parser.add_argument("--theme", type=str, default="light")
595
 
596
  return parser.parse_args()
597
 
598
 
599
  if __name__ == "__main__":
600
  args = parse_args()
 
601
  args.precision = torch.half if args.half else torch.bfloat16
 
 
 
 
 
 
 
602
 
603
  logger.info("Loading Llama model...")
604
  llama_queue = launch_thread_safe_queue(
 
605
  checkpoint_path=args.llama_checkpoint_path,
606
  device=args.device,
607
  precision=args.precision,
 
608
  compile=args.compile,
609
  )
 
610
  logger.info("Llama model loaded, loading VQ-GAN model...")
611
 
612
+ decoder_model = load_decoder_model(
613
+ config_name=args.decoder_config_name,
614
+ checkpoint_path=args.decoder_checkpoint_path,
615
  device=args.device,
616
  )
617
 
618
+ logger.info("Decoder model loaded, warming up...")
619
 
620
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
621
+ list(
622
+ inference(
623
+ text="Hello, world!",
624
+ enable_reference_audio=False,
625
+ reference_audio=None,
626
+ reference_text="",
627
+ max_new_tokens=0,
628
+ chunk_length=100,
629
+ top_p=0.7,
630
+ repetition_penalty=1.2,
631
+ temperature=0.7,
632
+ )
633
  )
634
 
635
  logger.info("Warming up done, launching the web UI...")
636
 
637
  app = build_app()
638
+ app.launch(show_api=True)