zhzluke96 commited on
Commit
9d9fe0d
1 Parent(s): ba0472f
.env.webui CHANGED
@@ -9,8 +9,9 @@ AUTH=
9
 
10
  # Model inference options
11
  HALF=True
12
- OFF_TQDM=True
13
  DEBUG_GENERATE=True
 
14
 
15
  # Text-to-Speech (TTS) configuration
16
  TTS_MAX_LEN=1000
 
9
 
10
  # Model inference options
11
  HALF=True
12
+ OFF_TQDM=
13
  DEBUG_GENERATE=True
14
+ PRELOAD_MODELS=True
15
 
16
  # Text-to-Speech (TTS) configuration
17
  TTS_MAX_LEN=1000
modules/ChatTTS/ChatTTS/core.py CHANGED
@@ -112,6 +112,8 @@ class Chat:
112
  dtype_gpt = dtype_gpt or dtype
113
  dtype_decoder = dtype_decoder or dtype
114
 
 
 
115
  if vocos_config_path:
116
  vocos = (
117
  Vocos.from_hparams(vocos_config_path)
@@ -119,7 +121,9 @@ class Chat:
119
  .eval()
120
  )
121
  assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
122
- vocos.load_state_dict(torch.load(vocos_ckpt_path))
 
 
123
  self.pretrain_models["vocos"] = vocos
124
  self.logger.log(logging.INFO, "vocos loaded.")
125
 
@@ -127,7 +131,7 @@ class Chat:
127
  cfg = OmegaConf.load(dvae_config_path)
128
  dvae = DVAE(**cfg).to(device=device, dtype=dtype_dvae).eval()
129
  assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
130
- dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
131
  self.pretrain_models["dvae"] = dvae
132
  self.logger.log(logging.INFO, "dvae loaded.")
133
 
@@ -135,7 +139,7 @@ class Chat:
135
  cfg = OmegaConf.load(gpt_config_path)
136
  gpt = GPT_warpper(**cfg).to(device=device, dtype=dtype_gpt).eval()
137
  assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
138
- gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
139
  if compile and "cuda" in str(device):
140
  self.logger.info("compile gpt model")
141
  gpt.gpt.forward = torch.compile(
@@ -146,21 +150,23 @@ class Chat:
146
  assert os.path.exists(
147
  spk_stat_path
148
  ), f"Missing spk_stat.pt: {spk_stat_path}"
149
- self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to(
150
- device=device, dtype=dtype
151
- )
152
  self.logger.log(logging.INFO, "gpt loaded.")
153
 
154
  if decoder_config_path:
155
  cfg = OmegaConf.load(decoder_config_path)
156
  decoder = DVAE(**cfg).to(device=device, dtype=dtype_decoder).eval()
157
  assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
158
- decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
 
 
159
  self.pretrain_models["decoder"] = decoder
160
  self.logger.log(logging.INFO, "decoder loaded.")
161
 
162
  if tokenizer_path:
163
- tokenizer = torch.load(tokenizer_path, map_location=device)
164
  tokenizer.padding_side = "left"
165
  self.pretrain_models["tokenizer"] = tokenizer
166
  self.logger.log(logging.INFO, "tokenizer loaded.")
 
112
  dtype_gpt = dtype_gpt or dtype
113
  dtype_decoder = dtype_decoder or dtype
114
 
115
+ map_location = torch.device("cpu")
116
+
117
  if vocos_config_path:
118
  vocos = (
119
  Vocos.from_hparams(vocos_config_path)
 
121
  .eval()
122
  )
123
  assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
124
+ vocos.load_state_dict(
125
+ torch.load(vocos_ckpt_path, map_location=map_location)
126
+ )
127
  self.pretrain_models["vocos"] = vocos
128
  self.logger.log(logging.INFO, "vocos loaded.")
129
 
 
131
  cfg = OmegaConf.load(dvae_config_path)
132
  dvae = DVAE(**cfg).to(device=device, dtype=dtype_dvae).eval()
133
  assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
134
+ dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=map_location))
135
  self.pretrain_models["dvae"] = dvae
136
  self.logger.log(logging.INFO, "dvae loaded.")
137
 
 
139
  cfg = OmegaConf.load(gpt_config_path)
140
  gpt = GPT_warpper(**cfg).to(device=device, dtype=dtype_gpt).eval()
141
  assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
142
+ gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=map_location))
143
  if compile and "cuda" in str(device):
144
  self.logger.info("compile gpt model")
145
  gpt.gpt.forward = torch.compile(
 
150
  assert os.path.exists(
151
  spk_stat_path
152
  ), f"Missing spk_stat.pt: {spk_stat_path}"
153
+ self.pretrain_models["spk_stat"] = torch.load(
154
+ spk_stat_path, map_location=map_location
155
+ ).to(device=device, dtype=dtype)
156
  self.logger.log(logging.INFO, "gpt loaded.")
157
 
158
  if decoder_config_path:
159
  cfg = OmegaConf.load(decoder_config_path)
160
  decoder = DVAE(**cfg).to(device=device, dtype=dtype_decoder).eval()
161
  assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
162
+ decoder.load_state_dict(
163
+ torch.load(decoder_ckpt_path, map_location=map_location)
164
+ )
165
  self.pretrain_models["decoder"] = decoder
166
  self.logger.log(logging.INFO, "decoder loaded.")
167
 
168
  if tokenizer_path:
169
+ tokenizer = torch.load(tokenizer_path, map_location=map_location)
170
  tokenizer.padding_side = "left"
171
  self.pretrain_models["tokenizer"] = tokenizer
172
  self.logger.log(logging.INFO, "tokenizer loaded.")
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -1,6 +1,5 @@
1
  import gc
2
- import os
3
- from typing import List, Literal
4
 
5
  import numpy as np
6
  from modules.devices import devices
@@ -14,7 +13,6 @@ from modules.utils.constants import MODELS_DIR
14
  from pathlib import Path
15
 
16
  from threading import Lock
17
- from modules import config
18
 
19
  import logging
20
 
@@ -34,12 +32,13 @@ class ResembleEnhance:
34
 
35
  def load_model(self):
36
  hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
37
- enhancer = Enhancer(hparams).to(device=self.device, dtype=self.dtype).eval()
38
  state_dict = torch.load(
39
  Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
40
- map_location=self.device,
41
  )["module"]
42
  enhancer.load_state_dict(state_dict)
 
43
 
44
  self.hparams = hparams
45
  self.enhancer = enhancer
 
1
  import gc
2
+ from typing import Literal
 
3
 
4
  import numpy as np
5
  from modules.devices import devices
 
13
  from pathlib import Path
14
 
15
  from threading import Lock
 
16
 
17
  import logging
18
 
 
32
 
33
  def load_model(self):
34
  hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
35
+ enhancer = Enhancer(hparams)
36
  state_dict = torch.load(
37
  Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
38
+ map_location="cpu",
39
  )["module"]
40
  enhancer.load_state_dict(state_dict)
41
+ enhancer.to(device=self.device, dtype=self.dtype).eval()
42
 
43
  self.hparams = hparams
44
  self.enhancer = enhancer
modules/api/api_setup.py CHANGED
@@ -1,9 +1,10 @@
1
  import logging
 
2
  from modules.devices import devices
3
  import argparse
4
 
5
- import torch
6
  from modules import config
 
7
  from modules.utils import env
8
  from modules import generate_audio
9
  from modules.api.Api import APIManager
@@ -77,6 +78,11 @@ def setup_model_args(parser: argparse.ArgumentParser):
77
  action="store_true",
78
  help="Enable debug mode for audio generation",
79
  )
 
 
 
 
 
80
 
81
 
82
  def process_model_args(args):
@@ -87,6 +93,7 @@ def process_model_args(args):
87
  no_half = env.get_and_update_env(args, "no_half", False, bool)
88
  off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
89
  debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
 
90
 
91
  generate_audio.setup_lru_cache()
92
  devices.reset_device()
@@ -95,6 +102,10 @@ def process_model_args(args):
95
  if debug_generate:
96
  generate_audio.logger.setLevel(logging.DEBUG)
97
 
 
 
 
 
98
 
99
  def setup_uvicon_args(parser: argparse.ArgumentParser):
100
  parser.add_argument("--host", type=str, help="Host to run the server on")
 
1
  import logging
2
+ from modules.Enhancer.ResembleEnhance import load_enhancer
3
  from modules.devices import devices
4
  import argparse
5
 
 
6
  from modules import config
7
+ from modules.models import load_chat_tts
8
  from modules.utils import env
9
  from modules import generate_audio
10
  from modules.api.Api import APIManager
 
78
  action="store_true",
79
  help="Enable debug mode for audio generation",
80
  )
81
+ parser.add_argument(
82
+ "--preload_models",
83
+ action="store_true",
84
+ help="Preload all models at startup",
85
+ )
86
 
87
 
88
  def process_model_args(args):
 
93
  no_half = env.get_and_update_env(args, "no_half", False, bool)
94
  off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
95
  debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
96
+ preload_models = env.get_and_update_env(args, "preload_models", False, bool)
97
 
98
  generate_audio.setup_lru_cache()
99
  devices.reset_device()
 
102
  if debug_generate:
103
  generate_audio.logger.setLevel(logging.DEBUG)
104
 
105
+ if preload_models:
106
+ load_chat_tts()
107
+ load_enhancer()
108
+
109
 
110
  def setup_uvicon_args(parser: argparse.ArgumentParser):
111
  parser.add_argument("--host", type=str, help="Host to run the server on")
modules/webui/css/style.css CHANGED
@@ -66,9 +66,9 @@
66
  display: none !important;
67
  }
68
 
69
- .progress-bar{
70
  height: 30px !important;
71
- }
72
 
73
  .progress-bar span {
74
  text-align: right;
 
66
  display: none !important;
67
  }
68
 
69
+ /* .progress-bar{
70
  height: 30px !important;
71
+ } */
72
 
73
  .progress-bar span {
74
  text-align: right;
modules/webui/webui_utils.py CHANGED
@@ -1,14 +1,13 @@
1
- import io
2
  from typing import Union
3
  import numpy as np
4
 
5
  from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
6
- from modules.devices import devices
7
  from modules.synthesize_audio import synthesize_audio
8
  from modules.utils.hf import spaces
9
  from modules.webui import webui_config
10
 
11
  import torch
 
12
 
13
  from modules.ssml_parser.SSMLParser import create_ssml_parser, SSMLBreak, SSMLSegment
14
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
@@ -24,7 +23,6 @@ from modules import refiner
24
  from modules.utils import audio
25
  from modules.SentenceSplitter import SentenceSplitter
26
 
27
- from pydub import AudioSegment
28
  import torch.profiler
29
 
30
 
@@ -97,6 +95,7 @@ def synthesize_ssml(
97
  enable_denoise=False,
98
  eos: str = "[uv_break]",
99
  spliter_thr: int = 100,
 
100
  ):
101
  try:
102
  batch_size = int(batch_size)
@@ -157,6 +156,7 @@ def tts_generate(
157
  spk_file=None,
158
  spliter_thr: int = 100,
159
  eos: str = "[uv_break]",
 
160
  ):
161
  try:
162
  batch_size = int(batch_size)
@@ -219,7 +219,11 @@ def tts_generate(
219
 
220
  @torch.inference_mode()
221
  @spaces.GPU(duration=120)
222
- def refine_text(text: str, prompt: str):
 
 
 
 
223
  text = text_normalize(text)
224
  return refiner.refine_text(text, prompt=prompt)
225
 
 
 
1
  from typing import Union
2
  import numpy as np
3
 
4
  from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
 
5
  from modules.synthesize_audio import synthesize_audio
6
  from modules.utils.hf import spaces
7
  from modules.webui import webui_config
8
 
9
  import torch
10
+ import gradio as gr
11
 
12
  from modules.ssml_parser.SSMLParser import create_ssml_parser, SSMLBreak, SSMLSegment
13
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
 
23
  from modules.utils import audio
24
  from modules.SentenceSplitter import SentenceSplitter
25
 
 
26
  import torch.profiler
27
 
28
 
 
95
  enable_denoise=False,
96
  eos: str = "[uv_break]",
97
  spliter_thr: int = 100,
98
+ progress=gr.Progress(track_tqdm=True),
99
  ):
100
  try:
101
  batch_size = int(batch_size)
 
156
  spk_file=None,
157
  spliter_thr: int = 100,
158
  eos: str = "[uv_break]",
159
+ progress=gr.Progress(track_tqdm=True),
160
  ):
161
  try:
162
  batch_size = int(batch_size)
 
219
 
220
  @torch.inference_mode()
221
  @spaces.GPU(duration=120)
222
+ def refine_text(
223
+ text: str,
224
+ prompt: str,
225
+ progress=gr.Progress(track_tqdm=True),
226
+ ):
227
  text = text_normalize(text)
228
  return refiner.refine_text(text, prompt=prompt)
229