lengyue233 commited on
Commit
12b4214
1 Parent(s): 662d788

Update model to large sft

Browse files
app.py CHANGED
@@ -1,34 +1,26 @@
1
  import subprocess as sp
2
  import os
 
3
 
4
  # Download if not exists
5
  os.makedirs("checkpoints", exist_ok=True)
6
-
7
- if not os.path.exists("checkpoints/text2semantic-medium-v1-2k.pth"):
8
- print("Downloading text2semantic-medium-v1-2k.pth")
9
- sp.run(["wget", "-q", "-O", "checkpoints/text2semantic-medium-v1-2k.pth", os.environ["CKPT_SEMANTIC"]])
10
-
11
- if not os.path.exists("checkpoints/vq-gan-group-fsq-2x1024.pth"):
12
- print("Downloading vq-gan-group-fsq-2x1024.pth")
13
- sp.run(["wget", "-q", "-O", "checkpoints/vq-gan-group-fsq-2x1024.pth", os.environ["CKPT_VQGAN"]])
14
 
15
  print("All checkpoints downloaded")
16
 
17
  import html
 
 
18
  from argparse import ArgumentParser
19
- from io import BytesIO
20
  from pathlib import Path
21
 
22
  import gradio as gr
23
  import librosa
24
- import spaces
25
  import torch
26
  from loguru import logger
27
- from torchaudio import functional as AF
28
  from transformers import AutoTokenizer
29
 
30
- from tools.llama.generate import generate_long
31
- from tools.llama.generate import load_model as load_llama_model
32
  from tools.vqgan.inference import load_model as load_vqgan_model
33
 
34
  # Make einx happy
@@ -52,16 +44,30 @@ We are not responsible for any misuse of the model, please consider your local l
52
 
53
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
54
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def build_html_error_message(error):
57
  return f"""
58
- <div style="color: red; font-weight: bold;">
 
59
  {html.escape(error)}
60
  </div>
61
  """
62
 
63
 
64
- @spaces.GPU
 
65
  def inference(
66
  text,
67
  enable_reference_audio,
@@ -73,13 +79,10 @@ def inference(
73
  top_p,
74
  repetition_penalty,
75
  temperature,
76
- speaker=None,
77
  ):
78
- if len(reference_text) > 100:
79
- return None, "Ref text is too long, please keep it under 100 characters."
80
-
81
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
82
- return None, "Text is too long, please keep it under 1000 characters."
83
 
84
  # Parse reference audio aka prompt
85
  prompt_tokens = None
@@ -103,11 +106,9 @@ def inference(
103
  prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
104
 
105
  # LLAMA Inference
106
- result = generate_long(
107
- model=llama_model,
108
  tokenizer=llama_tokenizer,
109
  device=vqgan_model.device,
110
- decode_one_token=decode_one_token,
111
  max_new_tokens=max_new_tokens,
112
  text=text,
113
  top_k=int(top_k) if top_k > 0 else None,
@@ -123,7 +124,18 @@ def inference(
123
  prompt_text=reference_text if enable_reference_audio else None,
124
  )
125
 
126
- codes = next(result)
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  # VQGAN Inference
129
  feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
@@ -151,9 +163,7 @@ def build_app():
151
  with gr.Row():
152
  with gr.Column(scale=3):
153
  text = gr.Textbox(
154
- label="Input Text / 输入文本",
155
- placeholder=TEXTBOX_PLACEHOLDER,
156
- lines=15,
157
  )
158
 
159
  with gr.Row():
@@ -198,11 +208,11 @@ def build_app():
198
  step=0.01,
199
  )
200
 
201
- # speaker = gr.Textbox(
202
- # label="Speaker / 说话人",
203
- # placeholder="Type name of the speaker / 输入说话人的名称",
204
- # lines=1,
205
- # )
206
 
207
  with gr.Tab(label="Reference Audio / 参考音频"):
208
  gr.Markdown(
@@ -248,7 +258,7 @@ def build_app():
248
  top_p,
249
  repetition_penalty,
250
  temperature,
251
- # speaker,
252
  ],
253
  [audio, error],
254
  concurrency_limit=1,
@@ -262,10 +272,10 @@ def parse_args():
262
  parser.add_argument(
263
  "--llama-checkpoint-path",
264
  type=Path,
265
- default="checkpoints/text2semantic-medium-v1-2k.pth",
266
  )
267
  parser.add_argument(
268
- "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
269
  )
270
  parser.add_argument(
271
  "--vqgan-checkpoint-path",
@@ -278,7 +288,7 @@ def parse_args():
278
  parser.add_argument("--half", action="store_true")
279
  parser.add_argument("--max-length", type=int, default=2048)
280
  parser.add_argument("--compile", action="store_true")
281
- parser.add_argument("--max-gradio-length", type=int, default=1024)
282
 
283
  return parser.parse_args()
284
 
@@ -288,9 +298,15 @@ if __name__ == "__main__":
288
 
289
  args.precision = torch.half if args.half else torch.bfloat16
290
  args.compile = True
 
 
 
 
 
 
291
 
292
  logger.info("Loading Llama model...")
293
- llama_model, decode_one_token = load_llama_model(
294
  config_name=args.llama_config_name,
295
  checkpoint_path=args.llama_checkpoint_path,
296
  device=args.device,
 
1
  import subprocess as sp
2
  import os
3
+ from huggingface_hub import hf_hub_download
4
 
5
  # Download if not exists
6
  os.makedirs("checkpoints", exist_ok=True)
7
+ hf_hub_download("fishaudio/fish-speech-1", "./checkpoints/fish-speech-1")
 
 
 
 
 
 
 
8
 
9
  print("All checkpoints downloaded")
10
 
11
  import html
12
+ import os
13
+ import threading
14
  from argparse import ArgumentParser
 
15
  from pathlib import Path
16
 
17
  import gradio as gr
18
  import librosa
 
19
  import torch
20
  from loguru import logger
 
21
  from transformers import AutoTokenizer
22
 
23
+ from tools.llama.generate import launch_thread_safe_queue
 
24
  from tools.vqgan.inference import load_model as load_vqgan_model
25
 
26
  # Make einx happy
 
44
 
45
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
46
 
47
+ try:
48
+ import spaces
49
+
50
+ GPU_DECORATOR = spaces.GPU
51
+ except ImportError:
52
+
53
+ def GPU_DECORATOR(func):
54
+ def wrapper(*args, **kwargs):
55
+ return func(*args, **kwargs)
56
+
57
+ return wrapper
58
+
59
 
60
  def build_html_error_message(error):
61
  return f"""
62
+ <div style="color: red;
63
+ font-weight: bold;">
64
  {html.escape(error)}
65
  </div>
66
  """
67
 
68
 
69
+ @GPU_DECORATOR
70
+ @torch.inference_mode()
71
  def inference(
72
  text,
73
  enable_reference_audio,
 
79
  top_p,
80
  repetition_penalty,
81
  temperature,
82
+ speaker,
83
  ):
 
 
 
84
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
85
+ return None, f"Text is too long, please keep it under {args.max_gradio_length} characters."
86
 
87
  # Parse reference audio aka prompt
88
  prompt_tokens = None
 
106
  prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
107
 
108
  # LLAMA Inference
109
+ request = dict(
 
110
  tokenizer=llama_tokenizer,
111
  device=vqgan_model.device,
 
112
  max_new_tokens=max_new_tokens,
113
  text=text,
114
  top_k=int(top_k) if top_k > 0 else None,
 
124
  prompt_text=reference_text if enable_reference_audio else None,
125
  )
126
 
127
+ payload = dict(
128
+ event=threading.Event(),
129
+ request=request,
130
+ )
131
+ llama_queue.put(payload)
132
+
133
+ # Wait for the result
134
+ payload["event"].wait()
135
+ if payload["success"] is False:
136
+ raise payload["response"]
137
+
138
+ codes = payload["response"][0]
139
 
140
  # VQGAN Inference
141
  feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
 
163
  with gr.Row():
164
  with gr.Column(scale=3):
165
  text = gr.Textbox(
166
+ label="Input Text / 输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
 
 
167
  )
168
 
169
  with gr.Row():
 
208
  step=0.01,
209
  )
210
 
211
+ speaker = gr.Textbox(
212
+ label="Speaker / 说话人",
213
+ placeholder="Type name of the speaker / 输入说话人的名称",
214
+ lines=1,
215
+ )
216
 
217
  with gr.Tab(label="Reference Audio / 参考音频"):
218
  gr.Markdown(
 
258
  top_p,
259
  repetition_penalty,
260
  temperature,
261
+ speaker,
262
  ],
263
  [audio, error],
264
  concurrency_limit=1,
 
272
  parser.add_argument(
273
  "--llama-checkpoint-path",
274
  type=Path,
275
+ default="checkpoints/text2semantic-sft-large-v1-4k.pth",
276
  )
277
  parser.add_argument(
278
+ "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
279
  )
280
  parser.add_argument(
281
  "--vqgan-checkpoint-path",
 
288
  parser.add_argument("--half", action="store_true")
289
  parser.add_argument("--max-length", type=int, default=2048)
290
  parser.add_argument("--compile", action="store_true")
291
+ parser.add_argument("--max-gradio-length", type=int, default=0)
292
 
293
  return parser.parse_args()
294
 
 
298
 
299
  args.precision = torch.half if args.half else torch.bfloat16
300
  args.compile = True
301
+ args.max_gradio_length = 1024
302
+ args.tokenizer = "./checkpoints/fish-speech-1"
303
+ args.llama_checkpoint_path = "./checkpoints/text2semantic-sft-large-v1-4k.pth"
304
+ args.llama_config_name = "dual_ar_2_codebook_large"
305
+ args.vqgan_checkpoint_path = "./checkpoints/vq-gan-group-fsq-2x1024.pth"
306
+ args.vqgan_config_name = "vqgan_pretrain"
307
 
308
  logger.info("Loading Llama model...")
309
+ llama_queue = launch_thread_safe_queue(
310
  config_name=args.llama_config_name,
311
  checkpoint_path=args.llama_checkpoint_path,
312
  device=args.device,
tools/extract_model.py DELETED
@@ -1,21 +0,0 @@
1
- import click
2
- import torch
3
- from loguru import logger
4
-
5
-
6
- @click.command()
7
- @click.argument("model_path")
8
- @click.argument("output_path")
9
- def main(model_path, output_path):
10
- if model_path == output_path:
11
- logger.error("Model path and output path are the same")
12
- return
13
-
14
- logger.info(f"Loading model from {model_path}")
15
- state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16
- torch.save(state_dict, output_path)
17
- logger.info(f"Model saved to {output_path}")
18
-
19
-
20
- if __name__ == "__main__":
21
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/build_dataset.py DELETED
@@ -1,165 +0,0 @@
1
- import itertools
2
- import os
3
- import re
4
- from collections import defaultdict
5
- from functools import partial
6
- from multiprocessing import Pool
7
- from pathlib import Path
8
-
9
- import click
10
- import numpy as np
11
- from loguru import logger
12
- from tqdm import tqdm
13
-
14
- from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
- from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
- from fish_speech.utils.file import load_filelist
17
-
18
- # To avoid CPU overload
19
- os.environ["MKL_NUM_THREADS"] = "1"
20
- os.environ["OMP_NUM_THREADS"] = "1"
21
-
22
-
23
- def task_generator_folder(root: Path, text_extension: str):
24
- files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
25
- files = sorted(files)
26
-
27
- grouped_files = defaultdict(list)
28
- for file in tqdm(files, desc=f"Grouping {root}"):
29
- p = str(file.parent)
30
-
31
- try:
32
- if isinstance(text_extension, str):
33
- texts = [file.with_suffix(text_extension).read_text()]
34
- else:
35
- texts = [file.with_suffix(ext).read_text() for ext in text_extension]
36
- except Exception as e:
37
- logger.error(f"Failed to read text {file}: {e}")
38
- continue
39
-
40
- grouped_files[p].append((file, texts))
41
-
42
- logger.info(
43
- f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
44
- )
45
- for name, subset in grouped_files.items():
46
- yield name, subset, "folder"
47
-
48
-
49
- def task_generator_filelist(filelist):
50
- grouped_files = defaultdict(list)
51
- for filename, speaker, _, text in load_filelist(filelist):
52
- grouped_files[speaker].append((Path(filename), [text]))
53
-
54
- logger.info(f"Found {len(grouped_files)} groups in {filelist}")
55
- for speaker, values in grouped_files.items():
56
- yield speaker, values, "filelist"
57
-
58
-
59
- def run_task(task):
60
- name, subset, source = task
61
-
62
- # Parse the files
63
- sentences = []
64
- for file in subset:
65
- file, texts = file
66
-
67
- np_file = file.with_suffix(".npy")
68
- if np_file.exists() is False:
69
- logger.warning(f"Can't find {np_file}")
70
- continue
71
-
72
- new_texts = []
73
-
74
- for text in texts:
75
- # Simple cleaning: replace { xxx } and < xxx > with space
76
- text = re.sub(r"\{.*?\}", " ", text)
77
- text = re.sub(r"<.*?>", " ", text)
78
- text = re.sub(r"\s+", " ", text)
79
- new_texts.append(text)
80
-
81
- try:
82
- semantics = np.load(np_file)
83
- except Exception as e:
84
- logger.error(f"Failed to parse {file}: {e}")
85
- continue
86
-
87
- if isinstance(semantics, np.ndarray):
88
- semantics = semantics.tolist()
89
-
90
- sentences.append(
91
- Sentence(
92
- texts=new_texts,
93
- semantics=[Semantics(values=s) for s in semantics],
94
- )
95
- )
96
-
97
- # Pack the sentences
98
- return pack_pb_stream(
99
- TextData(
100
- source=source,
101
- name=name,
102
- sentences=sentences,
103
- )
104
- )
105
-
106
-
107
- @click.command()
108
- @click.option(
109
- "--input",
110
- type=click.Path(path_type=Path),
111
- required=True,
112
- help="A folder containing the dataset or a filelist",
113
- multiple=True,
114
- )
115
- @click.option(
116
- "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
117
- )
118
- @click.option("--num-workers", type=int, default=16)
119
- @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
120
- @click.option(
121
- "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
122
- )
123
- def main(input, output, num_workers, text_extension, shard_size):
124
- generator_fns = []
125
-
126
- for f in input:
127
- assert f.exists(), f"{f} not found"
128
-
129
- if f.is_dir():
130
- generator_fn = task_generator_folder(f, text_extension)
131
- else:
132
- generator_fn = task_generator_filelist(f)
133
-
134
- generator_fns.append(generator_fn)
135
-
136
- generator_fn = itertools.chain(*generator_fns)
137
- output.mkdir(parents=True, exist_ok=True)
138
-
139
- dataset_fp = None
140
- tar_idx = 0
141
- written_size = 0
142
-
143
- with Pool(num_workers) as p:
144
- for result in tqdm(p.imap_unordered(run_task, generator_fn)):
145
- if dataset_fp is None:
146
- dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
147
-
148
- dataset_fp.write(result)
149
- written_size += len(result)
150
-
151
- if written_size > shard_size * 1024 * 1024:
152
- logger.info(f"Finished writing {tar_idx} shards to {output}")
153
- dataset_fp.close()
154
- dataset_fp = None
155
- written_size = 0
156
- tar_idx += 1
157
-
158
- if dataset_fp is not None:
159
- dataset_fp.close()
160
-
161
- logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
162
-
163
-
164
- if __name__ == "__main__":
165
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/generate.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
 
 
2
  import time
3
  from pathlib import Path
4
  from typing import Optional, Tuple, Union
5
 
6
  import click
 
7
  import numpy as np
8
  import torch
9
  import torch._dynamo.config
@@ -361,6 +364,7 @@ def encode_tokens(
361
  def load_model(
362
  config_name, checkpoint_path, device, precision, max_length, compile=False
363
  ):
 
364
  with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
365
  cfg = compose(
366
  config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
@@ -456,6 +460,7 @@ def generate_long(
456
  speaker: Optional[str] = None,
457
  prompt_text: Optional[str] = None,
458
  prompt_tokens: Optional[torch.Tensor] = None,
 
459
  ):
460
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
461
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
@@ -496,6 +501,10 @@ def generate_long(
496
  all_codes = []
497
  seg_idx = 0
498
 
 
 
 
 
499
  while seg_idx < len(encoded):
500
  logger.info(
501
  f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
@@ -562,10 +571,7 @@ def generate_long(
562
  codes = y[1:, prompt_length:-2].clone()
563
 
564
  codes = codes - 2
565
- if not (codes >= 0).all():
566
- global_encoded.pop()
567
- logger.warning(f"Negative code found: {codes}, retrying ...")
568
- continue
569
 
570
  decoded = y[:, prompt_length:-1].clone()
571
  if decoded[0, -1] != im_end_id: # <im_end>
@@ -576,13 +582,63 @@ def generate_long(
576
 
577
  # But for global encoding, we should keep the <im_end> token
578
  global_encoded.append(decoded)
579
- all_codes.append(codes)
 
 
 
 
 
 
580
  seg_idx += 1
581
 
582
- codes = torch.cat(all_codes, dim=1)
583
- assert (codes >= 0).all(), f"Negative code found: {codes}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- yield codes
586
 
587
 
588
  @click.command()
 
1
  import os
2
+ import queue
3
+ import threading
4
  import time
5
  from pathlib import Path
6
  from typing import Optional, Tuple, Union
7
 
8
  import click
9
+ import hydra
10
  import numpy as np
11
  import torch
12
  import torch._dynamo.config
 
364
  def load_model(
365
  config_name, checkpoint_path, device, precision, max_length, compile=False
366
  ):
367
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
368
  with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
369
  cfg = compose(
370
  config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
 
460
  speaker: Optional[str] = None,
461
  prompt_text: Optional[str] = None,
462
  prompt_tokens: Optional[torch.Tensor] = None,
463
+ is_streaming: bool = False,
464
  ):
465
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
466
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
501
  all_codes = []
502
  seg_idx = 0
503
 
504
+ if use_prompt:
505
+ seg_idx = 1
506
+ global_encoded.append(encoded[0])
507
+
508
  while seg_idx < len(encoded):
509
  logger.info(
510
  f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
 
571
  codes = y[1:, prompt_length:-2].clone()
572
 
573
  codes = codes - 2
574
+ assert (codes >= 0).all(), f"Negative code found"
 
 
 
575
 
576
  decoded = y[:, prompt_length:-1].clone()
577
  if decoded[0, -1] != im_end_id: # <im_end>
 
582
 
583
  # But for global encoding, we should keep the <im_end> token
584
  global_encoded.append(decoded)
585
+
586
+ if is_streaming:
587
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
588
+ yield codes
589
+ else:
590
+ all_codes.append(codes)
591
+
592
  seg_idx += 1
593
 
594
+ if is_streaming:
595
+ # This indicates the end of the current sample
596
+ yield None
597
+ else:
598
+ all_codes = torch.cat(all_codes, dim=1)
599
+ assert (all_codes >= 0).all(), f"Negative code found: {codes}"
600
+ yield all_codes
601
+
602
+
603
+ def launch_thread_safe_queue(
604
+ config_name,
605
+ checkpoint_path,
606
+ device,
607
+ precision,
608
+ max_length,
609
+ compile=False,
610
+ ):
611
+ input_queue = queue.Queue()
612
+
613
+ def worker():
614
+ model, decode_one_token = load_model(
615
+ config_name, checkpoint_path, device, precision, max_length, compile=compile
616
+ )
617
+
618
+ while True:
619
+ item = input_queue.get()
620
+ if item is None:
621
+ break
622
+
623
+ kwargs = item["request"]
624
+ event = item["event"]
625
+
626
+ try:
627
+ item["success"] = True
628
+ item["response"] = list(
629
+ generate_long(
630
+ model=model, decode_one_token=decode_one_token, **kwargs
631
+ )
632
+ )
633
+ except Exception as e:
634
+ item["success"] = False
635
+ item["response"] = e
636
+
637
+ event.set()
638
+
639
+ threading.Thread(target=worker, daemon=True).start()
640
 
641
+ return input_queue
642
 
643
 
644
  @click.command()
tools/llama/rebuild_tokenizer.py DELETED
@@ -1,57 +0,0 @@
1
- from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
2
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
-
4
- # Initialize a tokenizer
5
- tokenizer = Tokenizer(models.BPE())
6
-
7
- # Customize pre-tokenization and decoding
8
- tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
9
- tokenizer.decoder = decoders.ByteLevel()
10
- tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
11
-
12
- # Don't train the tokenizer
13
- trainer = trainers.BpeTrainer(
14
- vocab_size=0,
15
- min_frequency=2,
16
- initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
17
- special_tokens=[
18
- "<|begin_of_sequence|>",
19
- "<|end_of_sequence|>",
20
- "<|im_start|>",
21
- "<|im_sep|>", # system, user, assistant, etc.
22
- "<|im_end|>",
23
- "<|semantic|>", # audio features
24
- "<|pad|>",
25
- ],
26
- )
27
-
28
- # <|im_start|>user<|im_sep|>...<|im_end|>
29
- # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
30
- tokenizer.train_from_iterator([], trainer=trainer)
31
-
32
- print(len(tokenizer.get_vocab()))
33
- x = tokenizer.encode(
34
- "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
35
- ).ids
36
- print(x, len(x))
37
- print(tokenizer.decode(x, skip_special_tokens=True))
38
-
39
-
40
- tokenizer = PreTrainedTokenizerFast(
41
- tokenizer_object=tokenizer,
42
- pad_token="<|pad|>",
43
- bos_token="<|begin_of_sequence|>",
44
- eos_token="<|end_of_sequence|>",
45
- )
46
-
47
- # Try tokenizing a new sequence
48
- sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
49
- encoded = tokenizer(sequence).input_ids
50
-
51
- print("Test encoding....")
52
- print(f"\tSentence: {sequence}")
53
- print(f"\tEncoded: {encoded}")
54
- print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
55
- print(f"\tDecoded: {tokenizer.decode(encoded)}")
56
-
57
- tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/merge_asr_files.py DELETED
@@ -1,55 +0,0 @@
1
- import os
2
- from pathlib import Path
3
-
4
- from pydub import AudioSegment
5
- from tqdm import tqdm
6
-
7
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
8
-
9
-
10
- def merge_and_delete_files(save_dir, original_files):
11
- save_path = Path(save_dir)
12
- audio_slice_files = list_files(
13
- path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
14
- )
15
- audio_files = {}
16
- label_files = {}
17
- for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
18
- rel_path = Path(file_path).relative_to(save_path)
19
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
20
- if file_path.suffix == ".wav":
21
- prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
22
- if prefix == rel_path.parent / file_path.stem:
23
- continue
24
- audio = AudioSegment.from_wav(file_path)
25
- if prefix in audio_files.keys():
26
- audio_files[prefix] = audio_files[prefix] + audio
27
- else:
28
- audio_files[prefix] = audio
29
-
30
- elif file_path.suffix == ".lab":
31
- prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
32
- if prefix == rel_path.parent / file_path.stem:
33
- continue
34
- with open(file_path, "r", encoding="utf-8") as f:
35
- label = f.read()
36
- if prefix in label_files.keys():
37
- label_files[prefix] = label_files[prefix] + ", " + label
38
- else:
39
- label_files[prefix] = label
40
-
41
- for prefix, audio in audio_files.items():
42
- output_audio_path = save_path / f"{prefix}.wav"
43
- audio.export(output_audio_path, format="wav")
44
-
45
- for prefix, label in label_files.items():
46
- output_label_path = save_path / f"{prefix}.lab"
47
- with open(output_label_path, "w", encoding="utf-8") as f:
48
- f.write(label)
49
-
50
- for file_path in original_files:
51
- os.remove(file_path)
52
-
53
-
54
- if __name__ == "__main__":
55
- merge_and_delete_files("/made/by/spicysama/laziman", [__file__])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/vqgan/create_train_split.py DELETED
@@ -1,54 +0,0 @@
1
- import math
2
- from pathlib import Path
3
- from random import Random
4
-
5
- import click
6
- from loguru import logger
7
- from tqdm import tqdm
8
-
9
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
10
-
11
-
12
- @click.command()
13
- @click.argument("root", type=click.Path(exists=True, path_type=Path))
14
- @click.option("--val-ratio", type=float, default=None)
15
- @click.option("--val-count", type=int, default=None)
16
- @click.option("--filelist", default=None, type=Path)
17
- def main(root, val_ratio, val_count, filelist):
18
- if filelist:
19
- files = [i[0] for i in load_filelist(filelist)]
20
- else:
21
- files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
22
-
23
- logger.info(f"Found {len(files)} files")
24
- files = [str(file.relative_to(root)) for file in tqdm(files)]
25
-
26
- Random(42).shuffle(files)
27
-
28
- if val_count is None and val_ratio is None:
29
- logger.info("Validation ratio and count not specified, using min(20%, 100)")
30
- val_size = min(100, math.ceil(len(files) * 0.2))
31
- elif val_count is not None and val_ratio is not None:
32
- logger.error("Cannot specify both val_count and val_ratio")
33
- return
34
- elif val_count is not None:
35
- if val_count < 1 or val_count > len(files):
36
- logger.error("val_count must be between 1 and number of files")
37
- return
38
- val_size = val_count
39
- else:
40
- val_size = math.ceil(len(files) * val_ratio)
41
-
42
- logger.info(f"Using {val_size} files for validation")
43
-
44
- with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
45
- f.write("\n".join(files[val_size:]))
46
-
47
- with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
48
- f.write("\n".join(files[:val_size]))
49
-
50
- logger.info("Done")
51
-
52
-
53
- if __name__ == "__main__":
54
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/vqgan/extract_vq.py DELETED
@@ -1,213 +0,0 @@
1
- import os
2
- import subprocess as sp
3
- import sys
4
- import time
5
- from datetime import timedelta
6
- from functools import lru_cache
7
- from pathlib import Path
8
- from random import Random
9
-
10
- import click
11
- import numpy as np
12
- import torch
13
- import torchaudio
14
- from hydra import compose, initialize
15
- from hydra.utils import instantiate
16
- from lightning import LightningModule
17
- from loguru import logger
18
- from omegaconf import OmegaConf
19
-
20
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
-
22
- # register eval resolver
23
- OmegaConf.register_new_resolver("eval", eval)
24
- # This file is used to convert the audio files to text files using the Whisper model.
25
- # It's mainly used to generate the training data for the VQ model.
26
-
27
-
28
- RANK = int(os.environ.get("SLURM_PROCID", 0))
29
- WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
30
-
31
- logger_format = (
32
- "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
33
- "<level>{level: <8}</level> | "
34
- "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
35
- "{extra[rank]} - <level>{message}</level>"
36
- )
37
- logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
38
- logger.remove()
39
- logger.add(sys.stderr, format=logger_format)
40
-
41
-
42
- @lru_cache(maxsize=1)
43
- def get_model(
44
- config_name: str = "vqgan_pretrain",
45
- checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
46
- ):
47
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
48
- cfg = compose(config_name=config_name)
49
-
50
- model: LightningModule = instantiate(cfg.model)
51
- state_dict = torch.load(
52
- checkpoint_path,
53
- map_location=model.device,
54
- )
55
- if "state_dict" in state_dict:
56
- state_dict = state_dict["state_dict"]
57
-
58
- model.load_state_dict(state_dict, strict=False)
59
- model.eval()
60
- model.cuda()
61
-
62
- logger.info(f"Loaded model")
63
- return model
64
-
65
-
66
- @torch.inference_mode()
67
- def process_batch(files: list[Path], model) -> float:
68
- wavs = []
69
- audio_lengths = []
70
- new_files = []
71
- max_length = total_time = 0
72
-
73
- for file in files:
74
- try:
75
- wav, sr = torchaudio.load(
76
- str(file), backend="sox"
77
- ) # Need to install libsox-dev
78
- except Exception as e:
79
- logger.error(f"Error reading {file}: {e}")
80
- continue
81
-
82
- if wav.shape[0] > 1:
83
- wav = wav.mean(dim=0, keepdim=True)
84
-
85
- wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
86
- total_time += len(wav) / model.sampling_rate
87
- max_length = max(max_length, len(wav))
88
-
89
- wavs.append(wav)
90
- audio_lengths.append(len(wav))
91
- new_files.append(file)
92
-
93
- files = new_files
94
-
95
- # Pad to max length
96
- for i, wav in enumerate(wavs):
97
- wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
98
-
99
- audios = torch.stack(wavs, dim=0)[:, None]
100
- audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
101
-
102
- # Calculate lengths
103
- indices, feature_lengths = model.encode(audios, audio_lengths)
104
-
105
- # Save to disk
106
- outputs = indices.cpu().numpy()
107
-
108
- for file, length, feature, audio_length in zip(
109
- files, feature_lengths, outputs, audio_lengths
110
- ):
111
- feature = feature[:, :length]
112
-
113
- # (T,)
114
- with open(file.with_suffix(".npy"), "wb") as f:
115
- np.save(f, feature)
116
-
117
- return total_time
118
-
119
-
120
- @click.command()
121
- @click.argument("folder")
122
- @click.option("--num-workers", default=1)
123
- @click.option("--config-name", default="vqgan_pretrain")
124
- @click.option(
125
- "--checkpoint-path",
126
- default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
127
- )
128
- @click.option("--batch-size", default=64)
129
- @click.option("--filelist", default=None, type=Path)
130
- def main(
131
- folder: str,
132
- num_workers: int,
133
- config_name: str,
134
- checkpoint_path: str,
135
- batch_size: int,
136
- filelist: Path,
137
- ):
138
- if num_workers > 1 and WORLD_SIZE != num_workers:
139
- assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
140
-
141
- logger.info(f"Spawning {num_workers} workers")
142
-
143
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
144
- if visible_devices is None:
145
- visible_devices = list(range(torch.cuda.device_count()))
146
- else:
147
- visible_devices = visible_devices.split(",")
148
-
149
- processes = []
150
- for i in range(num_workers):
151
- env = os.environ.copy()
152
- env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
153
- env["SLURM_PROCID"] = str(i)
154
- env["SLURM_NTASKS"] = str(num_workers)
155
-
156
- processes.append(
157
- sp.Popen(
158
- [sys.executable] + sys.argv.copy(),
159
- env=env,
160
- )
161
- )
162
-
163
- for p in processes:
164
- p.wait()
165
-
166
- logger.info(f"All workers finished")
167
- return
168
-
169
- # This is a worker
170
- logger.info(f"Starting worker")
171
- if filelist:
172
- files = [i[0] for i in load_filelist(filelist)]
173
- else:
174
- files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
175
-
176
- print(f"Found {len(files)} files")
177
- # files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
178
-
179
- total_files = len(files)
180
- files = files[RANK::WORLD_SIZE]
181
- logger.info(f"Processing {len(files)}/{total_files} files")
182
-
183
- # Batch processing
184
- total_time = 0
185
- begin_time = time.time()
186
- processed_files = 0
187
- model = get_model(config_name, checkpoint_path)
188
-
189
- for n_batch, idx in enumerate(range(0, len(files), batch_size)):
190
- batch = files[idx : idx + batch_size]
191
- batch_time = process_batch(batch, model)
192
-
193
- total_time += batch_time
194
- processed_files += len(batch)
195
-
196
- if (n_batch + 1) % 10 == 0:
197
- eta = (
198
- (time.time() - begin_time)
199
- / processed_files
200
- * (len(files) - processed_files)
201
- )
202
- logger.info(
203
- f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
204
- + f"ETA: {timedelta(seconds=round(eta))}s"
205
- )
206
-
207
- logger.info(
208
- f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
209
- )
210
-
211
-
212
- if __name__ == "__main__":
213
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/whisper_asr.py DELETED
@@ -1,113 +0,0 @@
1
- """
2
- Used to transcribe all audio files in one folder into another folder.
3
- e.g.
4
- Directory structure:
5
- --pre_data_root
6
- ----SP_1
7
- ------01.wav
8
- ------02.wav
9
- ------......
10
- ----SP_2
11
- ------01.wav
12
- ------02.wav
13
- ------......
14
- Use
15
- python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
16
- to transcribe the first speaker.
17
-
18
- Use
19
- python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
20
- to transcribe the second speaker.
21
-
22
- Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
23
- """
24
- from pathlib import Path
25
-
26
- import click
27
- import librosa
28
- import soundfile as sf
29
- import whisper
30
- from loguru import logger
31
- from merge_asr_files import merge_and_delete_files
32
- from tqdm import tqdm
33
-
34
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
35
-
36
-
37
- @click.command()
38
- @click.option("--model-size", default="large", help="Size of the Whisper model")
39
- @click.option("--audio-dir", required=True, help="Directory containing audio files")
40
- @click.option(
41
- "--save-dir", required=True, help="Directory to save processed audio files"
42
- )
43
- @click.option(
44
- "--sample-rate",
45
- default=None,
46
- type=int,
47
- help="Output sample rate, default to input sample rate",
48
- )
49
- @click.option("--device", default="cuda", help="Device to use")
50
- @click.option("--language", default="ZH", help="Language of the transcription")
51
- def main(model_size, audio_dir, save_dir, sample_rate, device, language):
52
- logger.info("Loading / Downloading OpenAI Whisper model...")
53
- model = whisper.load_model(
54
- name=model_size,
55
- device=device,
56
- download_root=str(Path(".cache/whisper").resolve()),
57
- )
58
- logger.info("Model loaded.")
59
-
60
- save_path = Path(save_dir)
61
- save_path.mkdir(parents=True, exist_ok=True)
62
- original_files = []
63
- audio_files = list_files(
64
- path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
65
- )
66
- for file_path in tqdm(audio_files, desc="Processing audio file"):
67
- file_stem = file_path.stem
68
- file_suffix = file_path.suffix
69
-
70
- rel_path = Path(file_path).relative_to(audio_dir)
71
- (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
72
-
73
- if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and (
74
- save_path / rel_path.parent / f"{rel_path.stem}.lab"
75
- ).exists():
76
- continue
77
-
78
- audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
79
- transcription = model.transcribe(str(file_path), language=language)
80
-
81
- for segment in transcription.get("segments", []):
82
- id, text, start, end = (
83
- segment["id"],
84
- segment["text"],
85
- segment["start"],
86
- segment["end"],
87
- )
88
-
89
- extract = audio[..., int(start * sr) : int(end * sr)]
90
- audio_save_path = (
91
- save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}"
92
- )
93
- sf.write(
94
- audio_save_path,
95
- extract,
96
- samplerate=sr,
97
- )
98
- original_files.append(audio_save_path)
99
-
100
- transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab"
101
- with open(
102
- transcript_save_path,
103
- "w",
104
- encoding="utf-8",
105
- ) as f:
106
- f.write(text)
107
- original_files.append(transcript_save_path)
108
-
109
- merge_and_delete_files(save_dir, original_files)
110
-
111
-
112
- if __name__ == "__main__":
113
- main()