zhzluke96 commited on
Commit
d5d0921
1 Parent(s): 2be0618
data/speakers/Bob_ft10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91015b82a99c40034048090228b6d647ab99fd7b86e8babd6a7c3a9236e8d800
3
+ size 4508
modules/ChatTTS/ChatTTS/core.py CHANGED
@@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code
17
 
18
  from huggingface_hub import snapshot_download
19
 
20
- logging.basicConfig(level=logging.ERROR)
21
 
22
 
23
  class Chat:
 
17
 
18
  from huggingface_hub import snapshot_download
19
 
20
+ logging.basicConfig(level=logging.INFO)
21
 
22
 
23
  class Chat:
modules/SynthesizeSegments.py CHANGED
@@ -1,4 +1,5 @@
1
  import copy
 
2
  from box import Box
3
  from pydub import AudioSegment
4
  from typing import List, Union
@@ -160,7 +161,21 @@ class SynthesizeSegments:
160
  for i in range(0, len(bucket), self.batch_size):
161
  batch = bucket[i : i + self.batch_size]
162
  param_arr = [self.segment_to_generate_params(segment) for segment in batch]
163
- texts = [params.text + self.eos for params in param_arr]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  params = param_arr[0]
166
  audio_datas = generate_audio.generate_audio_batch(
@@ -182,6 +197,7 @@ class SynthesizeSegments:
182
 
183
  audio_segment = audio_data_to_segment(audio_data, sr)
184
  audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
 
185
  original_index = src_segments.index(segment)
186
  audio_segments[original_index] = audio_segment
187
 
@@ -226,13 +242,30 @@ class SynthesizeSegments:
226
 
227
  sentences = spliter.parse(text)
228
  for sentence in sentences:
229
- ret_segments.append(
230
- SSMLSegment(
231
- text=sentence,
232
- attrs=segment.attrs.copy(),
233
- params=copy.copy(segment.params),
234
- )
235
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  return ret_segments
238
 
 
1
  import copy
2
+ import re
3
  from box import Box
4
  from pydub import AudioSegment
5
  from typing import List, Union
 
161
  for i in range(0, len(bucket), self.batch_size):
162
  batch = bucket[i : i + self.batch_size]
163
  param_arr = [self.segment_to_generate_params(segment) for segment in batch]
164
+
165
+ def append_eos(text: str):
166
+ text = text.strip()
167
+ eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"]
168
+ has_eos = False
169
+ for eos in eos_arr:
170
+ if eos in text:
171
+ has_eos = True
172
+ break
173
+ if not has_eos:
174
+ text += self.eos
175
+ return text
176
+
177
+ # 这里会添加 end_of_text 到 text 之后
178
+ texts = [append_eos(params.text) for params in param_arr]
179
 
180
  params = param_arr[0]
181
  audio_datas = generate_audio.generate_audio_batch(
 
197
 
198
  audio_segment = audio_data_to_segment(audio_data, sr)
199
  audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
200
+ # compare by Box object
201
  original_index = src_segments.index(segment)
202
  audio_segments[original_index] = audio_segment
203
 
 
242
 
243
  sentences = spliter.parse(text)
244
  for sentence in sentences:
245
+ seg = SSMLSegment(
246
+ text=sentence,
247
+ attrs=segment.attrs.copy(),
248
+ params=copy.copy(segment.params),
 
 
249
  )
250
+ ret_segments.append(seg)
251
+ setattr(seg, "_idx", len(ret_segments) - 1)
252
+
253
+ def is_none_speak_segment(segment: SSMLSegment):
254
+ text = segment.text.strip()
255
+ regexp = r"\[[^\]]+?\]"
256
+ text = re.sub(regexp, "", text)
257
+ text = text.strip()
258
+ if not text:
259
+ return True
260
+ return False
261
+
262
+ # 将 none_speak 合并到前一个 speak segment
263
+ for i in range(1, len(ret_segments)):
264
+ if is_none_speak_segment(ret_segments[i]):
265
+ ret_segments[i - 1].text += ret_segments[i].text
266
+ ret_segments[i].text = ""
267
+ # 移除空的 segment
268
+ ret_segments = [seg for seg in ret_segments if seg.text.strip()]
269
 
270
  return ret_segments
271
 
modules/api/app_config.py CHANGED
@@ -1,6 +1,6 @@
1
  app_description = """
2
- ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
3
- ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
4
 
5
  项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
6
 
 
1
  app_description = """
2
+ 🍦 ChatTTS-Forge 是一个围绕 TTS 生成模型 ChatTTS 开发的项目,实现了 API Server 和 基于 Gradio 的 WebUI。<br/>
3
+ 🍦 ChatTTS-Forge is a project developed around the TTS generation model ChatTTS, implementing an API Server and a Gradio-based WebUI.
4
 
5
  项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
6
 
modules/api/impl/google_api.py CHANGED
@@ -1,38 +1,25 @@
1
- import base64
2
- from typing import Literal
3
  from fastapi import HTTPException
4
 
5
- import io
6
- import soundfile as sf
7
  from pydantic import BaseModel
8
 
9
 
10
- from modules.Enhancer.ResembleEnhance import (
11
- apply_audio_enhance,
12
- apply_audio_enhance_full,
13
- )
14
  from modules.api.Api import APIManager
15
- from modules.synthesize_audio import synthesize_audio
16
- from modules.utils import audio
17
- from modules.utils.audio import apply_prosody_to_audio_data
18
- from modules.normalization import text_normalize
 
19
 
20
- from modules import generate_audio as generate
21
- from modules.speaker import speaker_mgr
22
 
23
 
24
- from modules.ssml_parser.SSMLParser import create_ssml_parser
25
- from modules.SynthesizeSegments import (
26
- SynthesizeSegments,
27
- combine_audio_segments,
28
- )
29
-
30
  from modules.api import utils as api_utils
31
 
32
 
33
  class SynthesisInput(BaseModel):
34
- text: str = ""
35
- ssml: str = ""
36
 
37
 
38
  class VoiceSelectionParams(BaseModel):
@@ -50,24 +37,15 @@ class VoiceSelectionParams(BaseModel):
50
 
51
 
52
  class AudioConfig(BaseModel):
53
- audioEncoding: api_utils.AudioFormat = "mp3"
54
  speakingRate: float = 1
55
  pitch: float = 0
56
  volumeGainDb: float = 0
57
  sampleRateHertz: int = 24000
58
- batchSize: int = 1
59
  spliterThreshold: int = 100
60
 
61
 
62
- class EnhancerConfig(BaseModel):
63
- enabled: bool = False
64
- model: str = "resemble-enhance"
65
- nfe: int = 32
66
- solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
67
- lambd: float = 0.5
68
- tau: float = 0.5
69
-
70
-
71
  class GoogleTextSynthesizeRequest(BaseModel):
72
  input: SynthesisInput
73
  voice: VoiceSelectionParams
@@ -92,7 +70,11 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
92
  voice_name = voice.name
93
  infer_seed = voice.seed or 42
94
  eos = voice.eos or "[uv_break]"
95
- audio_format = audioConfig.audioEncoding or "mp3"
 
 
 
 
96
  speaking_rate = audioConfig.speakingRate or 1
97
  pitch = audioConfig.pitch or 0
98
  volume_gain_db = audioConfig.volumeGainDb or 0
@@ -101,6 +83,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
101
 
102
  spliter_threshold = audioConfig.spliterThreshold or 100
103
 
 
104
  sample_rate = audioConfig.sampleRateHertz or 24000
105
 
106
  params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
@@ -111,92 +94,68 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
111
  status_code=422, detail="The specified voice name is not supported."
112
  )
113
 
114
- if audio_format != "mp3" and audio_format != "wav":
115
  raise HTTPException(
116
- status_code=422, detail="Invalid audio encoding format specified."
117
  )
118
 
119
- if enhancerConfig.enabled:
120
- # TODO enhancer params checker
121
- pass
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  try:
124
  if input.text:
125
- # 处理文本合成逻辑
126
- text = text_normalize(input.text, is_end=True)
127
- sample_rate, audio_data = synthesize_audio(
128
- text,
129
- temperature=(
130
- voice.temperature
131
- if voice.temperature
132
- else params.get("temperature", 0.3)
133
- ),
134
- top_P=voice.topP if voice.topP else params.get("top_p", 0.7),
135
- top_K=voice.topK if voice.topK else params.get("top_k", 20),
136
- spk=params.get("spk", -1),
137
- infer_seed=infer_seed,
138
- prompt1=params.get("prompt1", ""),
139
- prompt2=params.get("prompt2", ""),
140
- prefix=params.get("prefix", ""),
141
- batch_size=batch_size,
142
- spliter_threshold=spliter_threshold,
143
- end_of_sentence=eos,
144
  )
145
 
146
- elif input.ssml:
147
- parser = create_ssml_parser()
148
- segments = parser.parse(input.ssml)
149
- for seg in segments:
150
- seg["text"] = text_normalize(seg["text"], is_end=True)
151
-
152
- if len(segments) == 0:
153
- raise HTTPException(
154
- status_code=422, detail="The SSML text is empty or parsing failed."
155
- )
156
-
157
- synthesize = SynthesizeSegments(
158
- batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
159
- )
160
- audio_segments = synthesize.synthesize_segments(segments)
161
- combined_audio = combine_audio_segments(audio_segments)
162
 
163
- sample_rate, audio_data = audio.pydub_to_np(combined_audio)
164
- else:
165
- raise HTTPException(
166
- status_code=422, detail="Either text or SSML input must be provided."
167
- )
168
 
169
- if enhancerConfig.enabled:
170
- audio_data, sample_rate = apply_audio_enhance_full(
171
- audio_data=audio_data,
172
- sr=sample_rate,
173
- nfe=enhancerConfig.nfe,
174
- solver=enhancerConfig.solver,
175
- lambd=enhancerConfig.lambd,
176
- tau=enhancerConfig.tau,
177
  )
178
 
179
- audio_data = apply_prosody_to_audio_data(
180
- audio_data,
181
- rate=speaking_rate,
182
- pitch=pitch,
183
- volume=volume_gain_db,
184
- sr=sample_rate,
185
- )
186
-
187
- buffer = io.BytesIO()
188
- sf.write(buffer, audio_data, sample_rate, format="wav")
189
- buffer.seek(0)
190
 
191
- if audio_format == "mp3":
192
- buffer = api_utils.wav_to_mp3(buffer)
193
 
194
- base64_encoded = base64.b64encode(buffer.read())
195
- base64_string = base64_encoded.decode("utf-8")
196
-
197
- return {
198
- "audioContent": f"data:audio/{audio_format.lower()};base64,{base64_string}"
199
- }
200
 
201
  except Exception as e:
202
  import logging
 
1
+ from typing import Union
 
2
  from fastapi import HTTPException
3
 
 
 
4
  from pydantic import BaseModel
5
 
6
 
 
 
 
 
7
  from modules.api.Api import APIManager
8
+ from modules.api.impl.handler.SSMLHandler import SSMLHandler
9
+ from modules.api.impl.handler.TTSHandler import TTSHandler
10
+ from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
11
+ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
12
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
13
 
14
+ from modules.speaker import Speaker, speaker_mgr
 
15
 
16
 
 
 
 
 
 
 
17
  from modules.api import utils as api_utils
18
 
19
 
20
  class SynthesisInput(BaseModel):
21
+ text: Union[str, None] = None
22
+ ssml: Union[str, None] = None
23
 
24
 
25
  class VoiceSelectionParams(BaseModel):
 
37
 
38
 
39
  class AudioConfig(BaseModel):
40
+ audioEncoding: AudioFormat = AudioFormat.mp3
41
  speakingRate: float = 1
42
  pitch: float = 0
43
  volumeGainDb: float = 0
44
  sampleRateHertz: int = 24000
45
+ batchSize: int = 4
46
  spliterThreshold: int = 100
47
 
48
 
 
 
 
 
 
 
 
 
 
49
  class GoogleTextSynthesizeRequest(BaseModel):
50
  input: SynthesisInput
51
  voice: VoiceSelectionParams
 
70
  voice_name = voice.name
71
  infer_seed = voice.seed or 42
72
  eos = voice.eos or "[uv_break]"
73
+ audio_format = audioConfig.audioEncoding
74
+
75
+ if not isinstance(audio_format, AudioFormat) and isinstance(audio_format, str):
76
+ audio_format = AudioFormat(audio_format)
77
+
78
  speaking_rate = audioConfig.speakingRate or 1
79
  pitch = audioConfig.pitch or 0
80
  volume_gain_db = audioConfig.volumeGainDb or 0
 
83
 
84
  spliter_threshold = audioConfig.spliterThreshold or 100
85
 
86
+ # TODO
87
  sample_rate = audioConfig.sampleRateHertz or 24000
88
 
89
  params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
 
94
  status_code=422, detail="The specified voice name is not supported."
95
  )
96
 
97
+ if not isinstance(params.get("spk"), Speaker):
98
  raise HTTPException(
99
+ status_code=422, detail="The specified voice name is not supported."
100
  )
101
 
102
+ speaker = params.get("spk")
103
+ tts_config = ChatTTSConfig(
104
+ style=params.get("style", ""),
105
+ temperature=voice.temperature,
106
+ top_k=voice.topK,
107
+ top_p=voice.topP,
108
+ )
109
+ infer_config = InferConfig(
110
+ batch_size=batch_size,
111
+ spliter_threshold=spliter_threshold,
112
+ eos=eos,
113
+ seed=infer_seed,
114
+ )
115
+ adjust_config = AdjustConfig(
116
+ speaking_rate=speaking_rate,
117
+ pitch=pitch,
118
+ volume_gain_db=volume_gain_db,
119
+ )
120
+ enhancer_config = enhancerConfig
121
+
122
+ mime_type = f"audio/{audio_format.value}"
123
+ if audio_format == AudioFormat.mp3:
124
+ mime_type = "audio/mpeg"
125
  try:
126
  if input.text:
127
+ text_content = input.text
128
+
129
+ handler = TTSHandler(
130
+ text_content=text_content,
131
+ spk=speaker,
132
+ tts_config=tts_config,
133
+ infer_config=infer_config,
134
+ adjust_config=adjust_config,
135
+ enhancer_config=enhancer_config,
 
 
 
 
 
 
 
 
 
 
136
  )
137
 
138
+ base64_string = handler.enqueue_to_base64(format=audio_format)
139
+ return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ elif input.ssml:
142
+ ssml_content = input.ssml
 
 
 
143
 
144
+ handler = SSMLHandler(
145
+ ssml_content=ssml_content,
146
+ infer_config=infer_config,
147
+ adjust_config=adjust_config,
148
+ enhancer_config=enhancer_config,
 
 
 
149
  )
150
 
151
+ base64_string = handler.enqueue_to_base64(format=audio_format)
 
 
 
 
 
 
 
 
 
 
152
 
153
+ return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
 
154
 
155
+ else:
156
+ raise HTTPException(
157
+ status_code=422, detail="Invalid input text or ssml specified."
158
+ )
 
 
159
 
160
  except Exception as e:
161
  import logging
modules/api/impl/handler/AudioHandler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import numpy as np
4
+ import soundfile as sf
5
+
6
+ from modules.api.impl.model.audio_model import AudioFormat
7
+ from modules.api import utils as api_utils
8
+
9
+
10
+ class AudioHandler:
11
+ def enqueue(self) -> tuple[np.ndarray, int]:
12
+ raise NotImplementedError
13
+
14
+ def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
15
+ audio_data, sample_rate = self.enqueue()
16
+
17
+ buffer = io.BytesIO()
18
+ sf.write(buffer, audio_data, sample_rate, format="wav")
19
+ buffer.seek(0)
20
+
21
+ if format == AudioFormat.mp3:
22
+ buffer = api_utils.wav_to_mp3(buffer)
23
+
24
+ return buffer
25
+
26
+ def enqueue_to_bytes(self, format: AudioFormat) -> bytes:
27
+ buffer = self.enqueue_to_buffer(format=format)
28
+ binary = buffer.read()
29
+ return binary
30
+
31
+ def enqueue_to_base64(self, format: AudioFormat) -> str:
32
+ binary = self.enqueue_to_bytes(format=format)
33
+
34
+ base64_encoded = base64.b64encode(binary)
35
+ base64_string = base64_encoded.decode("utf-8")
36
+
37
+ return base64_string
modules/api/impl/handler/SSMLHandler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+ import numpy as np
3
+
4
+ from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
5
+ from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
6
+ from modules.api.impl.handler.AudioHandler import AudioHandler
7
+ from modules.api.impl.model.audio_model import AdjustConfig
8
+ from modules.api.impl.model.chattts_model import InferConfig
9
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
10
+ from modules.normalization import text_normalize
11
+ from modules.ssml_parser.SSMLParser import create_ssml_parser
12
+ from modules.utils import audio
13
+
14
+
15
+ class SSMLHandler(AudioHandler):
16
+ def __init__(
17
+ self,
18
+ ssml_content: str,
19
+ infer_config: InferConfig,
20
+ adjust_config: AdjustConfig,
21
+ enhancer_config: EnhancerConfig,
22
+ ) -> None:
23
+ assert isinstance(ssml_content, str), "ssml_content must be a string."
24
+ assert isinstance(
25
+ infer_config, InferConfig
26
+ ), "infer_config must be an InferConfig object."
27
+ assert isinstance(
28
+ adjust_config, AdjustConfig
29
+ ), "adjest_config should be AdjustConfig"
30
+ assert isinstance(
31
+ enhancer_config, EnhancerConfig
32
+ ), "enhancer_config must be an EnhancerConfig object."
33
+
34
+ self.ssml_content = ssml_content
35
+ self.infer_config = infer_config
36
+ self.adjest_config = adjust_config
37
+ self.enhancer_config = enhancer_config
38
+
39
+ self.validate()
40
+
41
+ def validate(self):
42
+ # TODO params checker
43
+ pass
44
+
45
+ def enqueue(self) -> tuple[np.ndarray, int]:
46
+ ssml_content = self.ssml_content
47
+ infer_config = self.infer_config
48
+ adjust_config = self.adjest_config
49
+ enhancer_config = self.enhancer_config
50
+
51
+ parser = create_ssml_parser()
52
+ segments = parser.parse(ssml_content)
53
+ for seg in segments:
54
+ seg["text"] = text_normalize(seg["text"], is_end=True)
55
+
56
+ if len(segments) == 0:
57
+ raise HTTPException(
58
+ status_code=422, detail="The SSML text is empty or parsing failed."
59
+ )
60
+
61
+ synthesize = SynthesizeSegments(
62
+ batch_size=infer_config.batch_size,
63
+ eos=infer_config.eos,
64
+ spliter_thr=infer_config.spliter_threshold,
65
+ )
66
+ audio_segments = synthesize.synthesize_segments(segments)
67
+ combined_audio = combine_audio_segments(audio_segments)
68
+
69
+ sample_rate, audio_data = audio.pydub_to_np(combined_audio)
70
+
71
+ if enhancer_config.enabled:
72
+ nfe = enhancer_config.nfe
73
+ solver = enhancer_config.solver
74
+ lambd = enhancer_config.lambd
75
+ tau = enhancer_config.tau
76
+
77
+ audio_data, sample_rate = apply_audio_enhance_full(
78
+ audio_data=audio_data,
79
+ sr=sample_rate,
80
+ nfe=nfe,
81
+ solver=solver,
82
+ lambd=lambd,
83
+ tau=tau,
84
+ )
85
+
86
+ audio_data = audio.apply_prosody_to_audio_data(
87
+ audio_data=audio_data,
88
+ rate=adjust_config.speed_rate,
89
+ pitch=adjust_config.pitch,
90
+ volume=adjust_config.volume_gain_db,
91
+ sr=sample_rate,
92
+ )
93
+
94
+ return audio_data, sample_rate
modules/api/impl/handler/TTSHandler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
3
+ from modules.api.impl.handler.AudioHandler import AudioHandler
4
+ from modules.api.impl.model.audio_model import AdjustConfig
5
+ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
6
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
7
+ from modules.normalization import text_normalize
8
+ from modules.speaker import Speaker
9
+ from modules.synthesize_audio import synthesize_audio
10
+
11
+ from modules.utils.audio import apply_prosody_to_audio_data
12
+
13
+
14
+ class TTSHandler(AudioHandler):
15
+ def __init__(
16
+ self,
17
+ text_content: str,
18
+ spk: Speaker,
19
+ tts_config: ChatTTSConfig,
20
+ infer_config: InferConfig,
21
+ adjust_config: AdjustConfig,
22
+ enhancer_config: EnhancerConfig,
23
+ ):
24
+ assert isinstance(text_content, str), "text_content should be str"
25
+ assert isinstance(spk, Speaker), "spk should be Speaker"
26
+ assert isinstance(
27
+ tts_config, ChatTTSConfig
28
+ ), "tts_config should be ChatTTSConfig"
29
+ assert isinstance(
30
+ infer_config, InferConfig
31
+ ), "infer_config should be InferConfig"
32
+ assert isinstance(
33
+ adjust_config, AdjustConfig
34
+ ), "adjest_config should be AdjustConfig"
35
+ assert isinstance(
36
+ enhancer_config, EnhancerConfig
37
+ ), "enhancer_config should be EnhancerConfig"
38
+
39
+ self.text_content = text_content
40
+ self.spk = spk
41
+ self.tts_config = tts_config
42
+ self.infer_config = infer_config
43
+ self.adjest_config = adjust_config
44
+ self.enhancer_config = enhancer_config
45
+
46
+ self.validate()
47
+
48
+ def validate(self):
49
+ # TODO params checker
50
+ pass
51
+
52
+ def enqueue(self) -> tuple[np.ndarray, int]:
53
+ text = text_normalize(self.text_content)
54
+ tts_config = self.tts_config
55
+ infer_config = self.infer_config
56
+ adjust_config = self.adjest_config
57
+ enhancer_config = self.enhancer_config
58
+
59
+ sample_rate, audio_data = synthesize_audio(
60
+ text,
61
+ spk=self.spk,
62
+ temperature=tts_config.temperature,
63
+ top_P=tts_config.top_p,
64
+ top_K=tts_config.top_k,
65
+ prompt1=tts_config.prompt1,
66
+ prompt2=tts_config.prompt2,
67
+ prefix=tts_config.prefix,
68
+ infer_seed=infer_config.seed,
69
+ batch_size=infer_config.batch_size,
70
+ spliter_threshold=infer_config.spliter_threshold,
71
+ end_of_sentence=infer_config.eos,
72
+ )
73
+
74
+ if enhancer_config.enabled:
75
+ nfe = enhancer_config.nfe
76
+ solver = enhancer_config.solver
77
+ lambd = enhancer_config.lambd
78
+ tau = enhancer_config.tau
79
+
80
+ audio_data, sample_rate = apply_audio_enhance_full(
81
+ audio_data=audio_data,
82
+ sr=sample_rate,
83
+ nfe=nfe,
84
+ solver=solver,
85
+ lambd=lambd,
86
+ tau=tau,
87
+ )
88
+
89
+ audio_data = apply_prosody_to_audio_data(
90
+ audio_data=audio_data,
91
+ rate=adjust_config.speed_rate,
92
+ pitch=adjust_config.pitch,
93
+ volume=adjust_config.volume_gain_db,
94
+ sr=sample_rate,
95
+ )
96
+
97
+ return audio_data, sample_rate
modules/api/impl/model/audio_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class AudioFormat(str, Enum):
7
+ mp3 = "mp3"
8
+ wav = "wav"
9
+
10
+
11
+ class AdjustConfig(BaseModel):
12
+ pitch: float = 0
13
+ speed_rate: float = 1
14
+ volume_gain_db: float = 0
modules/api/impl/model/chattts_model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class ChatTTSConfig(BaseModel):
5
+ style: str = ""
6
+ temperature: float = 0.3
7
+ top_p: float = 0.7
8
+ top_k: int = 20
9
+ prompt1: str = ""
10
+ prompt2: str = ""
11
+ prefix: str = ""
12
+
13
+
14
+ class InferConfig(BaseModel):
15
+ batch_size: int = 4
16
+ spliter_threshold: int = 100
17
+ # end_of_sentence
18
+ eos: str = "[uv_break]"
19
+ seed: int = 42
modules/api/impl/model/enhancer_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class EnhancerConfig(BaseModel):
6
+ enabled: bool = False
7
+ model: str = "resemble-enhance"
8
+ nfe: int = 32
9
+ solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
10
+ lambd: float = 0.5
11
+ tau: float = 0.5
modules/api/impl/openai_api.py CHANGED
@@ -1,42 +1,38 @@
1
  from fastapi import File, Form, HTTPException, Body, UploadFile
2
- from fastapi.responses import StreamingResponse
3
 
4
- import io
5
  from numpy import clip
6
- import soundfile as sf
7
  from pydantic import BaseModel, Field
8
- from fastapi.responses import FileResponse
9
-
10
 
11
- from modules.synthesize_audio import synthesize_audio
12
- from modules.normalization import text_normalize
13
 
14
- from modules import generate_audio as generate
 
 
 
15
 
16
 
17
- from typing import List, Literal, Optional, Union
18
- import pyrubberband as pyrb
19
 
20
  from modules.api import utils as api_utils
21
  from modules.api.Api import APIManager
22
 
23
- from modules.speaker import speaker_mgr
24
  from modules.data import styles_mgr
25
 
26
- import numpy as np
27
-
28
 
29
  class AudioSpeechRequest(BaseModel):
30
  input: str # 需要合成的文本
31
  model: str = "chattts-4w"
32
  voice: str = "female2"
33
- response_format: Literal["mp3", "wav"] = "mp3"
34
  speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
35
  seed: int = 42
 
36
  temperature: float = 0.3
 
 
 
37
  style: str = ""
38
- # 是否开启batch合成,小于等于1表示不适用batch
39
- # 开启batch合成会自动分割句子
40
  batch_size: int = Field(1, ge=1, le=20, description="Batch size")
41
  spliter_threshold: float = Field(
42
  100, ge=10, le=1024, description="Threshold for sentence spliter"
@@ -44,6 +40,9 @@ class AudioSpeechRequest(BaseModel):
44
  # end of sentence
45
  eos: str = "[uv_break]"
46
 
 
 
 
47
 
48
  async def openai_speech_api(
49
  request: AudioSpeechRequest = Body(
@@ -55,7 +54,14 @@ async def openai_speech_api(
55
  voice = request.voice
56
  style = request.style
57
  eos = request.eos
 
 
58
  response_format = request.response_format
 
 
 
 
 
59
  batch_size = request.batch_size
60
  spliter_threshold = request.spliter_threshold
61
  speed = request.speed
@@ -71,49 +77,45 @@ async def openai_speech_api(
71
  except:
72
  raise HTTPException(status_code=400, detail="Invalid style.")
73
 
74
- try:
75
- # Normalize the text
76
- text = text_normalize(input_text, is_end=True)
77
-
78
- # Calculate speaker and style based on input voice
79
- params = api_utils.calc_spk_style(spk=voice, style=style)
80
-
81
- spk = params.get("spk", -1)
82
- seed = params.get("seed", request.seed or 42)
83
- temperature = params.get("temperature", request.temperature or 0.3)
84
- prompt1 = params.get("prompt1", "")
85
- prompt2 = params.get("prompt2", "")
86
- prefix = params.get("prefix", "")
87
-
88
- # Generate audio
89
- sample_rate, audio_data = synthesize_audio(
90
- text,
91
- temperature=temperature,
92
- top_P=0.7,
93
- top_K=20,
94
- spk=spk,
95
- infer_seed=seed,
96
- batch_size=batch_size,
97
- spliter_threshold=spliter_threshold,
98
- prompt1=prompt1,
99
- prompt2=prompt2,
100
- prefix=prefix,
101
- end_of_sentence=eos,
102
- )
103
 
104
- if speed != 1:
105
- audio_data = pyrb.time_stretch(audio_data, sample_rate, speed)
 
106
 
107
- # Convert audio data to wav format
108
- buffer = io.BytesIO()
109
- sf.write(buffer, audio_data, sample_rate, format="wav")
110
- buffer.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- if response_format == "mp3":
113
- # Convert wav to mp3
114
- buffer = api_utils.wav_to_mp3(buffer)
115
 
116
- return StreamingResponse(buffer, media_type="audio/mp3")
 
 
 
117
 
118
  except Exception as e:
119
  import logging
@@ -150,7 +152,6 @@ class TranscriptionsVerboseResponse(BaseModel):
150
  def setup(app: APIManager):
151
  app.post(
152
  "/v1/audio/speech",
153
- response_class=FileResponse,
154
  description="""
155
  openai api document:
156
  [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
 
1
  from fastapi import File, Form, HTTPException, Body, UploadFile
 
2
 
 
3
  from numpy import clip
 
4
  from pydantic import BaseModel, Field
5
+ from fastapi.responses import StreamingResponse
 
6
 
 
 
7
 
8
+ from modules.api.impl.handler.TTSHandler import TTSHandler
9
+ from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
+ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
11
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
12
 
13
 
14
+ from typing import List, Optional
 
15
 
16
  from modules.api import utils as api_utils
17
  from modules.api.Api import APIManager
18
 
19
+ from modules.speaker import Speaker, speaker_mgr
20
  from modules.data import styles_mgr
21
 
 
 
22
 
23
  class AudioSpeechRequest(BaseModel):
24
  input: str # 需要合成的文本
25
  model: str = "chattts-4w"
26
  voice: str = "female2"
27
+ response_format: AudioFormat = "mp3"
28
  speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
29
  seed: int = 42
30
+
31
  temperature: float = 0.3
32
+ top_k: int = 20
33
+ top_p: float = 0.7
34
+
35
  style: str = ""
 
 
36
  batch_size: int = Field(1, ge=1, le=20, description="Batch size")
37
  spliter_threshold: float = Field(
38
  100, ge=10, le=1024, description="Threshold for sentence spliter"
 
40
  # end of sentence
41
  eos: str = "[uv_break]"
42
 
43
+ enhance: bool = False
44
+ denoise: bool = False
45
+
46
 
47
  async def openai_speech_api(
48
  request: AudioSpeechRequest = Body(
 
54
  voice = request.voice
55
  style = request.style
56
  eos = request.eos
57
+ seed = request.seed
58
+
59
  response_format = request.response_format
60
+ if not isinstance(response_format, AudioFormat) and isinstance(
61
+ response_format, str
62
+ ):
63
+ response_format = AudioFormat(response_format)
64
+
65
  batch_size = request.batch_size
66
  spliter_threshold = request.spliter_threshold
67
  speed = request.speed
 
77
  except:
78
  raise HTTPException(status_code=400, detail="Invalid style.")
79
 
80
+ ctx_params = api_utils.calc_spk_style(spk=voice, style=style)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ speaker = ctx_params.get("spk")
83
+ if not isinstance(speaker, Speaker):
84
+ raise HTTPException(status_code=400, detail="Invalid voice.")
85
 
86
+ tts_config = ChatTTSConfig(
87
+ style=style,
88
+ temperature=request.temperature,
89
+ top_k=request.top_k,
90
+ top_p=request.top_p,
91
+ )
92
+ infer_config = InferConfig(
93
+ batch_size=batch_size,
94
+ spliter_threshold=spliter_threshold,
95
+ eos=eos,
96
+ seed=seed,
97
+ )
98
+ adjust_config = AdjustConfig(speaking_rate=speed)
99
+ enhancer_config = EnhancerConfig(
100
+ enabled=request.enhance or request.denoise or False,
101
+ lambd=0.9 if request.denoise else 0.1,
102
+ )
103
+ try:
104
+ handler = TTSHandler(
105
+ text_content=input_text,
106
+ spk=speaker,
107
+ tts_config=tts_config,
108
+ infer_config=infer_config,
109
+ adjust_config=adjust_config,
110
+ enhancer_config=enhancer_config,
111
+ )
112
 
113
+ buffer = handler.enqueue_to_buffer(response_format)
 
 
114
 
115
+ mime_type = f"audio/{response_format.value}"
116
+ if response_format == AudioFormat.mp3:
117
+ mime_type = "audio/mpeg"
118
+ return StreamingResponse(buffer, media_type=mime_type)
119
 
120
  except Exception as e:
121
  import logging
 
152
  def setup(app: APIManager):
153
  app.post(
154
  "/v1/audio/speech",
 
155
  description="""
156
  openai api document:
157
  [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
modules/api/impl/refiner_api.py CHANGED
@@ -31,6 +31,7 @@ async def refiner_prompt_post(request: RefineTextRequest):
31
  text = request.text
32
  if request.normalize:
33
  text = text_normalize(request.text)
 
34
  refined_text = refiner.refine_text(
35
  text=text,
36
  prompt=request.prompt,
 
31
  text = request.text
32
  if request.normalize:
33
  text = text_normalize(request.text)
34
+ # TODO 其实这里可以做 spliter 和 batch 处理
35
  refined_text = refiner.refine_text(
36
  text=text,
37
  prompt=request.prompt,
modules/api/impl/ssml_api.py CHANGED
@@ -1,27 +1,22 @@
1
  from fastapi import HTTPException, Body
2
  from fastapi.responses import StreamingResponse
3
 
4
- import io
5
  from pydantic import BaseModel
6
  from fastapi.responses import FileResponse
7
 
8
 
9
- from modules.normalization import text_normalize
10
- from modules.ssml_parser.SSMLParser import create_ssml_parser
11
- from modules.SynthesizeSegments import (
12
- SynthesizeSegments,
13
- combine_audio_segments,
14
- )
15
 
16
 
17
- from modules.api import utils as api_utils
18
-
19
  from modules.api.Api import APIManager
20
 
21
 
22
  class SSMLRequest(BaseModel):
23
  ssml: str
24
- format: str = "mp3"
25
 
26
  # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
27
  batch_size: int = 4
@@ -31,6 +26,9 @@ class SSMLRequest(BaseModel):
31
 
32
  spliter_thr: int = 100
33
 
 
 
 
34
 
35
  async def synthesize_ssml_api(
36
  request: SSMLRequest = Body(
@@ -43,6 +41,8 @@ async def synthesize_ssml_api(
43
  batch_size = request.batch_size
44
  eos = request.eos
45
  spliter_thr = request.spliter_thr
 
 
46
 
47
  if batch_size < 1:
48
  raise HTTPException(
@@ -62,22 +62,27 @@ async def synthesize_ssml_api(
62
  status_code=400, detail="Format must be 'mp3' or 'wav'."
63
  )
64
 
65
- parser = create_ssml_parser()
66
- segments = parser.parse(ssml)
67
- for seg in segments:
68
- seg["text"] = text_normalize(seg["text"], is_end=True)
69
-
70
- synthesize = SynthesizeSegments(
71
- batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
72
  )
73
- audio_segments = synthesize.synthesize_segments(segments)
74
- combined_audio = combine_audio_segments(audio_segments)
75
- buffer = io.BytesIO()
76
- combined_audio.export(buffer, format="wav")
77
- buffer.seek(0)
78
- if format == "mp3":
79
- buffer = api_utils.wav_to_mp3(buffer)
80
- return StreamingResponse(buffer, media_type=f"audio/{format}")
 
 
 
 
 
 
 
 
81
 
82
  except Exception as e:
83
  import logging
 
1
  from fastapi import HTTPException, Body
2
  from fastapi.responses import StreamingResponse
3
 
 
4
  from pydantic import BaseModel
5
  from fastapi.responses import FileResponse
6
 
7
 
8
+ from modules.api.impl.handler.SSMLHandler import SSMLHandler
9
+ from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
+ from modules.api.impl.model.chattts_model import InferConfig
11
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
 
 
12
 
13
 
 
 
14
  from modules.api.Api import APIManager
15
 
16
 
17
  class SSMLRequest(BaseModel):
18
  ssml: str
19
+ format: AudioFormat = "mp3"
20
 
21
  # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
22
  batch_size: int = 4
 
26
 
27
  spliter_thr: int = 100
28
 
29
+ enhancer: EnhancerConfig = EnhancerConfig()
30
+ adjuster: AdjustConfig = AdjustConfig()
31
+
32
 
33
  async def synthesize_ssml_api(
34
  request: SSMLRequest = Body(
 
41
  batch_size = request.batch_size
42
  eos = request.eos
43
  spliter_thr = request.spliter_thr
44
+ enhancer = request.enhancer
45
+ adjuster = request.adjuster
46
 
47
  if batch_size < 1:
48
  raise HTTPException(
 
62
  status_code=400, detail="Format must be 'mp3' or 'wav'."
63
  )
64
 
65
+ infer_config = InferConfig(
66
+ batch_size=batch_size,
67
+ spliter_threshold=spliter_thr,
68
+ eos=eos,
 
 
 
69
  )
70
+ adjust_config = adjuster
71
+ enhancer_config = enhancer
72
+
73
+ handler = SSMLHandler(
74
+ ssml_content=ssml,
75
+ infer_config=infer_config,
76
+ adjust_config=adjust_config,
77
+ enhancer_config=enhancer_config,
78
+ )
79
+
80
+ buffer = handler.enqueue_to_buffer(format=request.format)
81
+
82
+ mime_type = f"audio/{format}"
83
+ if format == AudioFormat.mp3:
84
+ mime_type = "audio/mpeg"
85
+ return StreamingResponse(buffer, media_type=mime_type)
86
 
87
  except Exception as e:
88
  import logging
modules/api/impl/tts_api.py CHANGED
@@ -1,17 +1,18 @@
1
  from fastapi import Depends, HTTPException, Query
2
  from fastapi.responses import StreamingResponse
3
 
4
- import io
5
  from pydantic import BaseModel
6
- import soundfile as sf
7
  from fastapi.responses import FileResponse
8
 
9
 
10
- from modules.normalization import text_normalize
 
 
 
11
 
12
  from modules.api import utils as api_utils
13
  from modules.api.Api import APIManager
14
- from modules.synthesize_audio import synthesize_audio
15
 
16
 
17
  class TTSParams(BaseModel):
@@ -23,10 +24,10 @@ class TTSParams(BaseModel):
23
  temperature: float = Query(
24
  0.3, description="Temperature for sampling (may be overridden by style or spk)"
25
  )
26
- top_P: float = Query(
27
  0.5, description="Top P for sampling (may be overridden by style or spk)"
28
  )
29
- top_K: int = Query(
30
  20, description="Top K for sampling (may be overridden by style or spk)"
31
  )
32
  seed: int = Query(
@@ -38,7 +39,14 @@ class TTSParams(BaseModel):
38
  prefix: str = Query("", description="Text prefix for inference")
39
  bs: str = Query("8", description="Batch size for inference")
40
  thr: str = Query("100", description="Threshold for sentence spliter")
41
- eos: str = Query("", description="End of sentence str")
 
 
 
 
 
 
 
42
 
43
 
44
  async def synthesize_tts(params: TTSParams = Depends()):
@@ -55,18 +63,18 @@ async def synthesize_tts(params: TTSParams = Depends()):
55
  status_code=422, detail="Temperature must be between 0 and 1"
56
  )
57
 
58
- # Validate top_P
59
- if not (0 <= params.top_P <= 1):
60
- raise HTTPException(status_code=422, detail="top_P must be between 0 and 1")
61
 
62
- # Validate top_K
63
- if params.top_K <= 0:
64
  raise HTTPException(
65
- status_code=422, detail="top_K must be a positive integer"
66
  )
67
- if params.top_K > 100:
68
  raise HTTPException(
69
- status_code=422, detail="top_K must be less than or equal to 100"
70
  )
71
 
72
  # Validate format
@@ -76,11 +84,13 @@ async def synthesize_tts(params: TTSParams = Depends()):
76
  detail="Invalid format. Supported formats are mp3 and wav",
77
  )
78
 
79
- text = text_normalize(params.text, is_end=False)
80
-
81
  calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
82
 
83
  spk = calc_params.get("spk", params.spk)
 
 
 
 
84
  seed = params.seed or calc_params.get("seed", params.seed)
85
  temperature = params.temperature or calc_params.get(
86
  "temperature", params.temperature
@@ -93,29 +103,46 @@ async def synthesize_tts(params: TTSParams = Depends()):
93
  batch_size = int(params.bs)
94
  threshold = int(params.thr)
95
 
96
- sample_rate, audio_data = synthesize_audio(
97
- text,
98
  temperature=temperature,
99
- top_P=params.top_P,
100
- top_K=params.top_K,
101
- spk=spk,
102
- infer_seed=seed,
103
  prompt1=prompt1,
104
  prompt2=prompt2,
105
- prefix=prefix,
 
106
  batch_size=batch_size,
107
  spliter_threshold=threshold,
108
- end_of_sentence=eos,
 
 
 
 
 
 
 
 
 
 
109
  )
110
 
111
- buffer = io.BytesIO()
112
- sf.write(buffer, audio_data, sample_rate, format="wav")
113
- buffer.seek(0)
 
 
 
 
 
114
 
115
- if format == "mp3":
116
- buffer = api_utils.wav_to_mp3(buffer)
117
 
118
- return StreamingResponse(buffer, media_type="audio/wav")
 
 
 
119
 
120
  except Exception as e:
121
  import logging
 
1
  from fastapi import Depends, HTTPException, Query
2
  from fastapi.responses import StreamingResponse
3
 
 
4
  from pydantic import BaseModel
 
5
  from fastapi.responses import FileResponse
6
 
7
 
8
+ from modules.api.impl.handler.TTSHandler import TTSHandler
9
+ from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
+ from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
11
+ from modules.api.impl.model.enhancer_model import EnhancerConfig
12
 
13
  from modules.api import utils as api_utils
14
  from modules.api.Api import APIManager
15
+ from modules.speaker import Speaker
16
 
17
 
18
  class TTSParams(BaseModel):
 
24
  temperature: float = Query(
25
  0.3, description="Temperature for sampling (may be overridden by style or spk)"
26
  )
27
+ top_p: float = Query(
28
  0.5, description="Top P for sampling (may be overridden by style or spk)"
29
  )
30
+ top_k: int = Query(
31
  20, description="Top K for sampling (may be overridden by style or spk)"
32
  )
33
  seed: int = Query(
 
39
  prefix: str = Query("", description="Text prefix for inference")
40
  bs: str = Query("8", description="Batch size for inference")
41
  thr: str = Query("100", description="Threshold for sentence spliter")
42
+ eos: str = Query("[uv_break]", description="End of sentence str")
43
+
44
+ enhance: bool = Query(False, description="Enable enhancer")
45
+ denoise: bool = Query(False, description="Enable denoiser")
46
+
47
+ speed: float = Query(1.0, description="Speed of the audio")
48
+ pitch: float = Query(0, description="Pitch of the audio")
49
+ volume_gain: float = Query(0, description="Volume gain of the audio")
50
 
51
 
52
  async def synthesize_tts(params: TTSParams = Depends()):
 
63
  status_code=422, detail="Temperature must be between 0 and 1"
64
  )
65
 
66
+ # Validate top_p
67
+ if not (0 <= params.top_p <= 1):
68
+ raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")
69
 
70
+ # Validate top_k
71
+ if params.top_k <= 0:
72
  raise HTTPException(
73
+ status_code=422, detail="top_k must be a positive integer"
74
  )
75
+ if params.top_k > 100:
76
  raise HTTPException(
77
+ status_code=422, detail="top_k must be less than or equal to 100"
78
  )
79
 
80
  # Validate format
 
84
  detail="Invalid format. Supported formats are mp3 and wav",
85
  )
86
 
 
 
87
  calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
88
 
89
  spk = calc_params.get("spk", params.spk)
90
+ if not isinstance(spk, Speaker):
91
+ raise HTTPException(status_code=422, detail="Invalid speaker")
92
+
93
+ style = calc_params.get("style", params.style)
94
  seed = params.seed or calc_params.get("seed", params.seed)
95
  temperature = params.temperature or calc_params.get(
96
  "temperature", params.temperature
 
103
  batch_size = int(params.bs)
104
  threshold = int(params.thr)
105
 
106
+ tts_config = ChatTTSConfig(
107
+ style=style,
108
  temperature=temperature,
109
+ top_k=params.top_k,
110
+ top_p=params.top_p,
111
+ prefix=prefix,
 
112
  prompt1=prompt1,
113
  prompt2=prompt2,
114
+ )
115
+ infer_config = InferConfig(
116
  batch_size=batch_size,
117
  spliter_threshold=threshold,
118
+ eos=eos,
119
+ seed=seed,
120
+ )
121
+ adjust_config = AdjustConfig(
122
+ pitch=params.pitch,
123
+ speed_rate=params.speed,
124
+ volume_gain_db=params.volume_gain,
125
+ )
126
+ enhancer_config = EnhancerConfig(
127
+ enabled=params.enhance or params.denoise or False,
128
+ lambd=0.9 if params.denoise else 0.1,
129
  )
130
 
131
+ handler = TTSHandler(
132
+ text_content=params.text,
133
+ spk=spk,
134
+ tts_config=tts_config,
135
+ infer_config=infer_config,
136
+ adjust_config=adjust_config,
137
+ enhancer_config=enhancer_config,
138
+ )
139
 
140
+ buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
 
141
 
142
+ media_type = f"audio/{params.format}"
143
+ if params.format == "mp3":
144
+ media_type = "audio/mpeg"
145
+ return StreamingResponse(buffer, media_type=media_type)
146
 
147
  except Exception as e:
148
  import logging
modules/api/impl/xtts_v2_api.py CHANGED
@@ -30,8 +30,19 @@ class XTTS_V2_Settings:
30
  self.top_k = 20
31
  self.enable_text_splitting = True
32
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  class TTSSettingsRequest(BaseModel):
 
35
  stream_chunk_size: int
36
  temperature: float
37
  speed: float
@@ -41,6 +52,15 @@ class TTSSettingsRequest(BaseModel):
41
  top_k: int
42
  enable_text_splitting: bool
43
 
 
 
 
 
 
 
 
 
 
44
 
45
  class SynthesisRequest(BaseModel):
46
  text: str
@@ -79,17 +99,22 @@ def setup(app: APIManager):
79
 
80
  text = text_normalize(text, is_end=True)
81
  sample_rate, audio_data = synthesize_audio(
82
- text=text,
83
- temperature=XTTSV2.temperature,
84
  # length_penalty=XTTSV2.length_penalty,
85
  # repetition_penalty=XTTSV2.repetition_penalty,
 
 
86
  top_P=XTTSV2.top_p,
87
  top_K=XTTSV2.top_k,
88
  spk=spk,
89
- spliter_threshold=XTTSV2.stream_chunk_size,
90
- # TODO 支持设置 batch_size
91
- batch_size=4,
92
- end_of_sentence="[uv_break]",
 
 
 
 
93
  )
94
 
95
  if XTTSV2.speed:
@@ -145,6 +170,8 @@ def setup(app: APIManager):
145
  )
146
 
147
  XTTSV2.stream_chunk_size = request.stream_chunk_size
 
 
148
  XTTSV2.temperature = request.temperature
149
  XTTSV2.speed = request.speed
150
  XTTSV2.length_penalty = request.length_penalty
@@ -152,6 +179,25 @@ def setup(app: APIManager):
152
  XTTSV2.top_p = request.top_p
153
  XTTSV2.top_k = request.top_k
154
  XTTSV2.enable_text_splitting = request.enable_text_splitting
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return {"message": "Settings successfully applied"}
156
  except Exception as e:
157
  if isinstance(e, HTTPException):
 
30
  self.top_k = 20
31
  self.enable_text_splitting = True
32
 
33
+ # 下面是额外配置 xtts_v2 中不包含的,但是本系统需要的
34
+ self.batch_size = 4
35
+ self.eos = "[uv_break]"
36
+ self.infer_seed = 42
37
+ self.use_decoder = True
38
+ self.prompt1 = ""
39
+ self.prompt2 = ""
40
+ self.prefix = ""
41
+ self.spliter_threshold = 100
42
+
43
 
44
  class TTSSettingsRequest(BaseModel):
45
+ # 这个 stream_chunk 现在当作 spliter_threshold 用
46
  stream_chunk_size: int
47
  temperature: float
48
  speed: float
 
52
  top_k: int
53
  enable_text_splitting: bool
54
 
55
+ batch_size: int = None
56
+ eos: str = None
57
+ infer_seed: int = None
58
+ use_decoder: bool = None
59
+ prompt1: str = None
60
+ prompt2: str = None
61
+ prefix: str = None
62
+ spliter_threshold: int = None
63
+
64
 
65
  class SynthesisRequest(BaseModel):
66
  text: str
 
99
 
100
  text = text_normalize(text, is_end=True)
101
  sample_rate, audio_data = synthesize_audio(
102
+ # TODO: 这两个参数现在用不着...但是其实gpt是可以用的
 
103
  # length_penalty=XTTSV2.length_penalty,
104
  # repetition_penalty=XTTSV2.repetition_penalty,
105
+ text=text,
106
+ temperature=XTTSV2.temperature,
107
  top_P=XTTSV2.top_p,
108
  top_K=XTTSV2.top_k,
109
  spk=spk,
110
+ spliter_threshold=XTTSV2.spliter_threshold,
111
+ batch_size=XTTSV2.batch_size,
112
+ end_of_sentence=XTTSV2.eos,
113
+ infer_seed=XTTSV2.infer_seed,
114
+ use_decoder=XTTSV2.use_decoder,
115
+ prompt1=XTTSV2.prompt1,
116
+ prompt2=XTTSV2.prompt2,
117
+ prefix=XTTSV2.prefix,
118
  )
119
 
120
  if XTTSV2.speed:
 
170
  )
171
 
172
  XTTSV2.stream_chunk_size = request.stream_chunk_size
173
+ XTTSV2.spliter_threshold = request.stream_chunk_size
174
+
175
  XTTSV2.temperature = request.temperature
176
  XTTSV2.speed = request.speed
177
  XTTSV2.length_penalty = request.length_penalty
 
179
  XTTSV2.top_p = request.top_p
180
  XTTSV2.top_k = request.top_k
181
  XTTSV2.enable_text_splitting = request.enable_text_splitting
182
+
183
+ # TODO: checker
184
+ if request.batch_size:
185
+ XTTSV2.batch_size = request.batch_size
186
+ if request.eos:
187
+ XTTSV2.eos = request.eos
188
+ if request.infer_seed:
189
+ XTTSV2.infer_seed = request.infer_seed
190
+ if request.use_decoder:
191
+ XTTSV2.use_decoder = request.use_decoder
192
+ if request.prompt1:
193
+ XTTSV2.prompt1 = request.prompt1
194
+ if request.prompt2:
195
+ XTTSV2.prompt2 = request.prompt2
196
+ if request.prefix:
197
+ XTTSV2.prefix = request.prefix
198
+ if request.spliter_threshold:
199
+ XTTSV2.spliter_threshold = request.spliter_threshold
200
+
201
  return {"message": "Settings successfully applied"}
202
  except Exception as e:
203
  if isinstance(e, HTTPException):
modules/api/utils.py CHANGED
@@ -1,9 +1,8 @@
1
  from pydantic import BaseModel
2
  from typing import Any, Union
3
 
4
- import torch
5
 
6
- from modules.speaker import Speaker, speaker_mgr
7
 
8
 
9
  from modules.data import styles_mgr
@@ -13,18 +12,10 @@ from pydub import AudioSegment
13
  from modules.ssml import merge_prompt
14
 
15
 
16
- from enum import Enum
17
-
18
-
19
  class ParamsTypeError(Exception):
20
  pass
21
 
22
 
23
- class AudioFormat(str, Enum):
24
- mp3 = "mp3"
25
- wav = "wav"
26
-
27
-
28
  class BaseResponse(BaseModel):
29
  message: str
30
  data: Any
@@ -35,7 +26,7 @@ def success_response(data: Any, message: str = "ok") -> BaseResponse:
35
 
36
 
37
  def wav_to_mp3(wav_data, bitrate="48k"):
38
- audio = AudioSegment.from_wav(
39
  wav_data,
40
  )
41
  return audio.export(format="mp3", bitrate=bitrate)
 
1
  from pydantic import BaseModel
2
  from typing import Any, Union
3
 
 
4
 
5
+ from modules.speaker import speaker_mgr
6
 
7
 
8
  from modules.data import styles_mgr
 
12
  from modules.ssml import merge_prompt
13
 
14
 
 
 
 
15
  class ParamsTypeError(Exception):
16
  pass
17
 
18
 
 
 
 
 
 
19
  class BaseResponse(BaseModel):
20
  message: str
21
  data: Any
 
26
 
27
 
28
  def wav_to_mp3(wav_data, bitrate="48k"):
29
+ audio: AudioSegment = AudioSegment.from_wav(
30
  wav_data,
31
  )
32
  return audio.export(format="mp3", bitrate=bitrate)
modules/devices/devices.py CHANGED
@@ -127,6 +127,12 @@ def reset_device():
127
  global dtype_gpt
128
  global dtype_decoder
129
 
 
 
 
 
 
 
130
  if not config.runtime_env_vars.no_half:
131
  dtype = torch.float16
132
  dtype_dvae = torch.float16
@@ -144,7 +150,7 @@ def reset_device():
144
 
145
  logger.info("Using full precision: torch.float32")
146
 
147
- if config.runtime_env_vars.use_cpu == "all":
148
  device = cpu
149
  else:
150
  device = get_optimal_device()
 
127
  global dtype_gpt
128
  global dtype_decoder
129
 
130
+ if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
131
+ logger.warning(
132
+ "Cannot use half precision with CPU, using full precision instead"
133
+ )
134
+ config.runtime_env_vars.no_half = True
135
+
136
  if not config.runtime_env_vars.no_half:
137
  dtype = torch.float16
138
  dtype_dvae = torch.float16
 
150
 
151
  logger.info("Using full precision: torch.float32")
152
 
153
+ if "all" in config.runtime_env_vars.use_cpu:
154
  device = cpu
155
  else:
156
  device = get_optimal_device()
modules/finetune/train_speaker.py CHANGED
@@ -45,9 +45,10 @@ def train_speaker_embeddings(
45
  )
46
  for speaker in dataset.speakers
47
  }
48
- for speaker_embed in speaker_embeds.values():
49
- std, mean = chat.pretrain_models["spk_stat"].chunk(2)
50
- speaker_embed.data = speaker_embed.data * std + mean
 
51
 
52
  SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
53
  AUDIO_EOS_TOKEN_ID = 0
@@ -166,13 +167,13 @@ def train_speaker_embeddings(
166
  audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
167
  )
168
  loss = audio_loss
169
- if train_text:
170
- text_logits = gpt.head_text(text_hidden_states)
171
- text_loss = loss_fn(
172
- text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
173
- )
174
- loss += text_loss
175
- logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
176
 
177
  gpt_gen_mel_specs = decoder_decoder(
178
  audio_hidden_states[:, :-1].transpose(1, 2)
@@ -181,7 +182,12 @@ def train_speaker_embeddings(
181
  loss += 0.01 * mse_loss
182
 
183
  optimizer.zero_grad()
184
- loss.backward()
 
 
 
 
 
185
  torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
186
  optimizer.step()
187
  logger.meters["loss"].update(loss.item(), n=batch_size)
@@ -203,6 +209,7 @@ if __name__ == "__main__":
203
  from modules.speaker import Speaker
204
 
205
  config.runtime_env_vars.no_half = True
 
206
  devices.reset_device()
207
 
208
  parser = argparse.ArgumentParser()
 
45
  )
46
  for speaker in dataset.speakers
47
  }
48
+
49
+ for speaker_embed in speaker_embeds.values():
50
+ std, mean = chat.pretrain_models["spk_stat"].chunk(2)
51
+ speaker_embed.data = speaker_embed.data * std + mean
52
 
53
  SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
54
  AUDIO_EOS_TOKEN_ID = 0
 
167
  audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
168
  )
169
  loss = audio_loss
170
+
171
+ text_logits = gpt.head_text(text_hidden_states)
172
+ text_loss = loss_fn(
173
+ text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
174
+ )
175
+ loss += text_loss
176
+ logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
177
 
178
  gpt_gen_mel_specs = decoder_decoder(
179
  audio_hidden_states[:, :-1].transpose(1, 2)
 
182
  loss += 0.01 * mse_loss
183
 
184
  optimizer.zero_grad()
185
+
186
+ if train_text:
187
+ # just for test
188
+ text_loss.backward()
189
+ else:
190
+ loss.backward()
191
  torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
192
  optimizer.step()
193
  logger.meters["loss"].update(loss.item(), n=batch_size)
 
209
  from modules.speaker import Speaker
210
 
211
  config.runtime_env_vars.no_half = True
212
+ config.runtime_env_vars.use_cpu = []
213
  devices.reset_device()
214
 
215
  parser = argparse.ArgumentParser()
modules/prompts/news_oral_prompt.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 任务要求
2
+ 任务: 新闻稿口播化
3
+
4
+ 你需要将一个新闻稿改写为口语化的口播文本
5
+ 同时,适当的添加一些 附语言 标签为文本增加多样性
6
+
7
+ 目前可以使用的附语言标签如下:
8
+ - `[laugh]`: 表示笑声
9
+ - `[uv_break]`: 表示无声停顿
10
+ - `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
11
+ - `[lbreak]`: 表示一个长停顿一般表示段落结束
12
+
13
+ # 输入
14
+ {{USER_INPUT}}
modules/prompts/podcast_prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ TODO
modules/ssml_parser/SSMLParser.py CHANGED
@@ -1,13 +1,10 @@
1
  from lxml import etree
2
 
3
 
4
- from typing import Any, List, Dict, Union
5
  import logging
6
 
7
- from modules.data import styles_mgr
8
- from modules.speaker import speaker_mgr
9
  from box import Box
10
- import copy
11
 
12
 
13
  class SSMLContext(Box):
 
1
  from lxml import etree
2
 
3
 
4
+ from typing import List, Union
5
  import logging
6
 
 
 
7
  from box import Box
 
8
 
9
 
10
  class SSMLContext(Box):
modules/webui/speaker/speaker_editor.py CHANGED
@@ -25,7 +25,7 @@ def speaker_editor_ui():
25
  spk: Speaker = Speaker.from_file(spk_file)
26
  spk.name = name
27
  spk.gender = gender
28
- spk.desc = desc
29
 
30
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
31
  torch.save(spk, tmp_file)
 
25
  spk: Speaker = Speaker.from_file(spk_file)
26
  spk.name = name
27
  spk.gender = gender
28
+ spk.describe = desc
29
 
30
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
31
  torch.save(spk, tmp_file)
modules/webui/speaker/speaker_merger.py CHANGED
@@ -38,12 +38,8 @@ def merge_spk(
38
  tensor_c = spk_to_tensor(spk_c)
39
  tensor_d = spk_to_tensor(spk_d)
40
 
41
- assert (
42
- tensor_a is not None
43
- or tensor_b is not None
44
- or tensor_c is not None
45
- or tensor_d is not None
46
- ), "At least one speaker should be selected"
47
 
48
  merge_tensor = torch.zeros_like(
49
  tensor_a
 
38
  tensor_c = spk_to_tensor(spk_c)
39
  tensor_d = spk_to_tensor(spk_d)
40
 
41
+ if tensor_a is None and tensor_b is None and tensor_c is None and tensor_d is None:
42
+ raise gr.Error("At least one speaker should be selected")
 
 
 
 
43
 
44
  merge_tensor = torch.zeros_like(
45
  tensor_a