Spaces:
Running
on
Zero
Running
on
Zero
zhzluke96
commited on
Commit
•
1df74c6
1
Parent(s):
3710ae9
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.webui +1 -1
- language/zh-CN.json +3 -0
- modules/ChatTTS/ChatTTS/core.py +1 -1
- modules/SynthesizeSegments.py +49 -16
- modules/api/api_setup.py +5 -3
- modules/api/impl/google_api.py +11 -12
- modules/api/impl/openai_api.py +4 -0
- modules/api/impl/ssml_api.py +17 -3
- modules/api/impl/tts_api.py +3 -0
- modules/api/impl/xtts_v2_api.py +160 -0
- modules/devices/devices.py +1 -1
- modules/finetune/__init__.py +0 -0
- modules/finetune/model/__init__.py +0 -0
- modules/finetune/model/encoder.py +87 -0
- modules/finetune/model/wavenet.py +227 -0
- modules/finetune/train_gpt.py +246 -0
- modules/finetune/train_speaker.py +296 -0
- modules/finetune/utils/__init__.py +0 -0
- modules/finetune/utils/dataset.py +487 -0
- modules/finetune/utils/logger.py +409 -0
- modules/finetune/utils/model.py +19 -0
- modules/finetune/utils/output.py +146 -0
- modules/generate_audio.py +2 -0
- modules/normalization.py +5 -0
- modules/repos_static/resemble_enhance/data/distorter/base.py +2 -1
- modules/repos_static/resemble_enhance/data/distorter/custom.py +8 -3
- modules/repos_static/resemble_enhance/data/distorter/sox.py +32 -8
- modules/repos_static/resemble_enhance/data/utils.py +4 -2
- modules/repos_static/resemble_enhance/denoiser/denoiser.py +2 -1
- modules/repos_static/resemble_enhance/enhancer/download.py +8 -3
- modules/repos_static/resemble_enhance/enhancer/enhancer.py +5 -2
- modules/repos_static/resemble_enhance/enhancer/hparams.py +4 -3
- modules/repos_static/resemble_enhance/enhancer/inference.py +2 -1
- modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py +20 -9
- modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +2 -1
- modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py +34 -7
- modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py +8 -3
- modules/repos_static/resemble_enhance/hparams.py +5 -2
- modules/speaker.py +5 -3
- modules/ssml_parser/SSMLParser.py +34 -18
- modules/synthesize_audio.py +27 -38
- modules/utils/audio.py +5 -1
- modules/webui/app.py +3 -0
- modules/webui/finetune/ProcessMonitor.py +92 -0
- modules/webui/finetune/ft_tab.py +13 -0
- modules/webui/finetune/ft_ui_utils.py +49 -0
- modules/webui/finetune/speaker_ft_tab.py +130 -0
- modules/webui/localization_runtime.py +126 -0
- modules/webui/ssml/podcast_tab.py +2 -65
- modules/webui/ssml/ssml_tab.py +21 -2
.env.webui
CHANGED
@@ -17,5 +17,5 @@ TTS_MAX_LEN=1000
|
|
17 |
SSML_MAX_LEN=3000
|
18 |
MAX_BATCH_SIZE=12
|
19 |
|
20 |
-
V_GIT_TAG="🤗hf(0.
|
21 |
V_GIT_COMMIT=main
|
|
|
17 |
SSML_MAX_LEN=3000
|
18 |
MAX_BATCH_SIZE=12
|
19 |
|
20 |
+
V_GIT_TAG="🤗hf(0.6.1-rc)"
|
21 |
V_GIT_COMMIT=main
|
language/zh-CN.json
CHANGED
@@ -80,6 +80,9 @@
|
|
80 |
"readme": "readme",
|
81 |
"changelog": "changelog",
|
82 |
"💼Speaker file": "💼音色文件",
|
|
|
|
|
|
|
83 |
"TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
|
84 |
"SSML_SPLITER_GUIDE": [
|
85 |
"- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`",
|
|
|
80 |
"readme": "readme",
|
81 |
"changelog": "changelog",
|
82 |
"💼Speaker file": "💼音色文件",
|
83 |
+
"🎛️Spliter": "🎛️分割器配置",
|
84 |
+
"eos": "句尾词",
|
85 |
+
"Spliter Threshold": "分割器阈值",
|
86 |
"TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
|
87 |
"SSML_SPLITER_GUIDE": [
|
88 |
"- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`",
|
modules/ChatTTS/ChatTTS/core.py
CHANGED
@@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code
|
|
17 |
|
18 |
from huggingface_hub import snapshot_download
|
19 |
|
20 |
-
logging.basicConfig(level=logging.
|
21 |
|
22 |
|
23 |
class Chat:
|
|
|
17 |
|
18 |
from huggingface_hub import snapshot_download
|
19 |
|
20 |
+
logging.basicConfig(level=logging.ERROR)
|
21 |
|
22 |
|
23 |
class Chat:
|
modules/SynthesizeSegments.py
CHANGED
@@ -1,8 +1,10 @@
|
|
|
|
1 |
from box import Box
|
2 |
from pydub import AudioSegment
|
3 |
from typing import List, Union
|
4 |
from scipy.io.wavfile import write
|
5 |
import io
|
|
|
6 |
from modules.api.utils import calc_spk_style
|
7 |
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
|
8 |
from modules.utils import rng
|
@@ -56,27 +58,27 @@ def to_number(value, t, default=0):
|
|
56 |
|
57 |
|
58 |
class TTSAudioSegment(Box):
|
59 |
-
text: str
|
60 |
-
temperature: float
|
61 |
-
top_P: float
|
62 |
-
top_K: int
|
63 |
-
spk: int
|
64 |
-
infer_seed: int
|
65 |
-
prompt1: str
|
66 |
-
prompt2: str
|
67 |
-
prefix: str
|
68 |
-
|
69 |
-
_type: str
|
70 |
-
|
71 |
def __init__(self, *args, **kwargs):
|
72 |
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
class SynthesizeSegments:
|
76 |
-
def __init__(self, batch_size: int = 8):
|
77 |
self.batch_size = batch_size
|
78 |
self.batch_default_spk_seed = rng.np_rng()
|
79 |
self.batch_default_infer_seed = rng.np_rng()
|
|
|
|
|
80 |
|
81 |
def segment_to_generate_params(
|
82 |
self, segment: Union[SSMLSegment, SSMLBreak]
|
@@ -85,9 +87,11 @@ class SynthesizeSegments:
|
|
85 |
return TTSAudioSegment(_type="break")
|
86 |
|
87 |
if segment.get("params", None) is not None:
|
88 |
-
|
|
|
|
|
89 |
|
90 |
-
text = segment.get("text", ""
|
91 |
is_end = segment.get("is_end", False)
|
92 |
|
93 |
text = str(text).strip()
|
@@ -156,7 +160,7 @@ class SynthesizeSegments:
|
|
156 |
for i in range(0, len(bucket), self.batch_size):
|
157 |
batch = bucket[i : i + self.batch_size]
|
158 |
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
159 |
-
texts = [params.text for params in param_arr]
|
160 |
|
161 |
params = param_arr[0]
|
162 |
audio_datas = generate_audio.generate_audio_batch(
|
@@ -204,9 +208,38 @@ class SynthesizeSegments:
|
|
204 |
|
205 |
return buckets
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
def synthesize_segments(
|
208 |
self, segments: List[Union[SSMLSegment, SSMLBreak]]
|
209 |
) -> List[AudioSegment]:
|
|
|
210 |
audio_segments = [None] * len(segments)
|
211 |
buckets = self.bucket_segments(segments)
|
212 |
|
|
|
1 |
+
import copy
|
2 |
from box import Box
|
3 |
from pydub import AudioSegment
|
4 |
from typing import List, Union
|
5 |
from scipy.io.wavfile import write
|
6 |
import io
|
7 |
+
from modules.SentenceSplitter import SentenceSplitter
|
8 |
from modules.api.utils import calc_spk_style
|
9 |
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
|
10 |
from modules.utils import rng
|
|
|
58 |
|
59 |
|
60 |
class TTSAudioSegment(Box):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def __init__(self, *args, **kwargs):
|
62 |
super().__init__(*args, **kwargs)
|
63 |
+
self._type = kwargs.get("_type", "voice")
|
64 |
+
self.text = kwargs.get("text", "")
|
65 |
+
self.temperature = kwargs.get("temperature", 0.3)
|
66 |
+
self.top_P = kwargs.get("top_P", 0.5)
|
67 |
+
self.top_K = kwargs.get("top_K", 20)
|
68 |
+
self.spk = kwargs.get("spk", -1)
|
69 |
+
self.infer_seed = kwargs.get("infer_seed", -1)
|
70 |
+
self.prompt1 = kwargs.get("prompt1", "")
|
71 |
+
self.prompt2 = kwargs.get("prompt2", "")
|
72 |
+
self.prefix = kwargs.get("prefix", "")
|
73 |
|
74 |
|
75 |
class SynthesizeSegments:
|
76 |
+
def __init__(self, batch_size: int = 8, eos="", spliter_thr=100):
|
77 |
self.batch_size = batch_size
|
78 |
self.batch_default_spk_seed = rng.np_rng()
|
79 |
self.batch_default_infer_seed = rng.np_rng()
|
80 |
+
self.eos = eos
|
81 |
+
self.spliter_thr = spliter_thr
|
82 |
|
83 |
def segment_to_generate_params(
|
84 |
self, segment: Union[SSMLSegment, SSMLBreak]
|
|
|
87 |
return TTSAudioSegment(_type="break")
|
88 |
|
89 |
if segment.get("params", None) is not None:
|
90 |
+
params = segment.get("params")
|
91 |
+
text = segment.get("text", None) or segment.text or ""
|
92 |
+
return TTSAudioSegment(**params, text=text)
|
93 |
|
94 |
+
text = segment.get("text", None) or segment.text or ""
|
95 |
is_end = segment.get("is_end", False)
|
96 |
|
97 |
text = str(text).strip()
|
|
|
160 |
for i in range(0, len(bucket), self.batch_size):
|
161 |
batch = bucket[i : i + self.batch_size]
|
162 |
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
163 |
+
texts = [params.text + self.eos for params in param_arr]
|
164 |
|
165 |
params = param_arr[0]
|
166 |
audio_datas = generate_audio.generate_audio_batch(
|
|
|
208 |
|
209 |
return buckets
|
210 |
|
211 |
+
def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]):
|
212 |
+
"""
|
213 |
+
将 segments 中的 text 经过 spliter 处理成多个 segments
|
214 |
+
"""
|
215 |
+
spliter = SentenceSplitter(threshold=self.spliter_thr)
|
216 |
+
ret_segments: List[Union[SSMLSegment, SSMLBreak]] = []
|
217 |
+
|
218 |
+
for segment in segments:
|
219 |
+
if isinstance(segment, SSMLBreak):
|
220 |
+
ret_segments.append(segment)
|
221 |
+
continue
|
222 |
+
|
223 |
+
text = segment.text
|
224 |
+
if not text:
|
225 |
+
continue
|
226 |
+
|
227 |
+
sentences = spliter.parse(text)
|
228 |
+
for sentence in sentences:
|
229 |
+
ret_segments.append(
|
230 |
+
SSMLSegment(
|
231 |
+
text=sentence,
|
232 |
+
attrs=segment.attrs.copy(),
|
233 |
+
params=copy.copy(segment.params),
|
234 |
+
)
|
235 |
+
)
|
236 |
+
|
237 |
+
return ret_segments
|
238 |
+
|
239 |
def synthesize_segments(
|
240 |
self, segments: List[Union[SSMLSegment, SSMLBreak]]
|
241 |
) -> List[AudioSegment]:
|
242 |
+
segments = self.split_segments(segments)
|
243 |
audio_segments = [None] * len(segments)
|
244 |
buckets = self.bucket_segments(segments)
|
245 |
|
modules/api/api_setup.py
CHANGED
@@ -18,6 +18,7 @@ from modules.api.impl import (
|
|
18 |
speaker_api,
|
19 |
ping_api,
|
20 |
models_api,
|
|
|
21 |
)
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
@@ -35,6 +36,7 @@ def create_api(app, exclude=[]):
|
|
35 |
google_api.setup(app_mgr)
|
36 |
openai_api.setup(app_mgr)
|
37 |
refiner_api.setup(app_mgr)
|
|
|
38 |
|
39 |
return app_mgr
|
40 |
|
@@ -42,9 +44,9 @@ def create_api(app, exclude=[]):
|
|
42 |
def setup_model_args(parser: argparse.ArgumentParser):
|
43 |
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
44 |
parser.add_argument(
|
45 |
-
"--
|
46 |
action="store_true",
|
47 |
-
help="
|
48 |
)
|
49 |
parser.add_argument(
|
50 |
"--off_tqdm",
|
@@ -82,7 +84,7 @@ def process_model_args(args):
|
|
82 |
compile = env.get_and_update_env(args, "compile", False, bool)
|
83 |
device_id = env.get_and_update_env(args, "device_id", None, str)
|
84 |
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
|
85 |
-
|
86 |
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
|
87 |
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
|
88 |
|
|
|
18 |
speaker_api,
|
19 |
ping_api,
|
20 |
models_api,
|
21 |
+
xtts_v2_api,
|
22 |
)
|
23 |
|
24 |
logger = logging.getLogger(__name__)
|
|
|
36 |
google_api.setup(app_mgr)
|
37 |
openai_api.setup(app_mgr)
|
38 |
refiner_api.setup(app_mgr)
|
39 |
+
xtts_v2_api.setup(app_mgr)
|
40 |
|
41 |
return app_mgr
|
42 |
|
|
|
44 |
def setup_model_args(parser: argparse.ArgumentParser):
|
45 |
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
46 |
parser.add_argument(
|
47 |
+
"--no_half",
|
48 |
action="store_true",
|
49 |
+
help="Disalbe half precision for model inference",
|
50 |
)
|
51 |
parser.add_argument(
|
52 |
"--off_tqdm",
|
|
|
84 |
compile = env.get_and_update_env(args, "compile", False, bool)
|
85 |
device_id = env.get_and_update_env(args, "device_id", None, str)
|
86 |
use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
|
87 |
+
no_half = env.get_and_update_env(args, "no_half", False, bool)
|
88 |
off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
|
89 |
debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
|
90 |
|
modules/api/impl/google_api.py
CHANGED
@@ -13,6 +13,7 @@ from modules.Enhancer.ResembleEnhance import (
|
|
13 |
)
|
14 |
from modules.api.Api import APIManager
|
15 |
from modules.synthesize_audio import synthesize_audio
|
|
|
16 |
from modules.utils.audio import apply_prosody_to_audio_data
|
17 |
from modules.normalization import text_normalize
|
18 |
|
@@ -44,6 +45,9 @@ class VoiceSelectionParams(BaseModel):
|
|
44 |
topK: int = 20
|
45 |
seed: int = 42
|
46 |
|
|
|
|
|
|
|
47 |
|
48 |
class AudioConfig(BaseModel):
|
49 |
audioEncoding: api_utils.AudioFormat = "mp3"
|
@@ -87,6 +91,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
87 |
language_code = voice.languageCode
|
88 |
voice_name = voice.name
|
89 |
infer_seed = voice.seed or 42
|
|
|
90 |
audio_format = audioConfig.audioEncoding or "mp3"
|
91 |
speaking_rate = audioConfig.speakingRate or 1
|
92 |
pitch = audioConfig.pitch or 0
|
@@ -94,11 +99,9 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
94 |
|
95 |
batch_size = audioConfig.batchSize or 1
|
96 |
|
97 |
-
# TODO spliter_threshold
|
98 |
spliter_threshold = audioConfig.spliterThreshold or 100
|
99 |
|
100 |
-
|
101 |
-
sample_rate_hertz = audioConfig.sampleRateHertz or 24000
|
102 |
|
103 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
104 |
|
@@ -137,10 +140,10 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
137 |
prefix=params.get("prefix", ""),
|
138 |
batch_size=batch_size,
|
139 |
spliter_threshold=spliter_threshold,
|
|
|
140 |
)
|
141 |
|
142 |
elif input.ssml:
|
143 |
-
# 处理SSML合成逻辑
|
144 |
parser = create_ssml_parser()
|
145 |
segments = parser.parse(input.ssml)
|
146 |
for seg in segments:
|
@@ -151,17 +154,13 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
151 |
status_code=422, detail="The SSML text is empty or parsing failed."
|
152 |
)
|
153 |
|
154 |
-
synthesize = SynthesizeSegments(
|
|
|
|
|
155 |
audio_segments = synthesize.synthesize_segments(segments)
|
156 |
combined_audio = combine_audio_segments(audio_segments)
|
157 |
|
158 |
-
|
159 |
-
combined_audio.export(buffer, format="wav")
|
160 |
-
|
161 |
-
buffer.seek(0)
|
162 |
-
|
163 |
-
audio_data = buffer.read()
|
164 |
-
|
165 |
else:
|
166 |
raise HTTPException(
|
167 |
status_code=422, detail="Either text or SSML input must be provided."
|
|
|
13 |
)
|
14 |
from modules.api.Api import APIManager
|
15 |
from modules.synthesize_audio import synthesize_audio
|
16 |
+
from modules.utils import audio
|
17 |
from modules.utils.audio import apply_prosody_to_audio_data
|
18 |
from modules.normalization import text_normalize
|
19 |
|
|
|
45 |
topK: int = 20
|
46 |
seed: int = 42
|
47 |
|
48 |
+
# end_of_sentence
|
49 |
+
eos: str = "[uv_break]"
|
50 |
+
|
51 |
|
52 |
class AudioConfig(BaseModel):
|
53 |
audioEncoding: api_utils.AudioFormat = "mp3"
|
|
|
91 |
language_code = voice.languageCode
|
92 |
voice_name = voice.name
|
93 |
infer_seed = voice.seed or 42
|
94 |
+
eos = voice.eos or "[uv_break]"
|
95 |
audio_format = audioConfig.audioEncoding or "mp3"
|
96 |
speaking_rate = audioConfig.speakingRate or 1
|
97 |
pitch = audioConfig.pitch or 0
|
|
|
99 |
|
100 |
batch_size = audioConfig.batchSize or 1
|
101 |
|
|
|
102 |
spliter_threshold = audioConfig.spliterThreshold or 100
|
103 |
|
104 |
+
sample_rate = audioConfig.sampleRateHertz or 24000
|
|
|
105 |
|
106 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
107 |
|
|
|
140 |
prefix=params.get("prefix", ""),
|
141 |
batch_size=batch_size,
|
142 |
spliter_threshold=spliter_threshold,
|
143 |
+
end_of_sentence=eos,
|
144 |
)
|
145 |
|
146 |
elif input.ssml:
|
|
|
147 |
parser = create_ssml_parser()
|
148 |
segments = parser.parse(input.ssml)
|
149 |
for seg in segments:
|
|
|
154 |
status_code=422, detail="The SSML text is empty or parsing failed."
|
155 |
)
|
156 |
|
157 |
+
synthesize = SynthesizeSegments(
|
158 |
+
batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
|
159 |
+
)
|
160 |
audio_segments = synthesize.synthesize_segments(segments)
|
161 |
combined_audio = combine_audio_segments(audio_segments)
|
162 |
|
163 |
+
sample_rate, audio_data = audio.pydub_to_np(combined_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
else:
|
165 |
raise HTTPException(
|
166 |
status_code=422, detail="Either text or SSML input must be provided."
|
modules/api/impl/openai_api.py
CHANGED
@@ -41,6 +41,8 @@ class AudioSpeechRequest(BaseModel):
|
|
41 |
spliter_threshold: float = Field(
|
42 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
43 |
)
|
|
|
|
|
44 |
|
45 |
|
46 |
async def openai_speech_api(
|
@@ -52,6 +54,7 @@ async def openai_speech_api(
|
|
52 |
input_text = request.input
|
53 |
voice = request.voice
|
54 |
style = request.style
|
|
|
55 |
response_format = request.response_format
|
56 |
batch_size = request.batch_size
|
57 |
spliter_threshold = request.spliter_threshold
|
@@ -95,6 +98,7 @@ async def openai_speech_api(
|
|
95 |
prompt1=prompt1,
|
96 |
prompt2=prompt2,
|
97 |
prefix=prefix,
|
|
|
98 |
)
|
99 |
|
100 |
if speed != 1:
|
|
|
41 |
spliter_threshold: float = Field(
|
42 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
43 |
)
|
44 |
+
# end of sentence
|
45 |
+
eos: str = "[uv_break]"
|
46 |
|
47 |
|
48 |
async def openai_speech_api(
|
|
|
54 |
input_text = request.input
|
55 |
voice = request.voice
|
56 |
style = request.style
|
57 |
+
eos = request.eos
|
58 |
response_format = request.response_format
|
59 |
batch_size = request.batch_size
|
60 |
spliter_threshold = request.spliter_threshold
|
|
|
98 |
prompt1=prompt1,
|
99 |
prompt2=prompt2,
|
100 |
prefix=prefix,
|
101 |
+
end_of_sentence=eos,
|
102 |
)
|
103 |
|
104 |
if speed != 1:
|
modules/api/impl/ssml_api.py
CHANGED
@@ -26,8 +26,13 @@ class SSMLRequest(BaseModel):
|
|
26 |
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
27 |
batch_size: int = 4
|
28 |
|
|
|
|
|
29 |
|
30 |
-
|
|
|
|
|
|
|
31 |
request: SSMLRequest = Body(
|
32 |
..., description="JSON body with SSML string and format"
|
33 |
)
|
@@ -36,12 +41,19 @@ async def synthesize_ssml(
|
|
36 |
ssml = request.ssml
|
37 |
format = request.format.lower()
|
38 |
batch_size = request.batch_size
|
|
|
|
|
39 |
|
40 |
if batch_size < 1:
|
41 |
raise HTTPException(
|
42 |
status_code=400, detail="Batch size must be greater than 0."
|
43 |
)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
if not ssml or ssml == "":
|
46 |
raise HTTPException(status_code=400, detail="SSML content is required.")
|
47 |
|
@@ -55,7 +67,9 @@ async def synthesize_ssml(
|
|
55 |
for seg in segments:
|
56 |
seg["text"] = text_normalize(seg["text"], is_end=True)
|
57 |
|
58 |
-
synthesize = SynthesizeSegments(
|
|
|
|
|
59 |
audio_segments = synthesize.synthesize_segments(segments)
|
60 |
combined_audio = combine_audio_segments(audio_segments)
|
61 |
buffer = io.BytesIO()
|
@@ -77,4 +91,4 @@ async def synthesize_ssml(
|
|
77 |
|
78 |
|
79 |
def setup(api_manager: APIManager):
|
80 |
-
api_manager.post("/v1/ssml", response_class=FileResponse)(
|
|
|
26 |
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
27 |
batch_size: int = 4
|
28 |
|
29 |
+
# end of sentence
|
30 |
+
eos: str = "[uv_break]"
|
31 |
|
32 |
+
spliter_thr: int = 100
|
33 |
+
|
34 |
+
|
35 |
+
async def synthesize_ssml_api(
|
36 |
request: SSMLRequest = Body(
|
37 |
..., description="JSON body with SSML string and format"
|
38 |
)
|
|
|
41 |
ssml = request.ssml
|
42 |
format = request.format.lower()
|
43 |
batch_size = request.batch_size
|
44 |
+
eos = request.eos
|
45 |
+
spliter_thr = request.spliter_thr
|
46 |
|
47 |
if batch_size < 1:
|
48 |
raise HTTPException(
|
49 |
status_code=400, detail="Batch size must be greater than 0."
|
50 |
)
|
51 |
|
52 |
+
if spliter_thr < 50:
|
53 |
+
raise HTTPException(
|
54 |
+
status_code=400, detail="Spliter threshold must be greater than 50."
|
55 |
+
)
|
56 |
+
|
57 |
if not ssml or ssml == "":
|
58 |
raise HTTPException(status_code=400, detail="SSML content is required.")
|
59 |
|
|
|
67 |
for seg in segments:
|
68 |
seg["text"] = text_normalize(seg["text"], is_end=True)
|
69 |
|
70 |
+
synthesize = SynthesizeSegments(
|
71 |
+
batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
|
72 |
+
)
|
73 |
audio_segments = synthesize.synthesize_segments(segments)
|
74 |
combined_audio = combine_audio_segments(audio_segments)
|
75 |
buffer = io.BytesIO()
|
|
|
91 |
|
92 |
|
93 |
def setup(api_manager: APIManager):
|
94 |
+
api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml_api)
|
modules/api/impl/tts_api.py
CHANGED
@@ -38,6 +38,7 @@ class TTSParams(BaseModel):
|
|
38 |
prefix: str = Query("", description="Text prefix for inference")
|
39 |
bs: str = Query("8", description="Batch size for inference")
|
40 |
thr: str = Query("100", description="Threshold for sentence spliter")
|
|
|
41 |
|
42 |
|
43 |
async def synthesize_tts(params: TTSParams = Depends()):
|
@@ -87,6 +88,7 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
87 |
prefix = params.prefix or calc_params.get("prefix", params.prefix)
|
88 |
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
|
89 |
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
|
|
|
90 |
|
91 |
batch_size = int(params.bs)
|
92 |
threshold = int(params.thr)
|
@@ -103,6 +105,7 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
103 |
prefix=prefix,
|
104 |
batch_size=batch_size,
|
105 |
spliter_threshold=threshold,
|
|
|
106 |
)
|
107 |
|
108 |
buffer = io.BytesIO()
|
|
|
38 |
prefix: str = Query("", description="Text prefix for inference")
|
39 |
bs: str = Query("8", description="Batch size for inference")
|
40 |
thr: str = Query("100", description="Threshold for sentence spliter")
|
41 |
+
eos: str = Query("", description="End of sentence str")
|
42 |
|
43 |
|
44 |
async def synthesize_tts(params: TTSParams = Depends()):
|
|
|
88 |
prefix = params.prefix or calc_params.get("prefix", params.prefix)
|
89 |
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
|
90 |
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
|
91 |
+
eos = params.eos or ""
|
92 |
|
93 |
batch_size = int(params.bs)
|
94 |
threshold = int(params.thr)
|
|
|
105 |
prefix=prefix,
|
106 |
batch_size=batch_size,
|
107 |
spliter_threshold=threshold,
|
108 |
+
end_of_sentence=eos,
|
109 |
)
|
110 |
|
111 |
buffer = io.BytesIO()
|
modules/api/impl/xtts_v2_api.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from fastapi import HTTPException
|
3 |
+
from fastapi.responses import StreamingResponse
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from modules.api import utils as api_utils
|
6 |
+
from modules.api.Api import APIManager
|
7 |
+
|
8 |
+
import soundfile as sf
|
9 |
+
|
10 |
+
from modules import config
|
11 |
+
from modules.normalization import text_normalize
|
12 |
+
from modules.speaker import speaker_mgr
|
13 |
+
from modules.synthesize_audio import synthesize_audio
|
14 |
+
|
15 |
+
import logging
|
16 |
+
|
17 |
+
from modules.utils.audio import apply_prosody_to_audio_data
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class XTTS_V2_Settings:
|
23 |
+
def __init__(self):
|
24 |
+
self.stream_chunk_size = 100
|
25 |
+
self.temperature = 0.3
|
26 |
+
self.speed = 1
|
27 |
+
self.length_penalty = 0.5
|
28 |
+
self.repetition_penalty = 1.0
|
29 |
+
self.top_p = 0.7
|
30 |
+
self.top_k = 20
|
31 |
+
self.enable_text_splitting = True
|
32 |
+
|
33 |
+
|
34 |
+
class TTSSettingsRequest(BaseModel):
|
35 |
+
stream_chunk_size: int
|
36 |
+
temperature: float
|
37 |
+
speed: float
|
38 |
+
length_penalty: float
|
39 |
+
repetition_penalty: float
|
40 |
+
top_p: float
|
41 |
+
top_k: int
|
42 |
+
enable_text_splitting: bool
|
43 |
+
|
44 |
+
|
45 |
+
class SynthesisRequest(BaseModel):
|
46 |
+
text: str
|
47 |
+
speaker_wav: str
|
48 |
+
language: str
|
49 |
+
|
50 |
+
|
51 |
+
def setup(app: APIManager):
|
52 |
+
XTTSV2 = XTTS_V2_Settings()
|
53 |
+
|
54 |
+
@app.get("/v1/xtts_v2/speakers")
|
55 |
+
async def speakers():
|
56 |
+
spks = speaker_mgr.list_speakers()
|
57 |
+
return [
|
58 |
+
{
|
59 |
+
"name": spk.name,
|
60 |
+
"voice_id": spk.id,
|
61 |
+
# TODO: 也许可以放一个 "/v1/tts" 接口地址在这里
|
62 |
+
"preview_url": "",
|
63 |
+
}
|
64 |
+
for spk in spks
|
65 |
+
]
|
66 |
+
|
67 |
+
@app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse)
|
68 |
+
async def tts_to_audio(request: SynthesisRequest):
|
69 |
+
text = request.text
|
70 |
+
# speaker_wav 就是 speaker id 。。。
|
71 |
+
voice_id = request.speaker_wav
|
72 |
+
language = request.language
|
73 |
+
|
74 |
+
spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
|
75 |
+
voice_id
|
76 |
+
)
|
77 |
+
if spk is None:
|
78 |
+
raise HTTPException(status_code=400, detail="Invalid speaker id")
|
79 |
+
|
80 |
+
text = text_normalize(text, is_end=True)
|
81 |
+
sample_rate, audio_data = synthesize_audio(
|
82 |
+
text=text,
|
83 |
+
temperature=XTTSV2.temperature,
|
84 |
+
# length_penalty=XTTSV2.length_penalty,
|
85 |
+
# repetition_penalty=XTTSV2.repetition_penalty,
|
86 |
+
top_P=XTTSV2.top_p,
|
87 |
+
top_K=XTTSV2.top_k,
|
88 |
+
spk=spk,
|
89 |
+
spliter_threshold=XTTSV2.stream_chunk_size,
|
90 |
+
# TODO 支持设置 batch_size
|
91 |
+
batch_size=4,
|
92 |
+
end_of_sentence="[uv_break]",
|
93 |
+
)
|
94 |
+
|
95 |
+
if XTTSV2.speed:
|
96 |
+
audio_data = apply_prosody_to_audio_data(
|
97 |
+
audio_data,
|
98 |
+
rate=XTTSV2.speed,
|
99 |
+
sr=sample_rate,
|
100 |
+
)
|
101 |
+
|
102 |
+
# to mp3
|
103 |
+
buffer = io.BytesIO()
|
104 |
+
sf.write(buffer, audio_data, sample_rate, format="wav")
|
105 |
+
buffer.seek(0)
|
106 |
+
|
107 |
+
buffer = api_utils.wav_to_mp3(buffer)
|
108 |
+
|
109 |
+
return StreamingResponse(buffer, media_type="audio/mpeg")
|
110 |
+
|
111 |
+
@app.get("/v1/xtts_v2/tts_stream")
|
112 |
+
async def tts_stream():
|
113 |
+
raise HTTPException(status_code=501, detail="Not implemented")
|
114 |
+
|
115 |
+
@app.post("/v1/xtts_v2/set_tts_settings")
|
116 |
+
async def set_tts_settings(request: TTSSettingsRequest):
|
117 |
+
try:
|
118 |
+
if request.stream_chunk_size < 50:
|
119 |
+
raise HTTPException(
|
120 |
+
status_code=400, detail="stream_chunk_size must be greater than 0"
|
121 |
+
)
|
122 |
+
if request.temperature < 0:
|
123 |
+
raise HTTPException(
|
124 |
+
status_code=400, detail="temperature must be greater than 0"
|
125 |
+
)
|
126 |
+
if request.speed < 0:
|
127 |
+
raise HTTPException(
|
128 |
+
status_code=400, detail="speed must be greater than 0"
|
129 |
+
)
|
130 |
+
if request.length_penalty < 0:
|
131 |
+
raise HTTPException(
|
132 |
+
status_code=400, detail="length_penalty must be greater than 0"
|
133 |
+
)
|
134 |
+
if request.repetition_penalty < 0:
|
135 |
+
raise HTTPException(
|
136 |
+
status_code=400, detail="repetition_penalty must be greater than 0"
|
137 |
+
)
|
138 |
+
if request.top_p < 0:
|
139 |
+
raise HTTPException(
|
140 |
+
status_code=400, detail="top_p must be greater than 0"
|
141 |
+
)
|
142 |
+
if request.top_k < 0:
|
143 |
+
raise HTTPException(
|
144 |
+
status_code=400, detail="top_k must be greater than 0"
|
145 |
+
)
|
146 |
+
|
147 |
+
XTTSV2.stream_chunk_size = request.stream_chunk_size
|
148 |
+
XTTSV2.temperature = request.temperature
|
149 |
+
XTTSV2.speed = request.speed
|
150 |
+
XTTSV2.length_penalty = request.length_penalty
|
151 |
+
XTTSV2.repetition_penalty = request.repetition_penalty
|
152 |
+
XTTSV2.top_p = request.top_p
|
153 |
+
XTTSV2.top_k = request.top_k
|
154 |
+
XTTSV2.enable_text_splitting = request.enable_text_splitting
|
155 |
+
return {"message": "Settings successfully applied"}
|
156 |
+
except Exception as e:
|
157 |
+
if isinstance(e, HTTPException):
|
158 |
+
raise e
|
159 |
+
logger.error(e)
|
160 |
+
raise HTTPException(status_code=500, detail=str(e))
|
modules/devices/devices.py
CHANGED
@@ -127,7 +127,7 @@ def reset_device():
|
|
127 |
global dtype_gpt
|
128 |
global dtype_decoder
|
129 |
|
130 |
-
if config.runtime_env_vars.
|
131 |
dtype = torch.float16
|
132 |
dtype_dvae = torch.float16
|
133 |
dtype_vocos = torch.float16
|
|
|
127 |
global dtype_gpt
|
128 |
global dtype_decoder
|
129 |
|
130 |
+
if not config.runtime_env_vars.no_half:
|
131 |
dtype = torch.float16
|
132 |
dtype_dvae = torch.float16
|
133 |
dtype_vocos = torch.float16
|
modules/finetune/__init__.py
ADDED
File without changes
|
modules/finetune/model/__init__.py
ADDED
File without changes
|
modules/finetune/model/encoder.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder
|
5 |
+
|
6 |
+
from .wavenet import WaveNet
|
7 |
+
|
8 |
+
|
9 |
+
def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]:
|
10 |
+
return {
|
11 |
+
"idim": decoder.conv_out.out_channels,
|
12 |
+
"odim": decoder.conv_in[0].in_channels,
|
13 |
+
"n_layer": len(decoder.decoder_block),
|
14 |
+
"bn_dim": decoder.conv_in[0].out_channels,
|
15 |
+
"hidden": decoder.conv_in[2].out_channels,
|
16 |
+
"kernel": decoder.decoder_block[0].dwconv.kernel_size[0],
|
17 |
+
"dilation": decoder.decoder_block[0].dwconv.dilation[0],
|
18 |
+
"down": decoder.up,
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
class DVAEEncoder(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
idim: int,
|
26 |
+
odim: int,
|
27 |
+
n_layer: int = 12,
|
28 |
+
bn_dim: int = 64,
|
29 |
+
hidden: int = 256,
|
30 |
+
kernel: int = 7,
|
31 |
+
dilation: int = 2,
|
32 |
+
down: bool = False,
|
33 |
+
) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.wavenet = WaveNet(
|
36 |
+
input_channels=100,
|
37 |
+
residual_channels=idim,
|
38 |
+
residual_layers=20,
|
39 |
+
dilation_cycle=4,
|
40 |
+
)
|
41 |
+
self.conv_in_transpose = nn.ConvTranspose1d(
|
42 |
+
idim, hidden, kernel_size=1, bias=False
|
43 |
+
)
|
44 |
+
# nn.Sequential(
|
45 |
+
# nn.ConvTranspose1d(100, idim, 3, 1, 1, bias=False),
|
46 |
+
# nn.ConvTranspose1d(idim, hidden, kernel_size=1, bias=False)
|
47 |
+
# )
|
48 |
+
self.encoder_block = nn.ModuleList(
|
49 |
+
[
|
50 |
+
ConvNeXtBlock(
|
51 |
+
hidden,
|
52 |
+
hidden * 4,
|
53 |
+
kernel,
|
54 |
+
dilation,
|
55 |
+
)
|
56 |
+
for _ in range(n_layer)
|
57 |
+
]
|
58 |
+
)
|
59 |
+
self.conv_out_transpose = nn.Sequential(
|
60 |
+
nn.Conv1d(hidden, bn_dim, 3, 1, 1),
|
61 |
+
nn.GELU(),
|
62 |
+
nn.Conv1d(bn_dim, odim, 3, 1, 1),
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self,
|
67 |
+
audio_mel_specs: torch.Tensor, # (batch_size, audio_len*2, 100)
|
68 |
+
audio_attention_mask: torch.Tensor, # (batch_size, audio_len)
|
69 |
+
conditioning=None,
|
70 |
+
) -> torch.Tensor:
|
71 |
+
mel_attention_mask = (
|
72 |
+
audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1)
|
73 |
+
)
|
74 |
+
x: torch.Tensor = self.wavenet(
|
75 |
+
audio_mel_specs.transpose(1, 2)
|
76 |
+
) # (batch_size, idim, audio_len*2)
|
77 |
+
x = x * mel_attention_mask.unsqueeze(1)
|
78 |
+
x = self.conv_in_transpose(x) # (batch_size, hidden, audio_len*2)
|
79 |
+
for f in self.encoder_block:
|
80 |
+
x = f(x, conditioning)
|
81 |
+
x = self.conv_out_transpose(x) # (batch_size, odim, audio_len*2)
|
82 |
+
x = (
|
83 |
+
x.view(x.size(0), x.size(1), 2, x.size(2) // 2)
|
84 |
+
.permute(0, 3, 1, 2)
|
85 |
+
.flatten(2)
|
86 |
+
)
|
87 |
+
return x # (batch_size, audio_len, audio_dim=odim*2)
|
modules/finetune/model/wavenet.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/vqgan/modules/wavenet.py"""
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class Mish(nn.Module):
|
12 |
+
def forward(self, x):
|
13 |
+
return x * torch.tanh(F.softplus(x))
|
14 |
+
|
15 |
+
|
16 |
+
class DiffusionEmbedding(nn.Module):
|
17 |
+
"""Diffusion Step Embedding"""
|
18 |
+
|
19 |
+
def __init__(self, d_denoiser):
|
20 |
+
super(DiffusionEmbedding, self).__init__()
|
21 |
+
self.dim = d_denoiser
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
device = x.device
|
25 |
+
half_dim = self.dim // 2
|
26 |
+
emb = math.log(10000) / (half_dim - 1)
|
27 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
28 |
+
emb = x[:, None] * emb[None, :]
|
29 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
class LinearNorm(nn.Module):
|
34 |
+
"""LinearNorm Projection"""
|
35 |
+
|
36 |
+
def __init__(self, in_features, out_features, bias=False):
|
37 |
+
super(LinearNorm, self).__init__()
|
38 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
39 |
+
|
40 |
+
nn.init.xavier_uniform_(self.linear.weight)
|
41 |
+
if bias:
|
42 |
+
nn.init.constant_(self.linear.bias, 0.0)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.linear(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class ConvNorm(nn.Module):
|
50 |
+
"""1D Convolution"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_channels,
|
55 |
+
out_channels,
|
56 |
+
kernel_size=1,
|
57 |
+
stride=1,
|
58 |
+
padding=None,
|
59 |
+
dilation=1,
|
60 |
+
bias=True,
|
61 |
+
w_init_gain="linear",
|
62 |
+
):
|
63 |
+
super(ConvNorm, self).__init__()
|
64 |
+
|
65 |
+
if padding is None:
|
66 |
+
assert kernel_size % 2 == 1
|
67 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
68 |
+
|
69 |
+
self.conv = nn.Conv1d(
|
70 |
+
in_channels,
|
71 |
+
out_channels,
|
72 |
+
kernel_size=kernel_size,
|
73 |
+
stride=stride,
|
74 |
+
padding=padding,
|
75 |
+
dilation=dilation,
|
76 |
+
bias=bias,
|
77 |
+
)
|
78 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
79 |
+
|
80 |
+
def forward(self, signal):
|
81 |
+
conv_signal = self.conv(signal)
|
82 |
+
|
83 |
+
return conv_signal
|
84 |
+
|
85 |
+
|
86 |
+
class ResidualBlock(nn.Module):
|
87 |
+
"""Residual Block"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
residual_channels,
|
92 |
+
use_linear_bias=False,
|
93 |
+
dilation=1,
|
94 |
+
condition_channels=None,
|
95 |
+
):
|
96 |
+
super(ResidualBlock, self).__init__()
|
97 |
+
self.conv_layer = ConvNorm(
|
98 |
+
residual_channels,
|
99 |
+
2 * residual_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=dilation,
|
103 |
+
dilation=dilation,
|
104 |
+
)
|
105 |
+
|
106 |
+
if condition_channels is not None:
|
107 |
+
self.diffusion_projection = LinearNorm(
|
108 |
+
residual_channels, residual_channels, use_linear_bias
|
109 |
+
)
|
110 |
+
self.condition_projection = ConvNorm(
|
111 |
+
condition_channels, 2 * residual_channels, kernel_size=1
|
112 |
+
)
|
113 |
+
|
114 |
+
self.output_projection = ConvNorm(
|
115 |
+
residual_channels, 2 * residual_channels, kernel_size=1
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(self, x, condition=None, diffusion_step=None):
|
119 |
+
y = x
|
120 |
+
|
121 |
+
if diffusion_step is not None:
|
122 |
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
123 |
+
y = y + diffusion_step
|
124 |
+
|
125 |
+
y = self.conv_layer(y)
|
126 |
+
|
127 |
+
if condition is not None:
|
128 |
+
condition = self.condition_projection(condition)
|
129 |
+
y = y + condition
|
130 |
+
|
131 |
+
gate, filter = torch.chunk(y, 2, dim=1)
|
132 |
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
133 |
+
|
134 |
+
y = self.output_projection(y)
|
135 |
+
residual, skip = torch.chunk(y, 2, dim=1)
|
136 |
+
|
137 |
+
return (x + residual) / math.sqrt(2.0), skip
|
138 |
+
|
139 |
+
|
140 |
+
class WaveNet(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
input_channels: Optional[int] = None,
|
144 |
+
output_channels: Optional[int] = None,
|
145 |
+
residual_channels: int = 512,
|
146 |
+
residual_layers: int = 20,
|
147 |
+
dilation_cycle: Optional[int] = 4,
|
148 |
+
is_diffusion: bool = False,
|
149 |
+
condition_channels: Optional[int] = None,
|
150 |
+
):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
# Input projection
|
154 |
+
self.input_projection = None
|
155 |
+
if input_channels is not None and input_channels != residual_channels:
|
156 |
+
self.input_projection = ConvNorm(
|
157 |
+
input_channels, residual_channels, kernel_size=1
|
158 |
+
)
|
159 |
+
|
160 |
+
if input_channels is None:
|
161 |
+
input_channels = residual_channels
|
162 |
+
|
163 |
+
self.input_channels = input_channels
|
164 |
+
|
165 |
+
# Residual layers
|
166 |
+
self.residual_layers = nn.ModuleList(
|
167 |
+
[
|
168 |
+
ResidualBlock(
|
169 |
+
residual_channels=residual_channels,
|
170 |
+
use_linear_bias=False,
|
171 |
+
dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
|
172 |
+
condition_channels=condition_channels,
|
173 |
+
)
|
174 |
+
for i in range(residual_layers)
|
175 |
+
]
|
176 |
+
)
|
177 |
+
|
178 |
+
# Skip projection
|
179 |
+
self.skip_projection = ConvNorm(
|
180 |
+
residual_channels, residual_channels, kernel_size=1
|
181 |
+
)
|
182 |
+
|
183 |
+
# Output projection
|
184 |
+
self.output_projection = None
|
185 |
+
if output_channels is not None and output_channels != residual_channels:
|
186 |
+
self.output_projection = ConvNorm(
|
187 |
+
residual_channels, output_channels, kernel_size=1
|
188 |
+
)
|
189 |
+
|
190 |
+
if is_diffusion:
|
191 |
+
self.diffusion_embedding = DiffusionEmbedding(residual_channels)
|
192 |
+
self.mlp = nn.Sequential(
|
193 |
+
LinearNorm(residual_channels, residual_channels * 4, False),
|
194 |
+
Mish(),
|
195 |
+
LinearNorm(residual_channels * 4, residual_channels, False),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.apply(self._init_weights)
|
199 |
+
|
200 |
+
def _init_weights(self, m):
|
201 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
202 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
203 |
+
if getattr(m, "bias", None) is not None:
|
204 |
+
nn.init.constant_(m.bias, 0)
|
205 |
+
|
206 |
+
def forward(self, x, t=None, condition=None):
|
207 |
+
if self.input_projection is not None:
|
208 |
+
x = self.input_projection(x)
|
209 |
+
x = F.silu(x)
|
210 |
+
|
211 |
+
if t is not None:
|
212 |
+
t = self.diffusion_embedding(t)
|
213 |
+
t = self.mlp(t)
|
214 |
+
|
215 |
+
skip = []
|
216 |
+
for layer in self.residual_layers:
|
217 |
+
x, skip_connection = layer(x, condition, t)
|
218 |
+
skip.append(skip_connection)
|
219 |
+
|
220 |
+
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
|
221 |
+
x = self.skip_projection(x)
|
222 |
+
|
223 |
+
if self.output_projection is not None:
|
224 |
+
x = F.silu(x)
|
225 |
+
x = self.output_projection(x)
|
226 |
+
|
227 |
+
return x
|
modules/finetune/train_gpt.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import peft
|
5 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
6 |
+
from utils.dataset import AudioCollator
|
7 |
+
from utils.logger import MetricLogger
|
8 |
+
from utils.output import ansi, get_ansi_len, output_iter
|
9 |
+
|
10 |
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
11 |
+
|
12 |
+
|
13 |
+
def train_gpt_lora(
|
14 |
+
chat,
|
15 |
+
dataset,
|
16 |
+
decoder_encoder,
|
17 |
+
dvae_encoder,
|
18 |
+
batch_size=16,
|
19 |
+
epochs=10,
|
20 |
+
train_text=True,
|
21 |
+
speaker_embeds=None,
|
22 |
+
lora_r=8,
|
23 |
+
lora_alpha=16,
|
24 |
+
):
|
25 |
+
if speaker_embeds is None:
|
26 |
+
speaker_embeds = {}
|
27 |
+
|
28 |
+
tokenizer = chat.pretrain_models["tokenizer"]
|
29 |
+
decoder_decoder = chat.pretrain_models["decoder"]
|
30 |
+
decoder_decoder.eval().requires_grad_(False)
|
31 |
+
decoder_encoder.to(device=dataset.device).eval().requires_grad_(False)
|
32 |
+
dvae_decoder = chat.pretrain_models["dvae"]
|
33 |
+
dvae_decoder.eval().requires_grad_(False)
|
34 |
+
dvae_encoder.to(device=dataset.device).eval().requires_grad_(False)
|
35 |
+
|
36 |
+
gpt = chat.pretrain_models["gpt"]
|
37 |
+
gpt.train().requires_grad_()
|
38 |
+
|
39 |
+
# Add LoRA to GPT model
|
40 |
+
lora_config = peft.LoraConfig(r=lora_r, lora_alpha=lora_alpha)
|
41 |
+
gpt.gpt = peft.get_peft_model(gpt.gpt, lora_config)
|
42 |
+
|
43 |
+
speaker_embeds = {
|
44 |
+
speaker: torch.randn(768, device=dataset.device, requires_grad=True)
|
45 |
+
for speaker in dataset.speakers
|
46 |
+
} | speaker_embeds
|
47 |
+
|
48 |
+
for speaker_embed in speaker_embeds.values():
|
49 |
+
std, mean = chat.pretrain_models["spk_stat"].chunk(2)
|
50 |
+
speaker_embed.data = speaker_embed.data * std + mean
|
51 |
+
|
52 |
+
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
53 |
+
AUDIO_EOS_TOKEN_ID = 0
|
54 |
+
AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
|
55 |
+
|
56 |
+
train_params = list(gpt.parameters()) + list(speaker_embeds.values())
|
57 |
+
optimizer = torch.optim.Adam(
|
58 |
+
gpt.parameters(), lr=1e-3, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
|
59 |
+
)
|
60 |
+
optimizer.add_param_group({"params": speaker_embeds.values(), "lr": 1e-1})
|
61 |
+
|
62 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
63 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
|
64 |
+
|
65 |
+
loader = torch.utils.data.DataLoader(
|
66 |
+
dataset,
|
67 |
+
batch_size=batch_size,
|
68 |
+
shuffle=True,
|
69 |
+
collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
|
70 |
+
)
|
71 |
+
logger = MetricLogger()
|
72 |
+
logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
|
73 |
+
|
74 |
+
for _epoch in range(epochs):
|
75 |
+
_epoch += 1
|
76 |
+
logger.reset()
|
77 |
+
header = "{blue_light}{0}: {1}{reset}".format(
|
78 |
+
"Epoch", output_iter(_epoch, epochs), **ansi
|
79 |
+
)
|
80 |
+
header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
|
81 |
+
iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
|
82 |
+
|
83 |
+
for batch in iterator:
|
84 |
+
speakers = batch["speaker"]
|
85 |
+
text_input_ids = batch["text_input_ids"]
|
86 |
+
text_attention_mask = batch["text_attention_mask"]
|
87 |
+
audio_mel_specs = batch["audio_mel_specs"]
|
88 |
+
audio_attention_mask = batch["audio_attention_mask"]
|
89 |
+
|
90 |
+
batch_size, text_len = text_attention_mask.size()
|
91 |
+
|
92 |
+
dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
|
93 |
+
_, dvae_audio_input_ids = quantize(
|
94 |
+
dvae_decoder.vq_layer.quantizer, dvae_audio_latents
|
95 |
+
)
|
96 |
+
dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
|
97 |
+
|
98 |
+
extended_audio_attention_mask = torch.cat(
|
99 |
+
[
|
100 |
+
audio_attention_mask,
|
101 |
+
torch.zeros(
|
102 |
+
(batch_size, 1),
|
103 |
+
dtype=audio_attention_mask.dtype,
|
104 |
+
device=audio_attention_mask.device,
|
105 |
+
),
|
106 |
+
],
|
107 |
+
dim=1,
|
108 |
+
)
|
109 |
+
extended_audio_input_ids = torch.cat(
|
110 |
+
[
|
111 |
+
dvae_audio_input_ids,
|
112 |
+
AUDIO_PAD_TOKEN_ID
|
113 |
+
* torch.ones(
|
114 |
+
(batch_size, 1, gpt.num_vq),
|
115 |
+
dtype=dvae_audio_input_ids.dtype,
|
116 |
+
device=dvae_audio_input_ids.device,
|
117 |
+
),
|
118 |
+
],
|
119 |
+
dim=1,
|
120 |
+
)
|
121 |
+
|
122 |
+
indices = audio_attention_mask.int().sum(dim=1)
|
123 |
+
for i in range(batch_size):
|
124 |
+
extended_audio_attention_mask[i, indices[i]] = 1
|
125 |
+
extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
|
126 |
+
|
127 |
+
input_ids = torch.cat(
|
128 |
+
[
|
129 |
+
text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
|
130 |
+
extended_audio_input_ids,
|
131 |
+
],
|
132 |
+
dim=1,
|
133 |
+
)
|
134 |
+
attention_mask = torch.cat(
|
135 |
+
[text_attention_mask, extended_audio_attention_mask], dim=1
|
136 |
+
)
|
137 |
+
text_mask = torch.cat(
|
138 |
+
[
|
139 |
+
torch.ones_like(text_attention_mask, dtype=bool),
|
140 |
+
torch.zeros_like(extended_audio_attention_mask, dtype=bool),
|
141 |
+
],
|
142 |
+
dim=1,
|
143 |
+
)
|
144 |
+
labels = input_ids.clone()
|
145 |
+
labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
|
146 |
+
|
147 |
+
inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
|
148 |
+
|
149 |
+
indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
|
150 |
+
for i, speaker in enumerate(speakers):
|
151 |
+
inputs_embeds[i, indices[i]] = torch.nn.functional.normalize(
|
152 |
+
speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
|
153 |
+
p=2.0,
|
154 |
+
dim=-1,
|
155 |
+
eps=1e-12,
|
156 |
+
).unsqueeze(0)
|
157 |
+
|
158 |
+
outputs = gpt.gpt.forward(
|
159 |
+
inputs_embeds=inputs_embeds, attention_mask=attention_mask
|
160 |
+
)
|
161 |
+
hidden_states = outputs.last_hidden_state
|
162 |
+
text_hidden_states = hidden_states[:, : text_len - 1]
|
163 |
+
audio_hidden_states = hidden_states[:, text_len - 1 : -1]
|
164 |
+
|
165 |
+
audio_logits = torch.stack(
|
166 |
+
[gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
|
167 |
+
dim=2,
|
168 |
+
)
|
169 |
+
audio_loss = loss_fn(
|
170 |
+
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
|
171 |
+
)
|
172 |
+
loss = audio_loss
|
173 |
+
|
174 |
+
if train_text:
|
175 |
+
text_logits = gpt.head_text(text_hidden_states)
|
176 |
+
text_loss = loss_fn(
|
177 |
+
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
|
178 |
+
)
|
179 |
+
loss += text_loss
|
180 |
+
logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
|
181 |
+
|
182 |
+
gpt_gen_mel_specs = decoder_decoder(
|
183 |
+
audio_hidden_states[:, :-1].transpose(1, 2)
|
184 |
+
).transpose(1, 2)
|
185 |
+
mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
|
186 |
+
loss += 0.01 * mse_loss
|
187 |
+
|
188 |
+
optimizer.zero_grad()
|
189 |
+
loss.backward()
|
190 |
+
torch.nn.utils.clip_grad_norm_(train_params, 1.0)
|
191 |
+
optimizer.step()
|
192 |
+
|
193 |
+
logger.meters["loss"].update(loss.item(), n=batch_size)
|
194 |
+
logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
|
195 |
+
logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
|
196 |
+
|
197 |
+
lr_scheduler.step()
|
198 |
+
optimizer.zero_grad()
|
199 |
+
return speaker_embeds
|
200 |
+
|
201 |
+
|
202 |
+
# Example usage
|
203 |
+
def main():
|
204 |
+
# Load necessary models and data paths
|
205 |
+
chat = ChatTTS.Chat()
|
206 |
+
chat.load_models()
|
207 |
+
dataset = XzListTar(
|
208 |
+
root="data/all.list",
|
209 |
+
tokenizer=chat.pretrain_models["tokenizer"],
|
210 |
+
vocos_model=chat.pretrain_models["vocos"],
|
211 |
+
tar_path="data/Xz.tar",
|
212 |
+
tar_in_memory=True,
|
213 |
+
process_ahead=True,
|
214 |
+
)
|
215 |
+
|
216 |
+
decoder_encoder = DVAEEncoder(
|
217 |
+
**get_encoder_config(chat.pretrain_models["decoder"].decoder)
|
218 |
+
)
|
219 |
+
dvae_encoder = DVAEEncoder(
|
220 |
+
**get_encoder_config(chat.pretrain_models["dvae"].decoder)
|
221 |
+
)
|
222 |
+
|
223 |
+
# Train GPT with LoRA
|
224 |
+
speaker_embeds = train_gpt_lora(
|
225 |
+
chat=chat,
|
226 |
+
dataset=dataset,
|
227 |
+
decoder_encoder=decoder_encoder,
|
228 |
+
dvae_encoder=dvae_encoder,
|
229 |
+
batch_size=32,
|
230 |
+
epochs=10,
|
231 |
+
train_text=True,
|
232 |
+
lora_r=8,
|
233 |
+
lora_alpha=16,
|
234 |
+
)
|
235 |
+
|
236 |
+
# Save LoRA parameters and embeddings
|
237 |
+
lora_save_path = "./saved_models/gpt_lora.pth"
|
238 |
+
peft.save_pretrained(gpt.gpt, lora_save_path)
|
239 |
+
np.savez(
|
240 |
+
"./saved_models/speaker_embeds.npz",
|
241 |
+
**{k: v.cpu().numpy() for k, v in speaker_embeds.items()}
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
if __name__ == "__main__":
|
246 |
+
main()
|
modules/finetune/train_speaker.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import transformers
|
4 |
+
|
5 |
+
from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
|
6 |
+
from modules.finetune.utils.output import get_ansi_len, output_iter, ansi
|
7 |
+
from .utils.logger import MetricLogger
|
8 |
+
from .utils.dataset import AudioCollator, XzListTar
|
9 |
+
from .utils.model import quantize
|
10 |
+
|
11 |
+
IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
|
12 |
+
|
13 |
+
|
14 |
+
def train_speaker_embeddings(
|
15 |
+
chat,
|
16 |
+
dataset,
|
17 |
+
gpt,
|
18 |
+
batch_size=16,
|
19 |
+
epochs=10,
|
20 |
+
train_text=True,
|
21 |
+
speaker_embeds=None,
|
22 |
+
):
|
23 |
+
tokenizer = chat.pretrain_models["tokenizer"]
|
24 |
+
|
25 |
+
decoder_decoder = chat.pretrain_models["decoder"]
|
26 |
+
decoder_decoder.eval().requires_grad_(False)
|
27 |
+
decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to(
|
28 |
+
device=dataset.device
|
29 |
+
)
|
30 |
+
decoder_encoder.eval().requires_grad_(False)
|
31 |
+
|
32 |
+
dvae_decoder = chat.pretrain_models["dvae"]
|
33 |
+
dvae_decoder.eval().requires_grad_(False)
|
34 |
+
dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to(
|
35 |
+
device=dataset.device
|
36 |
+
)
|
37 |
+
dvae_encoder.eval().requires_grad_(False)
|
38 |
+
|
39 |
+
if speaker_embeds is None:
|
40 |
+
speaker_embeds = {
|
41 |
+
speaker: torch.randn(
|
42 |
+
768,
|
43 |
+
device=dataset.device,
|
44 |
+
requires_grad=True,
|
45 |
+
)
|
46 |
+
for speaker in dataset.speakers
|
47 |
+
}
|
48 |
+
for speaker_embed in speaker_embeds.values():
|
49 |
+
std, mean = chat.pretrain_models["spk_stat"].chunk(2)
|
50 |
+
speaker_embed.data = speaker_embed.data * std + mean
|
51 |
+
|
52 |
+
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
53 |
+
AUDIO_EOS_TOKEN_ID = 0
|
54 |
+
AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
|
55 |
+
|
56 |
+
optimizer = torch.optim.Adam(
|
57 |
+
speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
|
58 |
+
)
|
59 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
60 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
|
61 |
+
|
62 |
+
loader = torch.utils.data.DataLoader(
|
63 |
+
dataset,
|
64 |
+
batch_size=batch_size,
|
65 |
+
shuffle=True,
|
66 |
+
collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
|
67 |
+
)
|
68 |
+
logger = MetricLogger()
|
69 |
+
logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
|
70 |
+
|
71 |
+
for _epoch in range(epochs):
|
72 |
+
_epoch += 1
|
73 |
+
logger.reset()
|
74 |
+
header = "{blue_light}{0}: {1}{reset}".format(
|
75 |
+
"Epoch", output_iter(_epoch, epochs), **ansi
|
76 |
+
)
|
77 |
+
header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
|
78 |
+
iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
|
79 |
+
|
80 |
+
for batch in iterator:
|
81 |
+
speakers = batch["speaker"]
|
82 |
+
text_input_ids = batch["text_input_ids"]
|
83 |
+
text_attention_mask = batch["text_attention_mask"]
|
84 |
+
audio_mel_specs = batch["audio_mel_specs"]
|
85 |
+
audio_attention_mask = batch["audio_attention_mask"]
|
86 |
+
|
87 |
+
batch_size, text_len = text_attention_mask.size()
|
88 |
+
|
89 |
+
dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
|
90 |
+
_, dvae_audio_input_ids = quantize(
|
91 |
+
dvae_decoder.vq_layer.quantizer, dvae_audio_latents
|
92 |
+
)
|
93 |
+
dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
|
94 |
+
|
95 |
+
extended_audio_attention_mask = torch.cat(
|
96 |
+
[
|
97 |
+
audio_attention_mask,
|
98 |
+
torch.zeros(
|
99 |
+
(batch_size, 1),
|
100 |
+
dtype=audio_attention_mask.dtype,
|
101 |
+
device=audio_attention_mask.device,
|
102 |
+
),
|
103 |
+
],
|
104 |
+
dim=1,
|
105 |
+
)
|
106 |
+
extended_audio_input_ids = torch.cat(
|
107 |
+
[
|
108 |
+
dvae_audio_input_ids,
|
109 |
+
AUDIO_PAD_TOKEN_ID
|
110 |
+
* torch.ones(
|
111 |
+
(batch_size, 1, gpt.num_vq),
|
112 |
+
dtype=dvae_audio_input_ids.dtype,
|
113 |
+
device=dvae_audio_input_ids.device,
|
114 |
+
),
|
115 |
+
],
|
116 |
+
dim=1,
|
117 |
+
)
|
118 |
+
indices = audio_attention_mask.int().sum(dim=1)
|
119 |
+
for i in range(batch_size):
|
120 |
+
extended_audio_attention_mask[i, indices[i]] = 1
|
121 |
+
extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
|
122 |
+
|
123 |
+
input_ids = torch.cat(
|
124 |
+
[
|
125 |
+
text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
|
126 |
+
extended_audio_input_ids,
|
127 |
+
],
|
128 |
+
dim=1,
|
129 |
+
)
|
130 |
+
attention_mask = torch.cat(
|
131 |
+
[text_attention_mask, extended_audio_attention_mask], dim=1
|
132 |
+
)
|
133 |
+
text_mask = torch.cat(
|
134 |
+
[
|
135 |
+
torch.ones_like(text_attention_mask, dtype=bool),
|
136 |
+
torch.zeros_like(extended_audio_attention_mask, dtype=bool),
|
137 |
+
],
|
138 |
+
dim=1,
|
139 |
+
)
|
140 |
+
|
141 |
+
labels = input_ids.clone()
|
142 |
+
labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
|
143 |
+
|
144 |
+
inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
|
145 |
+
|
146 |
+
indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
|
147 |
+
for i, speaker in enumerate(speakers):
|
148 |
+
inputs_embeds[i, indices[i]] = F.normalize(
|
149 |
+
speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
|
150 |
+
p=2.0,
|
151 |
+
dim=-1,
|
152 |
+
eps=1e-12,
|
153 |
+
).unsqueeze(0)
|
154 |
+
outputs = gpt.gpt.forward(
|
155 |
+
inputs_embeds=inputs_embeds, attention_mask=attention_mask
|
156 |
+
)
|
157 |
+
hidden_states = outputs.last_hidden_state
|
158 |
+
text_hidden_states = hidden_states[:, : text_len - 1]
|
159 |
+
audio_hidden_states = hidden_states[:, text_len - 1 : -1]
|
160 |
+
|
161 |
+
audio_logits = torch.stack(
|
162 |
+
[gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
|
163 |
+
dim=2,
|
164 |
+
)
|
165 |
+
audio_loss = loss_fn(
|
166 |
+
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
|
167 |
+
)
|
168 |
+
loss = audio_loss
|
169 |
+
if train_text:
|
170 |
+
text_logits = gpt.head_text(text_hidden_states)
|
171 |
+
text_loss = loss_fn(
|
172 |
+
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
|
173 |
+
)
|
174 |
+
loss += text_loss
|
175 |
+
logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
|
176 |
+
|
177 |
+
gpt_gen_mel_specs = decoder_decoder(
|
178 |
+
audio_hidden_states[:, :-1].transpose(1, 2)
|
179 |
+
).transpose(1, 2)
|
180 |
+
mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
|
181 |
+
loss += 0.01 * mse_loss
|
182 |
+
|
183 |
+
optimizer.zero_grad()
|
184 |
+
loss.backward()
|
185 |
+
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
|
186 |
+
optimizer.step()
|
187 |
+
logger.meters["loss"].update(loss.item(), n=batch_size)
|
188 |
+
logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
|
189 |
+
logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
|
190 |
+
lr_scheduler.step()
|
191 |
+
optimizer.zero_grad()
|
192 |
+
return speaker_embeds
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == "__main__":
|
196 |
+
import argparse
|
197 |
+
import os
|
198 |
+
import numpy as np
|
199 |
+
import pathlib
|
200 |
+
from modules.models import load_chat_tts
|
201 |
+
from modules.devices import devices
|
202 |
+
from modules import config
|
203 |
+
from modules.speaker import Speaker
|
204 |
+
|
205 |
+
config.runtime_env_vars.no_half = True
|
206 |
+
devices.reset_device()
|
207 |
+
|
208 |
+
parser = argparse.ArgumentParser()
|
209 |
+
parser.add_argument("--save_folder", type=str, default="./")
|
210 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
211 |
+
parser.add_argument("--epochs", type=int, default=100)
|
212 |
+
parser.add_argument("--train_text", action="store_true", help="train text loss")
|
213 |
+
# 初始化 speaker
|
214 |
+
parser.add_argument("--init_speaker", type=str)
|
215 |
+
parser.add_argument(
|
216 |
+
"--data_path",
|
217 |
+
type=str,
|
218 |
+
default="datasets/data_speaker_a/speaker_a.list",
|
219 |
+
help="the data_path to json/list file",
|
220 |
+
)
|
221 |
+
parser.add_argument("--tar_path", type=str, help="the tarball path with wavs")
|
222 |
+
parser.add_argument(
|
223 |
+
"--tar_in_memory", action="store_true", help="load tarball in memory"
|
224 |
+
)
|
225 |
+
|
226 |
+
args = parser.parse_args()
|
227 |
+
|
228 |
+
data_path: str = args.data_path
|
229 |
+
tar_path: str | None = args.tar_path
|
230 |
+
tar_in_memory: bool = args.tar_in_memory
|
231 |
+
train_text: bool = args.train_text
|
232 |
+
# gpt_lora: bool = args.gpt_lora
|
233 |
+
# gpt_kbit: int = args.gpt_kbit
|
234 |
+
save_folder: str = args.save_folder
|
235 |
+
batch_size: int = args.batch_size
|
236 |
+
epochs: int = args.epochs
|
237 |
+
init_speaker: str = args.init_speaker
|
238 |
+
|
239 |
+
speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz")
|
240 |
+
|
241 |
+
chat = load_chat_tts()
|
242 |
+
dataset = XzListTar(
|
243 |
+
root=data_path,
|
244 |
+
tokenizer=chat.pretrain_models["tokenizer"],
|
245 |
+
vocos_model=chat.pretrain_models["vocos"],
|
246 |
+
tar_path=tar_path,
|
247 |
+
tar_in_memory=tar_in_memory,
|
248 |
+
device=devices.device,
|
249 |
+
# speakers=None, # set(['speaker_A', 'speaker_B'])
|
250 |
+
)
|
251 |
+
|
252 |
+
print("len(dataset)", len(dataset))
|
253 |
+
|
254 |
+
speaker_embeds = None
|
255 |
+
if init_speaker:
|
256 |
+
spk: Speaker = Speaker.from_file(init_speaker)
|
257 |
+
speaker_embeds = {
|
258 |
+
speaker: torch.tensor(
|
259 |
+
spk.emb,
|
260 |
+
device=devices.device,
|
261 |
+
requires_grad=True,
|
262 |
+
)
|
263 |
+
for speaker in dataset.speakers
|
264 |
+
}
|
265 |
+
|
266 |
+
speaker_embeds = train_speaker_embeddings(
|
267 |
+
chat,
|
268 |
+
dataset,
|
269 |
+
chat.pretrain_models["gpt"],
|
270 |
+
batch_size=batch_size,
|
271 |
+
epochs=epochs,
|
272 |
+
train_text=train_text,
|
273 |
+
speaker_embeds=speaker_embeds,
|
274 |
+
)
|
275 |
+
speaker_outs = {
|
276 |
+
speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}")
|
277 |
+
for speaker, speaker_embed in speaker_embeds.items()
|
278 |
+
}
|
279 |
+
time_str = np.datetime_as_string(np.datetime64("now", "s"))
|
280 |
+
time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_")
|
281 |
+
for speaker, speaker_out in speaker_outs.items():
|
282 |
+
torch.save(
|
283 |
+
speaker_out,
|
284 |
+
pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt",
|
285 |
+
)
|
286 |
+
|
287 |
+
# example
|
288 |
+
"""
|
289 |
+
python -m modules.finetune.train_speaker \
|
290 |
+
--data_path datasets/data_speaker_a/speaker_a.list \
|
291 |
+
--save_folder ./data \
|
292 |
+
--init_speaker ./data/speakers/Bob.pt \
|
293 |
+
--epochs 100 \
|
294 |
+
--batch_size 6 \
|
295 |
+
--train_text
|
296 |
+
"""
|
modules/finetune/utils/__init__.py
ADDED
File without changes
|
modules/finetune/utils/dataset.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import functools
|
3 |
+
import json
|
4 |
+
import tarfile
|
5 |
+
import io
|
6 |
+
import logging
|
7 |
+
import abc
|
8 |
+
import typing
|
9 |
+
|
10 |
+
import torch.utils.data
|
11 |
+
import torchaudio
|
12 |
+
from torchvision.datasets.utils import download_url
|
13 |
+
import transformers
|
14 |
+
import vocos
|
15 |
+
|
16 |
+
from modules.ChatTTS.ChatTTS.utils.infer_utils import (
|
17 |
+
count_invalid_characters,
|
18 |
+
apply_character_map,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class LazyDataType(typing.TypedDict):
|
23 |
+
filepath: str
|
24 |
+
speaker: str
|
25 |
+
lang: str
|
26 |
+
text: str
|
27 |
+
|
28 |
+
|
29 |
+
class DataType(LazyDataType):
|
30 |
+
text_input_ids: torch.Tensor # (batch_size, text_len)
|
31 |
+
text_attention_mask: torch.Tensor # (batch_size, text_len)
|
32 |
+
audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100)
|
33 |
+
audio_attention_mask: torch.Tensor # (batch_size, audio_len)
|
34 |
+
|
35 |
+
|
36 |
+
class XzListTarKwargsType(typing.TypedDict):
|
37 |
+
tokenizer: typing.Union[transformers.PreTrainedTokenizer, None]
|
38 |
+
vocos_model: typing.Union[vocos.Vocos, None]
|
39 |
+
device: typing.Union[str, torch.device, None]
|
40 |
+
speakers: typing.Union[typing.Iterable[str], None]
|
41 |
+
sample_rate: typing.Union[int]
|
42 |
+
default_speaker: typing.Union[str, None]
|
43 |
+
default_lang: typing.Union[str, None]
|
44 |
+
tar_in_memory: typing.Union[bool, None]
|
45 |
+
process_ahead: typing.Union[bool, None]
|
46 |
+
|
47 |
+
|
48 |
+
class AudioFolder(torch.utils.data.Dataset, abc.ABC):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
root: str | io.BytesIO,
|
52 |
+
tokenizer: transformers.PreTrainedTokenizer | None = None,
|
53 |
+
vocos_model: vocos.Vocos | None = None,
|
54 |
+
device: str | torch.device | None = None,
|
55 |
+
speakers: typing.Iterable[str] | None = None,
|
56 |
+
sample_rate: int = 24_000,
|
57 |
+
default_speaker: str | None = None,
|
58 |
+
default_lang: str | None = None,
|
59 |
+
tar_path: str | None = None,
|
60 |
+
tar_in_memory: bool = False,
|
61 |
+
process_ahead: bool = False,
|
62 |
+
) -> None:
|
63 |
+
self.root = root
|
64 |
+
self.sample_rate = sample_rate
|
65 |
+
self.default_speaker = default_speaker
|
66 |
+
self.default_lang = default_lang
|
67 |
+
|
68 |
+
self.logger = logging.getLogger(__name__)
|
69 |
+
self.normalizer = {}
|
70 |
+
|
71 |
+
self.tokenizer = tokenizer
|
72 |
+
self.vocos = vocos_model
|
73 |
+
self.vocos_device = (
|
74 |
+
None if self.vocos is None else next(self.vocos.parameters()).device
|
75 |
+
)
|
76 |
+
self.device = device or self.vocos_device
|
77 |
+
|
78 |
+
# tar -cvf ../Xz.tar *
|
79 |
+
# tar -xf Xz.tar -C ./Xz
|
80 |
+
self.tar_path = tar_path
|
81 |
+
self.tar_file = None
|
82 |
+
self.tar_io = None
|
83 |
+
if tar_path is not None:
|
84 |
+
if tar_in_memory:
|
85 |
+
with open(tar_path, "rb") as f:
|
86 |
+
self.tar_io = io.BytesIO(f.read())
|
87 |
+
self.tar_file = tarfile.open(fileobj=self.tar_io)
|
88 |
+
else:
|
89 |
+
self.tar_file = tarfile.open(tar_path)
|
90 |
+
|
91 |
+
self.lazy_data, self.speakers = self.get_lazy_data(root, speakers)
|
92 |
+
|
93 |
+
self.text_input_ids: dict[int, torch.Tensor] = {}
|
94 |
+
self.audio_mel_specs: dict[int, torch.Tensor] = {}
|
95 |
+
if process_ahead:
|
96 |
+
for n, item in enumerate(self.lazy_data):
|
97 |
+
self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"])
|
98 |
+
self.text_input_ids[n] = self.preprocess_text(
|
99 |
+
item["text"], item["lang"]
|
100 |
+
)
|
101 |
+
if self.tar_file is not None:
|
102 |
+
self.tar_file.close()
|
103 |
+
if self.tar_io is not None:
|
104 |
+
self.tar_io.close()
|
105 |
+
|
106 |
+
@abc.abstractmethod
|
107 |
+
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ...
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
@abc.abstractmethod
|
111 |
+
def save_config(
|
112 |
+
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
|
113 |
+
) -> None: ...
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
return len(self.lazy_data)
|
117 |
+
|
118 |
+
def __getitem__(self, n: int) -> DataType:
|
119 |
+
lazy_data = self.lazy_data[n]
|
120 |
+
if n in self.audio_mel_specs:
|
121 |
+
audio_mel_specs = self.audio_mel_specs[n]
|
122 |
+
text_input_ids = self.text_input_ids[n]
|
123 |
+
else:
|
124 |
+
audio_mel_specs = self.preprocess_audio(lazy_data["filepath"])
|
125 |
+
text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"])
|
126 |
+
self.audio_mel_specs[n] = audio_mel_specs
|
127 |
+
self.text_input_ids[n] = text_input_ids
|
128 |
+
if len(self.audio_mel_specs) == len(self.lazy_data):
|
129 |
+
if self.tar_file is not None:
|
130 |
+
self.tar_file.close()
|
131 |
+
if self.tar_io is not None:
|
132 |
+
self.tar_io.close()
|
133 |
+
text_attention_mask = torch.ones(
|
134 |
+
len(text_input_ids), device=text_input_ids.device
|
135 |
+
)
|
136 |
+
audio_attention_mask = torch.ones(
|
137 |
+
(len(audio_mel_specs) + 1) // 2,
|
138 |
+
device=audio_mel_specs.device,
|
139 |
+
)
|
140 |
+
return {
|
141 |
+
"filepath": lazy_data["filepath"],
|
142 |
+
"speaker": lazy_data["speaker"],
|
143 |
+
"lang": lazy_data["lang"],
|
144 |
+
"text": lazy_data["text"],
|
145 |
+
"text_input_ids": text_input_ids,
|
146 |
+
"text_attention_mask": text_attention_mask,
|
147 |
+
"audio_mel_specs": audio_mel_specs,
|
148 |
+
"audio_attention_mask": audio_attention_mask,
|
149 |
+
}
|
150 |
+
|
151 |
+
def get_lazy_data(
|
152 |
+
self,
|
153 |
+
root: str | io.BytesIO,
|
154 |
+
speakers: typing.Iterable[str] | None = None,
|
155 |
+
) -> tuple[list[LazyDataType], set[str]]:
|
156 |
+
if speakers is not None:
|
157 |
+
new_speakers = set(speakers)
|
158 |
+
else:
|
159 |
+
new_speakers = set()
|
160 |
+
lazy_data = []
|
161 |
+
|
162 |
+
raw_data = self.get_raw_data(root)
|
163 |
+
folder_path = os.path.dirname(root) if isinstance(root, str) else ""
|
164 |
+
for item in raw_data:
|
165 |
+
if "speaker" not in item:
|
166 |
+
item["speaker"] = self.default_speaker
|
167 |
+
if "lang" not in item:
|
168 |
+
item["lang"] = self.default_lang
|
169 |
+
|
170 |
+
if speakers is not None and item["speaker"] not in speakers:
|
171 |
+
continue
|
172 |
+
if speakers is None and item["speaker"] not in new_speakers:
|
173 |
+
new_speakers.add(item["speaker"])
|
174 |
+
if self.tar_file is None and isinstance(root, str):
|
175 |
+
filepath = os.path.join(folder_path, item["filepath"])
|
176 |
+
else:
|
177 |
+
filepath = item["filepath"]
|
178 |
+
lazy_data.append(
|
179 |
+
{
|
180 |
+
"filepath": filepath,
|
181 |
+
"speaker": item["speaker"],
|
182 |
+
"lang": item["lang"].lower(),
|
183 |
+
"text": item["text"],
|
184 |
+
}
|
185 |
+
)
|
186 |
+
return lazy_data, new_speakers
|
187 |
+
|
188 |
+
def preprocess_text(
|
189 |
+
self,
|
190 |
+
text: str,
|
191 |
+
lang: str,
|
192 |
+
) -> torch.Tensor:
|
193 |
+
invalid_characters = count_invalid_characters(text)
|
194 |
+
if len(invalid_characters):
|
195 |
+
# self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
|
196 |
+
text = apply_character_map(text)
|
197 |
+
|
198 |
+
# if not skip_refine_text:
|
199 |
+
# text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
200 |
+
# text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
201 |
+
# text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
202 |
+
# if refine_text_only:
|
203 |
+
# return text
|
204 |
+
|
205 |
+
text = f"[Stts][spk_emb]{text}[Ptts]"
|
206 |
+
# text = f'[Stts][empty_spk]{text}[Ptts]'
|
207 |
+
|
208 |
+
text_token = self.tokenizer(
|
209 |
+
text, return_tensors="pt", add_special_tokens=False
|
210 |
+
).to(device=self.device)
|
211 |
+
return text_token["input_ids"].squeeze(0)
|
212 |
+
|
213 |
+
def preprocess_audio(self, filepath: str) -> torch.Tensor:
|
214 |
+
if self.tar_file is not None:
|
215 |
+
file = self.tar_file.extractfile(filepath)
|
216 |
+
waveform, sample_rate = torchaudio.load(file)
|
217 |
+
else:
|
218 |
+
waveform, sample_rate = torchaudio.load(filepath)
|
219 |
+
waveform = waveform.to(device=self.vocos_device)
|
220 |
+
if sample_rate != self.sample_rate:
|
221 |
+
waveform = torchaudio.functional.resample(
|
222 |
+
waveform,
|
223 |
+
orig_freq=sample_rate,
|
224 |
+
new_freq=self.sample_rate,
|
225 |
+
)
|
226 |
+
mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform)
|
227 |
+
return (
|
228 |
+
mel_spec.to(device=self.device).squeeze(0).transpose(0, 1)
|
229 |
+
) # (audio_len*2, 100)
|
230 |
+
|
231 |
+
|
232 |
+
class JsonFolder(AudioFolder):
|
233 |
+
"""
|
234 |
+
In json file, each item is formatted as following example:
|
235 |
+
`{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`.
|
236 |
+
|
237 |
+
filepath is relative to the dirname of root json file.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
|
241 |
+
with open(root, "r", encoding="utf-8") as f:
|
242 |
+
raw_data = json.load(f)
|
243 |
+
return raw_data
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def save_config(
|
247 |
+
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
|
248 |
+
) -> None:
|
249 |
+
save_data = [item.copy() for item in lazy_data]
|
250 |
+
for item in save_data:
|
251 |
+
item["filepath"] = os.path.relpath(item["filepath"], rel_path)
|
252 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
253 |
+
json.dump(save_data, f, ensure_ascii=False, indent=4)
|
254 |
+
|
255 |
+
|
256 |
+
class ListFolder(AudioFolder):
|
257 |
+
"""
|
258 |
+
In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator.
|
259 |
+
`path/to/file.wav|John|ZH|Hello`.
|
260 |
+
|
261 |
+
filepath is relative to the dirname of root list file.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
|
265 |
+
raw_data = []
|
266 |
+
with open(root, "r", encoding="utf-8") as f:
|
267 |
+
for line in f.readlines():
|
268 |
+
line = line.strip().removesuffix("\n")
|
269 |
+
if len(line) == 0:
|
270 |
+
continue
|
271 |
+
filepath, speaker, lang, text = line.split(sep="|", maxsplit=3)
|
272 |
+
raw_data.append(
|
273 |
+
{
|
274 |
+
"text": text,
|
275 |
+
"filepath": filepath,
|
276 |
+
"speaker": speaker,
|
277 |
+
"lang": lang,
|
278 |
+
}
|
279 |
+
)
|
280 |
+
return raw_data
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def save_config(
|
284 |
+
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
|
285 |
+
) -> None:
|
286 |
+
save_data = [item.copy() for item in lazy_data]
|
287 |
+
for item in save_data:
|
288 |
+
item["filepath"] = os.path.relpath(item["filepath"], rel_path)
|
289 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
290 |
+
for item in save_data:
|
291 |
+
f.write(
|
292 |
+
f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n"
|
293 |
+
)
|
294 |
+
|
295 |
+
|
296 |
+
class XzListTar(ListFolder):
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
*args,
|
300 |
+
root: str | io.BytesIO,
|
301 |
+
tar_path: str | None = None,
|
302 |
+
**kwargs,
|
303 |
+
):
|
304 |
+
if isinstance(root, io.BytesIO):
|
305 |
+
assert tar_path is not None
|
306 |
+
else:
|
307 |
+
# make sure root is a list file
|
308 |
+
if not root.endswith(".list"): # folder case
|
309 |
+
if os.path.isfile(root):
|
310 |
+
raise FileExistsError(f"{root} is a file!")
|
311 |
+
elif not os.path.exists(root):
|
312 |
+
os.makedirs(root)
|
313 |
+
root = os.path.join(root, "all.list")
|
314 |
+
if isinstance(root, str) and not os.path.isfile(root):
|
315 |
+
# prepare all.list
|
316 |
+
self.concat_dataset(
|
317 |
+
save_folder=os.path.dirname(root),
|
318 |
+
langs=kwargs.get("langs", ["zh", "en"]),
|
319 |
+
)
|
320 |
+
|
321 |
+
super().__init__(root, *args, tar_path=tar_path, **kwargs)
|
322 |
+
|
323 |
+
def concat_dataset(
|
324 |
+
self, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
|
325 |
+
) -> None:
|
326 |
+
if save_folder is None:
|
327 |
+
save_folder = os.path.dirname(self.root)
|
328 |
+
if os.path.isfile(save_folder):
|
329 |
+
raise FileExistsError(f"{save_folder} already exists as a file!")
|
330 |
+
elif not os.path.exists(save_folder):
|
331 |
+
os.makedirs(save_folder)
|
332 |
+
lazy_data = []
|
333 |
+
|
334 |
+
for member in self.tar_file.getmembers():
|
335 |
+
if not member.isfile():
|
336 |
+
continue
|
337 |
+
if member.name.endswith(".list"):
|
338 |
+
print(member.name)
|
339 |
+
root_io = self.tar_file.extractfile(member)
|
340 |
+
lazy_data += ListFolder(root_io).lazy_data
|
341 |
+
if member.name.endswith(".json"):
|
342 |
+
print(member.name)
|
343 |
+
root_io = self.tar_file.extractfile(member)
|
344 |
+
lazy_data += JsonFolder(root_io).lazy_data
|
345 |
+
if langs is not None:
|
346 |
+
lazy_data = [item for item in lazy_data if item["lang"] in langs]
|
347 |
+
ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data)
|
348 |
+
JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data)
|
349 |
+
print(f"all.list and all.json are saved to {save_folder}")
|
350 |
+
|
351 |
+
|
352 |
+
class XzListFolder(ListFolder):
|
353 |
+
"""
|
354 |
+
[Xz乔希](https://space.bilibili.com/5859321)
|
355 |
+
|
356 |
+
Only look at the basename of filepath in list file. Previous folder paths are ignored.
|
357 |
+
Files are organized as `[list basename]/[file basename]`
|
358 |
+
|
359 |
+
Example tree structure:
|
360 |
+
|
361 |
+
[folder]
|
362 |
+
├── speaker_A
|
363 |
+
│ ├── 1.wav
|
364 |
+
│ └── 2.wav
|
365 |
+
├── speaker_A.list
|
366 |
+
├── speaker_B
|
367 |
+
│ ├── 1.wav
|
368 |
+
│ └── 2.wav
|
369 |
+
└── speaker_B.list
|
370 |
+
"""
|
371 |
+
|
372 |
+
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
|
373 |
+
raw_data = super().get_raw_data(root)
|
374 |
+
for item in raw_data:
|
375 |
+
item["filepath"] = os.path.join(
|
376 |
+
os.path.basename(root).removesuffix(".list"),
|
377 |
+
os.path.basename(item["filepath"]),
|
378 |
+
)
|
379 |
+
return raw_data
|
380 |
+
|
381 |
+
|
382 |
+
class AudioCollator:
|
383 |
+
def __init__(self, text_pad: int = 0, audio_pad: int = 0):
|
384 |
+
self.text_pad = text_pad
|
385 |
+
self.audio_pad = audio_pad
|
386 |
+
|
387 |
+
def __call__(self, batch: list[DataType]):
|
388 |
+
batch = [x for x in batch if x is not None]
|
389 |
+
|
390 |
+
audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch)
|
391 |
+
text_maxlen = max(len(item["text_attention_mask"]) for item in batch)
|
392 |
+
|
393 |
+
filepath = []
|
394 |
+
speaker = []
|
395 |
+
lang = []
|
396 |
+
text = []
|
397 |
+
text_input_ids = []
|
398 |
+
text_attention_mask = []
|
399 |
+
audio_mel_specs = []
|
400 |
+
audio_attention_mask = []
|
401 |
+
|
402 |
+
for x in batch:
|
403 |
+
filepath.append(x["filepath"])
|
404 |
+
speaker.append(x["speaker"])
|
405 |
+
lang.append(x["lang"])
|
406 |
+
text.append(x["text"])
|
407 |
+
text_input_ids.append(
|
408 |
+
torch.nn.functional.pad(
|
409 |
+
x["text_input_ids"],
|
410 |
+
(text_maxlen - len(x["text_input_ids"]), 0),
|
411 |
+
value=self.text_pad,
|
412 |
+
)
|
413 |
+
)
|
414 |
+
text_attention_mask.append(
|
415 |
+
torch.nn.functional.pad(
|
416 |
+
x["text_attention_mask"],
|
417 |
+
(text_maxlen - len(x["text_attention_mask"]), 0),
|
418 |
+
value=0,
|
419 |
+
)
|
420 |
+
)
|
421 |
+
audio_mel_specs.append(
|
422 |
+
torch.nn.functional.pad(
|
423 |
+
x["audio_mel_specs"],
|
424 |
+
(0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])),
|
425 |
+
value=self.audio_pad,
|
426 |
+
)
|
427 |
+
)
|
428 |
+
audio_attention_mask.append(
|
429 |
+
torch.nn.functional.pad(
|
430 |
+
x["audio_attention_mask"],
|
431 |
+
(0, audio_maxlen - len(x["audio_attention_mask"])),
|
432 |
+
value=0,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
return {
|
436 |
+
"filepath": filepath,
|
437 |
+
"speaker": speaker,
|
438 |
+
"lang": lang,
|
439 |
+
"text": text,
|
440 |
+
"text_input_ids": torch.stack(text_input_ids),
|
441 |
+
"text_attention_mask": torch.stack(text_attention_mask),
|
442 |
+
"audio_mel_specs": torch.stack(audio_mel_specs),
|
443 |
+
"audio_attention_mask": torch.stack(audio_attention_mask),
|
444 |
+
}
|
445 |
+
|
446 |
+
|
447 |
+
def formalize_xz_list(src_folder: str):
|
448 |
+
for root, _, files in os.walk(src_folder):
|
449 |
+
for file in files:
|
450 |
+
if file.endswith(".list"):
|
451 |
+
filepath = os.path.join(root, file)
|
452 |
+
print(filepath)
|
453 |
+
lazy_data = XzListFolder(filepath).lazy_data
|
454 |
+
XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder)
|
455 |
+
|
456 |
+
|
457 |
+
def concat_dataset(
|
458 |
+
src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
|
459 |
+
) -> None:
|
460 |
+
if save_folder is None:
|
461 |
+
save_folder = src_folder
|
462 |
+
if os.path.isfile(save_folder):
|
463 |
+
raise FileExistsError(f"{save_folder} already exists as a file!")
|
464 |
+
elif not os.path.exists(save_folder):
|
465 |
+
os.makedirs(save_folder)
|
466 |
+
lazy_data = []
|
467 |
+
same_folder = os.path.samefile(src_folder, save_folder)
|
468 |
+
for root, _, files in os.walk(src_folder):
|
469 |
+
for file in files:
|
470 |
+
filepath = os.path.join(root, file)
|
471 |
+
if same_folder and file in ("all.list", "all.json"):
|
472 |
+
continue
|
473 |
+
if file.endswith(".list"):
|
474 |
+
print(filepath)
|
475 |
+
lazy_data += ListFolder(filepath).lazy_data
|
476 |
+
if file.endswith(".json"):
|
477 |
+
print(filepath)
|
478 |
+
lazy_data += JsonFolder(filepath).lazy_data
|
479 |
+
if langs is not None:
|
480 |
+
lazy_data = [item for item in lazy_data if item["lang"] in langs]
|
481 |
+
ListFolder.save_config(
|
482 |
+
os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder
|
483 |
+
)
|
484 |
+
JsonFolder.save_config(
|
485 |
+
os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder
|
486 |
+
)
|
487 |
+
print(f"all.list and all.json are saved to {save_folder}")
|
modules/finetune/utils/logger.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import statistics
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
from tqdm import tqdm as tqdm_class
|
7 |
+
|
8 |
+
from typing import Generator, Iterable, TypeVar
|
9 |
+
from typing_extensions import Self
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
from .output import ansi, prints, get_ansi_len
|
15 |
+
|
16 |
+
__all__ = ["SmoothedValue", "MetricLogger"]
|
17 |
+
|
18 |
+
MB = 1 << 20
|
19 |
+
T = TypeVar("T")
|
20 |
+
|
21 |
+
|
22 |
+
class SmoothedValue:
|
23 |
+
r"""Track a series of values and provide access to smoothed values over a
|
24 |
+
window or the global series average.
|
25 |
+
|
26 |
+
See Also:
|
27 |
+
https://github.com/pytorch/vision/blob/main/references/classification/utils.py
|
28 |
+
|
29 |
+
Args:
|
30 |
+
name (str): Name string.
|
31 |
+
window_size (int): The :attr:`maxlen` of :class:`~collections.deque`.
|
32 |
+
fmt (str): The format pattern of ``str(self)``.
|
33 |
+
|
34 |
+
Attributes:
|
35 |
+
name (str): Name string.
|
36 |
+
fmt (str): The string pattern.
|
37 |
+
deque (~collections.deque): The unique data series.
|
38 |
+
count (int): The amount of data.
|
39 |
+
total (float): The sum of all data.
|
40 |
+
|
41 |
+
median (float): The median of :attr:`deque`.
|
42 |
+
avg (float): The avg of :attr:`deque`.
|
43 |
+
global_avg (float): :math:`\frac{\text{total}}{\text{count}}`
|
44 |
+
max (float): The max of :attr:`deque`.
|
45 |
+
min (float): The min of :attr:`deque`.
|
46 |
+
last_value (float): The last value of :attr:`deque`.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}"
|
51 |
+
):
|
52 |
+
self.name = name
|
53 |
+
self.deque: deque[float] = deque(maxlen=window_size)
|
54 |
+
self.count: int = 0
|
55 |
+
self.total: float = 0.0
|
56 |
+
self.fmt = fmt
|
57 |
+
|
58 |
+
def update(self, value: float, n: int = 1) -> Self:
|
59 |
+
r"""Update :attr:`n` pieces of data with same :attr:`value`.
|
60 |
+
|
61 |
+
.. code-block:: python
|
62 |
+
|
63 |
+
self.deque.append(value)
|
64 |
+
self.total += value * n
|
65 |
+
self.count += n
|
66 |
+
|
67 |
+
Args:
|
68 |
+
value (float): the value to update.
|
69 |
+
n (int): the number of data with same :attr:`value`.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
SmoothedValue: return ``self`` for stream usage.
|
73 |
+
"""
|
74 |
+
self.deque.append(value)
|
75 |
+
self.total += value * n
|
76 |
+
self.count += n
|
77 |
+
return self
|
78 |
+
|
79 |
+
def update_list(self, value_list: list[float]) -> Self:
|
80 |
+
r"""Update :attr:`value_list`.
|
81 |
+
|
82 |
+
.. code-block:: python
|
83 |
+
|
84 |
+
for value in value_list:
|
85 |
+
self.deque.append(value)
|
86 |
+
self.total += value
|
87 |
+
self.count += len(value_list)
|
88 |
+
|
89 |
+
Args:
|
90 |
+
value_list (list[float]): the value list to update.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
SmoothedValue: return ``self`` for stream usage.
|
94 |
+
"""
|
95 |
+
for value in value_list:
|
96 |
+
self.deque.append(value)
|
97 |
+
self.total += value
|
98 |
+
self.count += len(value_list)
|
99 |
+
return self
|
100 |
+
|
101 |
+
def reset(self) -> Self:
|
102 |
+
r"""Reset ``deque``, ``count`` and ``total`` to be empty.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
SmoothedValue: return ``self`` for stream usage.
|
106 |
+
"""
|
107 |
+
self.deque = deque(maxlen=self.deque.maxlen)
|
108 |
+
self.count = 0
|
109 |
+
self.total = 0.0
|
110 |
+
return self
|
111 |
+
|
112 |
+
def synchronize_between_processes(self):
|
113 |
+
r"""
|
114 |
+
Warning:
|
115 |
+
Does NOT synchronize the deque!
|
116 |
+
"""
|
117 |
+
if not (dist.is_available() and dist.is_initialized()):
|
118 |
+
return
|
119 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
120 |
+
dist.barrier()
|
121 |
+
dist.all_reduce(t)
|
122 |
+
t = t.tolist()
|
123 |
+
self.count = int(t[0])
|
124 |
+
self.total = float(t[1])
|
125 |
+
|
126 |
+
@property
|
127 |
+
def median(self) -> float:
|
128 |
+
try:
|
129 |
+
return statistics.median(self.deque)
|
130 |
+
except Exception:
|
131 |
+
return 0.0
|
132 |
+
|
133 |
+
@property
|
134 |
+
def avg(self) -> float:
|
135 |
+
try:
|
136 |
+
return statistics.mean(self.deque)
|
137 |
+
except Exception:
|
138 |
+
return 0.0
|
139 |
+
|
140 |
+
@property
|
141 |
+
def global_avg(self) -> float:
|
142 |
+
try:
|
143 |
+
return self.total / self.count
|
144 |
+
except Exception:
|
145 |
+
return 0.0
|
146 |
+
|
147 |
+
@property
|
148 |
+
def max(self) -> float:
|
149 |
+
try:
|
150 |
+
return max(self.deque)
|
151 |
+
except Exception:
|
152 |
+
return 0.0
|
153 |
+
|
154 |
+
@property
|
155 |
+
def min(self) -> float:
|
156 |
+
try:
|
157 |
+
return min(self.deque)
|
158 |
+
except Exception:
|
159 |
+
return 0.0
|
160 |
+
|
161 |
+
@property
|
162 |
+
def last_value(self) -> float:
|
163 |
+
try:
|
164 |
+
return self.deque[-1]
|
165 |
+
except Exception:
|
166 |
+
return 0.0
|
167 |
+
|
168 |
+
def __str__(self):
|
169 |
+
return self.fmt.format(
|
170 |
+
name=self.name,
|
171 |
+
count=self.count,
|
172 |
+
total=self.total,
|
173 |
+
median=self.median,
|
174 |
+
avg=self.avg,
|
175 |
+
global_avg=self.global_avg,
|
176 |
+
min=self.min,
|
177 |
+
max=self.max,
|
178 |
+
last_value=self.last_value,
|
179 |
+
)
|
180 |
+
|
181 |
+
def __format__(self, format_spec: str) -> str:
|
182 |
+
return self.__str__()
|
183 |
+
|
184 |
+
|
185 |
+
class MetricLogger:
|
186 |
+
r"""
|
187 |
+
See Also:
|
188 |
+
https://github.com/pytorch/vision/blob/main/references/classification/utils.py
|
189 |
+
|
190 |
+
Args:
|
191 |
+
delimiter (str): The delimiter to join different meter strings.
|
192 |
+
Defaults to ``''``.
|
193 |
+
meter_length (int): The minimum length for each meter.
|
194 |
+
Defaults to ``20``.
|
195 |
+
tqdm (bool): Whether to use tqdm to show iteration information.
|
196 |
+
Defaults to ``env['tqdm']``.
|
197 |
+
indent (int): The space indent for the entire string.
|
198 |
+
Defaults to ``0``.
|
199 |
+
|
200 |
+
Attributes:
|
201 |
+
meters (dict[str, SmoothedValue]): The meter dict.
|
202 |
+
iter_time (SmoothedValue): Iteration time meter.
|
203 |
+
data_time (SmoothedValue): Data loading time meter.
|
204 |
+
memory (SmoothedValue): Memory usage meter.
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
delimiter: str = "",
|
210 |
+
meter_length: int = 20,
|
211 |
+
tqdm: bool = True,
|
212 |
+
indent: int = 0,
|
213 |
+
**kwargs,
|
214 |
+
):
|
215 |
+
self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue)
|
216 |
+
self.create_meters(**kwargs)
|
217 |
+
self.delimiter = delimiter
|
218 |
+
self.meter_length = meter_length
|
219 |
+
self.tqdm = tqdm
|
220 |
+
self.indent = indent
|
221 |
+
|
222 |
+
self.iter_time = SmoothedValue()
|
223 |
+
self.data_time = SmoothedValue()
|
224 |
+
self.memory = SmoothedValue(fmt="{max:.0f}")
|
225 |
+
|
226 |
+
def create_meters(self, **kwargs: str) -> Self:
|
227 |
+
r"""Create meters with specific ``fmt`` in :attr:`self.meters`.
|
228 |
+
|
229 |
+
``self.meters[meter_name] = SmoothedValue(fmt=fmt)``
|
230 |
+
|
231 |
+
Args:
|
232 |
+
**kwargs: ``(meter_name: fmt)``
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
MetricLogger: return ``self`` for stream usage.
|
236 |
+
"""
|
237 |
+
for k, v in kwargs.items():
|
238 |
+
self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v)
|
239 |
+
return self
|
240 |
+
|
241 |
+
def update(self, n: int = 1, **kwargs: float) -> Self:
|
242 |
+
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`.
|
243 |
+
|
244 |
+
``self.meters[meter_name].update(float(value), n=n)``
|
245 |
+
|
246 |
+
Args:
|
247 |
+
n (int): the number of data with same value.
|
248 |
+
**kwargs: ``{meter_name: value}``.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
MetricLogger: return ``self`` for stream usage.
|
252 |
+
"""
|
253 |
+
for k, v in kwargs.items():
|
254 |
+
if k not in self.meters:
|
255 |
+
self.meters[k] = SmoothedValue()
|
256 |
+
self.meters[k].update(float(v), n=n)
|
257 |
+
return self
|
258 |
+
|
259 |
+
def update_list(self, **kwargs: list) -> Self:
|
260 |
+
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`.
|
261 |
+
|
262 |
+
``self.meters[meter_name].update_list(value_list)``
|
263 |
+
|
264 |
+
Args:
|
265 |
+
**kwargs: ``{meter_name: value_list}``.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
MetricLogger: return ``self`` for stream usage.
|
269 |
+
"""
|
270 |
+
for k, v in kwargs.items():
|
271 |
+
self.meters[k].update_list(v)
|
272 |
+
return self
|
273 |
+
|
274 |
+
def reset(self) -> Self:
|
275 |
+
r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
MetricLogger: return ``self`` for stream usage.
|
279 |
+
"""
|
280 |
+
for meter in self.meters.values():
|
281 |
+
meter.reset()
|
282 |
+
return self
|
283 |
+
|
284 |
+
def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str:
|
285 |
+
r"""Generate formatted string based on keyword arguments.
|
286 |
+
|
287 |
+
``key: value`` with max length to be :attr:`self.meter_length`.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
cut_too_long (bool): Whether to cut too long values to first 5 characters.
|
291 |
+
Defaults to ``True``.
|
292 |
+
strip (bool): Whether to strip trailing whitespaces.
|
293 |
+
Defaults to ``True``.
|
294 |
+
**kwargs: Keyword arguments to generate string.
|
295 |
+
"""
|
296 |
+
str_list: list[str] = []
|
297 |
+
for k, v in kwargs.items():
|
298 |
+
v_str = str(v)
|
299 |
+
_str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi)
|
300 |
+
max_length = self.meter_length + get_ansi_len(_str)
|
301 |
+
if cut_too_long:
|
302 |
+
_str = _str[:max_length]
|
303 |
+
str_list.append(_str.ljust(max_length))
|
304 |
+
_str = self.delimiter.join(str_list)
|
305 |
+
if strip:
|
306 |
+
_str = _str.rstrip()
|
307 |
+
return _str
|
308 |
+
|
309 |
+
def __getattr__(self, attr: str) -> float:
|
310 |
+
if attr in self.meters:
|
311 |
+
return self.meters[attr]
|
312 |
+
if attr in vars(self): # TODO: use hasattr
|
313 |
+
return vars(self)[attr]
|
314 |
+
raise AttributeError(
|
315 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
316 |
+
)
|
317 |
+
|
318 |
+
def __str__(self) -> str:
|
319 |
+
return self.get_str(**self.meters)
|
320 |
+
|
321 |
+
def synchronize_between_processes(self):
|
322 |
+
for meter in self.meters.values():
|
323 |
+
meter.synchronize_between_processes()
|
324 |
+
|
325 |
+
def log_every(
|
326 |
+
self,
|
327 |
+
iterable: Iterable[T],
|
328 |
+
header: str = "",
|
329 |
+
tqdm: bool = None,
|
330 |
+
tqdm_header: str = "Iter",
|
331 |
+
indent: int = None,
|
332 |
+
verbose: int = 1,
|
333 |
+
) -> Generator[T, None, None]:
|
334 |
+
r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs.
|
335 |
+
|
336 |
+
* Middle Output:
|
337 |
+
``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}``
|
338 |
+
* Final Output
|
339 |
+
``{header} str(self) {memory} {iter_time} {data_time} {total_time}``
|
340 |
+
|
341 |
+
Args:
|
342 |
+
iterable (~collections.abc.Iterable): The raw iterator.
|
343 |
+
header (str): The header string for final output.
|
344 |
+
Defaults to ``''``.
|
345 |
+
tqdm (bool): Whether to use tqdm to show iteration information.
|
346 |
+
Defaults to ``self.tqdm``.
|
347 |
+
tqdm_header (str): The header string for middle output.
|
348 |
+
Defaults to ``'Iter'``.
|
349 |
+
indent (int): The space indent for the entire string.
|
350 |
+
if ``None``, use ``self.indent``.
|
351 |
+
Defaults to ``None``.
|
352 |
+
verbose (int): The verbose level of output information.
|
353 |
+
"""
|
354 |
+
tqdm = tqdm if tqdm is not None else self.tqdm
|
355 |
+
indent = indent if indent is not None else self.indent
|
356 |
+
iterator = iterable
|
357 |
+
if len(header) != 0:
|
358 |
+
header = header.ljust(30 + get_ansi_len(header))
|
359 |
+
if tqdm:
|
360 |
+
length = len(str(len(iterable)))
|
361 |
+
pattern: str = (
|
362 |
+
"{tqdm_header}: {blue_light}"
|
363 |
+
"[ {red}{{n_fmt:>{length}}}{blue_light} "
|
364 |
+
"/ {red}{{total_fmt}}{blue_light} ]{reset}"
|
365 |
+
).format(tqdm_header=tqdm_header, length=length, **ansi)
|
366 |
+
offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length
|
367 |
+
pattern = pattern.ljust(30 + offset + get_ansi_len(pattern))
|
368 |
+
time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False)
|
369 |
+
bar_format = f"{pattern}{{desc}}{time_str}"
|
370 |
+
iterator = tqdm_class(iterable, leave=False, bar_format=bar_format)
|
371 |
+
|
372 |
+
self.iter_time.reset()
|
373 |
+
self.data_time.reset()
|
374 |
+
self.memory.reset()
|
375 |
+
|
376 |
+
end = time.time()
|
377 |
+
start_time = time.time()
|
378 |
+
for obj in iterator:
|
379 |
+
cur_data_time = time.time() - end
|
380 |
+
self.data_time.update(cur_data_time)
|
381 |
+
yield obj
|
382 |
+
cur_iter_time = time.time() - end
|
383 |
+
self.iter_time.update(cur_iter_time)
|
384 |
+
if torch.cuda.is_available():
|
385 |
+
cur_memory = torch.cuda.max_memory_allocated() / MB
|
386 |
+
self.memory.update(cur_memory)
|
387 |
+
if tqdm:
|
388 |
+
_dict = {k: v for k, v in self.meters.items()}
|
389 |
+
if verbose > 2 and torch.cuda.is_available():
|
390 |
+
_dict.update(memory=f"{cur_memory:.0f} MB")
|
391 |
+
if verbose > 1:
|
392 |
+
_dict.update(
|
393 |
+
iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s"
|
394 |
+
)
|
395 |
+
iterator.set_description_str(self.get_str(**_dict, strip=False))
|
396 |
+
end = time.time()
|
397 |
+
self.synchronize_between_processes()
|
398 |
+
total_time = time.time() - start_time
|
399 |
+
total_time_str = tqdm_class.format_interval(total_time)
|
400 |
+
|
401 |
+
_dict = {k: v for k, v in self.meters.items()}
|
402 |
+
if verbose > 2 and torch.cuda.is_available():
|
403 |
+
_dict.update(memory=f"{str(self.memory)} MB")
|
404 |
+
if verbose > 1:
|
405 |
+
_dict.update(
|
406 |
+
iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s"
|
407 |
+
)
|
408 |
+
_dict.update(time=total_time_str)
|
409 |
+
prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent)
|
modules/finetune/utils/model.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ
|
4 |
+
|
5 |
+
|
6 |
+
def quantize(
|
7 |
+
quantizer: GroupedResidualFSQ,
|
8 |
+
audio_latents: torch.Tensor, # (batch_size, audio_len, audio_dim=1024)
|
9 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
10 |
+
# feat shape (batch_size, audio_len, audio_dim)
|
11 |
+
# ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R)
|
12 |
+
# num_vq=GFSQ.G*GFSQ.R
|
13 |
+
feat, ind = quantizer(audio_latents)
|
14 |
+
audio_quantized_latents = feat # (batch_size, audio_len, audio_dim)
|
15 |
+
audio_input_ids = rearrange( # (batch_size, audio_len, num_vq)
|
16 |
+
ind,
|
17 |
+
"g b t r ->b t (g r)",
|
18 |
+
)
|
19 |
+
return audio_quantized_latents, audio_input_ids
|
modules/finetune/utils/output.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
from contextlib import contextmanager
|
6 |
+
|
7 |
+
|
8 |
+
class ANSI:
|
9 |
+
ansi_color = {
|
10 |
+
"black": "\033[30m",
|
11 |
+
"red": "\033[31m",
|
12 |
+
"green": "\033[32m",
|
13 |
+
"yellow": "\033[33m",
|
14 |
+
"blue": "\033[34m",
|
15 |
+
"purple": "\033[35m",
|
16 |
+
"blue_light": "\033[36m",
|
17 |
+
"white": "\033[37m",
|
18 |
+
"reset": "\033[0m",
|
19 |
+
"upline": "\033[1A",
|
20 |
+
"clear_line": "\033[2K",
|
21 |
+
"clear": "\033[2J",
|
22 |
+
}
|
23 |
+
ansi_nocolor = {
|
24 |
+
"black": "",
|
25 |
+
"red": "",
|
26 |
+
"green": "",
|
27 |
+
"yellow": "",
|
28 |
+
"blue": "",
|
29 |
+
"purple": "",
|
30 |
+
"blue_light": "",
|
31 |
+
"white": "",
|
32 |
+
"reset": "",
|
33 |
+
"upline": "\033[1A\033[",
|
34 |
+
"clear_line": "\033[K",
|
35 |
+
"clear": "\033[2J",
|
36 |
+
}
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
self._dict = ANSI.ansi_color if ("--color" in sys.argv) else ANSI.ansi_nocolor
|
40 |
+
|
41 |
+
def switch(self, color: bool):
|
42 |
+
self._dict = ANSI.ansi_color if color else ANSI.ansi_nocolor
|
43 |
+
|
44 |
+
def keys(self):
|
45 |
+
return self._dict.keys()
|
46 |
+
|
47 |
+
def items(self):
|
48 |
+
return self._dict.items()
|
49 |
+
|
50 |
+
def __getitem__(self, key):
|
51 |
+
return self._dict[key]
|
52 |
+
|
53 |
+
def __str__(self):
|
54 |
+
return str(self._dict)
|
55 |
+
|
56 |
+
def __repr__(self):
|
57 |
+
return repr(self._dict)
|
58 |
+
|
59 |
+
|
60 |
+
ansi = ANSI()
|
61 |
+
|
62 |
+
|
63 |
+
def remove_ansi(s: str) -> str:
|
64 |
+
ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
|
65 |
+
return ansi_escape.sub("", s)
|
66 |
+
|
67 |
+
|
68 |
+
def get_ansi_len(s: str) -> int:
|
69 |
+
return len(s) - len(remove_ansi(s))
|
70 |
+
|
71 |
+
|
72 |
+
def prints(*args: str, indent: int = 0, prefix: str = "", **kwargs):
|
73 |
+
assert indent >= 0
|
74 |
+
new_args = []
|
75 |
+
for arg in args:
|
76 |
+
new_args.append(indent_str(str(arg), indent=indent))
|
77 |
+
if len(new_args):
|
78 |
+
new_args[0] = prefix + str(new_args[0])
|
79 |
+
print(*new_args, **kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
def output_iter(_iter: int, iteration: int = None, iter_len: int = 4) -> str:
|
83 |
+
if iteration is None:
|
84 |
+
pattern = "{blue_light}[ {red}{0}{blue_light} ]{reset}"
|
85 |
+
return pattern.format(str(_iter).rjust(iter_len), **ansi)
|
86 |
+
else:
|
87 |
+
iter_str = str(iteration)
|
88 |
+
length = len(iter_str)
|
89 |
+
pattern = (
|
90 |
+
"{blue_light}[ {red}{0}{blue_light} " "/ {red}{1}{blue_light} ]{reset}"
|
91 |
+
)
|
92 |
+
return pattern.format(str(_iter).rjust(length), iter_str, **ansi)
|
93 |
+
|
94 |
+
|
95 |
+
def indent_str(s_: str, indent: int = 0) -> str:
|
96 |
+
# modified from torch.nn.modules._addindent
|
97 |
+
if indent > 0 and s_:
|
98 |
+
s_ = indent * " " + str(s_[:-1]).replace("\n", "\n" + indent * " ") + s_[-1]
|
99 |
+
return s_
|
100 |
+
|
101 |
+
|
102 |
+
class IndentRedirect: # TODO: inherit TextIOWrapper?
|
103 |
+
def __init__(self, buffer: bool = True, indent: int = 0):
|
104 |
+
self.__console__ = sys.stdout
|
105 |
+
self.indent = indent
|
106 |
+
self.__buffer: str = None
|
107 |
+
if buffer:
|
108 |
+
self.__buffer = ""
|
109 |
+
|
110 |
+
def write(self, text: str, indent: int = None):
|
111 |
+
indent = indent if indent is not None else self.indent
|
112 |
+
text = indent_str(text, indent=indent)
|
113 |
+
if self.__buffer is None:
|
114 |
+
self.__console__.write(text)
|
115 |
+
else:
|
116 |
+
self.__buffer += text
|
117 |
+
|
118 |
+
def flush(self):
|
119 |
+
if self.__buffer is not None:
|
120 |
+
self.__console__.write(self.__buffer)
|
121 |
+
self.__buffer = ""
|
122 |
+
self.__console__.flush()
|
123 |
+
|
124 |
+
@contextmanager
|
125 |
+
def __call__(self) -> None:
|
126 |
+
try:
|
127 |
+
sys.stdout = self
|
128 |
+
yield
|
129 |
+
finally:
|
130 |
+
sys.stdout = self.__console__
|
131 |
+
self.__buffer = ""
|
132 |
+
|
133 |
+
def enable(self):
|
134 |
+
sys.stdout = self
|
135 |
+
|
136 |
+
def disable(self):
|
137 |
+
if self.__buffer is not None:
|
138 |
+
self.__buffer = ""
|
139 |
+
sys.stdout = self.__console__
|
140 |
+
|
141 |
+
@property
|
142 |
+
def buffer(self) -> str:
|
143 |
+
return self.__buffer
|
144 |
+
|
145 |
+
|
146 |
+
redirect = IndentRedirect()
|
modules/generate_audio.py
CHANGED
@@ -76,6 +76,8 @@ def generate_audio_batch(
|
|
76 |
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
|
77 |
logger.debug(("spk", spk))
|
78 |
elif isinstance(spk, Speaker):
|
|
|
|
|
79 |
params_infer_code["spk_emb"] = spk.emb
|
80 |
logger.debug(("spk", spk.name))
|
81 |
else:
|
|
|
76 |
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
|
77 |
logger.debug(("spk", spk))
|
78 |
elif isinstance(spk, Speaker):
|
79 |
+
if not isinstance(spk.emb, torch.Tensor):
|
80 |
+
raise ValueError("spk.pt is broken, please retrain the model.")
|
81 |
params_infer_code["spk_emb"] = spk.emb
|
82 |
logger.debug(("spk", spk.name))
|
83 |
else:
|
modules/normalization.py
CHANGED
@@ -120,6 +120,7 @@ character_map = {
|
|
120 |
"~": " ",
|
121 |
"~": " ",
|
122 |
"/": " ",
|
|
|
123 |
}
|
124 |
|
125 |
character_to_word = {
|
@@ -282,6 +283,9 @@ def text_normalize(text, is_end=False):
|
|
282 |
|
283 |
|
284 |
if __name__ == "__main__":
|
|
|
|
|
|
|
285 |
test_cases = [
|
286 |
"ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.",
|
287 |
" [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
|
@@ -319,6 +323,7 @@ State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
|
|
319 |
"""
|
320 |
120米
|
321 |
有12%的概率会下雨
|
|
|
322 |
""",
|
323 |
]
|
324 |
|
|
|
120 |
"~": " ",
|
121 |
"~": " ",
|
122 |
"/": " ",
|
123 |
+
"·": " ",
|
124 |
}
|
125 |
|
126 |
character_to_word = {
|
|
|
283 |
|
284 |
|
285 |
if __name__ == "__main__":
|
286 |
+
from modules.devices import devices
|
287 |
+
|
288 |
+
devices.reset_device()
|
289 |
test_cases = [
|
290 |
"ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.",
|
291 |
" [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
|
|
|
323 |
"""
|
324 |
120米
|
325 |
有12%的概率会下雨
|
326 |
+
埃隆·马斯克
|
327 |
""",
|
328 |
]
|
329 |
|
modules/repos_static/resemble_enhance/data/distorter/base.py
CHANGED
@@ -2,6 +2,7 @@ import itertools
|
|
2 |
import os
|
3 |
import random
|
4 |
import time
|
|
|
5 |
import warnings
|
6 |
|
7 |
import numpy as np
|
@@ -87,7 +88,7 @@ class Choice(Effect):
|
|
87 |
|
88 |
|
89 |
class Permutation(Effect):
|
90 |
-
def __init__(self, *effects, n: int
|
91 |
super().__init__()
|
92 |
self.effects = effects
|
93 |
self.n = n
|
|
|
2 |
import os
|
3 |
import random
|
4 |
import time
|
5 |
+
from typing import Union
|
6 |
import warnings
|
7 |
|
8 |
import numpy as np
|
|
|
88 |
|
89 |
|
90 |
class Permutation(Effect):
|
91 |
+
def __init__(self, *effects, n: Union[int, None] = None):
|
92 |
super().__init__()
|
93 |
self.effects = effects
|
94 |
self.n = n
|
modules/repos_static/resemble_enhance/data/distorter/custom.py
CHANGED
@@ -3,6 +3,7 @@ import random
|
|
3 |
from dataclasses import dataclass
|
4 |
from functools import cached_property
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import librosa
|
8 |
import numpy as np
|
@@ -16,7 +17,7 @@ _logger = logging.getLogger(__name__)
|
|
16 |
|
17 |
@dataclass
|
18 |
class RandomRIR(Effect):
|
19 |
-
rir_dir: Path
|
20 |
rir_rate: int = 44_000
|
21 |
rir_suffix: str = ".npy"
|
22 |
deterministic: bool = False
|
@@ -49,7 +50,9 @@ class RandomRIR(Effect):
|
|
49 |
|
50 |
length = len(wav)
|
51 |
|
52 |
-
wav = librosa.resample(
|
|
|
|
|
53 |
rir = self._sample_rir()
|
54 |
|
55 |
wav = signal.convolve(wav, rir, mode="same")
|
@@ -58,7 +61,9 @@ class RandomRIR(Effect):
|
|
58 |
if actlev > 0.99:
|
59 |
wav = (wav / actlev) * 0.98
|
60 |
|
61 |
-
wav = librosa.resample(
|
|
|
|
|
62 |
|
63 |
if abs(length - len(wav)) > 10:
|
64 |
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
|
|
|
3 |
from dataclasses import dataclass
|
4 |
from functools import cached_property
|
5 |
from pathlib import Path
|
6 |
+
from typing import Union
|
7 |
|
8 |
import librosa
|
9 |
import numpy as np
|
|
|
17 |
|
18 |
@dataclass
|
19 |
class RandomRIR(Effect):
|
20 |
+
rir_dir: Union[Path, None]
|
21 |
rir_rate: int = 44_000
|
22 |
rir_suffix: str = ".npy"
|
23 |
deterministic: bool = False
|
|
|
50 |
|
51 |
length = len(wav)
|
52 |
|
53 |
+
wav = librosa.resample(
|
54 |
+
wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast"
|
55 |
+
)
|
56 |
rir = self._sample_rir()
|
57 |
|
58 |
wav = signal.convolve(wav, rir, mode="same")
|
|
|
61 |
if actlev > 0.99:
|
62 |
wav = (wav / actlev) * 0.98
|
63 |
|
64 |
+
wav = librosa.resample(
|
65 |
+
wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast"
|
66 |
+
)
|
67 |
|
68 |
if abs(length - len(wav)) > 10:
|
69 |
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
|
modules/repos_static/resemble_enhance/data/distorter/sox.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
import random
|
|
|
4 |
import warnings
|
5 |
from functools import partial
|
6 |
|
@@ -29,7 +30,9 @@ class AttachableEffect(Effect):
|
|
29 |
chain = augment.EffectChain()
|
30 |
chain = self.attach(chain)
|
31 |
tensor = torch.from_numpy(wav)[None].float() # (1, T)
|
32 |
-
tensor = chain.apply(
|
|
|
|
|
33 |
wav = tensor.numpy()[0] # (T,)
|
34 |
return wav
|
35 |
|
@@ -41,7 +44,9 @@ class SoxEffect(AttachableEffect):
|
|
41 |
self.kwargs = kwargs
|
42 |
|
43 |
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
44 |
-
_logger.debug(
|
|
|
|
|
45 |
if not hasattr(chain, self.effect_name):
|
46 |
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
|
47 |
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
|
@@ -115,21 +120,30 @@ class Randint(Generator):
|
|
115 |
|
116 |
|
117 |
class Concat(Generator):
|
118 |
-
def __init__(self, *parts: Generator
|
119 |
self.parts = parts
|
120 |
|
121 |
def __call__(self):
|
122 |
-
return "".join(
|
|
|
|
|
123 |
|
124 |
|
125 |
class RandomLowpassDistorter(SoxEffect):
|
126 |
def __init__(self, low=2000, high=16000):
|
127 |
-
super().__init__(
|
|
|
|
|
128 |
|
129 |
|
130 |
class RandomBandpassDistorter(SoxEffect):
|
131 |
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
|
132 |
-
super().__init__(
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
@staticmethod
|
135 |
def _fn(low, high, min_width, max_width):
|
@@ -139,7 +153,15 @@ class RandomBandpassDistorter(SoxEffect):
|
|
139 |
|
140 |
|
141 |
class RandomEqualizer(SoxEffect):
|
142 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
super().__init__(
|
144 |
"equalizer",
|
145 |
Uniform(low, high),
|
@@ -150,7 +172,9 @@ class RandomEqualizer(SoxEffect):
|
|
150 |
|
151 |
class RandomOverdrive(SoxEffect):
|
152 |
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
|
153 |
-
super().__init__(
|
|
|
|
|
154 |
|
155 |
|
156 |
class RandomReverb(Chain):
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import random
|
4 |
+
from typing import Union
|
5 |
import warnings
|
6 |
from functools import partial
|
7 |
|
|
|
30 |
chain = augment.EffectChain()
|
31 |
chain = self.attach(chain)
|
32 |
tensor = torch.from_numpy(wav)[None].float() # (1, T)
|
33 |
+
tensor = chain.apply(
|
34 |
+
tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}
|
35 |
+
)
|
36 |
wav = tensor.numpy()[0] # (T,)
|
37 |
return wav
|
38 |
|
|
|
44 |
self.kwargs = kwargs
|
45 |
|
46 |
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
47 |
+
_logger.debug(
|
48 |
+
f"Attaching {self.effect_name} with {self.args} and {self.kwargs}"
|
49 |
+
)
|
50 |
if not hasattr(chain, self.effect_name):
|
51 |
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
|
52 |
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
|
|
|
120 |
|
121 |
|
122 |
class Concat(Generator):
|
123 |
+
def __init__(self, *parts: Union[Generator, str]):
|
124 |
self.parts = parts
|
125 |
|
126 |
def __call__(self):
|
127 |
+
return "".join(
|
128 |
+
[part if isinstance(part, str) else part() for part in self.parts]
|
129 |
+
)
|
130 |
|
131 |
|
132 |
class RandomLowpassDistorter(SoxEffect):
|
133 |
def __init__(self, low=2000, high=16000):
|
134 |
+
super().__init__(
|
135 |
+
"sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))
|
136 |
+
)
|
137 |
|
138 |
|
139 |
class RandomBandpassDistorter(SoxEffect):
|
140 |
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
|
141 |
+
super().__init__(
|
142 |
+
"sinc",
|
143 |
+
"-n",
|
144 |
+
Randint(50, 200),
|
145 |
+
partial(self._fn, low, high, min_width, max_width),
|
146 |
+
)
|
147 |
|
148 |
@staticmethod
|
149 |
def _fn(low, high, min_width, max_width):
|
|
|
153 |
|
154 |
|
155 |
class RandomEqualizer(SoxEffect):
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
low=100,
|
159 |
+
high=4000,
|
160 |
+
q_low=1,
|
161 |
+
q_high=5,
|
162 |
+
db_low: int = -30,
|
163 |
+
db_high: int = 30,
|
164 |
+
):
|
165 |
super().__init__(
|
166 |
"equalizer",
|
167 |
Uniform(low, high),
|
|
|
172 |
|
173 |
class RandomOverdrive(SoxEffect):
|
174 |
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
|
175 |
+
super().__init__(
|
176 |
+
"overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)
|
177 |
+
)
|
178 |
|
179 |
|
180 |
class RandomReverb(Chain):
|
modules/repos_static/resemble_enhance/data/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from pathlib import Path
|
2 |
-
from typing import Callable
|
3 |
|
4 |
from torch import Tensor
|
5 |
|
@@ -16,7 +16,9 @@ def rglob_audio_files(path: Path):
|
|
16 |
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
|
17 |
|
18 |
|
19 |
-
def mix_fg_bg(
|
|
|
|
|
20 |
"""
|
21 |
Args:
|
22 |
fg: (b, t)
|
|
|
1 |
from pathlib import Path
|
2 |
+
from typing import Callable, Union
|
3 |
|
4 |
from torch import Tensor
|
5 |
|
|
|
16 |
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
|
17 |
|
18 |
|
19 |
+
def mix_fg_bg(
|
20 |
+
fg: Tensor, bg: Tensor, alpha: Union[float, Callable[..., float]] = 0.5, eps=1e-7
|
21 |
+
):
|
22 |
"""
|
23 |
Args:
|
24 |
fg: (b, t)
|
modules/repos_static/resemble_enhance/denoiser/denoiser.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
@@ -154,7 +155,7 @@ class Denoiser(nn.Module):
|
|
154 |
sep_sin = sin * cos_res + cos * sin_res
|
155 |
return sep_mag, sep_cos, sep_sin
|
156 |
|
157 |
-
def forward(self, x: Tensor, y: Tensor
|
158 |
"""
|
159 |
Args:
|
160 |
x: (b t), a mixed audio
|
|
|
1 |
import logging
|
2 |
+
from typing import Union
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
155 |
sep_sin = sin * cos_res + cos * sin_res
|
156 |
return sep_mag, sep_cos, sep_sin
|
157 |
|
158 |
+
def forward(self, x: Tensor, y: Union[Tensor, None] = None):
|
159 |
"""
|
160 |
Args:
|
161 |
x: (b t), a mixed audio
|
modules/repos_static/resemble_enhance/enhancer/download.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import logging
|
2 |
from pathlib import Path
|
|
|
3 |
|
4 |
import torch
|
5 |
|
@@ -12,14 +13,18 @@ def get_source_url(relpath):
|
|
12 |
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
|
13 |
|
14 |
|
15 |
-
def get_target_path(relpath: str
|
16 |
if run_dir is None:
|
17 |
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
|
18 |
return Path(run_dir) / relpath
|
19 |
|
20 |
|
21 |
-
def download(run_dir: str
|
22 |
-
relpaths = [
|
|
|
|
|
|
|
|
|
23 |
for relpath in relpaths:
|
24 |
path = get_target_path(relpath, run_dir=run_dir)
|
25 |
if path.exists():
|
|
|
1 |
import logging
|
2 |
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
|
5 |
import torch
|
6 |
|
|
|
13 |
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
|
14 |
|
15 |
|
16 |
+
def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None):
|
17 |
if run_dir is None:
|
18 |
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
|
19 |
return Path(run_dir) / relpath
|
20 |
|
21 |
|
22 |
+
def download(run_dir: Union[str, Path, None] = None):
|
23 |
+
relpaths = [
|
24 |
+
"hparams.yaml",
|
25 |
+
"ds/G/latest",
|
26 |
+
"ds/G/default/mp_rank_00_model_states.pt",
|
27 |
+
]
|
28 |
for relpath in relpaths:
|
29 |
path = get_target_path(relpath, run_dir=run_dir)
|
30 |
if path.exists():
|
modules/repos_static/resemble_enhance/enhancer/enhancer.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
4 |
import pandas as pd
|
@@ -109,7 +110,7 @@ class Enhancer(nn.Module):
|
|
109 |
return self.mel_fn(x)[..., :-1] # (b d t)
|
110 |
return self.mel_fn(x)
|
111 |
|
112 |
-
def _may_denoise(self, x: Tensor, y: Tensor
|
113 |
if self.hp.lcfm_training_mode == "cfm":
|
114 |
return self.denoiser(x, y)
|
115 |
return x
|
@@ -126,7 +127,9 @@ class Enhancer(nn.Module):
|
|
126 |
self.lcfm.eval_tau_(tau)
|
127 |
self._eval_lambd = lambd
|
128 |
|
129 |
-
def forward(
|
|
|
|
|
130 |
"""
|
131 |
Args:
|
132 |
x: (b t), mix wavs (fg + bg)
|
|
|
1 |
import logging
|
2 |
+
from typing import Union
|
3 |
|
4 |
import matplotlib.pyplot as plt
|
5 |
import pandas as pd
|
|
|
110 |
return self.mel_fn(x)[..., :-1] # (b d t)
|
111 |
return self.mel_fn(x)
|
112 |
|
113 |
+
def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None):
|
114 |
if self.hp.lcfm_training_mode == "cfm":
|
115 |
return self.denoiser(x, y)
|
116 |
return x
|
|
|
127 |
self.lcfm.eval_tau_(tau)
|
128 |
self._eval_lambd = lambd
|
129 |
|
130 |
+
def forward(
|
131 |
+
self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None
|
132 |
+
):
|
133 |
"""
|
134 |
Args:
|
135 |
x: (b t), mix wavs (fg + bg)
|
modules/repos_static/resemble_enhance/enhancer/hparams.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from pathlib import Path
|
|
|
3 |
|
4 |
from ..hparams import HParams as HParamsBase
|
5 |
|
@@ -17,7 +18,7 @@ class HParams(HParamsBase):
|
|
17 |
|
18 |
vocoder_extra_dim: int = 32
|
19 |
|
20 |
-
gan_training_start_step: int
|
21 |
-
enhancer_stage1_run_dir: Path
|
22 |
|
23 |
-
denoiser_run_dir: Path
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
|
5 |
from ..hparams import HParams as HParamsBase
|
6 |
|
|
|
18 |
|
19 |
vocoder_extra_dim: int = 32
|
20 |
|
21 |
+
gan_training_start_step: Union[int, None] = 5_000
|
22 |
+
enhancer_stage1_run_dir: Union[Path, None] = None
|
23 |
|
24 |
+
denoiser_run_dir: Union[Path, None] = None
|
modules/repos_static/resemble_enhance/enhancer/inference.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
from functools import cache
|
3 |
from pathlib import Path
|
|
|
4 |
|
5 |
import torch
|
6 |
|
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|
13 |
|
14 |
|
15 |
@cache
|
16 |
-
def load_enhancer(run_dir: str
|
17 |
run_dir = download(run_dir)
|
18 |
hp = HParams.load(run_dir)
|
19 |
enhancer = Enhancer(hp)
|
|
|
1 |
import logging
|
2 |
from functools import cache
|
3 |
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
|
6 |
import torch
|
7 |
|
|
|
14 |
|
15 |
|
16 |
@cache
|
17 |
+
def load_enhancer(run_dir: Union[str, Path, None], device):
|
18 |
run_dir = download(run_dir)
|
19 |
hp = HParams.load(run_dir)
|
20 |
enhancer = Enhancer(hp)
|
modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import logging
|
2 |
from dataclasses import dataclass
|
3 |
from functools import partial
|
4 |
-
from typing import Protocol
|
5 |
|
6 |
import matplotlib.pyplot as plt
|
7 |
import numpy as np
|
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
|
|
17 |
|
18 |
|
19 |
class VelocityField(Protocol):
|
20 |
-
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
|
21 |
-
...
|
22 |
|
23 |
|
24 |
class Solver:
|
@@ -40,7 +39,9 @@ class Solver:
|
|
40 |
|
41 |
self._camera = None
|
42 |
self._mel_fn = mel_fn
|
43 |
-
self._time_mapping = partial(
|
|
|
|
|
44 |
|
45 |
def configurate_(self, nfe=None, method=None):
|
46 |
if nfe is None:
|
@@ -50,7 +51,9 @@ class Solver:
|
|
50 |
method = self.method
|
51 |
|
52 |
if nfe == 1 and method in ("midpoint", "rk4"):
|
53 |
-
logger.warning(
|
|
|
|
|
54 |
method = "euler"
|
55 |
|
56 |
self.nfe = nfe
|
@@ -105,7 +108,9 @@ class Solver:
|
|
105 |
)
|
106 |
else:
|
107 |
# Spectrogram, b c t
|
108 |
-
plt.imshow(
|
|
|
|
|
109 |
ax = plt.gca()
|
110 |
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
|
111 |
camera.snap()
|
@@ -271,7 +276,7 @@ class CFM(nn.Module):
|
|
271 |
global_dim=self.time_emb_dim,
|
272 |
)
|
273 |
|
274 |
-
def _perturb(self, ψ1: Tensor, t: Tensor
|
275 |
"""
|
276 |
Perturb ψ1 to ψt.
|
277 |
"""
|
@@ -311,7 +316,7 @@ class CFM(nn.Module):
|
|
311 |
"""
|
312 |
return ψ1 - ψ0
|
313 |
|
314 |
-
def _to_v(self, *, ψt, x, t: float
|
315 |
"""
|
316 |
Args:
|
317 |
ψt: (b c t)
|
@@ -364,7 +369,13 @@ class CFM(nn.Module):
|
|
364 |
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
|
365 |
return ψ1
|
366 |
|
367 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
if y is None:
|
369 |
y = self.sample(x, ψ0=ψ0, t0=t0)
|
370 |
else:
|
|
|
1 |
import logging
|
2 |
from dataclasses import dataclass
|
3 |
from functools import partial
|
4 |
+
from typing import Protocol, Union
|
5 |
|
6 |
import matplotlib.pyplot as plt
|
7 |
import numpy as np
|
|
|
17 |
|
18 |
|
19 |
class VelocityField(Protocol):
|
20 |
+
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ...
|
|
|
21 |
|
22 |
|
23 |
class Solver:
|
|
|
39 |
|
40 |
self._camera = None
|
41 |
self._mel_fn = mel_fn
|
42 |
+
self._time_mapping = partial(
|
43 |
+
self.exponential_decay_mapping, n=time_mapping_divisor
|
44 |
+
)
|
45 |
|
46 |
def configurate_(self, nfe=None, method=None):
|
47 |
if nfe is None:
|
|
|
51 |
method = self.method
|
52 |
|
53 |
if nfe == 1 and method in ("midpoint", "rk4"):
|
54 |
+
logger.warning(
|
55 |
+
f"1 NFE is not supported for {method}, using euler method instead."
|
56 |
+
)
|
57 |
method = "euler"
|
58 |
|
59 |
self.nfe = nfe
|
|
|
108 |
)
|
109 |
else:
|
110 |
# Spectrogram, b c t
|
111 |
+
plt.imshow(
|
112 |
+
ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none"
|
113 |
+
)
|
114 |
ax = plt.gca()
|
115 |
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
|
116 |
camera.snap()
|
|
|
276 |
global_dim=self.time_emb_dim,
|
277 |
)
|
278 |
|
279 |
+
def _perturb(self, ψ1: Tensor, t: Union[Tensor, None] = None):
|
280 |
"""
|
281 |
Perturb ψ1 to ψt.
|
282 |
"""
|
|
|
316 |
"""
|
317 |
return ψ1 - ψ0
|
318 |
|
319 |
+
def _to_v(self, *, ψt, x, t: Union[float, Tensor]):
|
320 |
"""
|
321 |
Args:
|
322 |
ψt: (b c t)
|
|
|
369 |
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
|
370 |
return ψ1
|
371 |
|
372 |
+
def forward(
|
373 |
+
self,
|
374 |
+
x: Tensor,
|
375 |
+
y: Union[Tensor, None] = None,
|
376 |
+
ψ0: Union[Tensor, None] = None,
|
377 |
+
t0=0.0,
|
378 |
+
):
|
379 |
if y is None:
|
380 |
y = self.sample(x, ψ0=ψ0, t0=t0)
|
381 |
else:
|
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import logging
|
2 |
from dataclasses import dataclass
|
|
|
3 |
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
@@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|
14 |
@dataclass
|
15 |
class IRMAEOutput:
|
16 |
latent: Tensor # latent vector
|
17 |
-
decoded: Tensor
|
18 |
|
19 |
|
20 |
class ResBlock(nn.Sequential):
|
|
|
1 |
import logging
|
2 |
from dataclasses import dataclass
|
3 |
+
from typing import Union
|
4 |
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
|
|
15 |
@dataclass
|
16 |
class IRMAEOutput:
|
17 |
latent: Tensor # latent vector
|
18 |
+
decoded: Union[Tensor, None] # decoder output, include extra dim
|
19 |
|
20 |
|
21 |
class ResBlock(nn.Sequential):
|
modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import logging
|
2 |
from enum import Enum
|
|
|
3 |
|
4 |
import matplotlib.pyplot as plt
|
5 |
import torch
|
@@ -70,19 +71,34 @@ class LCFM(nn.Module):
|
|
70 |
return
|
71 |
|
72 |
plt.subplot(221)
|
73 |
-
plt.imshow(
|
|
|
|
|
|
|
|
|
|
|
74 |
plt.title("GT")
|
75 |
|
76 |
plt.subplot(222)
|
77 |
y_ = y_[:, : y.shape[1]]
|
78 |
-
plt.imshow(
|
|
|
|
|
|
|
|
|
|
|
79 |
plt.title("Posterior")
|
80 |
|
81 |
plt.subplot(223)
|
82 |
z_ = self.cfm(x)
|
83 |
y__ = self.ae.decode(z_)
|
84 |
y__ = y__[:, : y.shape[1]]
|
85 |
-
plt.imshow(
|
|
|
|
|
|
|
|
|
|
|
86 |
plt.title("C-Prior")
|
87 |
del y__
|
88 |
|
@@ -90,7 +106,12 @@ class LCFM(nn.Module):
|
|
90 |
z_ = torch.randn_like(z_)
|
91 |
y__ = self.ae.decode(z_)
|
92 |
y__ = y__[:, : y.shape[1]]
|
93 |
-
plt.imshow(
|
|
|
|
|
|
|
|
|
|
|
94 |
plt.title("Prior")
|
95 |
del z_, y__
|
96 |
|
@@ -109,7 +130,7 @@ class LCFM(nn.Module):
|
|
109 |
def eval_tau_(self, tau):
|
110 |
self._eval_tau = tau
|
111 |
|
112 |
-
def forward(self, x, y: Tensor
|
113 |
"""
|
114 |
Args:
|
115 |
x: (b d t), condition mel
|
@@ -139,14 +160,20 @@ class LCFM(nn.Module):
|
|
139 |
|
140 |
h = self.ae.decode(z)
|
141 |
else:
|
142 |
-
ae_output: IRMAEOutput = self.ae(
|
|
|
|
|
143 |
|
144 |
if self.mode == self.Mode.CFM:
|
145 |
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
|
146 |
|
147 |
h = ae_output.decoded
|
148 |
|
149 |
-
if
|
|
|
|
|
|
|
|
|
150 |
self._visualize(x[:1], y[:1], h[:1])
|
151 |
|
152 |
return h
|
|
|
1 |
import logging
|
2 |
from enum import Enum
|
3 |
+
from typing import Union
|
4 |
|
5 |
import matplotlib.pyplot as plt
|
6 |
import torch
|
|
|
71 |
return
|
72 |
|
73 |
plt.subplot(221)
|
74 |
+
plt.imshow(
|
75 |
+
y[0].detach().cpu().numpy(),
|
76 |
+
aspect="auto",
|
77 |
+
origin="lower",
|
78 |
+
interpolation="none",
|
79 |
+
)
|
80 |
plt.title("GT")
|
81 |
|
82 |
plt.subplot(222)
|
83 |
y_ = y_[:, : y.shape[1]]
|
84 |
+
plt.imshow(
|
85 |
+
y_[0].detach().cpu().numpy(),
|
86 |
+
aspect="auto",
|
87 |
+
origin="lower",
|
88 |
+
interpolation="none",
|
89 |
+
)
|
90 |
plt.title("Posterior")
|
91 |
|
92 |
plt.subplot(223)
|
93 |
z_ = self.cfm(x)
|
94 |
y__ = self.ae.decode(z_)
|
95 |
y__ = y__[:, : y.shape[1]]
|
96 |
+
plt.imshow(
|
97 |
+
y__[0].detach().cpu().numpy(),
|
98 |
+
aspect="auto",
|
99 |
+
origin="lower",
|
100 |
+
interpolation="none",
|
101 |
+
)
|
102 |
plt.title("C-Prior")
|
103 |
del y__
|
104 |
|
|
|
106 |
z_ = torch.randn_like(z_)
|
107 |
y__ = self.ae.decode(z_)
|
108 |
y__ = y__[:, : y.shape[1]]
|
109 |
+
plt.imshow(
|
110 |
+
y__[0].detach().cpu().numpy(),
|
111 |
+
aspect="auto",
|
112 |
+
origin="lower",
|
113 |
+
interpolation="none",
|
114 |
+
)
|
115 |
plt.title("Prior")
|
116 |
del z_, y__
|
117 |
|
|
|
130 |
def eval_tau_(self, tau):
|
131 |
self._eval_tau = tau
|
132 |
|
133 |
+
def forward(self, x, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None):
|
134 |
"""
|
135 |
Args:
|
136 |
x: (b d t), condition mel
|
|
|
160 |
|
161 |
h = self.ae.decode(z)
|
162 |
else:
|
163 |
+
ae_output: IRMAEOutput = self.ae(
|
164 |
+
y, skip_decoding=self.mode == self.Mode.CFM
|
165 |
+
)
|
166 |
|
167 |
if self.mode == self.Mode.CFM:
|
168 |
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
|
169 |
|
170 |
h = ae_output.decoded
|
171 |
|
172 |
+
if (
|
173 |
+
h is not None
|
174 |
+
and self.global_step is not None
|
175 |
+
and self.global_step % 100 == 0
|
176 |
+
):
|
177 |
self._visualize(x[:1], y[:1], h[:1])
|
178 |
|
179 |
return h
|
modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
@@ -50,7 +51,9 @@ class UnivNet(nn.Module):
|
|
50 |
]
|
51 |
)
|
52 |
|
53 |
-
self.conv_pre = weight_norm(
|
|
|
|
|
54 |
|
55 |
self.conv_post = nn.Sequential(
|
56 |
nn.LeakyReLU(0.2),
|
@@ -64,7 +67,7 @@ class UnivNet(nn.Module):
|
|
64 |
def eps(self):
|
65 |
return 1e-5
|
66 |
|
67 |
-
def forward(self, x: Tensor, y: Tensor
|
68 |
"""
|
69 |
Args:
|
70 |
x: (b c t), acoustic features
|
@@ -74,7 +77,9 @@ class UnivNet(nn.Module):
|
|
74 |
"""
|
75 |
assert x.ndim == 3, "x must be 3D tensor"
|
76 |
assert y is None or y.ndim == 2, "y must be 2D tensor"
|
77 |
-
assert
|
|
|
|
|
78 |
assert npad >= 0, "npad must be positive or zero"
|
79 |
|
80 |
x = F.pad(x, (0, npad), "constant", 0)
|
|
|
1 |
+
from typing import Union
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
51 |
]
|
52 |
)
|
53 |
|
54 |
+
self.conv_pre = weight_norm(
|
55 |
+
nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")
|
56 |
+
)
|
57 |
|
58 |
self.conv_post = nn.Sequential(
|
59 |
nn.LeakyReLU(0.2),
|
|
|
67 |
def eps(self):
|
68 |
return 1e-5
|
69 |
|
70 |
+
def forward(self, x: Tensor, y: Union[Tensor, None] = None, npad=10):
|
71 |
"""
|
72 |
Args:
|
73 |
x: (b c t), acoustic features
|
|
|
77 |
"""
|
78 |
assert x.ndim == 3, "x must be 3D tensor"
|
79 |
assert y is None or y.ndim == 2, "y must be 2D tensor"
|
80 |
+
assert (
|
81 |
+
x.shape[1] == self.d_input
|
82 |
+
), f"x.shape[1] must be {self.d_input}, but got {x.shape}"
|
83 |
assert npad >= 0, "npad must be positive or zero"
|
84 |
|
85 |
x = F.pad(x, (0, npad), "constant", 0)
|
modules/repos_static/resemble_enhance/hparams.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
from dataclasses import asdict, dataclass
|
3 |
from pathlib import Path
|
|
|
4 |
|
5 |
from omegaconf import OmegaConf
|
6 |
from rich.console import Console
|
@@ -102,7 +103,7 @@ class HParams:
|
|
102 |
OmegaConf.save(asdict(self), str(path))
|
103 |
|
104 |
@classmethod
|
105 |
-
def load(cls, run_dir, yaml: Path
|
106 |
hps = []
|
107 |
|
108 |
if (run_dir / "hparams.yaml").exists():
|
@@ -120,7 +121,9 @@ class HParams:
|
|
120 |
for k, v in asdict(hp).items():
|
121 |
if getattr(hps[0], k) != v:
|
122 |
errors[k] = f"{getattr(hps[0], k)} != {v}"
|
123 |
-
raise ValueError(
|
|
|
|
|
124 |
|
125 |
return hps[0]
|
126 |
|
|
|
1 |
import logging
|
2 |
from dataclasses import asdict, dataclass
|
3 |
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
|
6 |
from omegaconf import OmegaConf
|
7 |
from rich.console import Console
|
|
|
103 |
OmegaConf.save(asdict(self), str(path))
|
104 |
|
105 |
@classmethod
|
106 |
+
def load(cls, run_dir, yaml: Union[Path, None] = None):
|
107 |
hps = []
|
108 |
|
109 |
if (run_dir / "hparams.yaml").exists():
|
|
|
121 |
for k, v in asdict(hp).items():
|
122 |
if getattr(hps[0], k) != v:
|
123 |
errors[k] = f"{getattr(hps[0], k)} != {v}"
|
124 |
+
raise ValueError(
|
125 |
+
f"Found inconsistent hparams: {errors}, consider deleting {run_dir}"
|
126 |
+
)
|
127 |
|
128 |
return hps[0]
|
129 |
|
modules/speaker.py
CHANGED
@@ -29,13 +29,15 @@ class Speaker:
|
|
29 |
speaker.emb = tensor
|
30 |
return speaker
|
31 |
|
32 |
-
def __init__(
|
|
|
|
|
33 |
self.id = uuid.uuid4()
|
34 |
-
self.seed =
|
35 |
self.name = name
|
36 |
self.gender = gender
|
37 |
self.describe = describe
|
38 |
-
self.emb = None
|
39 |
|
40 |
# TODO replace emb => tokens
|
41 |
self.tokens = []
|
|
|
29 |
speaker.emb = tensor
|
30 |
return speaker
|
31 |
|
32 |
+
def __init__(
|
33 |
+
self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
|
34 |
+
):
|
35 |
self.id = uuid.uuid4()
|
36 |
+
self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor
|
37 |
self.name = name
|
38 |
self.gender = gender
|
39 |
self.describe = describe
|
40 |
+
self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor
|
41 |
|
42 |
# TODO replace emb => tokens
|
43 |
self.tokens = []
|
modules/ssml_parser/SSMLParser.py
CHANGED
@@ -11,8 +11,8 @@ import copy
|
|
11 |
|
12 |
|
13 |
class SSMLContext(Box):
|
14 |
-
def __init__(self,
|
15 |
-
self.parent: Union[SSMLContext, None] =
|
16 |
|
17 |
self.style = None
|
18 |
self.spk = None
|
@@ -29,18 +29,14 @@ class SSMLContext(Box):
|
|
29 |
self.prompt2 = None
|
30 |
self.prefix = None
|
31 |
|
32 |
-
|
33 |
-
ctx = SSMLContext()
|
34 |
-
for k, v in self.items():
|
35 |
-
ctx[k] = v
|
36 |
-
return ctx
|
37 |
|
38 |
|
39 |
class SSMLSegment(Box):
|
40 |
-
def __init__(self, text: str, attrs=SSMLContext()):
|
41 |
-
self.attrs = attrs
|
42 |
self.text = text
|
43 |
-
self.params =
|
44 |
|
45 |
|
46 |
class SSMLBreak:
|
@@ -68,7 +64,7 @@ class SSMLParser:
|
|
68 |
root = etree.fromstring(ssml)
|
69 |
|
70 |
root_ctx = SSMLContext()
|
71 |
-
segments = []
|
72 |
self.resolve(root, root_ctx, segments)
|
73 |
|
74 |
return segments
|
@@ -89,8 +85,13 @@ def create_ssml_parser():
|
|
89 |
parser = SSMLParser()
|
90 |
|
91 |
@parser.resolver("speak")
|
92 |
-
def tag_speak(
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
version = element.get("version")
|
96 |
if version != "0.1":
|
@@ -100,8 +101,13 @@ def create_ssml_parser():
|
|
100 |
parser.resolve(child, ctx, segments)
|
101 |
|
102 |
@parser.resolver("voice")
|
103 |
-
def tag_voice(
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
ctx.spk = element.get("spk", ctx.spk)
|
107 |
ctx.style = element.get("style", ctx.style)
|
@@ -131,13 +137,23 @@ def create_ssml_parser():
|
|
131 |
segments.append(SSMLSegment(child.tail.strip(), ctx))
|
132 |
|
133 |
@parser.resolver("break")
|
134 |
-
def tag_break(
|
|
|
|
|
|
|
|
|
|
|
135 |
time_ms = int(element.get("time", "0").replace("ms", ""))
|
136 |
segments.append(SSMLBreak(time_ms))
|
137 |
|
138 |
@parser.resolver("prosody")
|
139 |
-
def tag_prosody(
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
ctx.spk = element.get("spk", ctx.spk)
|
143 |
ctx.style = element.get("style", ctx.style)
|
|
|
11 |
|
12 |
|
13 |
class SSMLContext(Box):
|
14 |
+
def __init__(self, *args, **kwargs):
|
15 |
+
self.parent: Union[SSMLContext, None] = None
|
16 |
|
17 |
self.style = None
|
18 |
self.spk = None
|
|
|
29 |
self.prompt2 = None
|
30 |
self.prefix = None
|
31 |
|
32 |
+
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
class SSMLSegment(Box):
|
36 |
+
def __init__(self, text: str, attrs=SSMLContext(), params=None):
|
37 |
+
self.attrs = SSMLContext(**attrs)
|
38 |
self.text = text
|
39 |
+
self.params = params
|
40 |
|
41 |
|
42 |
class SSMLBreak:
|
|
|
64 |
root = etree.fromstring(ssml)
|
65 |
|
66 |
root_ctx = SSMLContext()
|
67 |
+
segments: List[Union[SSMLSegment, SSMLBreak]] = []
|
68 |
self.resolve(root, root_ctx, segments)
|
69 |
|
70 |
return segments
|
|
|
85 |
parser = SSMLParser()
|
86 |
|
87 |
@parser.resolver("speak")
|
88 |
+
def tag_speak(
|
89 |
+
element: etree.Element,
|
90 |
+
context: Box,
|
91 |
+
segments: List[Union[SSMLSegment, SSMLBreak]],
|
92 |
+
parser: SSMLParser,
|
93 |
+
):
|
94 |
+
ctx = context.copy() if context is not None else SSMLContext()
|
95 |
|
96 |
version = element.get("version")
|
97 |
if version != "0.1":
|
|
|
101 |
parser.resolve(child, ctx, segments)
|
102 |
|
103 |
@parser.resolver("voice")
|
104 |
+
def tag_voice(
|
105 |
+
element: etree.Element,
|
106 |
+
context: Box,
|
107 |
+
segments: List[Union[SSMLSegment, SSMLBreak]],
|
108 |
+
parser: SSMLParser,
|
109 |
+
):
|
110 |
+
ctx = context.copy() if context is not None else SSMLContext()
|
111 |
|
112 |
ctx.spk = element.get("spk", ctx.spk)
|
113 |
ctx.style = element.get("style", ctx.style)
|
|
|
137 |
segments.append(SSMLSegment(child.tail.strip(), ctx))
|
138 |
|
139 |
@parser.resolver("break")
|
140 |
+
def tag_break(
|
141 |
+
element: etree.Element,
|
142 |
+
context: Box,
|
143 |
+
segments: List[Union[SSMLSegment, SSMLBreak]],
|
144 |
+
parser: SSMLParser,
|
145 |
+
):
|
146 |
time_ms = int(element.get("time", "0").replace("ms", ""))
|
147 |
segments.append(SSMLBreak(time_ms))
|
148 |
|
149 |
@parser.resolver("prosody")
|
150 |
+
def tag_prosody(
|
151 |
+
element: etree.Element,
|
152 |
+
context: Box,
|
153 |
+
segments: List[Union[SSMLSegment, SSMLBreak]],
|
154 |
+
parser: SSMLParser,
|
155 |
+
):
|
156 |
+
ctx = context.copy() if context is not None else SSMLContext()
|
157 |
|
158 |
ctx.spk = element.get("spk", ctx.spk)
|
159 |
ctx.style = element.get("style", ctx.style)
|
modules/synthesize_audio.py
CHANGED
@@ -7,6 +7,7 @@ from modules import generate_audio as generate
|
|
7 |
|
8 |
|
9 |
from modules.speaker import Speaker
|
|
|
10 |
from modules.utils import audio
|
11 |
|
12 |
|
@@ -23,45 +24,33 @@ def synthesize_audio(
|
|
23 |
prefix: str = "",
|
24 |
batch_size: int = 1,
|
25 |
spliter_threshold: int = 100,
|
|
|
26 |
):
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
)
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
{
|
46 |
-
"text": s,
|
47 |
-
"params": {
|
48 |
-
"text": s,
|
49 |
-
"temperature": temperature,
|
50 |
-
"top_P": top_P,
|
51 |
-
"top_K": top_K,
|
52 |
-
"spk": spk,
|
53 |
-
"infer_seed": infer_seed,
|
54 |
-
"use_decoder": use_decoder,
|
55 |
-
"prompt1": prompt1,
|
56 |
-
"prompt2": prompt2,
|
57 |
-
"prefix": prefix,
|
58 |
-
},
|
59 |
-
}
|
60 |
-
for s in sentences
|
61 |
-
]
|
62 |
-
synthesizer = SynthesizeSegments(batch_size)
|
63 |
-
audio_segments = synthesizer.synthesize_segments(text_segments)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
return audio.pydub_to_np(combined_audio)
|
|
|
7 |
|
8 |
|
9 |
from modules.speaker import Speaker
|
10 |
+
from modules.ssml_parser.SSMLParser import SSMLSegment
|
11 |
from modules.utils import audio
|
12 |
|
13 |
|
|
|
24 |
prefix: str = "",
|
25 |
batch_size: int = 1,
|
26 |
spliter_threshold: int = 100,
|
27 |
+
end_of_sentence="",
|
28 |
):
|
29 |
+
spliter = SentenceSplitter(spliter_threshold)
|
30 |
+
sentences = spliter.parse(text)
|
31 |
+
|
32 |
+
text_segments = [
|
33 |
+
SSMLSegment(
|
34 |
+
text=s,
|
35 |
+
params={
|
36 |
+
"temperature": temperature,
|
37 |
+
"top_P": top_P,
|
38 |
+
"top_K": top_K,
|
39 |
+
"spk": spk,
|
40 |
+
"infer_seed": infer_seed,
|
41 |
+
"use_decoder": use_decoder,
|
42 |
+
"prompt1": prompt1,
|
43 |
+
"prompt2": prompt2,
|
44 |
+
"prefix": prefix,
|
45 |
+
},
|
46 |
)
|
47 |
+
for s in sentences
|
48 |
+
]
|
49 |
+
synthesizer = SynthesizeSegments(
|
50 |
+
batch_size=batch_size, eos=end_of_sentence, spliter_thr=spliter_threshold
|
51 |
+
)
|
52 |
+
audio_segments = synthesizer.synthesize_segments(text_segments)
|
53 |
|
54 |
+
combined_audio = combine_audio_segments(audio_segments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
return audio.pydub_to_np(combined_audio)
|
|
|
|
modules/utils/audio.py
CHANGED
@@ -95,7 +95,11 @@ def pitch_shift(
|
|
95 |
|
96 |
|
97 |
def apply_prosody_to_audio_data(
|
98 |
-
audio_data: np.ndarray,
|
|
|
|
|
|
|
|
|
99 |
) -> np.ndarray:
|
100 |
if rate != 1:
|
101 |
audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate)
|
|
|
95 |
|
96 |
|
97 |
def apply_prosody_to_audio_data(
|
98 |
+
audio_data: np.ndarray,
|
99 |
+
rate: float = 1,
|
100 |
+
volume: float = 0,
|
101 |
+
pitch: float = 0,
|
102 |
+
sr: int = 24000,
|
103 |
) -> np.ndarray:
|
104 |
if rate != 1:
|
105 |
audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate)
|
modules/webui/app.py
CHANGED
@@ -7,6 +7,7 @@ from modules import config
|
|
7 |
from modules.webui import gradio_extensions, webui_config
|
8 |
|
9 |
from modules.webui.changelog_tab import create_changelog_tab
|
|
|
10 |
from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
|
11 |
from modules.webui.ssml.podcast_tab import create_ssml_podcast_tab
|
12 |
from modules.webui.system_tab import create_system_tab
|
@@ -118,6 +119,8 @@ def create_interface():
|
|
118 |
gr.Markdown("🚧 Under construction")
|
119 |
with gr.TabItem("ASR", visible=webui_config.experimental):
|
120 |
gr.Markdown("🚧 Under construction")
|
|
|
|
|
121 |
|
122 |
with gr.TabItem("System"):
|
123 |
create_system_tab()
|
|
|
7 |
from modules.webui import gradio_extensions, webui_config
|
8 |
|
9 |
from modules.webui.changelog_tab import create_changelog_tab
|
10 |
+
from modules.webui.finetune.ft_tab import create_ft_tabs
|
11 |
from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
|
12 |
from modules.webui.ssml.podcast_tab import create_ssml_podcast_tab
|
13 |
from modules.webui.system_tab import create_system_tab
|
|
|
119 |
gr.Markdown("🚧 Under construction")
|
120 |
with gr.TabItem("ASR", visible=webui_config.experimental):
|
121 |
gr.Markdown("🚧 Under construction")
|
122 |
+
with gr.TabItem("Finetune", visible=webui_config.experimental):
|
123 |
+
create_ft_tabs(demo)
|
124 |
|
125 |
with gr.TabItem("System"):
|
126 |
create_system_tab()
|
modules/webui/finetune/ProcessMonitor.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import subprocess
|
4 |
+
import threading
|
5 |
+
|
6 |
+
|
7 |
+
class ProcessMonitor:
|
8 |
+
def __init__(self):
|
9 |
+
self.process = None
|
10 |
+
self.stdout = ""
|
11 |
+
self.stderr = ""
|
12 |
+
self.lock = threading.Lock()
|
13 |
+
|
14 |
+
def start_process(self, command):
|
15 |
+
self.process = subprocess.Popen(
|
16 |
+
command,
|
17 |
+
stdout=subprocess.PIPE,
|
18 |
+
stderr=subprocess.PIPE,
|
19 |
+
bufsize=1,
|
20 |
+
universal_newlines=True,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Set pipes to non-blocking mode
|
24 |
+
fd_out = self.process.stdout.fileno()
|
25 |
+
fd_err = self.process.stderr.fileno()
|
26 |
+
|
27 |
+
if sys.platform != "win32":
|
28 |
+
import fcntl
|
29 |
+
|
30 |
+
fl_out = fcntl.fcntl(fd_out, fcntl.F_GETFL)
|
31 |
+
fl_err = fcntl.fcntl(fd_err, fcntl.F_GETFL)
|
32 |
+
fcntl.fcntl(fd_out, fcntl.F_SETFL, fl_out | os.O_NONBLOCK)
|
33 |
+
fcntl.fcntl(fd_err, fcntl.F_SETFL, fl_err | os.O_NONBLOCK)
|
34 |
+
|
35 |
+
# Start threads to read stdout and stderr
|
36 |
+
threading.Thread(target=self._read_stdout).start()
|
37 |
+
threading.Thread(target=self._read_stderr).start()
|
38 |
+
|
39 |
+
def _read_stdout(self):
|
40 |
+
while self.process is not None and self.process.poll() is None:
|
41 |
+
try:
|
42 |
+
output = self.process.stdout.read()
|
43 |
+
if output:
|
44 |
+
with self.lock:
|
45 |
+
self.stdout += output
|
46 |
+
except:
|
47 |
+
pass
|
48 |
+
|
49 |
+
def _read_stderr(self):
|
50 |
+
while self.process is not None and self.process.poll() is None:
|
51 |
+
try:
|
52 |
+
error = self.process.stderr.read()
|
53 |
+
if error:
|
54 |
+
with self.lock:
|
55 |
+
self.stderr += error
|
56 |
+
except:
|
57 |
+
pass
|
58 |
+
|
59 |
+
def get_output(self):
|
60 |
+
with self.lock:
|
61 |
+
return self.stdout, self.stderr
|
62 |
+
|
63 |
+
def stop_process(self):
|
64 |
+
if self.process:
|
65 |
+
self.process.terminate()
|
66 |
+
self.process = None
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
import time
|
71 |
+
|
72 |
+
pm = ProcessMonitor()
|
73 |
+
pm.start_process(
|
74 |
+
[
|
75 |
+
"python",
|
76 |
+
"-u",
|
77 |
+
"-c",
|
78 |
+
"import time; [print(i) or time.sleep(1) for i in range(5)]",
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
while pm.process and pm.process.poll() is None:
|
83 |
+
stdout, stderr = pm.get_output()
|
84 |
+
if stdout:
|
85 |
+
print("STDOUT:", stdout)
|
86 |
+
if stderr:
|
87 |
+
print("STDERR:", stderr)
|
88 |
+
time.sleep(1)
|
89 |
+
|
90 |
+
stdout, stderr = pm.get_output()
|
91 |
+
print("Final STDOUT:", stdout)
|
92 |
+
print("Final STDERR:", stderr)
|
modules/webui/finetune/ft_tab.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules.webui.finetune.speaker_ft_tab import create_speaker_ft_tab
|
4 |
+
|
5 |
+
|
6 |
+
def create_ft_tabs(demo):
|
7 |
+
with gr.Tabs():
|
8 |
+
with gr.TabItem("Speaker"):
|
9 |
+
create_speaker_ft_tab(demo)
|
10 |
+
with gr.TabItem("GPT"):
|
11 |
+
gr.Markdown("🚧 Under construction")
|
12 |
+
with gr.TabItem("AE"):
|
13 |
+
gr.Markdown("🚧 Under construction")
|
modules/webui/finetune/ft_ui_utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import IO, Union
|
3 |
+
from modules.speaker import Speaker, speaker_mgr
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
|
7 |
+
def get_datasets_dir():
|
8 |
+
"""
|
9 |
+
列出 ./datasets/data_* 文件夹
|
10 |
+
"""
|
11 |
+
dataset_path = "./datasets"
|
12 |
+
dataset_list = os.listdir(dataset_path)
|
13 |
+
dataset_list = [
|
14 |
+
d for d in dataset_list if os.path.isdir(os.path.join(dataset_path, d))
|
15 |
+
]
|
16 |
+
dataset_list = [d for d in dataset_list if d.startswith("data_")]
|
17 |
+
return dataset_list
|
18 |
+
|
19 |
+
|
20 |
+
def get_datasets_listfile():
|
21 |
+
datasets = get_datasets_dir()
|
22 |
+
listfiles = []
|
23 |
+
for d in datasets:
|
24 |
+
dir_path = os.path.join("./datasets", d)
|
25 |
+
files = os.listdir(dir_path)
|
26 |
+
for f in files:
|
27 |
+
if f.endswith(".list"):
|
28 |
+
listfiles.append(os.path.join(dir_path, f))
|
29 |
+
return listfiles
|
30 |
+
|
31 |
+
|
32 |
+
def run_speaker_ft(
|
33 |
+
batch_size: int, epochs: int, train_text: bool, data_path: str, init_speaker: str
|
34 |
+
):
|
35 |
+
command = ["python3", "-m", "modules.finetune.train_speaker"]
|
36 |
+
command += [
|
37 |
+
f"--batch_size={batch_size}",
|
38 |
+
f"--epochs={epochs}",
|
39 |
+
f"--data_path={data_path}",
|
40 |
+
]
|
41 |
+
if train_text:
|
42 |
+
command.append("--train_text")
|
43 |
+
if init_speaker:
|
44 |
+
command.append(f"--init_speaker={init_speaker}")
|
45 |
+
process = subprocess.Popen(
|
46 |
+
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1
|
47 |
+
)
|
48 |
+
|
49 |
+
return process
|
modules/webui/finetune/speaker_ft_tab.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules.Enhancer.ResembleEnhance import unload_enhancer
|
4 |
+
from modules.webui import webui_config
|
5 |
+
from modules.webui.webui_utils import get_speaker_names
|
6 |
+
from .ft_ui_utils import get_datasets_listfile, run_speaker_ft
|
7 |
+
from .ProcessMonitor import ProcessMonitor
|
8 |
+
from modules.speaker import speaker_mgr
|
9 |
+
from modules.models import unload_chat_tts
|
10 |
+
|
11 |
+
|
12 |
+
class SpeakerFt:
|
13 |
+
def __init__(self):
|
14 |
+
self.process_monitor = ProcessMonitor()
|
15 |
+
self.status_str = "idle"
|
16 |
+
|
17 |
+
def unload_main_thread_models(self):
|
18 |
+
unload_chat_tts()
|
19 |
+
unload_enhancer()
|
20 |
+
|
21 |
+
def run(
|
22 |
+
self,
|
23 |
+
batch_size: int,
|
24 |
+
epochs: int,
|
25 |
+
lr: str,
|
26 |
+
train_text: bool,
|
27 |
+
data_path: str,
|
28 |
+
select_speaker: str = "",
|
29 |
+
):
|
30 |
+
if self.process_monitor.process:
|
31 |
+
return
|
32 |
+
self.unload_main_thread_models()
|
33 |
+
spk_path = None
|
34 |
+
if select_speaker != "" and select_speaker != "none":
|
35 |
+
select_speaker = select_speaker.split(" : ")[1].strip()
|
36 |
+
spk = speaker_mgr.get_speaker(select_speaker)
|
37 |
+
if spk is None:
|
38 |
+
return ["Speaker not found"]
|
39 |
+
spk_filename = speaker_mgr.get_speaker_filename(spk.id)
|
40 |
+
spk_path = f"./data/speakers/{spk_filename}"
|
41 |
+
|
42 |
+
command = ["python3", "-m", "modules.finetune.train_speaker"]
|
43 |
+
command += [
|
44 |
+
f"--batch_size={batch_size}",
|
45 |
+
f"--epochs={epochs}",
|
46 |
+
f"--data_path={data_path}",
|
47 |
+
]
|
48 |
+
if train_text:
|
49 |
+
command.append("--train_text")
|
50 |
+
if spk_path:
|
51 |
+
command.append(f"--init_speaker={spk_path}")
|
52 |
+
|
53 |
+
self.status("Training process starting")
|
54 |
+
|
55 |
+
self.process_monitor.start_process(command)
|
56 |
+
|
57 |
+
self.status("Training started")
|
58 |
+
|
59 |
+
def status(self, text: str):
|
60 |
+
self.status_str = text
|
61 |
+
|
62 |
+
def flush(self):
|
63 |
+
stdout, stderr = self.process_monitor.get_output()
|
64 |
+
return f"{self.status_str}\n{stdout}\n{stderr}"
|
65 |
+
|
66 |
+
def clear(self):
|
67 |
+
self.process_monitor.stdout = ""
|
68 |
+
self.process_monitor.stderr = ""
|
69 |
+
self.status("Logs cleared")
|
70 |
+
|
71 |
+
def stop(self):
|
72 |
+
self.process_monitor.stop_process()
|
73 |
+
self.status("Training stopped")
|
74 |
+
|
75 |
+
|
76 |
+
def create_speaker_ft_tab(demo: gr.Blocks):
|
77 |
+
spk_ft = SpeakerFt()
|
78 |
+
speakers, speaker_names = get_speaker_names()
|
79 |
+
speaker_names = ["none"] + speaker_names
|
80 |
+
|
81 |
+
with gr.Row():
|
82 |
+
with gr.Column(scale=2):
|
83 |
+
with gr.Group():
|
84 |
+
gr.Markdown("🎛️hparams")
|
85 |
+
dataset_input = gr.Dropdown(
|
86 |
+
label="Dataset", choices=get_datasets_listfile()
|
87 |
+
)
|
88 |
+
lr_input = gr.Textbox(label="Learning Rate", value="1e-2")
|
89 |
+
epochs_input = gr.Slider(
|
90 |
+
label="Epochs", value=10, minimum=1, maximum=100, step=1
|
91 |
+
)
|
92 |
+
batch_size_input = gr.Slider(
|
93 |
+
label="Batch Size", value=4, minimum=1, maximum=64, step=1
|
94 |
+
)
|
95 |
+
train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True)
|
96 |
+
init_spk_dropdown = gr.Dropdown(
|
97 |
+
label="Initial Speaker",
|
98 |
+
choices=speaker_names,
|
99 |
+
value="none",
|
100 |
+
)
|
101 |
+
|
102 |
+
with gr.Group():
|
103 |
+
start_train_btn = gr.Button("Start Training")
|
104 |
+
stop_train_btn = gr.Button("Stop Training")
|
105 |
+
clear_train_btn = gr.Button("Clear logs")
|
106 |
+
with gr.Column(scale=5):
|
107 |
+
with gr.Group():
|
108 |
+
# log
|
109 |
+
gr.Markdown("📜logs")
|
110 |
+
log_output = gr.Textbox(
|
111 |
+
show_label=False, label="Log", value="", lines=20, interactive=True
|
112 |
+
)
|
113 |
+
|
114 |
+
start_train_btn.click(
|
115 |
+
spk_ft.run,
|
116 |
+
inputs=[
|
117 |
+
batch_size_input,
|
118 |
+
epochs_input,
|
119 |
+
lr_input,
|
120 |
+
train_text_checkbox,
|
121 |
+
dataset_input,
|
122 |
+
init_spk_dropdown,
|
123 |
+
],
|
124 |
+
outputs=[],
|
125 |
+
)
|
126 |
+
stop_train_btn.click(spk_ft.stop)
|
127 |
+
clear_train_btn.click(spk_ft.clear)
|
128 |
+
|
129 |
+
if webui_config.experimental:
|
130 |
+
demo.load(spk_ft.flush, every=1, outputs=[log_output])
|
modules/webui/localization_runtime.py
CHANGED
@@ -7,6 +7,7 @@ class LocalizationVars:
|
|
7 |
|
8 |
self.ssml_examples = []
|
9 |
self.tts_examples = []
|
|
|
10 |
|
11 |
|
12 |
class ZHLocalizationVars(LocalizationVars):
|
@@ -167,6 +168,69 @@ class ZHLocalizationVars(LocalizationVars):
|
|
167 |
},
|
168 |
]
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
class ENLocalizationVars(LocalizationVars):
|
172 |
def __init__(self):
|
@@ -224,3 +288,65 @@ class ENLocalizationVars(LocalizationVars):
|
|
224 |
"text": "Don't ever let somebody tell you you can't do something. Not even me. Alright? You got a dream, you gotta protect it. When people can't do something themselves, they're gonna tell you that you can't do it. You want something, go get it. Period.",
|
225 |
},
|
226 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
self.ssml_examples = []
|
9 |
self.tts_examples = []
|
10 |
+
self.podcast_default = []
|
11 |
|
12 |
|
13 |
class ZHLocalizationVars(LocalizationVars):
|
|
|
168 |
},
|
169 |
]
|
170 |
|
171 |
+
self.podcast_default = [
|
172 |
+
[
|
173 |
+
1,
|
174 |
+
"female2",
|
175 |
+
"你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。",
|
176 |
+
"podcast",
|
177 |
+
],
|
178 |
+
[
|
179 |
+
2,
|
180 |
+
"Alice",
|
181 |
+
"嗨,我特别期待这个话题!中华料理真的是博大精深。",
|
182 |
+
"podcast",
|
183 |
+
],
|
184 |
+
[
|
185 |
+
3,
|
186 |
+
"Bob",
|
187 |
+
"没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。",
|
188 |
+
"podcast",
|
189 |
+
],
|
190 |
+
[
|
191 |
+
4,
|
192 |
+
"female2",
|
193 |
+
"那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。",
|
194 |
+
"podcast",
|
195 |
+
],
|
196 |
+
[
|
197 |
+
5,
|
198 |
+
"Alice",
|
199 |
+
"对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。",
|
200 |
+
"podcast",
|
201 |
+
],
|
202 |
+
[
|
203 |
+
6,
|
204 |
+
"Bob",
|
205 |
+
"除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。",
|
206 |
+
"podcast",
|
207 |
+
],
|
208 |
+
[
|
209 |
+
7,
|
210 |
+
"female2",
|
211 |
+
"对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。",
|
212 |
+
"podcast",
|
213 |
+
],
|
214 |
+
[
|
215 |
+
8,
|
216 |
+
"Alice",
|
217 |
+
"还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。",
|
218 |
+
"podcast",
|
219 |
+
],
|
220 |
+
[
|
221 |
+
9,
|
222 |
+
"Bob",
|
223 |
+
"不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。",
|
224 |
+
"podcast",
|
225 |
+
],
|
226 |
+
[
|
227 |
+
10,
|
228 |
+
"female2",
|
229 |
+
"对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。",
|
230 |
+
"podcast",
|
231 |
+
],
|
232 |
+
]
|
233 |
+
|
234 |
|
235 |
class ENLocalizationVars(LocalizationVars):
|
236 |
def __init__(self):
|
|
|
288 |
"text": "Don't ever let somebody tell you you can't do something. Not even me. Alright? You got a dream, you gotta protect it. When people can't do something themselves, they're gonna tell you that you can't do it. You want something, go get it. Period.",
|
289 |
},
|
290 |
]
|
291 |
+
self.podcast_default = [
|
292 |
+
[
|
293 |
+
1,
|
294 |
+
"female2",
|
295 |
+
"Hello, welcome to today's podcast. Today, we're going to talk about global cuisine.",
|
296 |
+
"podcast",
|
297 |
+
],
|
298 |
+
[
|
299 |
+
2,
|
300 |
+
"Alice",
|
301 |
+
"Hi, I'm really excited about this topic! Global cuisine is incredibly diverse and fascinating.",
|
302 |
+
"podcast",
|
303 |
+
],
|
304 |
+
[
|
305 |
+
3,
|
306 |
+
"Bob",
|
307 |
+
"Absolutely, every country has its own unique culinary traditions and specialties.",
|
308 |
+
"podcast",
|
309 |
+
],
|
310 |
+
[
|
311 |
+
4,
|
312 |
+
"female2",
|
313 |
+
"Let's start with Italian cuisine. Italian food is loved worldwide, especially for its pasta and pizza.",
|
314 |
+
"podcast",
|
315 |
+
],
|
316 |
+
[
|
317 |
+
5,
|
318 |
+
"Alice",
|
319 |
+
"Yes, I especially love a good Margherita pizza and a hearty plate of spaghetti carbonara. The flavors are simply amazing.",
|
320 |
+
"podcast",
|
321 |
+
],
|
322 |
+
[
|
323 |
+
6,
|
324 |
+
"Bob",
|
325 |
+
"Besides Italian cuisine, Japanese cuisine is also very popular. Dishes like sushi and ramen have become global favorites.",
|
326 |
+
"podcast",
|
327 |
+
],
|
328 |
+
[
|
329 |
+
7,
|
330 |
+
"female2",
|
331 |
+
"Exactly, Japanese cuisine is known for its emphasis on fresh ingredients and delicate presentation.",
|
332 |
+
"podcast",
|
333 |
+
],
|
334 |
+
[
|
335 |
+
8,
|
336 |
+
"Alice",
|
337 |
+
"And then there's Mexican cuisine, with its bold flavors and colorful dishes like tacos and guacamole.",
|
338 |
+
"podcast",
|
339 |
+
],
|
340 |
+
[
|
341 |
+
9,
|
342 |
+
"Bob",
|
343 |
+
"Not to mention, there's also Indian cuisine, Thai cuisine, French cuisine, and so many more, each with its own distinctive flavors and techniques.",
|
344 |
+
"podcast",
|
345 |
+
],
|
346 |
+
[
|
347 |
+
10,
|
348 |
+
"female2",
|
349 |
+
"Yes, like Indian curry, Thai tom yum soup, and French croissants, these are all mouth-watering dishes that are loved by people all over the world.",
|
350 |
+
"podcast",
|
351 |
+
],
|
352 |
+
]
|
modules/webui/ssml/podcast_tab.py
CHANGED
@@ -3,72 +3,9 @@ import pandas as pd
|
|
3 |
import torch
|
4 |
|
5 |
from modules.normalization import text_normalize
|
6 |
-
from modules.webui import webui_utils
|
7 |
from modules.utils.hf import spaces
|
8 |
|
9 |
-
podcast_default_case = [
|
10 |
-
[
|
11 |
-
1,
|
12 |
-
"female2",
|
13 |
-
"你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
|
14 |
-
"podcast",
|
15 |
-
],
|
16 |
-
[
|
17 |
-
2,
|
18 |
-
"Alice",
|
19 |
-
"嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
|
20 |
-
"podcast",
|
21 |
-
],
|
22 |
-
[
|
23 |
-
3,
|
24 |
-
"Bob",
|
25 |
-
"没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
|
26 |
-
"podcast",
|
27 |
-
],
|
28 |
-
[
|
29 |
-
4,
|
30 |
-
"female2",
|
31 |
-
"那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
|
32 |
-
"podcast",
|
33 |
-
],
|
34 |
-
[
|
35 |
-
5,
|
36 |
-
"Alice",
|
37 |
-
"对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
|
38 |
-
"podcast",
|
39 |
-
],
|
40 |
-
[
|
41 |
-
6,
|
42 |
-
"Bob",
|
43 |
-
"除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
|
44 |
-
"podcast",
|
45 |
-
],
|
46 |
-
[
|
47 |
-
7,
|
48 |
-
"female2",
|
49 |
-
"对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
|
50 |
-
"podcast",
|
51 |
-
],
|
52 |
-
[
|
53 |
-
8,
|
54 |
-
"Alice",
|
55 |
-
"还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
|
56 |
-
"podcast",
|
57 |
-
],
|
58 |
-
[
|
59 |
-
9,
|
60 |
-
"Bob",
|
61 |
-
"不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
|
62 |
-
"podcast",
|
63 |
-
],
|
64 |
-
[
|
65 |
-
10,
|
66 |
-
"female2",
|
67 |
-
"对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
|
68 |
-
"podcast",
|
69 |
-
],
|
70 |
-
]
|
71 |
-
|
72 |
|
73 |
# NOTE: 因为 text_normalize 需要使用 tokenizer
|
74 |
@torch.inference_mode()
|
@@ -133,7 +70,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
|
|
133 |
datatype=["number", "str", "str", "str"],
|
134 |
interactive=True,
|
135 |
wrap=True,
|
136 |
-
value=
|
137 |
row_count=(0, "dynamic"),
|
138 |
col_count=(4, "fixed"),
|
139 |
)
|
|
|
3 |
import torch
|
4 |
|
5 |
from modules.normalization import text_normalize
|
6 |
+
from modules.webui import webui_config, webui_utils
|
7 |
from modules.utils.hf import spaces
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# NOTE: 因为 text_normalize 需要使用 tokenizer
|
11 |
@torch.inference_mode()
|
|
|
70 |
datatype=["number", "str", "str", "str"],
|
71 |
interactive=True,
|
72 |
wrap=True,
|
73 |
+
value=webui_config.localization.podcast_default,
|
74 |
row_count=(0, "dynamic"),
|
75 |
col_count=(4, "fixed"),
|
76 |
)
|
modules/webui/ssml/ssml_tab.py
CHANGED
@@ -22,7 +22,6 @@ def create_ssml_interface():
|
|
22 |
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
|
23 |
with gr.Column(scale=1):
|
24 |
with gr.Group():
|
25 |
-
# 参数
|
26 |
gr.Markdown("🎛️Parameters")
|
27 |
# batch size
|
28 |
batch_size_input = gr.Slider(
|
@@ -32,6 +31,19 @@ def create_ssml_interface():
|
|
32 |
maximum=webui_config.max_batch_size,
|
33 |
step=1,
|
34 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
with gr.Group():
|
37 |
gr.Markdown("💪🏼Enhance")
|
@@ -49,7 +61,14 @@ def create_ssml_interface():
|
|
49 |
|
50 |
ssml_button.click(
|
51 |
synthesize_ssml,
|
52 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
outputs=ssml_output,
|
54 |
)
|
55 |
|
|
|
22 |
ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
|
23 |
with gr.Column(scale=1):
|
24 |
with gr.Group():
|
|
|
25 |
gr.Markdown("🎛️Parameters")
|
26 |
# batch size
|
27 |
batch_size_input = gr.Slider(
|
|
|
31 |
maximum=webui_config.max_batch_size,
|
32 |
step=1,
|
33 |
)
|
34 |
+
with gr.Group():
|
35 |
+
gr.Markdown("🎛️Spliter")
|
36 |
+
eos_input = gr.Textbox(
|
37 |
+
label="eos",
|
38 |
+
value="[uv_break]",
|
39 |
+
)
|
40 |
+
spliter_thr_input = gr.Slider(
|
41 |
+
label="Spliter Threshold",
|
42 |
+
value=100,
|
43 |
+
minimum=50,
|
44 |
+
maximum=1000,
|
45 |
+
step=1,
|
46 |
+
)
|
47 |
|
48 |
with gr.Group():
|
49 |
gr.Markdown("💪🏼Enhance")
|
|
|
61 |
|
62 |
ssml_button.click(
|
63 |
synthesize_ssml,
|
64 |
+
inputs=[
|
65 |
+
ssml_input,
|
66 |
+
batch_size_input,
|
67 |
+
enable_enhance,
|
68 |
+
enable_de_noise,
|
69 |
+
eos_input,
|
70 |
+
spliter_thr_input,
|
71 |
+
],
|
72 |
outputs=ssml_output,
|
73 |
)
|
74 |
|