Spaces:
Running
Running
kevinwang676
commited on
Commit
•
2d6ed53
1
Parent(s):
e74f0aa
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +8 -0
- .gitattributes +1 -0
- .gitignore +14 -0
- Changelog_CN.md +143 -0
- Docker/damo.sha256 +3 -0
- Docker/download.py +5 -0
- Docker/download.sh +11 -0
- Docker/links.sha256 +12 -0
- Docker/links.txt +34 -0
- Dockerfile +45 -0
- GPT-SoVITS-models/.gitattributes +44 -0
- GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/del-checkpoint.sh +12 -0
- GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/webui-checkpoint.py +719 -0
- GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/启动webui-checkpoint.sh +2 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/.ipynb_checkpoints/inference_webui-checkpoint.py +270 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/__init__.py +0 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/__init__.py +0 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/bucket_sampler.py +157 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/data_module.py +66 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/dataset.py +302 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/__init__.py +0 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/BEATs.py +179 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/README.md +127 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/Tokenizers.py +172 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/__init__.py +2 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/backbone.py +791 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/config.py +19 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/modules.py +220 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/ontology.json +0 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/quantizer.py +235 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_beats_librilight.py +321 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones.py +232 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones_librilight.py +198 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_txt_librilight.py +255 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/split_train_val.py +35 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/t2s.py +197 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/test.py +139 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/text.txt +10 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train.py +103 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train_librilight_6k.py +170 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/__init__.py +0 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_lightning_module.py +128 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model.py +298 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/utils.py +164 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/__init__.py +0 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/activation.py +397 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/embedding.py +78 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/lr_schedulers.py +85 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/optim.py +622 -0
- GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache.py +388 -0
.dockerignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
docs
|
2 |
+
logs
|
3 |
+
output
|
4 |
+
reference
|
5 |
+
SoVITS_weights
|
6 |
+
GPT_weights
|
7 |
+
TEMP
|
8 |
+
.git
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tools/damo_asr/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
__pycache__
|
3 |
+
*.pyc
|
4 |
+
env
|
5 |
+
runtime
|
6 |
+
.idea
|
7 |
+
output
|
8 |
+
logs
|
9 |
+
reference
|
10 |
+
GPT_weights
|
11 |
+
SoVITS_weights
|
12 |
+
TEMP
|
13 |
+
|
14 |
+
|
Changelog_CN.md
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 20240121更新
|
2 |
+
|
3 |
+
1-config添加is_share,诸如colab等场景可以将此改为True,来使得webui映射到公网
|
4 |
+
|
5 |
+
2-WebUI添加英文系统英文翻译适配
|
6 |
+
|
7 |
+
3-cmd-asr自动判断是否已自带damo模型,如不在默认目录上将从modelscope自带下载
|
8 |
+
|
9 |
+
4-[SoVITS训练报错ZeroDivisionError](https://github.com/RVC-Boss/GPT-SoVITS/issues/79) 尝试修复(过滤长度0的样本等)
|
10 |
+
|
11 |
+
5-清理TEMP文件夹缓存音频等文件
|
12 |
+
|
13 |
+
6-大幅削弱合成音频包含参考音频结尾的问题
|
14 |
+
|
15 |
+
### 20240122更新
|
16 |
+
|
17 |
+
1-修复过短输出文件返回重复参考音频的问题。
|
18 |
+
|
19 |
+
2-经测试,英文日文训练原生支持(日文训练需要根目录不含非英文等特殊字符)。
|
20 |
+
|
21 |
+
3-音频路径检查。如果尝试读取输入错的路径报错路径不存在,而非ffmpeg错误。
|
22 |
+
|
23 |
+
### 20240123更新
|
24 |
+
|
25 |
+
1-解决hubert提取nan导致SoVITS/GPT训练报错ZeroDivisionError的问题
|
26 |
+
|
27 |
+
2-支持推理界面快速切换模型
|
28 |
+
|
29 |
+
3-优化模型文件排序逻辑
|
30 |
+
|
31 |
+
4-中文分词使用jieba_fast代替jieba
|
32 |
+
|
33 |
+
### 20240126更新
|
34 |
+
|
35 |
+
1-支持输出文本中英混合、日英混合
|
36 |
+
|
37 |
+
2-输出可选切分模式
|
38 |
+
|
39 |
+
3-修复uvr5读取到目录自动跳出的问题
|
40 |
+
|
41 |
+
4-修复多个换行导致推理报错
|
42 |
+
|
43 |
+
5-去除推理界面大量冗余log
|
44 |
+
|
45 |
+
6-支持mac训练推理
|
46 |
+
|
47 |
+
7-自动识别不支持半精度的卡强制单精度。cpu推理下强制单精度。
|
48 |
+
|
49 |
+
### 20240128更新
|
50 |
+
|
51 |
+
1-修复数字转汉字念法问题
|
52 |
+
|
53 |
+
2-修复句首少量字容易吞字的问题
|
54 |
+
|
55 |
+
3-通过限制排除不合理的参考音频长度
|
56 |
+
|
57 |
+
4-修复GPT训练不保存ckpt的问题
|
58 |
+
|
59 |
+
5-完善Dockerfile的下载模型流程
|
60 |
+
|
61 |
+
### 20240129更新
|
62 |
+
|
63 |
+
1-16系等半精度训练有问题的显卡把训练配置改为单精度训练
|
64 |
+
|
65 |
+
2-测试更新可用的colab版本
|
66 |
+
|
67 |
+
3-修复git clone modelscope funasr仓库+老版本funasr导致接口不对齐报错的问题
|
68 |
+
|
69 |
+
|
70 |
+
### 20240130更新
|
71 |
+
|
72 |
+
1-所有涉及路径的地方双引号自动去除,小白复制路径带双引号不会报错
|
73 |
+
|
74 |
+
2-修复中英文标点切割问题和句首句尾补标点的问题
|
75 |
+
|
76 |
+
3-增加按标点符号切分
|
77 |
+
|
78 |
+
### 20240201更新
|
79 |
+
|
80 |
+
1-修复uvr5读取格式错误导致分离失败的问题
|
81 |
+
|
82 |
+
2-支持中日英混合多种文本自动切分识别语种
|
83 |
+
|
84 |
+
### 20240202更新
|
85 |
+
|
86 |
+
1-修复asr路径尾缀带/保存文件名报错
|
87 |
+
|
88 |
+
2-引入paddlespeech的Normalizer https://github.com/RVC-Boss/GPT-SoVITS/pull/377 修复一些问题,例如:xx.xx%(带百分号类),元/吨 会读成 元吨 而不是元每吨,下划线不再会报错
|
89 |
+
|
90 |
+
### 20240207更新
|
91 |
+
|
92 |
+
1-修正语种传参混乱导致中文推理效果下降 https://github.com/RVC-Boss/GPT-SoVITS/issues/391
|
93 |
+
|
94 |
+
2-uvr5适配高版本librosa https://github.com/RVC-Boss/GPT-SoVITS/pull/403
|
95 |
+
|
96 |
+
3-修复uvr5 inf everywhere报错的问题(is_half传参未转换bool导致恒定半精度推理,16系显卡会inf) https://github.com/RVC-Boss/GPT-SoVITS/commit/14a285109a521679f8846589c22da8f656a46ad8
|
97 |
+
|
98 |
+
4-优化英文文本前端
|
99 |
+
|
100 |
+
5-修复gradio依赖
|
101 |
+
|
102 |
+
6-支持三连根目录留空自动读取.list全路径
|
103 |
+
|
104 |
+
7-集成faster whisper ASR日文英文
|
105 |
+
|
106 |
+
### 20240208更新
|
107 |
+
|
108 |
+
1-GPT训练卡死(win10 1909)和https://github.com/RVC-Boss/GPT-SoVITS/issues/232 (系统语言繁体)GPT训练报错,[尝试修复](https://github.com/RVC-Boss/GPT-SoVITS/commit/59f35adad85815df27e9c6b33d420f5ebfd8376b)。
|
109 |
+
|
110 |
+
### 20240212更新
|
111 |
+
|
112 |
+
1-faster whisper和funasr逻辑优化。faster whisper转镜像站下载,规避huggingface连不上的问题。
|
113 |
+
|
114 |
+
2-DPO Loss实验性训练选项开启,通过构造负样本训练缓解GPT重复漏字问题。推理界面公开几个推理参数。 https://github.com/RVC-Boss/GPT-SoVITS/pull/457
|
115 |
+
|
116 |
+
### 20240214更新
|
117 |
+
|
118 |
+
1-训练支持中文实验名(原来会报错)
|
119 |
+
|
120 |
+
2-DPO训练改为可勾选选项而非必须。如勾选batch size自动减半。修复推理界面新参数不传参的问题。
|
121 |
+
|
122 |
+
### 20240216更新
|
123 |
+
|
124 |
+
1-支持无参考文本输入
|
125 |
+
|
126 |
+
2-修复中文文本前端bug https://github.com/RVC-Boss/GPT-SoVITS/issues/475
|
127 |
+
|
128 |
+
### 20240221更新
|
129 |
+
|
130 |
+
1-数据处理添加语音降噪选项
|
131 |
+
|
132 |
+
2-中文日文前端处理优化 https://github.com/RVC-Boss/GPT-SoVITS/pull/559 https://github.com/RVC-Boss/GPT-SoVITS/pull/556 https://github.com/RVC-Boss/GPT-SoVITS/pull/532 https://github.com/RVC-Boss/GPT-SoVITS/pull/507 https://github.com/RVC-Boss/GPT-SoVITS/pull/509
|
133 |
+
|
134 |
+
3-mac CPU推理更快因此把推理设备从mps改到CPU
|
135 |
+
|
136 |
+
4-colab修复不开启公网url
|
137 |
+
|
138 |
+
todolist:
|
139 |
+
|
140 |
+
1-中文多音字推理优化
|
141 |
+
|
142 |
+
|
143 |
+
|
Docker/damo.sha256
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
5bba782a5e9196166233b9ab12ba04cadff9ef9212b4ff6153ed9290ff679025 /workspace/tools/damo_asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pb
|
2 |
+
b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5 /workspace/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pb
|
3 |
+
a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916 /workspace/tools/damo_asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pb
|
Docker/download.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Download moda ASR related models
|
2 |
+
from modelscope import snapshot_download
|
3 |
+
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
|
4 |
+
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
|
5 |
+
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
|
Docker/download.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
set -Eeuo pipefail
|
4 |
+
|
5 |
+
echo "Downloading models..."
|
6 |
+
|
7 |
+
aria2c --disable-ipv6 --input-file /workspace/Docker/links.txt --dir /workspace --continue
|
8 |
+
|
9 |
+
echo "Checking SHA256..."
|
10 |
+
|
11 |
+
parallel --will-cite -a /workspace/Docker/links.sha256 "echo -n {} | sha256sum -c"
|
Docker/links.sha256
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1 /workspace/GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
2 |
+
fc579c1db3c1e21b721001cf99d7a584214280df19b002e200b630a34fa06eb8 /workspace/GPT_SoVITS/pretrained_models/s2D488k.pth
|
3 |
+
020a014e1e01e550e510f2f61fae5e5f5b6aab40f15c22f1f12f724df507e835 /workspace/GPT_SoVITS/pretrained_models/s2G488k.pth
|
4 |
+
24164f129c66499d1346e2aa55f183250c223161ec2770c0da3d3b08cf432d3c /workspace/GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
|
5 |
+
e53a693acc59ace251d143d068096ae0d7b79e4b1b503fa84c9dcf576448c1d8 /workspace/GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
|
6 |
+
39796caa5db18d7f9382d8ac997ac967bfd85f7761014bb807d2543cc844ef05 /workspace/tools/uvr5/uvr5_weights/HP2_all_vocals.pth
|
7 |
+
45e6b65199e781b4a6542002699be9f19cd3d1cb7d1558bc2bfbcd84674dfe28 /workspace/tools/uvr5/uvr5_weights/HP3_all_vocals.pth
|
8 |
+
5908891829634926119720241e8573d97cbeb8277110a7512bdb0bd7563258ee /workspace/tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
|
9 |
+
8c8fd1582f9aabc363e47af62ddb88df6cae7e064cae75bbf041a067a5e0aee2 /workspace/tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
|
10 |
+
01376dd2a571bf3cb9cced680732726d2d732609d09216a610b0d110f133febe /workspace/tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
|
11 |
+
56aba59db3bcdd14a14464e62f3129698ecdea62eee0f003b9360923eb3ac79e /workspace/tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
|
12 |
+
233bb5c6aaa365e568659a0a81211746fa881f8f47f82d9e864fce1f7692db80 /workspace/tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
|
Docker/links.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GPT-SoVITS models
|
2 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s1bert25hz-2kh-longer-epoch%3D68e-step%3D50232.ckpt
|
3 |
+
out=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
4 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2D488k.pth
|
5 |
+
out=GPT_SoVITS/pretrained_models/s2D488k.pth
|
6 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2G488k.pth
|
7 |
+
out=GPT_SoVITS/pretrained_models/s2G488k.pth
|
8 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/config.json
|
9 |
+
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/config.json
|
10 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/preprocessor_config.json
|
11 |
+
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/preprocessor_config.json
|
12 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/pytorch_model.bin
|
13 |
+
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
|
14 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/config.json
|
15 |
+
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/config.json
|
16 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/pytorch_model.bin
|
17 |
+
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
|
18 |
+
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/tokenizer.json
|
19 |
+
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json
|
20 |
+
# UVR5
|
21 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2_all_vocals.pth
|
22 |
+
out=tools/uvr5/uvr5_weights/HP2_all_vocals.pth
|
23 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP3_all_vocals.pth
|
24 |
+
out=tools/uvr5/uvr5_weights/HP3_all_vocals.pth
|
25 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5_only_main_vocal.pth
|
26 |
+
out=tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
|
27 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoAggressive.pth
|
28 |
+
out=tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
|
29 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoDeReverb.pth
|
30 |
+
out=tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
|
31 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoNormal.pth
|
32 |
+
out=tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
|
33 |
+
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
|
34 |
+
out=tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
|
Dockerfile
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base CUDA image
|
2 |
+
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
|
3 |
+
|
4 |
+
LABEL maintainer="[email protected]"
|
5 |
+
LABEL version="dev-20240209"
|
6 |
+
LABEL description="Docker image for GPT-SoVITS"
|
7 |
+
|
8 |
+
|
9 |
+
# Install 3rd party apps
|
10 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
11 |
+
ENV TZ=Etc/UTC
|
12 |
+
RUN apt-get update && \
|
13 |
+
apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
|
14 |
+
git lfs install && \
|
15 |
+
rm -rf /var/lib/apt/lists/*
|
16 |
+
|
17 |
+
# Copy only requirements.txt initially to leverage Docker cache
|
18 |
+
WORKDIR /workspace
|
19 |
+
COPY requirements.txt /workspace/
|
20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
21 |
+
|
22 |
+
# Define a build-time argument for image type
|
23 |
+
ARG IMAGE_TYPE=full
|
24 |
+
|
25 |
+
# Conditional logic based on the IMAGE_TYPE argument
|
26 |
+
# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
|
27 |
+
COPY ./Docker /workspace/Docker
|
28 |
+
# elite 类型的镜像里面不包含额外的模型
|
29 |
+
RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
|
30 |
+
chmod +x /workspace/Docker/download.sh && \
|
31 |
+
/workspace/Docker/download.sh && \
|
32 |
+
python /workspace/Docker/download.py && \
|
33 |
+
python -m nltk.downloader averaged_perceptron_tagger cmudict; \
|
34 |
+
fi
|
35 |
+
|
36 |
+
|
37 |
+
# Copy the rest of the application
|
38 |
+
COPY . /workspace
|
39 |
+
|
40 |
+
# Copy the rest of the application
|
41 |
+
COPY . /workspace
|
42 |
+
|
43 |
+
EXPOSE 9871 9872 9873 9874 9880
|
44 |
+
|
45 |
+
CMD ["python", "webui.py"]
|
GPT-SoVITS-models/.gitattributes
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
GPT-SoVITS/TEMP/gradio/2bbf387613664982acc3847e4b4970fc6bf09120/audio.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
GPT-SoVITS/TEMP/gradio/4e25df8e5470697bd435cc94559e1c34f09bab16/audio.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
GPT-SoVITS/TEMP/gradio/6b579cccde8715941d9b9b06a1b9787ce0fdb4db/audio.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
GPT-SoVITS/TEMP/gradio/873c1f03462a87c00222fd2422a8b328244f45da/audio.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
GPT-SoVITS/TEMP/gradio/d2c38e2d7f131cfc51fe07c541177b0f5a061cc3/audio.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
GPT-SoVITS/TEMP/gradio/e6f05e0d768171ac3b7355d968cb1badf9d84864/wyxy_101-0-100.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
GPT-SoVITS/TEMP/gradio/e6f05e0d768171ac3b7355d968cb1badf9d84864/wyxy_101.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
GPT-SoVITS/TEMP/jieba.cache filter=lfs diff=lfs merge=lfs -text
|
44 |
+
GPT-SoVITS/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav filter=lfs diff=lfs merge=lfs -text
|
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/del-checkpoint.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
cd /root/autodl-tmp/workdir/GPT-SoVITS
|
3 |
+
rm -rf GPT_weights/*
|
4 |
+
rm -rf SoVITS_weights/*
|
5 |
+
|
6 |
+
rm -rf input/*
|
7 |
+
rm -rf output/asr_opt/*
|
8 |
+
rm -rf output/slicer_opt/*
|
9 |
+
rm -rf output/uvr5_opt/*
|
10 |
+
rm -rf logs/*
|
11 |
+
|
12 |
+
echo 初始化完成
|
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/webui-checkpoint.py
ADDED
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json,yaml,warnings,torch
|
2 |
+
warnings.filterwarnings("ignore")
|
3 |
+
torch.manual_seed(233333)
|
4 |
+
import os,pdb,sys
|
5 |
+
now_dir = os.getcwd()
|
6 |
+
tmp = os.path.join(now_dir, "TEMP")
|
7 |
+
os.makedirs(tmp, exist_ok=True)
|
8 |
+
os.environ["TEMP"] = tmp
|
9 |
+
import site
|
10 |
+
site_packages_root="%s/root/miniconda3/lib/python3.10/site-packages"%now_dir
|
11 |
+
for path in site.getsitepackages():
|
12 |
+
if("site-packages"in path):site_packages_root=path
|
13 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "4"
|
14 |
+
os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
|
15 |
+
with open("%s/users.pth"%(site_packages_root),"w")as f:
|
16 |
+
f.write("%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"%(now_dir,now_dir,now_dir,now_dir,now_dir))
|
17 |
+
import traceback
|
18 |
+
sys.path.append(now_dir)
|
19 |
+
import shutil
|
20 |
+
import pdb
|
21 |
+
import gradio as gr
|
22 |
+
from subprocess import Popen
|
23 |
+
import signal
|
24 |
+
from config import python_exec,infer_device,is_half,exp_root
|
25 |
+
from i18n.i18n import I18nAuto
|
26 |
+
i18n = I18nAuto()
|
27 |
+
from scipy.io import wavfile
|
28 |
+
from tools.my_utils import load_audio
|
29 |
+
from multiprocessing import cpu_count
|
30 |
+
n_cpu=cpu_count()
|
31 |
+
|
32 |
+
# 判断是否有能用来训练和加速推理的N卡
|
33 |
+
ngpu = torch.cuda.device_count()
|
34 |
+
gpu_infos = []
|
35 |
+
mem = []
|
36 |
+
if_gpu_ok = False
|
37 |
+
|
38 |
+
if torch.cuda.is_available() or ngpu != 0:
|
39 |
+
for i in range(ngpu):
|
40 |
+
gpu_name = torch.cuda.get_device_name(i)
|
41 |
+
if any(value in gpu_name.upper()for value in ["10","16","20","30","40","A2","A3","A4","P4","A50","500","A60","70","80","90","M4","T4","TITAN","L"]):
|
42 |
+
# A10#A100#V100#A40#P40#M40#K80#A4500
|
43 |
+
if_gpu_ok = True # 至少有一张能用的N卡
|
44 |
+
gpu_infos.append("%s\t%s" % (i, gpu_name))
|
45 |
+
mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4))
|
46 |
+
if if_gpu_ok and len(gpu_infos) > 0:
|
47 |
+
gpu_info = "\n".join(gpu_infos)
|
48 |
+
default_batch_size = min(mem) // 2
|
49 |
+
else:
|
50 |
+
gpu_info = i18n("很遗憾您这没有能用的显卡来支持您训练")
|
51 |
+
default_batch_size = 1
|
52 |
+
gpus = "-".join([i[0] for i in gpu_infos])
|
53 |
+
|
54 |
+
pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"
|
55 |
+
pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
56 |
+
def get_weights_names():
|
57 |
+
SoVITS_names = [pretrained_sovits_name]
|
58 |
+
for name in os.listdir(SoVITS_weight_root):
|
59 |
+
if name.endswith(".pth"):SoVITS_names.append(name)
|
60 |
+
GPT_names = [pretrained_gpt_name]
|
61 |
+
for name in os.listdir(GPT_weight_root):
|
62 |
+
if name.endswith(".ckpt"): GPT_names.append(name)
|
63 |
+
return SoVITS_names,GPT_names
|
64 |
+
SoVITS_weight_root="SoVITS_weights"
|
65 |
+
GPT_weight_root="GPT_weights"
|
66 |
+
SoVITS_names,GPT_names = get_weights_names()
|
67 |
+
|
68 |
+
def change_choices():
|
69 |
+
SoVITS_names, GPT_names = get_weights_names()
|
70 |
+
return {"choices": sorted(SoVITS_names), "__type__": "update"}, {"choices": sorted(GPT_names), "__type__": "update"}
|
71 |
+
|
72 |
+
p_label=None
|
73 |
+
p_uvr5=None
|
74 |
+
p_asr=None
|
75 |
+
p_tts_inference=None
|
76 |
+
|
77 |
+
def kill_process(pid):
|
78 |
+
os.system("taskkill /t /f /pid %s" % pid) # todo:识别linux用kill -9
|
79 |
+
# os.kill(p_label.pid,19)#主进程#控制台进程#python子进程###不好使,连主进程的webui一起关了,辣鸡
|
80 |
+
|
81 |
+
def change_label(if_label,path_list):
|
82 |
+
global p_label
|
83 |
+
if(if_label==True and p_label==None):
|
84 |
+
cmd = '"%s" tools/subfix_webui.py --load_list "%s"'%(python_exec,path_list)
|
85 |
+
yield "打标工具WebUI已开启"
|
86 |
+
print(cmd)
|
87 |
+
p_label = Popen(cmd, shell=True)
|
88 |
+
elif(if_label==False and p_label!=None):
|
89 |
+
kill_process(p_label.pid)
|
90 |
+
p_label=None
|
91 |
+
yield "打标工具WebUI已关闭"
|
92 |
+
|
93 |
+
def change_uvr5(if_uvr5):
|
94 |
+
global p_uvr5
|
95 |
+
if(if_uvr5==True and p_uvr5==None):
|
96 |
+
cmd = '"%s" tools/uvr5/webui.py "%s" %s'%(python_exec,infer_device,is_half)
|
97 |
+
yield "UVR5已开启"
|
98 |
+
print(cmd)
|
99 |
+
p_uvr5 = Popen(cmd, shell=True)
|
100 |
+
elif(if_uvr5==False and p_uvr5!=None):
|
101 |
+
kill_process(p_uvr5.pid)
|
102 |
+
p_uvr5=None
|
103 |
+
yield "UVR5已关闭"
|
104 |
+
|
105 |
+
def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path):
|
106 |
+
global p_tts_inference
|
107 |
+
if(if_tts==True and p_tts_inference==None):
|
108 |
+
os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path)
|
109 |
+
os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path)
|
110 |
+
os.environ["cnhubert_base_path"]=cnhubert_base_path
|
111 |
+
os.environ["bert_path"]=bert_path
|
112 |
+
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_number
|
113 |
+
os.environ["is_half"]=str(is_half)
|
114 |
+
cmd = '"%s" GPT_SoVITS/inference_webui.py'%(python_exec)
|
115 |
+
yield "TTS推理进程已开启"
|
116 |
+
print(cmd)
|
117 |
+
p_tts_inference = Popen(cmd, shell=True)
|
118 |
+
elif(if_tts==False and p_tts_inference!=None):
|
119 |
+
kill_process(p_tts_inference.pid)
|
120 |
+
p_tts_inference=None
|
121 |
+
yield "TTS推理进程已关闭"
|
122 |
+
|
123 |
+
|
124 |
+
def open_asr(asr_inp_dir):
|
125 |
+
global p_asr
|
126 |
+
if(p_asr==None):
|
127 |
+
cmd = '"%s" tools/damo_asr/cmd-asr.py "%s"'%(python_exec,asr_inp_dir)
|
128 |
+
yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
129 |
+
print(cmd)
|
130 |
+
p_asr = Popen(cmd, shell=True)
|
131 |
+
p_asr.wait()
|
132 |
+
p_asr=None
|
133 |
+
yield "ASR任务完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
134 |
+
else:
|
135 |
+
yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
136 |
+
|
137 |
+
def close_asr():
|
138 |
+
global p_asr
|
139 |
+
if(p_asr!=None):
|
140 |
+
kill_process(p_asr.pid)
|
141 |
+
p_asr=None
|
142 |
+
return "已终止ASR进程",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
143 |
+
|
144 |
+
'''
|
145 |
+
button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Bb,button1Ba_open,button1Ba_close])
|
146 |
+
button1Ba_close.click(close1Ba, [], [info1Bb,button1Ba_open,button1Ba_close])
|
147 |
+
'''
|
148 |
+
p_train_SoVITS=None
|
149 |
+
def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D):
|
150 |
+
global p_train_SoVITS
|
151 |
+
if(p_train_SoVITS==None):
|
152 |
+
with open("GPT_SoVITS/configs/s2.json")as f:
|
153 |
+
data=f.read()
|
154 |
+
data=json.loads(data)
|
155 |
+
s2_dir="%s/%s"%(exp_root,exp_name)
|
156 |
+
os.makedirs("%s/logs_s2"%(s2_dir),exist_ok=True)
|
157 |
+
data["train"]["batch_size"]=batch_size
|
158 |
+
data["train"]["epochs"]=total_epoch
|
159 |
+
data["train"]["text_low_lr_rate"]=text_low_lr_rate
|
160 |
+
data["train"]["pretrained_s2G"]=pretrained_s2G
|
161 |
+
data["train"]["pretrained_s2D"]=pretrained_s2D
|
162 |
+
data["train"]["if_save_latest"]=if_save_latest
|
163 |
+
data["train"]["if_save_every_weights"]=if_save_every_weights
|
164 |
+
data["train"]["save_every_epoch"]=save_every_epoch
|
165 |
+
data["train"]["gpu_numbers"]=gpu_numbers1Ba
|
166 |
+
data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir
|
167 |
+
data["save_weight_dir"]=SoVITS_weight_root
|
168 |
+
data["name"]=exp_name
|
169 |
+
tmp_config_path="TEMP/tmp_s2.json"
|
170 |
+
with open(tmp_config_path,"w")as f:f.write(json.dumps(data))
|
171 |
+
|
172 |
+
cmd = '"%s" GPT_SoVITS/s2_train.py --config "%s"'%(python_exec,tmp_config_path)
|
173 |
+
yield "SoVITS训练开始:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
174 |
+
print(cmd)
|
175 |
+
p_train_SoVITS = Popen(cmd, shell=True)
|
176 |
+
p_train_SoVITS.wait()
|
177 |
+
p_train_SoVITS=None
|
178 |
+
yield "SoVITS训练完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
179 |
+
else:
|
180 |
+
yield "已有正在进行的SoVITS训练任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
181 |
+
|
182 |
+
def close1Ba():
|
183 |
+
global p_train_SoVITS
|
184 |
+
if(p_train_SoVITS!=None):
|
185 |
+
kill_process(p_train_SoVITS.pid)
|
186 |
+
p_train_SoVITS=None
|
187 |
+
return "已终止SoVITS训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
188 |
+
|
189 |
+
p_train_GPT=None
|
190 |
+
def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers,pretrained_s1):
|
191 |
+
global p_train_GPT
|
192 |
+
if(p_train_GPT==None):
|
193 |
+
with open("GPT_SoVITS/configs/s1longer.yaml")as f:
|
194 |
+
data=f.read()
|
195 |
+
data=yaml.load(data, Loader=yaml.FullLoader)
|
196 |
+
s1_dir="%s/%s"%(exp_root,exp_name)
|
197 |
+
os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
|
198 |
+
data["train"]["batch_size"]=batch_size
|
199 |
+
data["train"]["epochs"]=total_epoch
|
200 |
+
data["pretrained_s1"]=pretrained_s1
|
201 |
+
data["train"]["save_every_n_epoch"]=save_every_epoch
|
202 |
+
data["train"]["if_save_every_weights"]=if_save_every_weights
|
203 |
+
data["train"]["if_save_latest"]=if_save_latest
|
204 |
+
data["train"]["half_weights_save_dir"]=GPT_weight_root
|
205 |
+
data["train"]["exp_name"]=exp_name
|
206 |
+
data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
|
207 |
+
data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
|
208 |
+
data["output_dir"]="%s/logs_s1"%s1_dir
|
209 |
+
|
210 |
+
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
|
211 |
+
os.environ["hz"]="25hz"
|
212 |
+
tmp_config_path="TEMP/tmp_s1.yaml"
|
213 |
+
with open(tmp_config_path, "w") as f:f.write(yaml.dump(data, default_flow_style=False))
|
214 |
+
# cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" --train_semantic_path "%s/6-name2semantic.tsv" --train_phoneme_path "%s/2-name2text.txt" --output_dir "%s/logs_s1"'%(python_exec,tmp_config_path,s1_dir,s1_dir,s1_dir)
|
215 |
+
cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" '%(python_exec,tmp_config_path)
|
216 |
+
yield "GPT训练开始:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
217 |
+
print(cmd)
|
218 |
+
p_train_GPT = Popen(cmd, shell=True)
|
219 |
+
p_train_GPT.wait()
|
220 |
+
p_train_GPT=None
|
221 |
+
yield "GPT训练完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
222 |
+
else:
|
223 |
+
yield "已有正在进行的GPT训练任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
224 |
+
|
225 |
+
def close1Bb():
|
226 |
+
global p_train_GPT
|
227 |
+
if(p_train_GPT!=None):
|
228 |
+
kill_process(p_train_GPT.pid)
|
229 |
+
p_train_GPT=None
|
230 |
+
return "已终止GPT训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
231 |
+
|
232 |
+
ps_slice=[]
|
233 |
+
def open_slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_parts):
|
234 |
+
global ps_slice
|
235 |
+
if(os.path.exists(inp)==False):
|
236 |
+
yield "输入路径不存在",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
237 |
+
return
|
238 |
+
if os.path.isfile(inp):n_parts=1
|
239 |
+
elif os.path.isdir(inp):pass
|
240 |
+
else:
|
241 |
+
yield "输入路径存在但既不是文件也不是文件夹",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
242 |
+
return
|
243 |
+
if (ps_slice == []):
|
244 |
+
for i_part in range(n_parts):
|
245 |
+
cmd = '"%s" tools/slice_audio.py "%s" "%s" %s %s %s %s %s %s %s %s %s''' % (python_exec,inp, opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, i_part, n_parts)
|
246 |
+
print(cmd)
|
247 |
+
p = Popen(cmd, shell=True)
|
248 |
+
ps_slice.append(p)
|
249 |
+
yield "切割执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
250 |
+
for p in ps_slice:
|
251 |
+
p.wait()
|
252 |
+
ps_slice=[]
|
253 |
+
yield "切割结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
254 |
+
else:
|
255 |
+
yield "已有正在进行的切割任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
256 |
+
|
257 |
+
def close_slice():
|
258 |
+
global ps_slice
|
259 |
+
if (ps_slice != []):
|
260 |
+
for p_slice in ps_slice:
|
261 |
+
try:
|
262 |
+
kill_process(p_slice.pid)
|
263 |
+
except:
|
264 |
+
traceback.print_exc()
|
265 |
+
ps_slice=[]
|
266 |
+
return "已终止所有切割进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
267 |
+
|
268 |
+
'''
|
269 |
+
inp_text= os.environ.get("inp_text")
|
270 |
+
inp_wav_dir= os.environ.get("inp_wav_dir")
|
271 |
+
exp_name= os.environ.get("exp_name")
|
272 |
+
i_part= os.environ.get("i_part")
|
273 |
+
all_parts= os.environ.get("all_parts")
|
274 |
+
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
|
275 |
+
opt_dir= os.environ.get("opt_dir")#"/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
276 |
+
bert_pretrained_dir= os.environ.get("bert_pretrained_dir")#"/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
|
277 |
+
'''
|
278 |
+
ps1a=[]
|
279 |
+
def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
|
280 |
+
global ps1a
|
281 |
+
if (ps1a == []):
|
282 |
+
config={
|
283 |
+
"inp_text":inp_text,
|
284 |
+
"inp_wav_dir":inp_wav_dir,
|
285 |
+
"exp_name":exp_name,
|
286 |
+
"opt_dir":"%s/%s"%(exp_root,exp_name),
|
287 |
+
"bert_pretrained_dir":bert_pretrained_dir,
|
288 |
+
}
|
289 |
+
gpu_names=gpu_numbers.split("-")
|
290 |
+
all_parts=len(gpu_names)
|
291 |
+
for i_part in range(all_parts):
|
292 |
+
config.update(
|
293 |
+
{
|
294 |
+
"i_part": str(i_part),
|
295 |
+
"all_parts": str(all_parts),
|
296 |
+
"_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
|
297 |
+
"is_half": str(is_half)
|
298 |
+
}
|
299 |
+
)
|
300 |
+
os.environ.update(config)
|
301 |
+
cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
|
302 |
+
print(cmd)
|
303 |
+
p = Popen(cmd, shell=True)
|
304 |
+
ps1a.append(p)
|
305 |
+
yield "文本进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
306 |
+
for p in ps1a:
|
307 |
+
p.wait()
|
308 |
+
ps1a=[]
|
309 |
+
yield "文本进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
310 |
+
else:
|
311 |
+
yield "已有正在进行的文本任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
312 |
+
|
313 |
+
def close1a():
|
314 |
+
global ps1a
|
315 |
+
if (ps1a != []):
|
316 |
+
for p1a in ps1a:
|
317 |
+
try:
|
318 |
+
kill_process(p1a.pid)
|
319 |
+
except:
|
320 |
+
traceback.print_exc()
|
321 |
+
ps1a=[]
|
322 |
+
return "已终止所有1a进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
323 |
+
'''
|
324 |
+
inp_text= os.environ.get("inp_text")
|
325 |
+
inp_wav_dir= os.environ.get("inp_wav_dir")
|
326 |
+
exp_name= os.environ.get("exp_name")
|
327 |
+
i_part= os.environ.get("i_part")
|
328 |
+
all_parts= os.environ.get("all_parts")
|
329 |
+
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
|
330 |
+
opt_dir= os.environ.get("opt_dir")
|
331 |
+
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
|
332 |
+
'''
|
333 |
+
ps1b=[]
|
334 |
+
def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
|
335 |
+
global ps1b
|
336 |
+
if (ps1b == []):
|
337 |
+
config={
|
338 |
+
"inp_text":inp_text,
|
339 |
+
"inp_wav_dir":inp_wav_dir,
|
340 |
+
"exp_name":exp_name,
|
341 |
+
"opt_dir":"%s/%s"%(exp_root,exp_name),
|
342 |
+
"cnhubert_base_dir":ssl_pretrained_dir,
|
343 |
+
"is_half": str(is_half)
|
344 |
+
}
|
345 |
+
gpu_names=gpu_numbers.split("-")
|
346 |
+
all_parts=len(gpu_names)
|
347 |
+
for i_part in range(all_parts):
|
348 |
+
config.update(
|
349 |
+
{
|
350 |
+
"i_part": str(i_part),
|
351 |
+
"all_parts": str(all_parts),
|
352 |
+
"_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
|
353 |
+
}
|
354 |
+
)
|
355 |
+
os.environ.update(config)
|
356 |
+
cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
|
357 |
+
print(cmd)
|
358 |
+
p = Popen(cmd, shell=True)
|
359 |
+
ps1b.append(p)
|
360 |
+
yield "SSL提取进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
361 |
+
for p in ps1b:
|
362 |
+
p.wait()
|
363 |
+
ps1b=[]
|
364 |
+
yield "SSL提取进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
365 |
+
else:
|
366 |
+
yield "已有正在进行的SSL提取任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
367 |
+
|
368 |
+
def close1b():
|
369 |
+
global ps1b
|
370 |
+
if (ps1b != []):
|
371 |
+
for p1b in ps1b:
|
372 |
+
try:
|
373 |
+
kill_process(p1b.pid)
|
374 |
+
except:
|
375 |
+
traceback.print_exc()
|
376 |
+
ps1b=[]
|
377 |
+
return "已终止所有1b进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
378 |
+
'''
|
379 |
+
inp_text= os.environ.get("inp_text")
|
380 |
+
exp_name= os.environ.get("exp_name")
|
381 |
+
i_part= os.environ.get("i_part")
|
382 |
+
all_parts= os.environ.get("all_parts")
|
383 |
+
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
|
384 |
+
opt_dir= os.environ.get("opt_dir")
|
385 |
+
pretrained_s2G= os.environ.get("pretrained_s2G")
|
386 |
+
'''
|
387 |
+
ps1c=[]
|
388 |
+
def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
|
389 |
+
global ps1c
|
390 |
+
if (ps1c == []):
|
391 |
+
config={
|
392 |
+
"inp_text":inp_text,
|
393 |
+
"exp_name":exp_name,
|
394 |
+
"opt_dir":"%s/%s"%(exp_root,exp_name),
|
395 |
+
"pretrained_s2G":pretrained_s2G_path,
|
396 |
+
"s2config_path":"GPT_SoVITS/configs/s2.json",
|
397 |
+
"is_half": str(is_half)
|
398 |
+
}
|
399 |
+
gpu_names=gpu_numbers.split("-")
|
400 |
+
all_parts=len(gpu_names)
|
401 |
+
for i_part in range(all_parts):
|
402 |
+
config.update(
|
403 |
+
{
|
404 |
+
"i_part": str(i_part),
|
405 |
+
"all_parts": str(all_parts),
|
406 |
+
"_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
|
407 |
+
}
|
408 |
+
)
|
409 |
+
os.environ.update(config)
|
410 |
+
cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
|
411 |
+
print(cmd)
|
412 |
+
p = Popen(cmd, shell=True)
|
413 |
+
ps1c.append(p)
|
414 |
+
yield "语义token提取进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
415 |
+
for p in ps1c:
|
416 |
+
p.wait()
|
417 |
+
ps1c=[]
|
418 |
+
yield "语义token提取进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
419 |
+
else:
|
420 |
+
yield "已有正在进行的语义token提取任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
421 |
+
|
422 |
+
def close1c():
|
423 |
+
global ps1c
|
424 |
+
if (ps1c != []):
|
425 |
+
for p1c in ps1c:
|
426 |
+
try:
|
427 |
+
kill_process(p1c.pid)
|
428 |
+
except:
|
429 |
+
traceback.print_exc()
|
430 |
+
ps1c=[]
|
431 |
+
return "已终止所有语义token进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
432 |
+
#####inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G
|
433 |
+
ps1abc=[]
|
434 |
+
def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
|
435 |
+
global ps1abc
|
436 |
+
if (ps1abc == []):
|
437 |
+
opt_dir="%s/%s"%(exp_root,exp_name)
|
438 |
+
try:
|
439 |
+
#############################1a
|
440 |
+
path_text="%s/2-name2text.txt" % opt_dir
|
441 |
+
if(os.path.exists(path_text)==False):
|
442 |
+
config={
|
443 |
+
"inp_text":inp_text,
|
444 |
+
"inp_wav_dir":inp_wav_dir,
|
445 |
+
"exp_name":exp_name,
|
446 |
+
"opt_dir":opt_dir,
|
447 |
+
"bert_pretrained_dir":bert_pretrained_dir,
|
448 |
+
"is_half": str(is_half)
|
449 |
+
}
|
450 |
+
gpu_names=gpu_numbers1a.split("-")
|
451 |
+
all_parts=len(gpu_names)
|
452 |
+
for i_part in range(all_parts):
|
453 |
+
config.update(
|
454 |
+
{
|
455 |
+
"i_part": str(i_part),
|
456 |
+
"all_parts": str(all_parts),
|
457 |
+
"_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
|
458 |
+
}
|
459 |
+
)
|
460 |
+
os.environ.update(config)
|
461 |
+
cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
|
462 |
+
print(cmd)
|
463 |
+
p = Popen(cmd, shell=True)
|
464 |
+
ps1abc.append(p)
|
465 |
+
yield "进度:1a-ing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
466 |
+
for p in ps1abc:p.wait()
|
467 |
+
|
468 |
+
opt = []
|
469 |
+
for i_part in range(all_parts):#txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
|
470 |
+
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
471 |
+
with open(txt_path, "r",encoding="utf8") as f:
|
472 |
+
opt += f.read().strip("\n").split("\n")
|
473 |
+
os.remove(txt_path)
|
474 |
+
with open(path_text, "w",encoding="utf8") as f:
|
475 |
+
f.write("\n".join(opt) + "\n")
|
476 |
+
|
477 |
+
yield "进度:1a-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
478 |
+
ps1abc=[]
|
479 |
+
#############################1b
|
480 |
+
config={
|
481 |
+
"inp_text":inp_text,
|
482 |
+
"inp_wav_dir":inp_wav_dir,
|
483 |
+
"exp_name":exp_name,
|
484 |
+
"opt_dir":opt_dir,
|
485 |
+
"cnhubert_base_dir":ssl_pretrained_dir,
|
486 |
+
}
|
487 |
+
gpu_names=gpu_numbers1Ba.split("-")
|
488 |
+
all_parts=len(gpu_names)
|
489 |
+
for i_part in range(all_parts):
|
490 |
+
config.update(
|
491 |
+
{
|
492 |
+
"i_part": str(i_part),
|
493 |
+
"all_parts": str(all_parts),
|
494 |
+
"_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
|
495 |
+
}
|
496 |
+
)
|
497 |
+
os.environ.update(config)
|
498 |
+
cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
|
499 |
+
print(cmd)
|
500 |
+
p = Popen(cmd, shell=True)
|
501 |
+
ps1abc.append(p)
|
502 |
+
yield "进度:1a-done, 1b-ing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
503 |
+
for p in ps1abc:p.wait()
|
504 |
+
yield "进度:1a1b-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
505 |
+
ps1abc=[]
|
506 |
+
#############################1c
|
507 |
+
path_semantic = "%s/6-name2semantic.tsv" % opt_dir
|
508 |
+
if(os.path.exists(path_semantic)==False):
|
509 |
+
config={
|
510 |
+
"inp_text":inp_text,
|
511 |
+
"exp_name":exp_name,
|
512 |
+
"opt_dir":opt_dir,
|
513 |
+
"pretrained_s2G":pretrained_s2G_path,
|
514 |
+
"s2config_path":"GPT_SoVITS/configs/s2.json",
|
515 |
+
}
|
516 |
+
gpu_names=gpu_numbers1c.split("-")
|
517 |
+
all_parts=len(gpu_names)
|
518 |
+
for i_part in range(all_parts):
|
519 |
+
config.update(
|
520 |
+
{
|
521 |
+
"i_part": str(i_part),
|
522 |
+
"all_parts": str(all_parts),
|
523 |
+
"_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
|
524 |
+
}
|
525 |
+
)
|
526 |
+
os.environ.update(config)
|
527 |
+
cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
|
528 |
+
print(cmd)
|
529 |
+
p = Popen(cmd, shell=True)
|
530 |
+
ps1abc.append(p)
|
531 |
+
yield "进度:1a1b-done, 1cing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
532 |
+
for p in ps1abc:p.wait()
|
533 |
+
|
534 |
+
opt = ["item_name semantic_audio"]
|
535 |
+
for i_part in range(all_parts):
|
536 |
+
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
|
537 |
+
with open(semantic_path, "r",encoding="utf8") as f:
|
538 |
+
opt += f.read().strip("\n").split("\n")
|
539 |
+
os.remove(semantic_path)
|
540 |
+
with open(path_semantic, "w",encoding="utf8") as f:
|
541 |
+
f.write("\n".join(opt) + "\n")
|
542 |
+
yield "进度:all-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
543 |
+
ps1abc = []
|
544 |
+
yield "一键三连进程结束", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
545 |
+
except:
|
546 |
+
traceback.print_exc()
|
547 |
+
close1abc()
|
548 |
+
yield "一键三连中途报错", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
549 |
+
else:
|
550 |
+
yield "已有正在进行的一键三连任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
|
551 |
+
|
552 |
+
def close1abc():
|
553 |
+
global ps1abc
|
554 |
+
if (ps1abc != []):
|
555 |
+
for p1abc in ps1abc:
|
556 |
+
try:
|
557 |
+
kill_process(p1abc.pid)
|
558 |
+
except:
|
559 |
+
traceback.print_exc()
|
560 |
+
ps1abc=[]
|
561 |
+
return "已终止所有一键三连进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
|
562 |
+
|
563 |
+
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
564 |
+
gr.Markdown(
|
565 |
+
value=
|
566 |
+
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
|
567 |
+
)
|
568 |
+
with gr.Tabs():
|
569 |
+
with gr.TabItem("0-前置数据集获取工具"):#提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标
|
570 |
+
gr.Markdown(value="0a-UVR5人声伴奏分离&去混响去延迟工具")
|
571 |
+
with gr.Row():
|
572 |
+
if_uvr5 = gr.Checkbox(label="是否开启UVR5-WebUI",show_label=True)
|
573 |
+
uvr5_info = gr.Textbox(label="UVR5进程输出信息")
|
574 |
+
gr.Markdown(value="0b-语音切分工具")
|
575 |
+
with gr.Row():
|
576 |
+
with gr.Row():
|
577 |
+
slice_inp_path=gr.Textbox(label="音频自动切分输入路径,可文件可文件夹",value="")
|
578 |
+
slice_opt_root=gr.Textbox(label="切分后的子音频的输出根目录",value="output/slicer_opt")
|
579 |
+
threshold=gr.Textbox(label="threshold:音量小于这个值视作静音的备选切割点",value="-34")
|
580 |
+
min_length=gr.Textbox(label="min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值",value="4000")
|
581 |
+
min_interval=gr.Textbox(label="min_interval:最短切割间隔",value="300")
|
582 |
+
hop_size=gr.Textbox(label="hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)",value="10")
|
583 |
+
max_sil_kept=gr.Textbox(label="max_sil_kept:切完后静音最多留多长",value="500")
|
584 |
+
with gr.Row():
|
585 |
+
open_slicer_button=gr.Button("开启语音切割", variant="primary",visible=True)
|
586 |
+
close_slicer_button=gr.Button("终止语音切割", variant="primary",visible=False)
|
587 |
+
_max=gr.Slider(minimum=0,maximum=1,step=0.05,label="max:归一化后最大值多少",value=0.9,interactive=True)
|
588 |
+
alpha=gr.Slider(minimum=0,maximum=1,step=0.05,label="alpha_mix:混多少比例归一化后音频进来",value=0.25,interactive=True)
|
589 |
+
n_process=gr.Slider(minimum=1,maximum=n_cpu,step=1,label="切割使用的进程数",value=4,interactive=True)
|
590 |
+
slicer_info = gr.Textbox(label="语音切割进程输出信息")
|
591 |
+
gr.Markdown(value="0c-中文批量离线ASR工具")
|
592 |
+
with gr.Row():
|
593 |
+
open_asr_button = gr.Button("开启离线批量ASR", variant="primary",visible=True)
|
594 |
+
close_asr_button = gr.Button("终止ASR进程", variant="primary",visible=False)
|
595 |
+
asr_inp_dir = gr.Textbox(
|
596 |
+
label="批量ASR(中文only)输入文件夹路径",
|
597 |
+
value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx",
|
598 |
+
interactive=True,
|
599 |
+
)
|
600 |
+
asr_info = gr.Textbox(label="ASR进程输出信息")
|
601 |
+
gr.Markdown(value="0d-语音文本校对标注工具")
|
602 |
+
with gr.Row():
|
603 |
+
if_label = gr.Checkbox(label="是否开启打标WebUI",show_label=True)
|
604 |
+
path_list = gr.Textbox(
|
605 |
+
label="打标数据标注文件路径",
|
606 |
+
value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list",
|
607 |
+
interactive=True,
|
608 |
+
)
|
609 |
+
label_info = gr.Textbox(label="打标工具进程输出信息")
|
610 |
+
if_label.change(change_label, [if_label,path_list], [label_info])
|
611 |
+
if_uvr5.change(change_uvr5, [if_uvr5], [uvr5_info])
|
612 |
+
open_asr_button.click(open_asr, [asr_inp_dir], [asr_info,open_asr_button,close_asr_button])
|
613 |
+
close_asr_button.click(close_asr, [], [asr_info,open_asr_button,close_asr_button])
|
614 |
+
open_slicer_button.click(open_slice, [slice_inp_path,slice_opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_process], [slicer_info,open_slicer_button,close_slicer_button])
|
615 |
+
close_slicer_button.click(close_slice, [], [slicer_info,open_slicer_button,close_slicer_button])
|
616 |
+
with gr.TabItem("1-GPT-SoVITS-TTS"):
|
617 |
+
with gr.Row():
|
618 |
+
exp_name = gr.Textbox(label="*实验/模型名", value="xxx", interactive=True)
|
619 |
+
gpu_info = gr.Textbox(label="显卡信息", value=gpu_info, visible=True, interactive=False)
|
620 |
+
pretrained_s2G = gr.Textbox(label="预训练的SoVITS-G模型路径", value="GPT_SoVITS/pretrained_models/s2G488k.pth", interactive=True)
|
621 |
+
pretrained_s2D = gr.Textbox(label="预训练的SoVITS-D模型路径", value="GPT_SoVITS/pretrained_models/s2D488k.pth", interactive=True)
|
622 |
+
pretrained_s1 = gr.Textbox(label="预训练的GPT模型路径", value="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", interactive=True)
|
623 |
+
with gr.TabItem("1A-训练集格式化工具"):
|
624 |
+
gr.Markdown(value="输出logs/实验名目录下应有23456开头的文件和文件夹")
|
625 |
+
with gr.Row():
|
626 |
+
inp_text = gr.Textbox(label="*文本标注文件",value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list",interactive=True)
|
627 |
+
inp_wav_dir = gr.Textbox(label="*训练集音频文件目录",value=r"D:\RVC1006\GPT-SoVITS\raw\xxx",interactive=True)
|
628 |
+
gr.Markdown(value="1Aa-文本内容")
|
629 |
+
with gr.Row():
|
630 |
+
gpu_numbers1a = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
|
631 |
+
bert_pretrained_dir = gr.Textbox(label="预训练的中文BERT模型路径",value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",interactive=False)
|
632 |
+
button1a_open = gr.Button("开启文本获取", variant="primary",visible=True)
|
633 |
+
button1a_close = gr.Button("终止文本获取进程", variant="primary",visible=False)
|
634 |
+
info1a=gr.Textbox(label="文本进程输出信息")
|
635 |
+
gr.Markdown(value="1Ab-SSL自监督特征提取")
|
636 |
+
with gr.Row():
|
637 |
+
gpu_numbers1Ba = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
|
638 |
+
cnhubert_base_dir = gr.Textbox(label="预训练的SSL模型路径",value="GPT_SoVITS/pretrained_models/chinese-hubert-base",interactive=False)
|
639 |
+
button1b_open = gr.Button("开启SSL提取", variant="primary",visible=True)
|
640 |
+
button1b_close = gr.Button("终止SSL提取进程", variant="primary",visible=False)
|
641 |
+
info1b=gr.Textbox(label="SSL进程输出信息")
|
642 |
+
gr.Markdown(value="1Ac-语义token提取")
|
643 |
+
with gr.Row():
|
644 |
+
gpu_numbers1c = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
|
645 |
+
button1c_open = gr.Button("开启语义token提取", variant="primary",visible=True)
|
646 |
+
button1c_close = gr.Button("终止语义token提取进程", variant="primary",visible=False)
|
647 |
+
info1c=gr.Textbox(label="语义token提取进程输出信息")
|
648 |
+
gr.Markdown(value="1Aabc-训练集格式化一键三连")
|
649 |
+
with gr.Row():
|
650 |
+
button1abc_open = gr.Button("开启一键三连", variant="primary",visible=True)
|
651 |
+
button1abc_close = gr.Button("终止一键三连", variant="primary",visible=False)
|
652 |
+
info1abc=gr.Textbox(label="一键三连进程输出信息")
|
653 |
+
button1a_open.click(open1a, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,bert_pretrained_dir], [info1a,button1a_open,button1a_close])
|
654 |
+
button1a_close.click(close1a, [], [info1a,button1a_open,button1a_close])
|
655 |
+
button1b_open.click(open1b, [inp_text,inp_wav_dir,exp_name,gpu_numbers1Ba,cnhubert_base_dir], [info1b,button1b_open,button1b_close])
|
656 |
+
button1b_close.click(close1b, [], [info1b,button1b_open,button1b_close])
|
657 |
+
button1c_open.click(open1c, [inp_text,exp_name,gpu_numbers1c,pretrained_s2G], [info1c,button1c_open,button1c_close])
|
658 |
+
button1c_close.click(close1c, [], [info1c,button1c_open,button1c_close])
|
659 |
+
button1abc_open.click(open1abc, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G], [info1abc,button1abc_open,button1abc_close])
|
660 |
+
button1abc_close.click(close1abc, [], [info1abc,button1abc_open,button1abc_close])
|
661 |
+
with gr.TabItem("1B-微调训练"):
|
662 |
+
gr.Markdown(value="1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。")
|
663 |
+
with gr.Row():
|
664 |
+
batch_size = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
|
665 |
+
total_epoch = gr.Slider(minimum=2,maximum=100,step=1,label=i18n("总训练轮数total_epoch,不建议太高"),value=10,interactive=True)
|
666 |
+
text_low_lr_rate = gr.Slider(minimum=0.2,maximum=0.6,step=0.05,label="文本模块学习率权重",value=0.4,interactive=True)
|
667 |
+
save_every_epoch = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
|
668 |
+
if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
|
669 |
+
if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
|
670 |
+
gpu_numbers1Ba = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程", value="%s" % (gpus), interactive=True)
|
671 |
+
with gr.Row():
|
672 |
+
button1Ba_open = gr.Button("开启SoVITS训练", variant="primary",visible=True)
|
673 |
+
button1Ba_close = gr.Button("终止SoVITS训练", variant="primary",visible=False)
|
674 |
+
info1Ba=gr.Textbox(label="SoVITS训练进程输出信息")
|
675 |
+
gr.Markdown(value="1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。")
|
676 |
+
with gr.Row():
|
677 |
+
batch_size1Bb = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
|
678 |
+
total_epoch1Bb = gr.Slider(minimum=2,maximum=200,step=1,label=i18n("总训练轮数total_epoch"),value=15,interactive=True)
|
679 |
+
if_save_latest1Bb = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
|
680 |
+
if_save_every_weights1Bb = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
|
681 |
+
save_every_epoch1Bb = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
|
682 |
+
gpu_numbers1Bb = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程", value="%s" % (gpus), interactive=True)
|
683 |
+
with gr.Row():
|
684 |
+
button1Bb_open = gr.Button("开启GPT训练", variant="primary",visible=True)
|
685 |
+
button1Bb_close = gr.Button("终止GPT训练", variant="primary",visible=False)
|
686 |
+
info1Bb=gr.Textbox(label="GPT训练进程输出信息")
|
687 |
+
button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close])
|
688 |
+
button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
|
689 |
+
button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
|
690 |
+
button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
|
691 |
+
with gr.TabItem("1C-推理"):
|
692 |
+
gr.Markdown(value="选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。")
|
693 |
+
with gr.Row():
|
694 |
+
GPT_dropdown = gr.Dropdown(label="*GPT模型列表", choices=sorted(GPT_names),value=pretrained_gpt_name)
|
695 |
+
SoVITS_dropdown = gr.Dropdown(label="*SoVITS模型列表", choices=sorted(SoVITS_names),value=pretrained_sovits_name)
|
696 |
+
gpu_number_1C=gr.Textbox(label="GPU卡号,只能填1个整数", value=gpus, interactive=True)
|
697 |
+
refresh_button = gr.Button("刷新模型路径", variant="primary")
|
698 |
+
refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown])
|
699 |
+
with gr.Row():
|
700 |
+
if_tts = gr.Checkbox(label="是否开启TTS推理WebUI", show_label=True)
|
701 |
+
tts_info = gr.Textbox(label="TTS推理WebUI进程输出信息")
|
702 |
+
if_tts.change(change_tts_inference, [if_tts,bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info])
|
703 |
+
with gr.TabItem("2-GPT-SoVITS-变声"):gr.Markdown(value="施工中,请静候佳音")
|
704 |
+
|
705 |
+
'''
|
706 |
+
os.environ["gpt_path"]=gpt_path
|
707 |
+
os.environ["sovits_path"]=sovits_path#bert_pretrained_dir
|
708 |
+
os.environ["cnhubert_base_path"]=cnhubert_base_path#cnhubert_base_dir
|
709 |
+
os.environ["bert_path"]=bert_path
|
710 |
+
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_number
|
711 |
+
'''
|
712 |
+
|
713 |
+
app.queue(concurrency_count=511, max_size=1022).launch(
|
714 |
+
share=True,
|
715 |
+
server_name="0.0.0.0",
|
716 |
+
inbrowser=True,
|
717 |
+
server_port=7890,
|
718 |
+
quiet=True,
|
719 |
+
)
|
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/启动webui-checkpoint.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
python /root/autodl-tmp/workdir/GPT-SoVITS/webui.py
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/.ipynb_checkpoints/inference_webui-checkpoint.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
3 |
+
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
|
4 |
+
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
|
5 |
+
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
|
6 |
+
if("_CUDA_VISIBLE_DEVICES"in os.environ):
|
7 |
+
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
|
8 |
+
is_half=eval(os.environ.get("is_half","True"))
|
9 |
+
import gradio as gr
|
10 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
11 |
+
import sys,torch,numpy as np
|
12 |
+
from pathlib import Path
|
13 |
+
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
|
14 |
+
# torch.backends.cuda.sdp_kernel("flash")
|
15 |
+
# torch.backends.cuda.enable_flash_sdp(True)
|
16 |
+
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
|
17 |
+
# torch.backends.cuda.enable_math_sdp(True)
|
18 |
+
from random import shuffle
|
19 |
+
from AR.utils import get_newest_ckpt
|
20 |
+
from glob import glob
|
21 |
+
from tqdm import tqdm
|
22 |
+
from feature_extractor import cnhubert
|
23 |
+
cnhubert.cnhubert_base_path=cnhubert_base_path
|
24 |
+
from io import BytesIO
|
25 |
+
from module.models import SynthesizerTrn
|
26 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
27 |
+
from AR.utils.io import load_yaml_config
|
28 |
+
from text import cleaned_text_to_sequence
|
29 |
+
from text.cleaner import text_to_sequence, clean_text
|
30 |
+
from time import time as ttime
|
31 |
+
from module.mel_processing import spectrogram_torch
|
32 |
+
from my_utils import load_audio
|
33 |
+
|
34 |
+
device="cuda"
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
36 |
+
bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
|
37 |
+
if(is_half==True):bert_model=bert_model.half().to(device)
|
38 |
+
else:bert_model=bert_model.to(device)
|
39 |
+
# bert_model=bert_model.to(device)
|
40 |
+
def get_bert_feature(text, word2ph):
|
41 |
+
with torch.no_grad():
|
42 |
+
inputs = tokenizer(text, return_tensors="pt")
|
43 |
+
for i in inputs:
|
44 |
+
inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
|
45 |
+
res = bert_model(**inputs, output_hidden_states=True)
|
46 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
47 |
+
assert len(word2ph) == len(text)
|
48 |
+
phone_level_feature = []
|
49 |
+
for i in range(len(word2ph)):
|
50 |
+
repeat_feature = res[i].repeat(word2ph[i], 1)
|
51 |
+
phone_level_feature.append(repeat_feature)
|
52 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
53 |
+
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
54 |
+
return phone_level_feature.T
|
55 |
+
|
56 |
+
n_semantic = 1024
|
57 |
+
dict_s2=torch.load(sovits_path,map_location="cpu")
|
58 |
+
hps=dict_s2["config"]
|
59 |
+
class DictToAttrRecursive:
|
60 |
+
def __init__(self, input_dict):
|
61 |
+
for key, value in input_dict.items():
|
62 |
+
if isinstance(value, dict):
|
63 |
+
# 如果值是字典,递归调用构造函数
|
64 |
+
setattr(self, key, DictToAttrRecursive(value))
|
65 |
+
else:
|
66 |
+
setattr(self, key, value)
|
67 |
+
|
68 |
+
hps = DictToAttrRecursive(hps)
|
69 |
+
hps.model.semantic_frame_rate="25hz"
|
70 |
+
dict_s1=torch.load(gpt_path,map_location="cpu")
|
71 |
+
config=dict_s1["config"]
|
72 |
+
ssl_model=cnhubert.get_model()
|
73 |
+
if(is_half==True):ssl_model=ssl_model.half().to(device)
|
74 |
+
else:ssl_model=ssl_model.to(device)
|
75 |
+
|
76 |
+
vq_model = SynthesizerTrn(
|
77 |
+
hps.data.filter_length // 2 + 1,
|
78 |
+
hps.train.segment_size // hps.data.hop_length,
|
79 |
+
n_speakers=hps.data.n_speakers,
|
80 |
+
**hps.model)
|
81 |
+
if(is_half==True):vq_model=vq_model.half().to(device)
|
82 |
+
else:vq_model=vq_model.to(device)
|
83 |
+
vq_model.eval()
|
84 |
+
print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
|
85 |
+
hz = 50
|
86 |
+
max_sec = config['data']['max_sec']
|
87 |
+
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
|
88 |
+
t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
|
89 |
+
t2s_model.load_state_dict(dict_s1["weight"])
|
90 |
+
if(is_half==True):t2s_model=t2s_model.half()
|
91 |
+
t2s_model=t2s_model.to(device)
|
92 |
+
t2s_model.eval()
|
93 |
+
total = sum([param.nelement() for param in t2s_model.parameters()])
|
94 |
+
print("Number of parameter: %.2fM" % (total / 1e6))
|
95 |
+
def get_spepc(hps, filename):
|
96 |
+
audio=load_audio(filename,int(hps.data.sampling_rate))
|
97 |
+
audio=torch.FloatTensor(audio)
|
98 |
+
audio_norm = audio
|
99 |
+
audio_norm = audio_norm.unsqueeze(0)
|
100 |
+
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
|
101 |
+
return spec
|
102 |
+
|
103 |
+
dict_language={
|
104 |
+
"中文":"zh",
|
105 |
+
"英文":"en",
|
106 |
+
"日文":"ja"
|
107 |
+
}
|
108 |
+
def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
|
109 |
+
t0 = ttime()
|
110 |
+
prompt_text=prompt_text.strip("\n")
|
111 |
+
prompt_language,text=prompt_language,text.strip("\n")
|
112 |
+
with torch.no_grad():
|
113 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
|
114 |
+
wav16k = torch.from_numpy(wav16k)
|
115 |
+
if(is_half==True):wav16k=wav16k.half().to(device)
|
116 |
+
else:wav16k=wav16k.to(device)
|
117 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
|
118 |
+
codes = vq_model.extract_latent(ssl_content)
|
119 |
+
prompt_semantic = codes[0, 0]
|
120 |
+
t1 = ttime()
|
121 |
+
prompt_language=dict_language[prompt_language]
|
122 |
+
text_language=dict_language[text_language]
|
123 |
+
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
124 |
+
phones1=cleaned_text_to_sequence(phones1)
|
125 |
+
texts=text.split("\n")
|
126 |
+
audio_opt = []
|
127 |
+
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
|
128 |
+
for text in texts:
|
129 |
+
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
130 |
+
phones2 = cleaned_text_to_sequence(phones2)
|
131 |
+
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1)
|
132 |
+
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
|
133 |
+
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2)
|
134 |
+
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
135 |
+
bert = torch.cat([bert1, bert2], 1)
|
136 |
+
|
137 |
+
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
138 |
+
bert = bert.to(device).unsqueeze(0)
|
139 |
+
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
140 |
+
prompt = prompt_semantic.unsqueeze(0).to(device)
|
141 |
+
t2 = ttime()
|
142 |
+
with torch.no_grad():
|
143 |
+
# pred_semantic = t2s_model.model.infer(
|
144 |
+
pred_semantic,idx = t2s_model.model.infer_panel(
|
145 |
+
all_phoneme_ids,
|
146 |
+
all_phoneme_len,
|
147 |
+
prompt,
|
148 |
+
bert,
|
149 |
+
# prompt_phone_len=ph_offset,
|
150 |
+
top_k=config['inference']['top_k'],
|
151 |
+
early_stop_num=hz * max_sec)
|
152 |
+
t3 = ttime()
|
153 |
+
# print(pred_semantic.shape,idx)
|
154 |
+
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
155 |
+
refer = get_spepc(hps, ref_wav_path)#.to(device)
|
156 |
+
if(is_half==True):refer=refer.half().to(device)
|
157 |
+
else:refer=refer.to(device)
|
158 |
+
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
159 |
+
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
|
160 |
+
audio_opt.append(audio)
|
161 |
+
audio_opt.append(zero_wav)
|
162 |
+
t4 = ttime()
|
163 |
+
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
164 |
+
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
|
165 |
+
|
166 |
+
|
167 |
+
splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
|
168 |
+
def split(todo_text):
|
169 |
+
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
170 |
+
if (todo_text[-1] not in splits): todo_text += "。"
|
171 |
+
i_split_head = i_split_tail = 0
|
172 |
+
len_text = len(todo_text)
|
173 |
+
todo_texts = []
|
174 |
+
while (1):
|
175 |
+
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
176 |
+
if (todo_text[i_split_head] in splits):
|
177 |
+
i_split_head += 1
|
178 |
+
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
179 |
+
i_split_tail = i_split_head
|
180 |
+
else:
|
181 |
+
i_split_head += 1
|
182 |
+
return todo_texts
|
183 |
+
def cut1(inp):
|
184 |
+
inp=inp.strip("\n")
|
185 |
+
inps=split(inp)
|
186 |
+
split_idx=list(range(0,len(inps),5))
|
187 |
+
split_idx[-1]=None
|
188 |
+
if(len(split_idx)>1):
|
189 |
+
opts=[]
|
190 |
+
for idx in range(len(split_idx)-1):
|
191 |
+
opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
|
192 |
+
else:
|
193 |
+
opts=[inp]
|
194 |
+
return "\n".join(opts)
|
195 |
+
|
196 |
+
def cut2(inp):
|
197 |
+
inp=inp.strip("\n")
|
198 |
+
inps=split(inp)
|
199 |
+
if(len(inps)<2):return [inp]
|
200 |
+
opts=[]
|
201 |
+
summ=0
|
202 |
+
tmp_str=""
|
203 |
+
for i in range(len(inps)):
|
204 |
+
summ+=len(inps[i])
|
205 |
+
tmp_str+=inps[i]
|
206 |
+
if(summ>50):
|
207 |
+
summ=0
|
208 |
+
opts.append(tmp_str)
|
209 |
+
tmp_str=""
|
210 |
+
if(tmp_str!=""):opts.append(tmp_str)
|
211 |
+
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
|
212 |
+
opts[-2]=opts[-2]+opts[-1]
|
213 |
+
opts=opts[:-1]
|
214 |
+
return "\n".join(opts)
|
215 |
+
|
216 |
+
def cut3(inp):
|
217 |
+
inp=inp.strip("\n")
|
218 |
+
return "\n".join(["%s。"%item for item in inp.strip("。").split("。")])
|
219 |
+
|
220 |
+
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
221 |
+
gr.Markdown(
|
222 |
+
value=
|
223 |
+
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
|
224 |
+
)
|
225 |
+
# with gr.Tabs():
|
226 |
+
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
227 |
+
with gr.Group():
|
228 |
+
gr.Markdown(
|
229 |
+
value=
|
230 |
+
"*请上传并填写参考信息"
|
231 |
+
)
|
232 |
+
with gr.Row():
|
233 |
+
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
|
234 |
+
prompt_text= gr.Textbox(label="参考音频的文本",value="")
|
235 |
+
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"])
|
236 |
+
gr.Markdown(
|
237 |
+
value=
|
238 |
+
"*请填写需要合成的目标文本"
|
239 |
+
)
|
240 |
+
with gr.Row():
|
241 |
+
text=gr.Textbox(label="需要合成的文本",value="")
|
242 |
+
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"])
|
243 |
+
inference_button=gr.Button("合成语音", variant="primary")
|
244 |
+
output = gr.Audio(label="输出的语音")
|
245 |
+
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
|
246 |
+
|
247 |
+
gr.Markdown(
|
248 |
+
value=
|
249 |
+
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
250 |
+
)
|
251 |
+
with gr.Row():
|
252 |
+
text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
|
253 |
+
button1 = gr.Button("凑五句一切", variant="primary")
|
254 |
+
button2 = gr.Button("凑50字一切", variant="primary")
|
255 |
+
button3 = gr.Button("按中文句号。切", variant="primary")
|
256 |
+
text_opt = gr.Textbox(label="切分后文本", value="")
|
257 |
+
button1.click(cut1,[text_inp],[text_opt])
|
258 |
+
button2.click(cut2,[text_inp],[text_opt])
|
259 |
+
button3.click(cut3,[text_inp],[text_opt])
|
260 |
+
gr.Markdown(
|
261 |
+
value=
|
262 |
+
"后续将支持混合语种编码文本输入。"
|
263 |
+
)
|
264 |
+
|
265 |
+
app.queue(concurrency_count=511, max_size=1022).launch(
|
266 |
+
server_name="0.0.0.0",
|
267 |
+
inbrowser=True,
|
268 |
+
server_port=6006,
|
269 |
+
quiet=True,
|
270 |
+
)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/__init__.py
ADDED
File without changes
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/__init__.py
ADDED
File without changes
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/bucket_sampler.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
|
2 |
+
import itertools
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from random import shuffle
|
6 |
+
from typing import Iterator
|
7 |
+
from typing import Optional
|
8 |
+
from typing import TypeVar
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from torch.utils.data import Sampler
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"DistributedBucketSampler",
|
17 |
+
]
|
18 |
+
|
19 |
+
T_co = TypeVar('T_co', covariant=True)
|
20 |
+
|
21 |
+
|
22 |
+
class DistributedBucketSampler(Sampler[T_co]):
|
23 |
+
r"""
|
24 |
+
sort the dataset wrt. input length
|
25 |
+
divide samples into buckets
|
26 |
+
sort within buckets
|
27 |
+
divide buckets into batches
|
28 |
+
sort batches
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
dataset: Dataset,
|
33 |
+
num_replicas: Optional[int]=None,
|
34 |
+
rank: Optional[int]=None,
|
35 |
+
shuffle: bool=True,
|
36 |
+
seed: int=0,
|
37 |
+
drop_last: bool=False,
|
38 |
+
batch_size: int=32) -> None:
|
39 |
+
if num_replicas is None:
|
40 |
+
if not dist.is_available():
|
41 |
+
raise RuntimeError(
|
42 |
+
"Requires distributed package to be available")
|
43 |
+
num_replicas = dist.get_world_size()
|
44 |
+
if rank is None:
|
45 |
+
if not dist.is_available():
|
46 |
+
raise RuntimeError(
|
47 |
+
"Requires distributed package to be available")
|
48 |
+
rank = dist.get_rank()
|
49 |
+
torch.cuda.set_device(rank)
|
50 |
+
if rank >= num_replicas or rank < 0:
|
51 |
+
raise ValueError("Invalid rank {}, rank should be in the interval"
|
52 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
53 |
+
self.dataset = dataset
|
54 |
+
self.num_replicas = num_replicas
|
55 |
+
self.rank = rank
|
56 |
+
self.epoch = 0
|
57 |
+
self.drop_last = drop_last
|
58 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
59 |
+
# is no need to drop any data, since the dataset will be split equally.
|
60 |
+
if self.drop_last and len(
|
61 |
+
self.
|
62 |
+
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
63 |
+
# Split to nearest available length that is evenly divisible.
|
64 |
+
# This is to ensure each rank receives the same amount of data when
|
65 |
+
# using this Sampler.
|
66 |
+
self.num_samples = math.ceil(
|
67 |
+
(len(self.dataset) - self.num_replicas) /
|
68 |
+
self.num_replicas # type: ignore[arg-type]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.num_samples = math.ceil(
|
72 |
+
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
73 |
+
self.total_size = self.num_samples * self.num_replicas
|
74 |
+
self.shuffle = shuffle
|
75 |
+
self.seed = seed
|
76 |
+
self.batch_size = batch_size
|
77 |
+
self.id_with_length = self._get_sample_lengths()
|
78 |
+
self.id_buckets = self.make_buckets(bucket_width=2.0)
|
79 |
+
|
80 |
+
def _get_sample_lengths(self):
|
81 |
+
id_with_lengths = []
|
82 |
+
for i in range(len(self.dataset)):
|
83 |
+
id_with_lengths.append((i, self.dataset.get_sample_length(i)))
|
84 |
+
id_with_lengths.sort(key=lambda x: x[1])
|
85 |
+
return id_with_lengths
|
86 |
+
|
87 |
+
def make_buckets(self, bucket_width: float=2.0):
|
88 |
+
buckets = []
|
89 |
+
cur = []
|
90 |
+
max_sec = bucket_width
|
91 |
+
for id, sec in self.id_with_length:
|
92 |
+
if sec < max_sec:
|
93 |
+
cur.append(id)
|
94 |
+
else:
|
95 |
+
buckets.append(cur)
|
96 |
+
cur = [id]
|
97 |
+
max_sec += bucket_width
|
98 |
+
if len(cur) > 0:
|
99 |
+
buckets.append(cur)
|
100 |
+
return buckets
|
101 |
+
|
102 |
+
def __iter__(self) -> Iterator[T_co]:
|
103 |
+
if self.shuffle:
|
104 |
+
# deterministically shuffle based on epoch and seed
|
105 |
+
g = torch.Generator()
|
106 |
+
g.manual_seed(self.seed + self.epoch)
|
107 |
+
random.seed(self.epoch + self.seed)
|
108 |
+
shuffled_bucket = []
|
109 |
+
for buc in self.id_buckets:
|
110 |
+
buc_copy = buc.copy()
|
111 |
+
shuffle(buc_copy)
|
112 |
+
shuffled_bucket.append(buc_copy)
|
113 |
+
grouped_batch_size = self.batch_size * self.num_replicas
|
114 |
+
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
115 |
+
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
116 |
+
batches = [
|
117 |
+
shuffled_bucket[b * grouped_batch_size:(b + 1) *
|
118 |
+
grouped_batch_size] for b in range(n_batch)
|
119 |
+
]
|
120 |
+
shuffle(batches)
|
121 |
+
indices = list(itertools.chain(*batches))
|
122 |
+
else:
|
123 |
+
# type: ignore[arg-type]
|
124 |
+
indices = list(range(len(self.dataset)))
|
125 |
+
|
126 |
+
if not self.drop_last:
|
127 |
+
# add extra samples to make it evenly divisible
|
128 |
+
padding_size = self.total_size - len(indices)
|
129 |
+
if padding_size <= len(indices):
|
130 |
+
indices += indices[:padding_size]
|
131 |
+
else:
|
132 |
+
indices += (indices * math.ceil(padding_size /
|
133 |
+
len(indices)))[:padding_size]
|
134 |
+
else:
|
135 |
+
# remove tail of data to make it evenly divisible.
|
136 |
+
indices = indices[:self.total_size]
|
137 |
+
assert len(indices) == self.total_size
|
138 |
+
|
139 |
+
# subsample
|
140 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
141 |
+
assert len(indices) == self.num_samples
|
142 |
+
|
143 |
+
return iter(indices)
|
144 |
+
|
145 |
+
def __len__(self) -> int:
|
146 |
+
return self.num_samples
|
147 |
+
|
148 |
+
def set_epoch(self, epoch: int) -> None:
|
149 |
+
r"""
|
150 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
151 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
152 |
+
sampler will yield the same ordering.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
epoch (int): Epoch number.
|
156 |
+
"""
|
157 |
+
self.epoch = epoch
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/data_module.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
2 |
+
from pytorch_lightning import LightningDataModule
|
3 |
+
from AR.data.bucket_sampler import DistributedBucketSampler
|
4 |
+
from AR.data.dataset import Text2SemanticDataset
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
|
8 |
+
class Text2SemanticDataModule(LightningDataModule):
|
9 |
+
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
|
10 |
+
super().__init__()
|
11 |
+
self.config = config
|
12 |
+
self.train_semantic_path = train_semantic_path
|
13 |
+
self.train_phoneme_path = train_phoneme_path
|
14 |
+
self.dev_semantic_path = dev_semantic_path
|
15 |
+
self.dev_phoneme_path = dev_phoneme_path
|
16 |
+
self.num_workers = self.config['data']['num_workers']
|
17 |
+
|
18 |
+
def prepare_data(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def setup(self, stage=None, output_logs=False):
|
22 |
+
self._train_dataset = Text2SemanticDataset(
|
23 |
+
phoneme_path=self.train_phoneme_path,
|
24 |
+
semantic_path=self.train_semantic_path,
|
25 |
+
max_sec=self.config['data']['max_sec'],
|
26 |
+
pad_val=self.config['data']['pad_val'])
|
27 |
+
self._dev_dataset = self._train_dataset
|
28 |
+
# self._dev_dataset = Text2SemanticDataset(
|
29 |
+
# phoneme_path=self.dev_phoneme_path,
|
30 |
+
# semantic_path=self.dev_semantic_path,
|
31 |
+
# max_sample=self.config['data']['max_eval_sample'],
|
32 |
+
# max_sec=self.config['data']['max_sec'],
|
33 |
+
# pad_val=self.config['data']['pad_val'])
|
34 |
+
|
35 |
+
def train_dataloader(self):
|
36 |
+
batch_size = self.config['train']['batch_size']
|
37 |
+
sampler = DistributedBucketSampler(
|
38 |
+
self._train_dataset, batch_size=batch_size)
|
39 |
+
return DataLoader(
|
40 |
+
self._train_dataset,
|
41 |
+
batch_size=batch_size,
|
42 |
+
sampler=sampler,
|
43 |
+
collate_fn=self._train_dataset.collate,
|
44 |
+
num_workers=self.num_workers,
|
45 |
+
persistent_workers=True,
|
46 |
+
prefetch_factor=16
|
47 |
+
)
|
48 |
+
|
49 |
+
def val_dataloader(self):
|
50 |
+
return DataLoader(
|
51 |
+
self._dev_dataset,
|
52 |
+
batch_size=1,
|
53 |
+
shuffle=False,
|
54 |
+
collate_fn=self._train_dataset.collate,
|
55 |
+
num_workers=max(self.num_workers,12),
|
56 |
+
persistent_workers=True,
|
57 |
+
prefetch_factor=16
|
58 |
+
)
|
59 |
+
|
60 |
+
# 这个会使用到嘛?
|
61 |
+
def test_dataloader(self):
|
62 |
+
return DataLoader(
|
63 |
+
self._dev_dataset,
|
64 |
+
batch_size=1,
|
65 |
+
shuffle=False,
|
66 |
+
collate_fn=self._train_dataset.collate)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/dataset.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
|
2 |
+
import pdb
|
3 |
+
import sys
|
4 |
+
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
5 |
+
import traceback,os
|
6 |
+
from typing import Dict
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch,json
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
from text import cleaned_text_to_sequence
|
17 |
+
# from config import exp_dir
|
18 |
+
|
19 |
+
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
|
20 |
+
seq = sequences[0]
|
21 |
+
ndim = seq.ndim
|
22 |
+
if axis < 0:
|
23 |
+
axis += ndim
|
24 |
+
dtype = seq.dtype
|
25 |
+
pad_value = dtype.type(pad_value)
|
26 |
+
seq_lengths = [seq.shape[axis] for seq in sequences]
|
27 |
+
max_length = np.max(seq_lengths)
|
28 |
+
|
29 |
+
padded_sequences = []
|
30 |
+
for seq, length in zip(sequences, seq_lengths):
|
31 |
+
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
|
32 |
+
ndim - axis - 1)
|
33 |
+
padded_seq = np.pad(
|
34 |
+
seq, padding, mode='constant', constant_values=pad_value)
|
35 |
+
padded_sequences.append(padded_seq)
|
36 |
+
batch = np.stack(padded_sequences)
|
37 |
+
return batch
|
38 |
+
|
39 |
+
class Text2SemanticDataset(Dataset):
|
40 |
+
"""dataset class for text tokens to semantic model training."""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
phoneme_path: str,
|
44 |
+
semantic_path: str,
|
45 |
+
max_sample: int = None,
|
46 |
+
max_sec: int = 100,
|
47 |
+
pad_val: int = 1024,
|
48 |
+
# min value of phoneme/sec
|
49 |
+
min_ps_ratio: int = 3,
|
50 |
+
# max value of phoneme/sec
|
51 |
+
max_ps_ratio: int = 25) -> None:
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
|
55 |
+
# get dict
|
56 |
+
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
|
57 |
+
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
|
58 |
+
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
59 |
+
assert os.path.exists(self.path2)
|
60 |
+
assert os.path.exists(self.path6)
|
61 |
+
self.phoneme_data={}
|
62 |
+
with open(self.path2,"r",encoding="utf8")as f:
|
63 |
+
lines=f.read().strip("\n").split("\n")
|
64 |
+
|
65 |
+
for line in lines:
|
66 |
+
tmp=line.split("\t")
|
67 |
+
if(len(tmp)!=4):continue
|
68 |
+
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
|
69 |
+
|
70 |
+
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
|
71 |
+
# pad for semantic tokens
|
72 |
+
self.PAD: int = pad_val
|
73 |
+
# self.hz = 25
|
74 |
+
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
|
75 |
+
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
|
76 |
+
# self.hz=int(data[:-2])#
|
77 |
+
self.hz=int(os.environ.get("hz","25hz")[:-2])
|
78 |
+
|
79 |
+
# max seconds of semantic token
|
80 |
+
self.max_sec = max_sec
|
81 |
+
self.min_ps_ratio = min_ps_ratio
|
82 |
+
self.max_ps_ratio = max_ps_ratio
|
83 |
+
|
84 |
+
if max_sample is not None:
|
85 |
+
self.semantic_data = self.semantic_data[:max_sample]
|
86 |
+
|
87 |
+
# {idx: (semantic, phoneme)}
|
88 |
+
# semantic list, phoneme list
|
89 |
+
self.semantic_phoneme = []
|
90 |
+
self.item_names = []
|
91 |
+
|
92 |
+
self.inited = False
|
93 |
+
|
94 |
+
if not self.inited:
|
95 |
+
# 调用初始化函数
|
96 |
+
self.init_batch()
|
97 |
+
self.inited = True
|
98 |
+
del self.semantic_data
|
99 |
+
del self.phoneme_data
|
100 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
|
101 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
|
102 |
+
|
103 |
+
|
104 |
+
def init_batch(self):
|
105 |
+
semantic_data_len = len(self.semantic_data)
|
106 |
+
phoneme_data_len = len(self.phoneme_data.keys())
|
107 |
+
print("semantic_data_len:", semantic_data_len)
|
108 |
+
print("phoneme_data_len:", phoneme_data_len)
|
109 |
+
idx = 0
|
110 |
+
num_not_in = 0
|
111 |
+
num_deleted_bigger = 0
|
112 |
+
num_deleted_ps = 0
|
113 |
+
for i in range(semantic_data_len):
|
114 |
+
# 先依次遍历
|
115 |
+
# get str
|
116 |
+
item_name = self.semantic_data['item_name'][i]
|
117 |
+
# print(self.phoneme_data)
|
118 |
+
try:
|
119 |
+
phoneme, word2ph, text = self.phoneme_data[item_name]
|
120 |
+
except Exception:
|
121 |
+
traceback.print_exc()
|
122 |
+
# print(f"{item_name} not in self.phoneme_data !")
|
123 |
+
num_not_in += 1
|
124 |
+
continue
|
125 |
+
|
126 |
+
semantic_str = self.semantic_data['semantic_audio'][i]
|
127 |
+
# get token list
|
128 |
+
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
|
129 |
+
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
130 |
+
# 过滤掉太长的样本
|
131 |
+
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token���数推测总时长过滤时长60s(config里)#40*25=1k
|
132 |
+
num_deleted_bigger += 1
|
133 |
+
continue
|
134 |
+
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
|
135 |
+
phoneme = phoneme.split(' ')
|
136 |
+
|
137 |
+
try:
|
138 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
139 |
+
except:
|
140 |
+
traceback.print_exc()
|
141 |
+
# print(f"{item_name} not in self.phoneme_data !")
|
142 |
+
num_not_in += 1
|
143 |
+
continue
|
144 |
+
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
145 |
+
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行
|
146 |
+
num_deleted_ps += 1
|
147 |
+
continue
|
148 |
+
# if len(semantic_ids) > 1000:###########3
|
149 |
+
# num_deleted_bigger += 1
|
150 |
+
# continue
|
151 |
+
|
152 |
+
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
153 |
+
|
154 |
+
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
|
155 |
+
num_deleted_ps += 1
|
156 |
+
# print(item_name)
|
157 |
+
continue
|
158 |
+
|
159 |
+
self.semantic_phoneme.append((semantic_ids, phoneme_ids))
|
160 |
+
idx += 1
|
161 |
+
self.item_names.append(item_name)
|
162 |
+
|
163 |
+
min_num=100#20直接不补#30补了也不存ckpt
|
164 |
+
leng =len(self.semantic_phoneme)
|
165 |
+
if(leng<min_num):
|
166 |
+
tmp1=self.semantic_phoneme
|
167 |
+
tmp2=self.item_names
|
168 |
+
self.semantic_phoneme=[]
|
169 |
+
self.item_names=[]
|
170 |
+
for _ in range(max(2,int(min_num/leng))):
|
171 |
+
self.semantic_phoneme+=tmp1
|
172 |
+
self.item_names+=tmp2
|
173 |
+
if num_not_in > 0:
|
174 |
+
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
175 |
+
if num_deleted_bigger > 0:
|
176 |
+
print(
|
177 |
+
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
178 |
+
)
|
179 |
+
if num_deleted_ps > 0:
|
180 |
+
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
181 |
+
print(
|
182 |
+
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
183 |
+
)
|
184 |
+
'''
|
185 |
+
there are 31 semantic datas not in phoneme datas
|
186 |
+
deleted 34 audios who's duration are bigger than 54 seconds
|
187 |
+
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
|
188 |
+
dataset.__len__(): 366463
|
189 |
+
|
190 |
+
'''
|
191 |
+
# 345410 for LibriTTS
|
192 |
+
print("dataset.__len__():", self.__len__())
|
193 |
+
|
194 |
+
def __get_item_names__(self) -> List[str]:
|
195 |
+
return self.item_names
|
196 |
+
|
197 |
+
def __len__(self) -> int:
|
198 |
+
return len(self.semantic_phoneme)
|
199 |
+
|
200 |
+
def __getitem__(self, idx: int) -> Dict:
|
201 |
+
semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
|
202 |
+
item_name = self.item_names[idx]
|
203 |
+
phoneme_ids_len = len(phoneme_ids)
|
204 |
+
# semantic tokens target
|
205 |
+
semantic_ids_len = len(semantic_ids)
|
206 |
+
|
207 |
+
flag=0
|
208 |
+
path_bert = "%s/%s.pt" % (self.path3, item_name)
|
209 |
+
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
|
210 |
+
else:flag=1
|
211 |
+
if(flag==1):
|
212 |
+
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
|
213 |
+
bert_feature=None
|
214 |
+
else:
|
215 |
+
assert bert_feature.shape[-1] == len(phoneme_ids)
|
216 |
+
return {
|
217 |
+
'idx': idx,
|
218 |
+
'phoneme_ids': phoneme_ids,
|
219 |
+
'phoneme_ids_len': phoneme_ids_len,
|
220 |
+
'semantic_ids': semantic_ids,
|
221 |
+
'semantic_ids_len': semantic_ids_len,
|
222 |
+
'bert_feature': bert_feature,
|
223 |
+
}
|
224 |
+
|
225 |
+
def get_sample_length(self, idx: int):
|
226 |
+
semantic_ids = self.semantic_phoneme[idx][0]
|
227 |
+
sec = 1.0 * len(semantic_ids) / self.hz
|
228 |
+
return sec
|
229 |
+
|
230 |
+
def collate(self, examples: List[Dict]) -> Dict:
|
231 |
+
sample_index: List[int] = []
|
232 |
+
phoneme_ids: List[torch.Tensor] = []
|
233 |
+
phoneme_ids_lens: List[int] = []
|
234 |
+
semantic_ids: List[torch.Tensor] = []
|
235 |
+
semantic_ids_lens: List[int] = []
|
236 |
+
# return
|
237 |
+
|
238 |
+
|
239 |
+
for item in examples:
|
240 |
+
sample_index.append(item["idx"])
|
241 |
+
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
|
242 |
+
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
|
243 |
+
phoneme_ids_lens.append(item["phoneme_ids_len"])
|
244 |
+
semantic_ids_lens.append(item["semantic_ids_len"])
|
245 |
+
|
246 |
+
# pad 0
|
247 |
+
phoneme_ids = batch_sequences(phoneme_ids)
|
248 |
+
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
|
249 |
+
|
250 |
+
# # convert each batch to torch.tensor
|
251 |
+
phoneme_ids = torch.tensor(phoneme_ids)
|
252 |
+
semantic_ids = torch.tensor(semantic_ids)
|
253 |
+
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
|
254 |
+
semantic_ids_lens = torch.tensor(semantic_ids_lens)
|
255 |
+
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
|
256 |
+
bert_padded.zero_()
|
257 |
+
|
258 |
+
for idx, item in enumerate(examples):
|
259 |
+
bert = item['bert_feature']
|
260 |
+
if(bert!=None):
|
261 |
+
bert_padded[idx, :, :bert.shape[-1]] = bert
|
262 |
+
|
263 |
+
return {
|
264 |
+
# List[int]
|
265 |
+
"ids": sample_index,
|
266 |
+
# torch.Tensor (B, max_phoneme_length)
|
267 |
+
"phoneme_ids": phoneme_ids,
|
268 |
+
# torch.Tensor (B)
|
269 |
+
"phoneme_ids_len": phoneme_ids_lens,
|
270 |
+
# torch.Tensor (B, max_semantic_ids_length)
|
271 |
+
"semantic_ids": semantic_ids,
|
272 |
+
# torch.Tensor (B)
|
273 |
+
"semantic_ids_len": semantic_ids_lens,
|
274 |
+
# torch.Tensor (B, 1024, max_phoneme_length)
|
275 |
+
"bert_feature": bert_padded,
|
276 |
+
}
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
|
281 |
+
dataset = Text2SemanticDataset(
|
282 |
+
phoneme_path=root_dir + 'phoneme_train.npy',
|
283 |
+
semantic_path=root_dir + 'semantic_train.tsv')
|
284 |
+
|
285 |
+
batch_size = 12
|
286 |
+
dataloader = DataLoader(
|
287 |
+
dataset,
|
288 |
+
batch_size=batch_size,
|
289 |
+
collate_fn=dataset.collate,
|
290 |
+
shuffle=False)
|
291 |
+
for i, batch in enumerate(dataloader):
|
292 |
+
if(i%1000==0):print(i)
|
293 |
+
# if i == 0:
|
294 |
+
# print('batch["ids"]:', batch["ids"])
|
295 |
+
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
|
296 |
+
# batch["phoneme_ids"].shape)
|
297 |
+
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
|
298 |
+
# batch["phoneme_ids_len"].shape)
|
299 |
+
# print('batch["semantic_ids"]:', batch["semantic_ids"],
|
300 |
+
# batch["semantic_ids"].shape)
|
301 |
+
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
|
302 |
+
# batch["semantic_ids_len"].shape)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/__init__.py
ADDED
File without changes
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/BEATs.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import logging
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
from torch.nn import LayerNorm
|
16 |
+
|
17 |
+
from .backbone import TransformerEncoder
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class BEATsConfig:
|
23 |
+
def __init__(self, cfg=None):
|
24 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
25 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
26 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
27 |
+
|
28 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
29 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
30 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
31 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
32 |
+
self.activation_fn: str = "gelu" # activation function to use
|
33 |
+
|
34 |
+
self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
|
35 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
36 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
37 |
+
|
38 |
+
# dropouts
|
39 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
40 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
41 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
42 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
43 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
44 |
+
|
45 |
+
# positional embeddings
|
46 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
47 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
48 |
+
|
49 |
+
# relative position embedding
|
50 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
51 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
52 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
53 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
54 |
+
|
55 |
+
# label predictor
|
56 |
+
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
57 |
+
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
58 |
+
self.predictor_class: int = 527 # target class number for the predictor
|
59 |
+
|
60 |
+
if cfg is not None:
|
61 |
+
self.update(cfg)
|
62 |
+
|
63 |
+
def update(self, cfg: dict):
|
64 |
+
self.__dict__.update(cfg)
|
65 |
+
|
66 |
+
|
67 |
+
class BEATs(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
cfg: BEATsConfig, ) -> None:
|
71 |
+
super().__init__()
|
72 |
+
logger.info(f"BEATs Config: {cfg.__dict__}")
|
73 |
+
|
74 |
+
self.cfg = cfg
|
75 |
+
|
76 |
+
self.embed = cfg.embed_dim
|
77 |
+
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
|
78 |
+
if self.embed != cfg.encoder_embed_dim else
|
79 |
+
None)
|
80 |
+
|
81 |
+
self.input_patch_size = cfg.input_patch_size
|
82 |
+
self.patch_embedding = nn.Conv2d(
|
83 |
+
1,
|
84 |
+
self.embed,
|
85 |
+
kernel_size=self.input_patch_size,
|
86 |
+
stride=self.input_patch_size,
|
87 |
+
bias=cfg.conv_bias)
|
88 |
+
|
89 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
90 |
+
|
91 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
92 |
+
self.encoder = TransformerEncoder(cfg)
|
93 |
+
self.layer_norm = LayerNorm(self.embed)
|
94 |
+
|
95 |
+
if cfg.finetuned_model:
|
96 |
+
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
97 |
+
self.predictor = nn.Linear(cfg.encoder_embed_dim,
|
98 |
+
cfg.predictor_class)
|
99 |
+
else:
|
100 |
+
self.predictor = None
|
101 |
+
|
102 |
+
def forward_padding_mask(
|
103 |
+
self,
|
104 |
+
features: torch.Tensor,
|
105 |
+
padding_mask: torch.Tensor, ) -> torch.Tensor:
|
106 |
+
extra = padding_mask.size(1) % features.size(1)
|
107 |
+
if extra > 0:
|
108 |
+
padding_mask = padding_mask[:, :-extra]
|
109 |
+
padding_mask = padding_mask.view(
|
110 |
+
padding_mask.size(0), features.size(1), -1)
|
111 |
+
padding_mask = padding_mask.all(-1)
|
112 |
+
return padding_mask
|
113 |
+
|
114 |
+
def preprocess(
|
115 |
+
self,
|
116 |
+
source: torch.Tensor,
|
117 |
+
fbank_mean: float=15.41663,
|
118 |
+
fbank_std: float=6.55582, ) -> torch.Tensor:
|
119 |
+
fbanks = []
|
120 |
+
for waveform in source:
|
121 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
122 |
+
fbank = ta_kaldi.fbank(
|
123 |
+
waveform,
|
124 |
+
num_mel_bins=128,
|
125 |
+
sample_frequency=16000,
|
126 |
+
frame_length=25,
|
127 |
+
frame_shift=10)
|
128 |
+
fbanks.append(fbank)
|
129 |
+
fbank = torch.stack(fbanks, dim=0)
|
130 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
131 |
+
return fbank
|
132 |
+
|
133 |
+
def extract_features(
|
134 |
+
self,
|
135 |
+
source: torch.Tensor,
|
136 |
+
padding_mask: Optional[torch.Tensor]=None,
|
137 |
+
fbank_mean: float=15.41663,
|
138 |
+
fbank_std: float=6.55582, ):
|
139 |
+
fbank = self.preprocess(
|
140 |
+
source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
141 |
+
|
142 |
+
if padding_mask is not None:
|
143 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
144 |
+
|
145 |
+
fbank = fbank.unsqueeze(1)
|
146 |
+
features = self.patch_embedding(fbank)
|
147 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
148 |
+
features = features.transpose(1, 2)
|
149 |
+
features = self.layer_norm(features)
|
150 |
+
|
151 |
+
if padding_mask is not None:
|
152 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
153 |
+
|
154 |
+
if self.post_extract_proj is not None:
|
155 |
+
features = self.post_extract_proj(features)
|
156 |
+
|
157 |
+
x = self.dropout_input(features)
|
158 |
+
|
159 |
+
x, layer_results = self.encoder(
|
160 |
+
x,
|
161 |
+
padding_mask=padding_mask, )
|
162 |
+
|
163 |
+
if self.predictor is not None:
|
164 |
+
x = self.predictor_dropout(x)
|
165 |
+
logits = self.predictor(x)
|
166 |
+
|
167 |
+
if padding_mask is not None and padding_mask.any():
|
168 |
+
logits[padding_mask] = 0
|
169 |
+
logits = logits.sum(dim=1)
|
170 |
+
logits = logits / (~padding_mask).sum(
|
171 |
+
dim=1).unsqueeze(-1).expand_as(logits)
|
172 |
+
else:
|
173 |
+
logits = logits.mean(dim=1)
|
174 |
+
|
175 |
+
lprobs = torch.sigmoid(logits)
|
176 |
+
|
177 |
+
return lprobs, padding_mask
|
178 |
+
else:
|
179 |
+
return x, padding_mask
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/README.md
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# BEATs
|
3 |
+
|
4 |
+
[**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers**
|
5 |
+
|
6 |
+
Official PyTorch implementation and pretrained models of BEATs
|
7 |
+
|
8 |
+
## Pre-Trained and Fine-Tuned Tokenizers and Models
|
9 |
+
Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2
|
10 |
+
|---|---|---|---|---
|
11 |
+
Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
12 |
+
Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
13 |
+
Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
14 |
+
Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
15 |
+
Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
|
16 |
+
|
17 |
+
|
18 |
+
### Load Tokenizers
|
19 |
+
|
20 |
+
```python
|
21 |
+
import torch
|
22 |
+
from Tokenizers import TokenizersConfig, Tokenizers
|
23 |
+
|
24 |
+
# load the pre-trained checkpoints
|
25 |
+
checkpoint = torch.load('/path/to/tokenizer.pt')
|
26 |
+
|
27 |
+
cfg = TokenizersConfig(checkpoint['cfg'])
|
28 |
+
BEATs_tokenizer = Tokenizers(cfg)
|
29 |
+
BEATs_tokenizer.load_state_dict(checkpoint['model'])
|
30 |
+
BEATs_tokenizer.eval()
|
31 |
+
|
32 |
+
# tokenize the audio and generate the labels
|
33 |
+
audio_input_16khz = torch.randn(1, 10000)
|
34 |
+
padding_mask = torch.zeros(1, 10000).bool()
|
35 |
+
|
36 |
+
labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Load Pre-Trained Models
|
41 |
+
|
42 |
+
```python
|
43 |
+
import torch
|
44 |
+
from BEATs import BEATs, BEATsConfig
|
45 |
+
|
46 |
+
# load the pre-trained checkpoints
|
47 |
+
checkpoint = torch.load('/path/to/model.pt')
|
48 |
+
|
49 |
+
cfg = BEATsConfig(checkpoint['cfg'])
|
50 |
+
BEATs_model = BEATs(cfg)
|
51 |
+
BEATs_model.load_state_dict(checkpoint['model'])
|
52 |
+
BEATs_model.eval()
|
53 |
+
|
54 |
+
# extract the the audio representation
|
55 |
+
audio_input_16khz = torch.randn(1, 10000)
|
56 |
+
padding_mask = torch.zeros(1, 10000).bool()
|
57 |
+
|
58 |
+
representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
|
59 |
+
```
|
60 |
+
|
61 |
+
|
62 |
+
### Load Fine-tuned Models
|
63 |
+
|
64 |
+
```python
|
65 |
+
import torch
|
66 |
+
from BEATs import BEATs, BEATsConfig
|
67 |
+
|
68 |
+
# load the fine-tuned checkpoints
|
69 |
+
checkpoint = torch.load('/path/to/model.pt')
|
70 |
+
|
71 |
+
cfg = BEATsConfig(checkpoint['cfg'])
|
72 |
+
BEATs_model = BEATs(cfg)
|
73 |
+
BEATs_model.load_state_dict(checkpoint['model'])
|
74 |
+
BEATs_model.eval()
|
75 |
+
|
76 |
+
# predict the classification probability of each class
|
77 |
+
audio_input_16khz = torch.randn(3, 10000)
|
78 |
+
padding_mask = torch.zeros(3, 10000).bool()
|
79 |
+
|
80 |
+
probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
|
81 |
+
|
82 |
+
for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
|
83 |
+
top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
|
84 |
+
print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
|
85 |
+
```
|
86 |
+
|
87 |
+
## Evaluation Results
|
88 |
+
|
89 |
+
### Comparing with the SOTA Single Models
|
90 |
+
![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png)
|
91 |
+
|
92 |
+
|
93 |
+
### Comparing with the SOTA Ensemble Models
|
94 |
+
![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png)
|
95 |
+
|
96 |
+
|
97 |
+
### Comparing Different BEATS Tokenizers
|
98 |
+
![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png)
|
99 |
+
|
100 |
+
|
101 |
+
### Comparing Different Pre-Training Targets
|
102 |
+
![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png)
|
103 |
+
|
104 |
+
|
105 |
+
## License
|
106 |
+
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
|
107 |
+
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project.
|
108 |
+
|
109 |
+
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
|
110 |
+
|
111 |
+
|
112 |
+
### Reference
|
113 |
+
If you find our work is useful in your research, please cite the following paper:
|
114 |
+
``` latex
|
115 |
+
@article{Chen2022beats,
|
116 |
+
title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
|
117 |
+
author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
|
118 |
+
eprint={2212.09058},
|
119 |
+
archivePrefix={arXiv},
|
120 |
+
year={2022}
|
121 |
+
}
|
122 |
+
```
|
123 |
+
### Contact Information
|
124 |
+
|
125 |
+
For help or issues using BEATs models, please submit a GitHub issue.
|
126 |
+
|
127 |
+
For other communications related to BEATs, please contact Yu Wu (`[email protected]`).
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/Tokenizers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import logging
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
from backbone import (
|
16 |
+
TransformerEncoder, )
|
17 |
+
from quantizer import (
|
18 |
+
NormEMAVectorQuantizer, )
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class TokenizersConfig:
|
25 |
+
def __init__(self, cfg=None):
|
26 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
27 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
28 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
29 |
+
|
30 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
31 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
32 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
33 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
34 |
+
self.activation_fn: str = "gelu" # activation function to use
|
35 |
+
|
36 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
37 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
38 |
+
|
39 |
+
# dropouts
|
40 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
41 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
42 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
43 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
44 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
45 |
+
|
46 |
+
# positional embeddings
|
47 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
48 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
49 |
+
|
50 |
+
# relative position embedding
|
51 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
52 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
53 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
54 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
55 |
+
|
56 |
+
# quantizer
|
57 |
+
self.quant_n: int = 1024 # codebook number in quantizer
|
58 |
+
self.quant_dim: int = 256 # codebook dimension in quantizer
|
59 |
+
|
60 |
+
if cfg is not None:
|
61 |
+
self.update(cfg)
|
62 |
+
|
63 |
+
def update(self, cfg: dict):
|
64 |
+
self.__dict__.update(cfg)
|
65 |
+
|
66 |
+
|
67 |
+
class Tokenizers(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
cfg: TokenizersConfig, ) -> None:
|
71 |
+
super().__init__()
|
72 |
+
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
73 |
+
|
74 |
+
self.cfg = cfg
|
75 |
+
|
76 |
+
self.embed = cfg.embed_dim
|
77 |
+
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
|
78 |
+
if self.embed != cfg.encoder_embed_dim else
|
79 |
+
None)
|
80 |
+
|
81 |
+
self.input_patch_size = cfg.input_patch_size
|
82 |
+
self.patch_embedding = nn.Conv2d(
|
83 |
+
1,
|
84 |
+
self.embed,
|
85 |
+
kernel_size=self.input_patch_size,
|
86 |
+
stride=self.input_patch_size,
|
87 |
+
bias=cfg.conv_bias)
|
88 |
+
|
89 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
90 |
+
|
91 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
92 |
+
self.encoder = TransformerEncoder(cfg)
|
93 |
+
self.layer_norm = LayerNorm(self.embed)
|
94 |
+
|
95 |
+
self.quantize = NormEMAVectorQuantizer(
|
96 |
+
n_embed=cfg.quant_n,
|
97 |
+
embedding_dim=cfg.quant_dim,
|
98 |
+
beta=1.0,
|
99 |
+
kmeans_init=True,
|
100 |
+
decay=0.99, )
|
101 |
+
self.quant_n = cfg.quant_n
|
102 |
+
self.quantize_layer = nn.Sequential(
|
103 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
104 |
+
nn.Tanh(),
|
105 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
106 |
+
)
|
107 |
+
|
108 |
+
def forward_padding_mask(
|
109 |
+
self,
|
110 |
+
features: torch.Tensor,
|
111 |
+
padding_mask: torch.Tensor, ) -> torch.Tensor:
|
112 |
+
extra = padding_mask.size(1) % features.size(1)
|
113 |
+
if extra > 0:
|
114 |
+
padding_mask = padding_mask[:, :-extra]
|
115 |
+
padding_mask = padding_mask.view(
|
116 |
+
padding_mask.size(0), features.size(1), -1)
|
117 |
+
padding_mask = padding_mask.all(-1)
|
118 |
+
return padding_mask
|
119 |
+
|
120 |
+
def preprocess(
|
121 |
+
self,
|
122 |
+
source: torch.Tensor,
|
123 |
+
fbank_mean: float=15.41663,
|
124 |
+
fbank_std: float=6.55582, ) -> torch.Tensor:
|
125 |
+
fbanks = []
|
126 |
+
for waveform in source:
|
127 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
128 |
+
fbank = ta_kaldi.fbank(
|
129 |
+
waveform,
|
130 |
+
num_mel_bins=128,
|
131 |
+
sample_frequency=16000,
|
132 |
+
frame_length=25,
|
133 |
+
frame_shift=10)
|
134 |
+
fbanks.append(fbank)
|
135 |
+
fbank = torch.stack(fbanks, dim=0)
|
136 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
137 |
+
return fbank
|
138 |
+
|
139 |
+
def extract_labels(
|
140 |
+
self,
|
141 |
+
source: torch.Tensor,
|
142 |
+
padding_mask: Optional[torch.Tensor]=None,
|
143 |
+
fbank_mean: float=15.41663,
|
144 |
+
fbank_std: float=6.55582, ):
|
145 |
+
fbank = self.preprocess(
|
146 |
+
source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
147 |
+
|
148 |
+
if padding_mask is not None:
|
149 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
150 |
+
|
151 |
+
fbank = fbank.unsqueeze(1)
|
152 |
+
features = self.patch_embedding(fbank)
|
153 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
154 |
+
features = features.transpose(1, 2)
|
155 |
+
features = self.layer_norm(features)
|
156 |
+
|
157 |
+
if padding_mask is not None:
|
158 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
159 |
+
|
160 |
+
if self.post_extract_proj is not None:
|
161 |
+
features = self.post_extract_proj(features)
|
162 |
+
|
163 |
+
x = self.dropout_input(features)
|
164 |
+
|
165 |
+
x, layer_results = self.encoder(
|
166 |
+
x,
|
167 |
+
padding_mask=padding_mask, )
|
168 |
+
|
169 |
+
quantize_input = self.quantize_layer(x)
|
170 |
+
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
171 |
+
|
172 |
+
return embed_ind
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# this folder is modified from https://github.com/microsoft/unilm/tree/master/beats
|
2 |
+
# ontology.json is from https://github.com/audioset/ontology/
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/backbone.py
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import math
|
10 |
+
from typing import Dict
|
11 |
+
from typing import Optional
|
12 |
+
from typing import Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch import nn
|
18 |
+
from torch import Tensor
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
from torch.nn import Parameter
|
21 |
+
|
22 |
+
from .modules import get_activation_fn
|
23 |
+
from .modules import GLU_Linear
|
24 |
+
from .modules import GradMultiply
|
25 |
+
from .modules import quant_noise
|
26 |
+
from .modules import SamePad
|
27 |
+
|
28 |
+
|
29 |
+
class TransformerEncoder(nn.Module):
|
30 |
+
def __init__(self, args):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.dropout = args.dropout
|
34 |
+
self.embedding_dim = args.encoder_embed_dim
|
35 |
+
|
36 |
+
self.pos_conv = nn.Conv1d(
|
37 |
+
self.embedding_dim,
|
38 |
+
self.embedding_dim,
|
39 |
+
kernel_size=args.conv_pos,
|
40 |
+
padding=args.conv_pos // 2,
|
41 |
+
groups=args.conv_pos_groups, )
|
42 |
+
dropout = 0
|
43 |
+
std = math.sqrt(
|
44 |
+
(4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
45 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
46 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
47 |
+
|
48 |
+
self.pos_conv = nn.utils.weight_norm(
|
49 |
+
self.pos_conv, name="weight", dim=2)
|
50 |
+
self.pos_conv = nn.Sequential(self.pos_conv,
|
51 |
+
SamePad(args.conv_pos), nn.GELU())
|
52 |
+
|
53 |
+
if hasattr(args, "relative_position_embedding"):
|
54 |
+
self.relative_position_embedding = args.relative_position_embedding
|
55 |
+
self.num_buckets = args.num_buckets
|
56 |
+
self.max_distance = args.max_distance
|
57 |
+
else:
|
58 |
+
self.relative_position_embedding = False
|
59 |
+
self.num_buckets = 0
|
60 |
+
self.max_distance = 0
|
61 |
+
|
62 |
+
self.layers = nn.ModuleList([
|
63 |
+
TransformerSentenceEncoderLayer(
|
64 |
+
embedding_dim=self.embedding_dim,
|
65 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
66 |
+
num_attention_heads=args.encoder_attention_heads,
|
67 |
+
dropout=self.dropout,
|
68 |
+
attention_dropout=args.attention_dropout,
|
69 |
+
activation_dropout=args.activation_dropout,
|
70 |
+
activation_fn=args.activation_fn,
|
71 |
+
layer_norm_first=args.layer_norm_first,
|
72 |
+
deep_norm=args.deep_norm,
|
73 |
+
has_relative_attention_bias=self.relative_position_embedding,
|
74 |
+
num_buckets=self.num_buckets,
|
75 |
+
max_distance=self.max_distance,
|
76 |
+
gru_rel_pos=args.gru_rel_pos,
|
77 |
+
encoder_layers=args.encoder_layers, )
|
78 |
+
for i in range(args.encoder_layers)
|
79 |
+
])
|
80 |
+
if self.relative_position_embedding:
|
81 |
+
for i in range(1, args.encoder_layers):
|
82 |
+
del self.layers[i].self_attn.relative_attention_bias
|
83 |
+
self.layers[i].self_attn.relative_attention_bias = self.layers[
|
84 |
+
0].self_attn.relative_attention_bias
|
85 |
+
|
86 |
+
self.layer_norm_first = args.layer_norm_first
|
87 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
88 |
+
self.layerdrop = args.encoder_layerdrop
|
89 |
+
|
90 |
+
self.apply(init_bert_params)
|
91 |
+
|
92 |
+
if args.deep_norm:
|
93 |
+
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
94 |
+
for i in range(args.encoder_layers):
|
95 |
+
nn.init.xavier_normal_(
|
96 |
+
self.layers[i].self_attn.k_proj.weight, gain=1)
|
97 |
+
nn.init.xavier_normal_(
|
98 |
+
self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
|
99 |
+
nn.init.xavier_normal_(
|
100 |
+
self.layers[i].self_attn.q_proj.weight, gain=1)
|
101 |
+
nn.init.xavier_normal_(
|
102 |
+
self.layers[i].self_attn.out_proj.weight,
|
103 |
+
gain=deep_norm_beta)
|
104 |
+
nn.init.xavier_normal_(
|
105 |
+
self.layers[i].fc1.weight, gain=deep_norm_beta)
|
106 |
+
nn.init.xavier_normal_(
|
107 |
+
self.layers[i].fc2.weight, gain=deep_norm_beta)
|
108 |
+
|
109 |
+
self.layer_wise_gradient_decay_ratio = getattr(
|
110 |
+
args, "layer_wise_gradient_decay_ratio", 1)
|
111 |
+
|
112 |
+
def forward(self, x, padding_mask=None, layer=None):
|
113 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
114 |
+
|
115 |
+
if self.layer_norm_first and layer is None:
|
116 |
+
x = self.layer_norm(x)
|
117 |
+
|
118 |
+
return x, layer_results
|
119 |
+
|
120 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
121 |
+
|
122 |
+
if padding_mask is not None:
|
123 |
+
x[padding_mask] = 0
|
124 |
+
|
125 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
126 |
+
x_conv = x_conv.transpose(1, 2)
|
127 |
+
x = x + x_conv
|
128 |
+
|
129 |
+
if not self.layer_norm_first:
|
130 |
+
x = self.layer_norm(x)
|
131 |
+
|
132 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
133 |
+
|
134 |
+
# B x T x C -> T x B x C
|
135 |
+
x = x.transpose(0, 1)
|
136 |
+
|
137 |
+
layer_results = []
|
138 |
+
z = None
|
139 |
+
if tgt_layer is not None:
|
140 |
+
layer_results.append((x, z))
|
141 |
+
r = None
|
142 |
+
pos_bias = None
|
143 |
+
for i, layer in enumerate(self.layers):
|
144 |
+
if self.layer_wise_gradient_decay_ratio != 1.0:
|
145 |
+
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
146 |
+
dropout_probability = np.random.random()
|
147 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
148 |
+
x, z, pos_bias = layer(
|
149 |
+
x,
|
150 |
+
self_attn_padding_mask=padding_mask,
|
151 |
+
need_weights=False,
|
152 |
+
pos_bias=pos_bias)
|
153 |
+
if tgt_layer is not None:
|
154 |
+
layer_results.append((x, z))
|
155 |
+
if i == tgt_layer:
|
156 |
+
r = x
|
157 |
+
break
|
158 |
+
|
159 |
+
if r is not None:
|
160 |
+
x = r
|
161 |
+
|
162 |
+
# T x B x C -> B x T x C
|
163 |
+
x = x.transpose(0, 1)
|
164 |
+
|
165 |
+
return x, layer_results
|
166 |
+
|
167 |
+
|
168 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
embedding_dim: float=768,
|
172 |
+
ffn_embedding_dim: float=3072,
|
173 |
+
num_attention_heads: float=8,
|
174 |
+
dropout: float=0.1,
|
175 |
+
attention_dropout: float=0.1,
|
176 |
+
activation_dropout: float=0.1,
|
177 |
+
activation_fn: str="relu",
|
178 |
+
layer_norm_first: bool=False,
|
179 |
+
deep_norm: bool=False,
|
180 |
+
has_relative_attention_bias: bool=False,
|
181 |
+
num_buckets: int=0,
|
182 |
+
max_distance: int=0,
|
183 |
+
rescale_init: bool=False,
|
184 |
+
gru_rel_pos: bool=False,
|
185 |
+
encoder_layers: int=0, ) -> None:
|
186 |
+
|
187 |
+
super().__init__()
|
188 |
+
self.embedding_dim = embedding_dim
|
189 |
+
self.dropout = dropout
|
190 |
+
self.activation_dropout = activation_dropout
|
191 |
+
|
192 |
+
self.activation_name = activation_fn
|
193 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
194 |
+
self.self_attn = MultiheadAttention(
|
195 |
+
self.embedding_dim,
|
196 |
+
num_attention_heads,
|
197 |
+
dropout=attention_dropout,
|
198 |
+
self_attention=True,
|
199 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
200 |
+
num_buckets=num_buckets,
|
201 |
+
max_distance=max_distance,
|
202 |
+
rescale_init=rescale_init,
|
203 |
+
gru_rel_pos=gru_rel_pos, )
|
204 |
+
|
205 |
+
self.dropout1 = nn.Dropout(dropout)
|
206 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
207 |
+
self.dropout3 = nn.Dropout(dropout)
|
208 |
+
|
209 |
+
self.layer_norm_first = layer_norm_first
|
210 |
+
|
211 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
212 |
+
|
213 |
+
if self.activation_name == "glu":
|
214 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim,
|
215 |
+
"swish")
|
216 |
+
else:
|
217 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
218 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
219 |
+
|
220 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
221 |
+
|
222 |
+
self.deep_norm = deep_norm
|
223 |
+
if self.deep_norm:
|
224 |
+
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
225 |
+
else:
|
226 |
+
self.deep_norm_alpha = 1
|
227 |
+
|
228 |
+
def forward(self,
|
229 |
+
x: torch.Tensor,
|
230 |
+
self_attn_mask: torch.Tensor=None,
|
231 |
+
self_attn_padding_mask: torch.Tensor=None,
|
232 |
+
need_weights: bool=False,
|
233 |
+
pos_bias=None):
|
234 |
+
residual = x
|
235 |
+
|
236 |
+
if self.layer_norm_first:
|
237 |
+
x = self.self_attn_layer_norm(x)
|
238 |
+
x, attn, pos_bias = self.self_attn(
|
239 |
+
query=x,
|
240 |
+
key=x,
|
241 |
+
value=x,
|
242 |
+
key_padding_mask=self_attn_padding_mask,
|
243 |
+
need_weights=False,
|
244 |
+
attn_mask=self_attn_mask,
|
245 |
+
position_bias=pos_bias)
|
246 |
+
x = self.dropout1(x)
|
247 |
+
x = residual + x
|
248 |
+
|
249 |
+
residual = x
|
250 |
+
x = self.final_layer_norm(x)
|
251 |
+
if self.activation_name == "glu":
|
252 |
+
x = self.fc1(x)
|
253 |
+
else:
|
254 |
+
x = self.activation_fn(self.fc1(x))
|
255 |
+
x = self.dropout2(x)
|
256 |
+
x = self.fc2(x)
|
257 |
+
x = self.dropout3(x)
|
258 |
+
x = residual + x
|
259 |
+
else:
|
260 |
+
x, attn, pos_bias = self.self_attn(
|
261 |
+
query=x,
|
262 |
+
key=x,
|
263 |
+
value=x,
|
264 |
+
key_padding_mask=self_attn_padding_mask,
|
265 |
+
need_weights=need_weights,
|
266 |
+
attn_mask=self_attn_mask,
|
267 |
+
position_bias=pos_bias)
|
268 |
+
|
269 |
+
x = self.dropout1(x)
|
270 |
+
x = residual * self.deep_norm_alpha + x
|
271 |
+
|
272 |
+
x = self.self_attn_layer_norm(x)
|
273 |
+
|
274 |
+
residual = x
|
275 |
+
if self.activation_name == "glu":
|
276 |
+
x = self.fc1(x)
|
277 |
+
else:
|
278 |
+
x = self.activation_fn(self.fc1(x))
|
279 |
+
x = self.dropout2(x)
|
280 |
+
x = self.fc2(x)
|
281 |
+
x = self.dropout3(x)
|
282 |
+
x = residual * self.deep_norm_alpha + x
|
283 |
+
x = self.final_layer_norm(x)
|
284 |
+
|
285 |
+
return x, attn, pos_bias
|
286 |
+
|
287 |
+
|
288 |
+
class MultiheadAttention(nn.Module):
|
289 |
+
"""Multi-headed attention.
|
290 |
+
|
291 |
+
See "Attention Is All You Need" for more details.
|
292 |
+
"""
|
293 |
+
|
294 |
+
def __init__(
|
295 |
+
self,
|
296 |
+
embed_dim,
|
297 |
+
num_heads,
|
298 |
+
kdim=None,
|
299 |
+
vdim=None,
|
300 |
+
dropout=0.0,
|
301 |
+
bias=True,
|
302 |
+
add_bias_kv=False,
|
303 |
+
add_zero_attn=False,
|
304 |
+
self_attention=False,
|
305 |
+
encoder_decoder_attention=False,
|
306 |
+
q_noise=0.0,
|
307 |
+
qn_block_size=8,
|
308 |
+
has_relative_attention_bias=False,
|
309 |
+
num_buckets=32,
|
310 |
+
max_distance=128,
|
311 |
+
gru_rel_pos=False,
|
312 |
+
rescale_init=False, ):
|
313 |
+
super().__init__()
|
314 |
+
self.embed_dim = embed_dim
|
315 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
316 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
317 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
318 |
+
|
319 |
+
self.num_heads = num_heads
|
320 |
+
self.dropout_module = nn.Dropout(dropout)
|
321 |
+
|
322 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
323 |
+
self.num_buckets = num_buckets
|
324 |
+
self.max_distance = max_distance
|
325 |
+
if self.has_relative_attention_bias:
|
326 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
327 |
+
|
328 |
+
self.head_dim = embed_dim // num_heads
|
329 |
+
self.q_head_dim = self.head_dim
|
330 |
+
self.k_head_dim = self.head_dim
|
331 |
+
assert (self.head_dim * num_heads == self.embed_dim
|
332 |
+
), "embed_dim must be divisible by num_heads"
|
333 |
+
self.scaling = self.head_dim**-0.5
|
334 |
+
|
335 |
+
self.self_attention = self_attention
|
336 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
337 |
+
|
338 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
339 |
+
"Self-attention requires query, key and "
|
340 |
+
"value to be of the same size")
|
341 |
+
|
342 |
+
k_bias = True
|
343 |
+
if rescale_init:
|
344 |
+
k_bias = False
|
345 |
+
|
346 |
+
k_embed_dim = embed_dim
|
347 |
+
q_embed_dim = embed_dim
|
348 |
+
|
349 |
+
self.k_proj = quant_noise(
|
350 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise,
|
351 |
+
qn_block_size)
|
352 |
+
self.v_proj = quant_noise(
|
353 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
354 |
+
self.q_proj = quant_noise(
|
355 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise,
|
356 |
+
qn_block_size)
|
357 |
+
|
358 |
+
self.out_proj = quant_noise(
|
359 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
360 |
+
|
361 |
+
if add_bias_kv:
|
362 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
363 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
364 |
+
else:
|
365 |
+
self.bias_k = self.bias_v = None
|
366 |
+
|
367 |
+
self.add_zero_attn = add_zero_attn
|
368 |
+
|
369 |
+
self.gru_rel_pos = gru_rel_pos
|
370 |
+
if self.gru_rel_pos:
|
371 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
372 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
373 |
+
|
374 |
+
self.reset_parameters()
|
375 |
+
|
376 |
+
def reset_parameters(self):
|
377 |
+
if self.qkv_same_dim:
|
378 |
+
# Empirically observed the convergence to be much better with
|
379 |
+
# the scaled initialization
|
380 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
381 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
382 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
383 |
+
else:
|
384 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
385 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
386 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
387 |
+
|
388 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
389 |
+
if self.out_proj.bias is not None:
|
390 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
391 |
+
if self.bias_k is not None:
|
392 |
+
nn.init.xavier_normal_(self.bias_k)
|
393 |
+
if self.bias_v is not None:
|
394 |
+
nn.init.xavier_normal_(self.bias_v)
|
395 |
+
if self.has_relative_attention_bias:
|
396 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
397 |
+
|
398 |
+
def _relative_positions_bucket(self, relative_positions,
|
399 |
+
bidirectional=True):
|
400 |
+
num_buckets = self.num_buckets
|
401 |
+
max_distance = self.max_distance
|
402 |
+
relative_buckets = 0
|
403 |
+
|
404 |
+
if bidirectional:
|
405 |
+
num_buckets = num_buckets // 2
|
406 |
+
relative_buckets += (
|
407 |
+
relative_positions > 0).to(torch.long) * num_buckets
|
408 |
+
relative_positions = torch.abs(relative_positions)
|
409 |
+
else:
|
410 |
+
relative_positions = -torch.min(
|
411 |
+
relative_positions, torch.zeros_like(relative_positions))
|
412 |
+
|
413 |
+
max_exact = num_buckets // 2
|
414 |
+
is_small = relative_positions < max_exact
|
415 |
+
|
416 |
+
relative_postion_if_large = max_exact + (
|
417 |
+
torch.log(relative_positions.float() / max_exact) / math.log(
|
418 |
+
max_distance / max_exact) *
|
419 |
+
(num_buckets - max_exact)).to(torch.long)
|
420 |
+
relative_postion_if_large = torch.min(
|
421 |
+
relative_postion_if_large,
|
422 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1))
|
423 |
+
|
424 |
+
relative_buckets += torch.where(is_small, relative_positions,
|
425 |
+
relative_postion_if_large)
|
426 |
+
return relative_buckets
|
427 |
+
|
428 |
+
def compute_bias(self, query_length, key_length):
|
429 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
430 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
431 |
+
relative_position = memory_position - context_position
|
432 |
+
relative_position_bucket = self._relative_positions_bucket(
|
433 |
+
relative_position, bidirectional=True)
|
434 |
+
relative_position_bucket = relative_position_bucket.to(
|
435 |
+
self.relative_attention_bias.weight.device)
|
436 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
437 |
+
values = values.permute([2, 0, 1])
|
438 |
+
return values
|
439 |
+
|
440 |
+
def forward(self,
|
441 |
+
query,
|
442 |
+
key: Optional[Tensor],
|
443 |
+
value: Optional[Tensor],
|
444 |
+
key_padding_mask: Optional[Tensor]=None,
|
445 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[
|
446 |
+
Tensor]]]]=None,
|
447 |
+
need_weights: bool=True,
|
448 |
+
static_kv: bool=False,
|
449 |
+
attn_mask: Optional[Tensor]=None,
|
450 |
+
before_softmax: bool=False,
|
451 |
+
need_head_weights: bool=False,
|
452 |
+
position_bias: Optional[Tensor]=None
|
453 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
454 |
+
"""Input shape: Time x Batch x Channel
|
455 |
+
|
456 |
+
Args:
|
457 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
458 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
459 |
+
padding elements are indicated by 1s.
|
460 |
+
need_weights (bool, optional): return the attention weights,
|
461 |
+
averaged over heads (default: False).
|
462 |
+
attn_mask (ByteTensor, optional): typically used to
|
463 |
+
implement causal attention, where the mask prevents the
|
464 |
+
attention from looking forward in time (default: None).
|
465 |
+
before_softmax (bool, optional): return the raw attention
|
466 |
+
weights and values before the attention softmax.
|
467 |
+
need_head_weights (bool, optional): return the attention
|
468 |
+
weights for each head. Implies *need_weights*. Default:
|
469 |
+
return the average attention weights over all heads.
|
470 |
+
"""
|
471 |
+
if need_head_weights:
|
472 |
+
need_weights = True
|
473 |
+
|
474 |
+
is_tpu = query.device.type == "xla"
|
475 |
+
|
476 |
+
tgt_len, bsz, embed_dim = query.size()
|
477 |
+
src_len = tgt_len
|
478 |
+
assert embed_dim == self.embed_dim
|
479 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
480 |
+
if key is not None:
|
481 |
+
src_len, key_bsz, _ = key.size()
|
482 |
+
if not torch.jit.is_scripting():
|
483 |
+
assert key_bsz == bsz
|
484 |
+
assert value is not None
|
485 |
+
assert src_len, bsz == value.shape[:2]
|
486 |
+
|
487 |
+
if self.has_relative_attention_bias and position_bias is None:
|
488 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
489 |
+
position_bias = position_bias.unsqueeze(0).repeat(
|
490 |
+
bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
491 |
+
|
492 |
+
if incremental_state is not None:
|
493 |
+
saved_state = self._get_input_buffer(incremental_state)
|
494 |
+
if saved_state is not None and "prev_key" in saved_state:
|
495 |
+
# previous time steps are cached - no need to recompute
|
496 |
+
# key and value if they are static
|
497 |
+
if static_kv:
|
498 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
499 |
+
key = value = None
|
500 |
+
else:
|
501 |
+
saved_state = None
|
502 |
+
|
503 |
+
if self.self_attention:
|
504 |
+
q = self.q_proj(query)
|
505 |
+
k = self.k_proj(query)
|
506 |
+
v = self.v_proj(query)
|
507 |
+
elif self.encoder_decoder_attention:
|
508 |
+
# encoder-decoder attention
|
509 |
+
q = self.q_proj(query)
|
510 |
+
if key is None:
|
511 |
+
assert value is None
|
512 |
+
k = v = None
|
513 |
+
else:
|
514 |
+
k = self.k_proj(key)
|
515 |
+
v = self.v_proj(key)
|
516 |
+
|
517 |
+
else:
|
518 |
+
assert key is not None and value is not None
|
519 |
+
q = self.q_proj(query)
|
520 |
+
k = self.k_proj(key)
|
521 |
+
v = self.v_proj(value)
|
522 |
+
q *= self.scaling
|
523 |
+
alpha = 32
|
524 |
+
q *= 1 / alpha
|
525 |
+
|
526 |
+
if self.bias_k is not None:
|
527 |
+
assert self.bias_v is not None
|
528 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
529 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
530 |
+
if attn_mask is not None:
|
531 |
+
attn_mask = torch.cat(
|
532 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
|
533 |
+
dim=1)
|
534 |
+
if key_padding_mask is not None:
|
535 |
+
key_padding_mask = torch.cat(
|
536 |
+
[
|
537 |
+
key_padding_mask,
|
538 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
539 |
+
],
|
540 |
+
dim=1, )
|
541 |
+
|
542 |
+
q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
543 |
+
.transpose(0, 1))
|
544 |
+
if k is not None:
|
545 |
+
k = (k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim)
|
546 |
+
.transpose(0, 1))
|
547 |
+
if v is not None:
|
548 |
+
v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim)
|
549 |
+
.transpose(0, 1))
|
550 |
+
|
551 |
+
if saved_state is not None:
|
552 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
553 |
+
if "prev_key" in saved_state:
|
554 |
+
_prev_key = saved_state["prev_key"]
|
555 |
+
assert _prev_key is not None
|
556 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1,
|
557 |
+
self.head_dim)
|
558 |
+
if static_kv:
|
559 |
+
k = prev_key
|
560 |
+
else:
|
561 |
+
assert k is not None
|
562 |
+
k = torch.cat([prev_key, k], dim=1)
|
563 |
+
src_len = k.size(1)
|
564 |
+
if "prev_value" in saved_state:
|
565 |
+
_prev_value = saved_state["prev_value"]
|
566 |
+
assert _prev_value is not None
|
567 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1,
|
568 |
+
self.head_dim)
|
569 |
+
if static_kv:
|
570 |
+
v = prev_value
|
571 |
+
else:
|
572 |
+
assert v is not None
|
573 |
+
v = torch.cat([prev_value, v], dim=1)
|
574 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
575 |
+
if "prev_key_padding_mask" in saved_state:
|
576 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
577 |
+
assert k is not None and v is not None
|
578 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
579 |
+
key_padding_mask=key_padding_mask,
|
580 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
581 |
+
batch_size=bsz,
|
582 |
+
src_len=k.size(1),
|
583 |
+
static_kv=static_kv, )
|
584 |
+
|
585 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
|
586 |
+
self.head_dim)
|
587 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
|
588 |
+
self.head_dim)
|
589 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
590 |
+
# In this branch incremental_state is never None
|
591 |
+
assert incremental_state is not None
|
592 |
+
incremental_state = self._set_input_buffer(incremental_state,
|
593 |
+
saved_state)
|
594 |
+
assert k is not None
|
595 |
+
assert k.size(1) == src_len
|
596 |
+
|
597 |
+
# This is part of a workaround to get around fork/join parallelism
|
598 |
+
# not supporting Optional types.
|
599 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
600 |
+
key_padding_mask = None
|
601 |
+
|
602 |
+
if key_padding_mask is not None:
|
603 |
+
assert key_padding_mask.size(0) == bsz
|
604 |
+
assert key_padding_mask.size(1) == src_len
|
605 |
+
|
606 |
+
if self.add_zero_attn:
|
607 |
+
assert v is not None
|
608 |
+
src_len += 1
|
609 |
+
k = torch.cat(
|
610 |
+
[k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
611 |
+
v = torch.cat(
|
612 |
+
[v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
613 |
+
if attn_mask is not None:
|
614 |
+
attn_mask = torch.cat(
|
615 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
|
616 |
+
dim=1)
|
617 |
+
if key_padding_mask is not None:
|
618 |
+
key_padding_mask = torch.cat(
|
619 |
+
[
|
620 |
+
key_padding_mask,
|
621 |
+
torch.zeros(key_padding_mask.size(0),
|
622 |
+
1).type_as(key_padding_mask),
|
623 |
+
],
|
624 |
+
dim=1, )
|
625 |
+
|
626 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
627 |
+
attn_weights = (
|
628 |
+
attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
|
629 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
|
630 |
+
bsz)
|
631 |
+
|
632 |
+
assert list(
|
633 |
+
attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
634 |
+
|
635 |
+
if attn_mask is not None:
|
636 |
+
attn_mask = attn_mask.unsqueeze(0)
|
637 |
+
attn_weights += attn_mask
|
638 |
+
|
639 |
+
if key_padding_mask is not None:
|
640 |
+
# don't attend to padding symbols
|
641 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
|
642 |
+
src_len)
|
643 |
+
if not is_tpu:
|
644 |
+
attn_weights = attn_weights.masked_fill(
|
645 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
646 |
+
float("-inf"), )
|
647 |
+
else:
|
648 |
+
attn_weights = attn_weights.transpose(0, 2)
|
649 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask,
|
650 |
+
float("-inf"))
|
651 |
+
attn_weights = attn_weights.transpose(0, 2)
|
652 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
|
653 |
+
src_len)
|
654 |
+
|
655 |
+
if before_softmax:
|
656 |
+
return attn_weights, v, position_bias
|
657 |
+
|
658 |
+
if position_bias is not None:
|
659 |
+
attn_mask_rel_pos = position_bias
|
660 |
+
if self.gru_rel_pos == 1:
|
661 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len,
|
662 |
+
self.q_head_dim) * alpha / self.scaling
|
663 |
+
_B, _H, _L, __ = query_layer.size()
|
664 |
+
gate_a, gate_b = torch.sigmoid(
|
665 |
+
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(
|
666 |
+
-1, keepdim=False)).chunk(
|
667 |
+
2, dim=-1)
|
668 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
669 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len,
|
670 |
+
1) * position_bias
|
671 |
+
|
672 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
673 |
+
|
674 |
+
attn_weights = attn_weights + attn_mask_rel_pos
|
675 |
+
|
676 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
677 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
678 |
+
attn_probs = self.dropout_module(attn_weights)
|
679 |
+
|
680 |
+
assert v is not None
|
681 |
+
attn = torch.bmm(attn_probs, v)
|
682 |
+
assert list(
|
683 |
+
attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
684 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
685 |
+
attn = self.out_proj(attn)
|
686 |
+
attn_weights: Optional[Tensor] = None
|
687 |
+
if need_weights:
|
688 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len,
|
689 |
+
src_len).transpose(1, 0)
|
690 |
+
if not need_head_weights:
|
691 |
+
# average attention weights over heads
|
692 |
+
attn_weights = attn_weights.mean(dim=0)
|
693 |
+
|
694 |
+
return attn, attn_weights, position_bias
|
695 |
+
|
696 |
+
@staticmethod
|
697 |
+
def _append_prev_key_padding_mask(
|
698 |
+
key_padding_mask: Optional[Tensor],
|
699 |
+
prev_key_padding_mask: Optional[Tensor],
|
700 |
+
batch_size: int,
|
701 |
+
src_len: int,
|
702 |
+
static_kv: bool, ) -> Optional[Tensor]:
|
703 |
+
# saved key padding masks have shape (bsz, seq_len)
|
704 |
+
if prev_key_padding_mask is not None and static_kv:
|
705 |
+
new_key_padding_mask = prev_key_padding_mask
|
706 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
707 |
+
new_key_padding_mask = torch.cat(
|
708 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()],
|
709 |
+
dim=1)
|
710 |
+
# During incremental decoding, as the padding token enters and
|
711 |
+
# leaves the frame, there will be a time when prev or current
|
712 |
+
# is None
|
713 |
+
elif prev_key_padding_mask is not None:
|
714 |
+
if src_len > prev_key_padding_mask.size(1):
|
715 |
+
filler = torch.zeros(
|
716 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
717 |
+
device=prev_key_padding_mask.device, )
|
718 |
+
new_key_padding_mask = torch.cat(
|
719 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1)
|
720 |
+
else:
|
721 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
722 |
+
elif key_padding_mask is not None:
|
723 |
+
if src_len > key_padding_mask.size(1):
|
724 |
+
filler = torch.zeros(
|
725 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
726 |
+
device=key_padding_mask.device, )
|
727 |
+
new_key_padding_mask = torch.cat(
|
728 |
+
[filler.float(), key_padding_mask.float()], dim=1)
|
729 |
+
else:
|
730 |
+
new_key_padding_mask = key_padding_mask.float()
|
731 |
+
else:
|
732 |
+
new_key_padding_mask = prev_key_padding_mask
|
733 |
+
return new_key_padding_mask
|
734 |
+
|
735 |
+
def _get_input_buffer(
|
736 |
+
self,
|
737 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
738 |
+
) -> Dict[str, Optional[Tensor]]:
|
739 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
740 |
+
if result is not None:
|
741 |
+
return result
|
742 |
+
else:
|
743 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
744 |
+
return empty_result
|
745 |
+
|
746 |
+
def _set_input_buffer(
|
747 |
+
self,
|
748 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
749 |
+
buffer: Dict[str, Optional[Tensor]], ):
|
750 |
+
return self.set_incremental_state(incremental_state, "attn_state",
|
751 |
+
buffer)
|
752 |
+
|
753 |
+
def apply_sparse_mask(self,
|
754 |
+
attn_weights,
|
755 |
+
tgt_len: int,
|
756 |
+
src_len: int,
|
757 |
+
bsz: int):
|
758 |
+
return attn_weights
|
759 |
+
|
760 |
+
|
761 |
+
def init_bert_params(module):
|
762 |
+
"""
|
763 |
+
Initialize the weights specific to the BERT Model.
|
764 |
+
This overrides the default initializations depending on the specified arguments.
|
765 |
+
1. If normal_init_linear_weights is set then weights of linear
|
766 |
+
layer will be initialized using the normal distribution and
|
767 |
+
bais will be set to the specified value.
|
768 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
769 |
+
layer will be initialized using the normal distribution.
|
770 |
+
3. If normal_init_proj_weights is set then weights of
|
771 |
+
in_project_weight for MultiHeadAttention initialized using
|
772 |
+
the normal distribution (to be validated).
|
773 |
+
"""
|
774 |
+
|
775 |
+
def normal_(data):
|
776 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
777 |
+
# so that the RNG is consistent with and without FSDP
|
778 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
779 |
+
|
780 |
+
if isinstance(module, nn.Linear):
|
781 |
+
normal_(module.weight.data)
|
782 |
+
if module.bias is not None:
|
783 |
+
module.bias.data.zero_()
|
784 |
+
if isinstance(module, nn.Embedding):
|
785 |
+
normal_(module.weight.data)
|
786 |
+
if module.padding_idx is not None:
|
787 |
+
module.weight.data[module.padding_idx].zero_()
|
788 |
+
if isinstance(module, MultiheadAttention):
|
789 |
+
normal_(module.q_proj.weight.data)
|
790 |
+
normal_(module.k_proj.weight.data)
|
791 |
+
normal_(module.v_proj.weight.data)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
# 获取当前脚本的所在目录
|
5 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
# JSON 文件的文件名
|
8 |
+
json_filename = "ontology.json"
|
9 |
+
|
10 |
+
# 构建 JSON 文件的完整路径
|
11 |
+
json_path = os.path.join(script_dir, json_filename)
|
12 |
+
|
13 |
+
id_name_dict = {}
|
14 |
+
|
15 |
+
with open(json_path, 'r') as f:
|
16 |
+
json_items = json.load(f)
|
17 |
+
# '/m/0dgw9r' -> 'Human sounds' and etc.
|
18 |
+
for item in json_items:
|
19 |
+
id_name_dict[item['id']] = item['name']
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/modules.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
class GradMultiply(torch.autograd.Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(ctx, x, scale):
|
20 |
+
ctx.scale = scale
|
21 |
+
res = x.new(x)
|
22 |
+
return res
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, grad):
|
26 |
+
return grad * ctx.scale, None
|
27 |
+
|
28 |
+
|
29 |
+
class SamePad(nn.Module):
|
30 |
+
def __init__(self, kernel_size, causal=False):
|
31 |
+
super().__init__()
|
32 |
+
if causal:
|
33 |
+
self.remove = kernel_size - 1
|
34 |
+
else:
|
35 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
if self.remove > 0:
|
39 |
+
x = x[:, :, :-self.remove]
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class Swish(nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super(Swish, self).__init__()
|
46 |
+
self.act = torch.nn.Sigmoid()
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return x * self.act(x)
|
50 |
+
|
51 |
+
|
52 |
+
class GLU_Linear(nn.Module):
|
53 |
+
def __init__(self,
|
54 |
+
input_dim,
|
55 |
+
output_dim,
|
56 |
+
glu_type="sigmoid",
|
57 |
+
bias_in_glu=True):
|
58 |
+
super(GLU_Linear, self).__init__()
|
59 |
+
|
60 |
+
self.glu_type = glu_type
|
61 |
+
self.output_dim = output_dim
|
62 |
+
|
63 |
+
if glu_type == "sigmoid":
|
64 |
+
self.glu_act = torch.nn.Sigmoid()
|
65 |
+
elif glu_type == "swish":
|
66 |
+
self.glu_act = Swish()
|
67 |
+
elif glu_type == "relu":
|
68 |
+
self.glu_act = torch.nn.ReLU()
|
69 |
+
elif glu_type == "gelu":
|
70 |
+
self.glu_act = torch.nn.GELU()
|
71 |
+
|
72 |
+
if bias_in_glu:
|
73 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
74 |
+
else:
|
75 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
79 |
+
x = self.linear(x)
|
80 |
+
|
81 |
+
if self.glu_type == "bilinear":
|
82 |
+
x = (x[:, :, 0:self.output_dim] *
|
83 |
+
x[:, :, self.output_dim:self.output_dim * 2])
|
84 |
+
else:
|
85 |
+
x = (x[:, :, 0:self.output_dim] *
|
86 |
+
self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
def gelu_accurate(x):
|
92 |
+
if not hasattr(gelu_accurate, "_a"):
|
93 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
94 |
+
return (0.5 * x * (1 + torch.tanh(gelu_accurate._a *
|
95 |
+
(x + 0.044715 * torch.pow(x, 3)))))
|
96 |
+
|
97 |
+
|
98 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
99 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
100 |
+
|
101 |
+
|
102 |
+
def get_activation_fn(activation: str):
|
103 |
+
"""Returns the activation function corresponding to `activation`"""
|
104 |
+
|
105 |
+
if activation == "relu":
|
106 |
+
return F.relu
|
107 |
+
elif activation == "gelu":
|
108 |
+
return gelu
|
109 |
+
elif activation == "gelu_fast":
|
110 |
+
warnings.warn(
|
111 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
112 |
+
return gelu_accurate
|
113 |
+
elif activation == "gelu_accurate":
|
114 |
+
return gelu_accurate
|
115 |
+
elif activation == "tanh":
|
116 |
+
return torch.tanh
|
117 |
+
elif activation == "linear":
|
118 |
+
return lambda x: x
|
119 |
+
elif activation == "glu":
|
120 |
+
return lambda x: x
|
121 |
+
else:
|
122 |
+
raise RuntimeError(
|
123 |
+
"--activation-fn {} not supported".format(activation))
|
124 |
+
|
125 |
+
|
126 |
+
def quant_noise(module, p, block_size):
|
127 |
+
"""
|
128 |
+
Wraps modules and applies quantization noise to the weights for
|
129 |
+
subsequent quantization with Iterative Product Quantization as
|
130 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
131 |
+
|
132 |
+
Args:
|
133 |
+
- module: nn.Module
|
134 |
+
- p: amount of Quantization Noise
|
135 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
136 |
+
|
137 |
+
Remarks:
|
138 |
+
- Module weights must have the right sizes wrt the block size
|
139 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
140 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
141 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
142 |
+
- We implement the simplest form of noise here as stated in the paper
|
143 |
+
which consists in randomly dropping blocks
|
144 |
+
"""
|
145 |
+
|
146 |
+
# if no quantization noise, don't register hook
|
147 |
+
if p <= 0:
|
148 |
+
return module
|
149 |
+
|
150 |
+
# supported modules
|
151 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
152 |
+
|
153 |
+
# test whether module.weight has the right sizes wrt block_size
|
154 |
+
is_conv = module.weight.ndim == 4
|
155 |
+
|
156 |
+
# 2D matrix
|
157 |
+
if not is_conv:
|
158 |
+
assert (
|
159 |
+
module.weight.size(1) %
|
160 |
+
block_size == 0), "Input features must be a multiple of block sizes"
|
161 |
+
|
162 |
+
# 4D matrix
|
163 |
+
else:
|
164 |
+
# 1x1 convolutions
|
165 |
+
if module.kernel_size == (1, 1):
|
166 |
+
assert (module.in_channels % block_size == 0
|
167 |
+
), "Input channels must be a multiple of block sizes"
|
168 |
+
# regular convolutions
|
169 |
+
else:
|
170 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
171 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
172 |
+
|
173 |
+
def _forward_pre_hook(mod, input):
|
174 |
+
# no noise for evaluation
|
175 |
+
if mod.training:
|
176 |
+
if not is_conv:
|
177 |
+
# gather weight and sizes
|
178 |
+
weight = mod.weight
|
179 |
+
in_features = weight.size(1)
|
180 |
+
out_features = weight.size(0)
|
181 |
+
|
182 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
183 |
+
mask = torch.zeros(
|
184 |
+
in_features // block_size * out_features,
|
185 |
+
device=weight.device)
|
186 |
+
mask.bernoulli_(p)
|
187 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1,
|
188 |
+
in_features)
|
189 |
+
|
190 |
+
else:
|
191 |
+
# gather weight and sizes
|
192 |
+
weight = mod.weight
|
193 |
+
in_channels = mod.in_channels
|
194 |
+
out_channels = mod.out_channels
|
195 |
+
|
196 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
197 |
+
if mod.kernel_size == (1, 1):
|
198 |
+
mask = torch.zeros(
|
199 |
+
int(in_channels // block_size * out_channels),
|
200 |
+
device=weight.device, )
|
201 |
+
mask.bernoulli_(p)
|
202 |
+
mask = mask.repeat_interleave(block_size, -1).view(
|
203 |
+
-1, in_channels)
|
204 |
+
else:
|
205 |
+
mask = torch.zeros(
|
206 |
+
weight.size(0), weight.size(1), device=weight.device)
|
207 |
+
mask.bernoulli_(p)
|
208 |
+
mask = (
|
209 |
+
mask.unsqueeze(2).unsqueeze(3)
|
210 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
|
211 |
+
|
212 |
+
# scale weights and apply mask
|
213 |
+
mask = mask.to(
|
214 |
+
torch.
|
215 |
+
bool) # x.bool() is not currently supported in TorchScript
|
216 |
+
s = 1 / (1 - p)
|
217 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
218 |
+
|
219 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
220 |
+
return module
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/ontology.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/quantizer.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on VQGAN code bases
|
7 |
+
# https://github.com/CompVis/taming-transformers
|
8 |
+
# --------------------------------------------------------'
|
9 |
+
import torch
|
10 |
+
import torch.distributed as distributed
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
try:
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
except ImportError:
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
def l2norm(t):
|
21 |
+
return F.normalize(t, p=2, dim=-1)
|
22 |
+
|
23 |
+
|
24 |
+
def ema_inplace(moving_avg, new, decay):
|
25 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
26 |
+
|
27 |
+
|
28 |
+
def sample_vectors(samples, num):
|
29 |
+
num_samples, device = samples.shape[0], samples.device
|
30 |
+
|
31 |
+
if num_samples >= num:
|
32 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
33 |
+
else:
|
34 |
+
indices = torch.randint(0, num_samples, (num, ), device=device)
|
35 |
+
|
36 |
+
return samples[indices]
|
37 |
+
|
38 |
+
|
39 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
40 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
41 |
+
|
42 |
+
means = sample_vectors(samples, num_clusters)
|
43 |
+
|
44 |
+
for _ in range(num_iters):
|
45 |
+
if use_cosine_sim:
|
46 |
+
dists = samples @ means.t()
|
47 |
+
else:
|
48 |
+
diffs = rearrange(samples, 'n d -> n () d') \
|
49 |
+
- rearrange(means, 'c d -> () c d')
|
50 |
+
dists = -(diffs**2).sum(dim=-1)
|
51 |
+
|
52 |
+
buckets = dists.max(dim=-1).indices
|
53 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
54 |
+
zero_mask = bins == 0
|
55 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
56 |
+
|
57 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
58 |
+
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
59 |
+
new_means = new_means / bins_min_clamped[..., None]
|
60 |
+
|
61 |
+
if use_cosine_sim:
|
62 |
+
new_means = l2norm(new_means)
|
63 |
+
|
64 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
65 |
+
|
66 |
+
return means, bins
|
67 |
+
|
68 |
+
|
69 |
+
class EmbeddingEMA(nn.Module):
|
70 |
+
def __init__(self,
|
71 |
+
num_tokens,
|
72 |
+
codebook_dim,
|
73 |
+
decay=0.99,
|
74 |
+
eps=1e-5,
|
75 |
+
kmeans_init=True,
|
76 |
+
codebook_init_path=''):
|
77 |
+
super().__init__()
|
78 |
+
self.num_tokens = num_tokens
|
79 |
+
self.codebook_dim = codebook_dim
|
80 |
+
self.decay = decay
|
81 |
+
self.eps = eps
|
82 |
+
if codebook_init_path == '':
|
83 |
+
if not kmeans_init:
|
84 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
85 |
+
weight = l2norm(weight)
|
86 |
+
else:
|
87 |
+
weight = torch.zeros(num_tokens, codebook_dim)
|
88 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
89 |
+
else:
|
90 |
+
print(f"load init codebook weight from {codebook_init_path}")
|
91 |
+
codebook_ckpt_weight = torch.load(
|
92 |
+
codebook_init_path, map_location='cpu')
|
93 |
+
weight = codebook_ckpt_weight.clone()
|
94 |
+
self.register_buffer('initted', torch.Tensor([True]))
|
95 |
+
|
96 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
97 |
+
self.cluster_size = nn.Parameter(
|
98 |
+
torch.zeros(num_tokens), requires_grad=False)
|
99 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
100 |
+
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
101 |
+
self.update = True
|
102 |
+
|
103 |
+
@torch.jit.ignore
|
104 |
+
def init_embed_(self, data):
|
105 |
+
if self.initted:
|
106 |
+
return
|
107 |
+
print("Performing Kemans init for codebook")
|
108 |
+
embed, cluster_size = kmeans(
|
109 |
+
data, self.num_tokens, 10, use_cosine_sim=True)
|
110 |
+
self.weight.data.copy_(embed)
|
111 |
+
self.cluster_size.data.copy_(cluster_size)
|
112 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
113 |
+
|
114 |
+
def forward(self, embed_id):
|
115 |
+
return F.embedding(embed_id, self.weight)
|
116 |
+
|
117 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
118 |
+
self.cluster_size.data.mul_(self.decay).add_(
|
119 |
+
new_cluster_size, alpha=1 - self.decay)
|
120 |
+
|
121 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
122 |
+
self.embed_avg.data.mul_(self.decay).add_(
|
123 |
+
new_embed_avg, alpha=1 - self.decay)
|
124 |
+
|
125 |
+
def weight_update(self, num_tokens):
|
126 |
+
n = self.cluster_size.sum()
|
127 |
+
smoothed_cluster_size = (
|
128 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n)
|
129 |
+
# normalize embedding average with smoothed cluster size
|
130 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
131 |
+
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
132 |
+
self.weight.data.copy_(embed_normalized)
|
133 |
+
|
134 |
+
|
135 |
+
def norm_ema_inplace(moving_avg, new, decay):
|
136 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
137 |
+
moving_avg.data.copy_(l2norm(moving_avg.data))
|
138 |
+
|
139 |
+
|
140 |
+
class NormEMAVectorQuantizer(nn.Module):
|
141 |
+
def __init__(self,
|
142 |
+
n_embed,
|
143 |
+
embedding_dim,
|
144 |
+
beta,
|
145 |
+
decay=0.99,
|
146 |
+
eps=1e-5,
|
147 |
+
statistic_code_usage=True,
|
148 |
+
kmeans_init=False,
|
149 |
+
codebook_init_path=''):
|
150 |
+
super().__init__()
|
151 |
+
self.codebook_dim = embedding_dim
|
152 |
+
self.num_tokens = n_embed
|
153 |
+
self.beta = beta
|
154 |
+
self.decay = decay
|
155 |
+
|
156 |
+
# learnable = True if orthogonal_reg_weight > 0 else False
|
157 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay,
|
158 |
+
eps, kmeans_init, codebook_init_path)
|
159 |
+
|
160 |
+
self.statistic_code_usage = statistic_code_usage
|
161 |
+
if statistic_code_usage:
|
162 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
163 |
+
if distributed.is_available() and distributed.is_initialized():
|
164 |
+
print(
|
165 |
+
"ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!"
|
166 |
+
)
|
167 |
+
self.all_reduce_fn = distributed.all_reduce
|
168 |
+
else:
|
169 |
+
self.all_reduce_fn = nn.Identity()
|
170 |
+
|
171 |
+
def reset_cluster_size(self, device):
|
172 |
+
if self.statistic_code_usage:
|
173 |
+
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
174 |
+
self.cluster_size = self.cluster_size.to(device)
|
175 |
+
|
176 |
+
def forward(self, z):
|
177 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
178 |
+
# z, 'b c h w -> b h w c'
|
179 |
+
# z = rearrange(z, 'b c h w -> b h w c')
|
180 |
+
# z = z.transpose(1, 2)
|
181 |
+
z = l2norm(z)
|
182 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
183 |
+
|
184 |
+
self.embedding.init_embed_(z_flattened)
|
185 |
+
|
186 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
187 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
188 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
189 |
+
|
190 |
+
encoding_indices = torch.argmin(d, dim=1)
|
191 |
+
|
192 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
193 |
+
|
194 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
195 |
+
|
196 |
+
if not self.training:
|
197 |
+
with torch.no_grad():
|
198 |
+
cluster_size = encodings.sum(0)
|
199 |
+
self.all_reduce_fn(cluster_size)
|
200 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
201 |
+
|
202 |
+
if self.training and self.embedding.update:
|
203 |
+
# EMA cluster size
|
204 |
+
|
205 |
+
bins = encodings.sum(0)
|
206 |
+
self.all_reduce_fn(bins)
|
207 |
+
|
208 |
+
# self.embedding.cluster_size_ema_update(bins)
|
209 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
210 |
+
|
211 |
+
zero_mask = (bins == 0)
|
212 |
+
bins = bins.masked_fill(zero_mask, 1.)
|
213 |
+
|
214 |
+
embed_sum = z_flattened.t() @ encodings
|
215 |
+
self.all_reduce_fn(embed_sum)
|
216 |
+
|
217 |
+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
218 |
+
embed_normalized = l2norm(embed_normalized)
|
219 |
+
|
220 |
+
embed_normalized = torch.where(
|
221 |
+
zero_mask[..., None], self.embedding.weight, embed_normalized)
|
222 |
+
norm_ema_inplace(self.embedding.weight, embed_normalized,
|
223 |
+
self.decay)
|
224 |
+
|
225 |
+
# compute loss for embedding
|
226 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
227 |
+
|
228 |
+
# preserve gradients
|
229 |
+
z_q = z + (z_q - z).detach()
|
230 |
+
|
231 |
+
# reshape back to match original input shape
|
232 |
+
# z_q, 'b h w c -> b c h w'
|
233 |
+
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
234 |
+
# z_q = z_q.transpose(1, 2)
|
235 |
+
return z_q, loss, encoding_indices
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_beats_librilight.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'
|
2 |
+
# non_speech.npy, 存储一个 python dict 表示非 speech 类型的音频的 tag, 更小,加载和搜索速度更快
|
3 |
+
# audio_tag 目录存储 {utt_id}.txt, 第一行是小写的 top1 tag
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import traceback
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import tqdm
|
15 |
+
from AR.exps.beats.BEATs import BEATs
|
16 |
+
from AR.exps.beats.BEATs import BEATsConfig
|
17 |
+
from AR.exps.beats.config import id_name_dict
|
18 |
+
from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
|
19 |
+
from soundstorm.utils import check_txt_file
|
20 |
+
|
21 |
+
|
22 |
+
def get_BEATs_top1(wav,
|
23 |
+
BEATs_model,
|
24 |
+
BEATs_label_dict,
|
25 |
+
device: str='cpu',
|
26 |
+
topk: int=1):
|
27 |
+
wav = torch.tensor(wav).unsqueeze(0).to(device)
|
28 |
+
padding_mask = torch.zeros(wav.shape).bool().to(device)
|
29 |
+
probs = BEATs_model.extract_features(wav, padding_mask=padding_mask)[0]
|
30 |
+
# 单条推理
|
31 |
+
probs = probs[0]
|
32 |
+
topk_label_prob, topk_label_idx = probs.topk(k=topk)
|
33 |
+
topk_label = [
|
34 |
+
BEATs_label_dict[label_idx.item()] for label_idx in topk_label_idx
|
35 |
+
]
|
36 |
+
topk_label_name = [id_name_dict[label] for label in topk_label]
|
37 |
+
top1_label = topk_label_name[0]
|
38 |
+
return top1_label
|
39 |
+
|
40 |
+
|
41 |
+
def process_sentence(args,
|
42 |
+
fp: Path,
|
43 |
+
train_dump_dir: Path,
|
44 |
+
dev_dump_dir: Path,
|
45 |
+
test_dump_dir: Path,
|
46 |
+
VAD_dict,
|
47 |
+
BEATs_model,
|
48 |
+
BEATs_label_dict,
|
49 |
+
device: str='cpu'):
|
50 |
+
utt_id = fp.stem
|
51 |
+
sr = args.sr
|
52 |
+
record = []
|
53 |
+
train_audio_tag_dir = train_dump_dir / "audio_tag"
|
54 |
+
train_audio_tag_dir.mkdir(parents=True, exist_ok=True)
|
55 |
+
|
56 |
+
dev_audio_tag_dir = dev_dump_dir / "audio_tag"
|
57 |
+
dev_audio_tag_dir.mkdir(parents=True, exist_ok=True)
|
58 |
+
|
59 |
+
test_audio_tag_dir = test_dump_dir / "audio_tag"
|
60 |
+
test_audio_tag_dir.mkdir(parents=True, exist_ok=True)
|
61 |
+
|
62 |
+
try:
|
63 |
+
# get info for path
|
64 |
+
wav_path_list = str(fp).strip().split('/')
|
65 |
+
sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
|
66 |
+
-3], wav_path_list[-2]
|
67 |
+
wav_name = wav_path_list[-1][:-5]
|
68 |
+
assert wav_name == utt_id
|
69 |
+
# key_name for big wav
|
70 |
+
key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
|
71 |
+
# 判断 VAD 字典中不存在该条音频信息的情况
|
72 |
+
if key_name not in VAD_dict.keys():
|
73 |
+
print(key_name, 'not in VAD_dict !')
|
74 |
+
return record
|
75 |
+
wav = None
|
76 |
+
sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
|
77 |
+
len_dict = len(sorted_split_VAD_dict)
|
78 |
+
for index, item in enumerate(sorted_split_VAD_dict):
|
79 |
+
split_name, value = item
|
80 |
+
start, end = value
|
81 |
+
# train | dev | test
|
82 |
+
if index == len_dict - 1:
|
83 |
+
subset = 'test'
|
84 |
+
audio_tag_path = test_audio_tag_dir / (split_name + ".txt")
|
85 |
+
elif index == len_dict - 2:
|
86 |
+
subset = 'dev'
|
87 |
+
audio_tag_path = dev_audio_tag_dir / (split_name + ".txt")
|
88 |
+
else:
|
89 |
+
subset = 'train'
|
90 |
+
audio_tag_path = train_audio_tag_dir / (split_name + ".txt")
|
91 |
+
|
92 |
+
if os.path.exists(audio_tag_path) and check_txt_file(
|
93 |
+
audio_tag_path):
|
94 |
+
# print(audio_tag_path, 'exits!')
|
95 |
+
pass
|
96 |
+
else:
|
97 |
+
# 这里加判断保证在 sub wav 的循环中只 load 一次
|
98 |
+
if wav is None:
|
99 |
+
# load big wav
|
100 |
+
# 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
|
101 |
+
wav, _ = librosa.load(str(fp), sr=sr)
|
102 |
+
sub_wav = wav[int(start * sr):int(end * sr)]
|
103 |
+
audio_tag_top1 = get_BEATs_top1(
|
104 |
+
wav=sub_wav,
|
105 |
+
BEATs_model=BEATs_model,
|
106 |
+
BEATs_label_dict=BEATs_label_dict,
|
107 |
+
device=device)
|
108 |
+
|
109 |
+
with open(audio_tag_path, 'w') as f:
|
110 |
+
f.write(audio_tag_top1)
|
111 |
+
|
112 |
+
sub_record = {
|
113 |
+
"utt_id": split_name,
|
114 |
+
"audio_tag_path": audio_tag_path,
|
115 |
+
"subset": subset
|
116 |
+
}
|
117 |
+
# recodrd 变成 List of Dict
|
118 |
+
record.append(sub_record)
|
119 |
+
except Exception:
|
120 |
+
print("occur Exception")
|
121 |
+
traceback.print_exc()
|
122 |
+
# record 有可能是一个不完整的 List
|
123 |
+
return record
|
124 |
+
return record
|
125 |
+
|
126 |
+
|
127 |
+
def process_sentences(args,
|
128 |
+
fps: Path,
|
129 |
+
train_dump_dir: Path,
|
130 |
+
dev_dump_dir: Path,
|
131 |
+
test_dump_dir: Path,
|
132 |
+
VAD_dict,
|
133 |
+
BEATs_model,
|
134 |
+
BEATs_label_dict,
|
135 |
+
device: str='cpu',
|
136 |
+
nprocs: int=1):
|
137 |
+
print("nprocs:", nprocs)
|
138 |
+
if nprocs == 1:
|
139 |
+
results = []
|
140 |
+
for fp in tqdm.tqdm(fps, total=len(fps)):
|
141 |
+
record = process_sentence(
|
142 |
+
args=args,
|
143 |
+
fp=fp,
|
144 |
+
train_dump_dir=train_dump_dir,
|
145 |
+
dev_dump_dir=dev_dump_dir,
|
146 |
+
test_dump_dir=test_dump_dir,
|
147 |
+
VAD_dict=VAD_dict,
|
148 |
+
BEATs_model=BEATs_model,
|
149 |
+
BEATs_label_dict=BEATs_label_dict,
|
150 |
+
device=device)
|
151 |
+
if record:
|
152 |
+
results.append(record)
|
153 |
+
else:
|
154 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
155 |
+
futures = []
|
156 |
+
with tqdm.tqdm(total=len(fps)) as progress:
|
157 |
+
for fp in fps:
|
158 |
+
future = pool.submit(process_sentence, args, fp,
|
159 |
+
train_dump_dir, dev_dump_dir,
|
160 |
+
test_dump_dir, VAD_dict, BEATs_model,
|
161 |
+
BEATs_label_dict, device)
|
162 |
+
future.add_done_callback(lambda p: progress.update())
|
163 |
+
futures.append(future)
|
164 |
+
|
165 |
+
results = []
|
166 |
+
for ft in futures:
|
167 |
+
record = ft.result()
|
168 |
+
if record:
|
169 |
+
results.append(record)
|
170 |
+
|
171 |
+
# torch.save() to a large `.pth` file
|
172 |
+
non_speech_dict = dict()
|
173 |
+
non_speech_dict['train'] = {}
|
174 |
+
non_speech_dict['dev'] = {}
|
175 |
+
non_speech_dict['test'] = {}
|
176 |
+
# record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
|
177 |
+
print(f"start to save {args.rank}_{args.nshard}.npy ...")
|
178 |
+
save_start_time = time.time()
|
179 |
+
for record in tqdm.tqdm(results, total=len(results), colour='green'):
|
180 |
+
for sub_record in record:
|
181 |
+
# 这里加 try, 因为 txt 文件可能损坏
|
182 |
+
try:
|
183 |
+
utt_id = sub_record["utt_id"]
|
184 |
+
subset = sub_record["subset"]
|
185 |
+
audio_tag_top1 = check_txt_file(sub_record["audio_tag_path"])
|
186 |
+
if audio_tag_top1 is not False:
|
187 |
+
if 'speech' not in audio_tag_top1.lower():
|
188 |
+
non_speech_dict[subset][utt_id] = audio_tag_top1
|
189 |
+
else:
|
190 |
+
# print(f'audio tag result of {utt_id} is speech')
|
191 |
+
pass
|
192 |
+
else:
|
193 |
+
print(f'audio tag result of {utt_id} is False')
|
194 |
+
except Exception:
|
195 |
+
print(f"{utt_id} occur Exception")
|
196 |
+
traceback.print_exc()
|
197 |
+
continue
|
198 |
+
|
199 |
+
train_filename = train_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
|
200 |
+
dev_filename = dev_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
|
201 |
+
test_filename = test_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
|
202 |
+
np.save(train_filename, non_speech_dict['train'])
|
203 |
+
print(f"npy file '{train_filename}' write down")
|
204 |
+
|
205 |
+
np.save(dev_filename, non_speech_dict['dev'])
|
206 |
+
print(f"npy file '{dev_filename}' write down")
|
207 |
+
|
208 |
+
np.save(test_filename, non_speech_dict['test'])
|
209 |
+
print(f"npy file '{test_filename}' write down")
|
210 |
+
print('time of save stage:', time.time() - save_start_time)
|
211 |
+
|
212 |
+
|
213 |
+
def main():
|
214 |
+
# parse config and args
|
215 |
+
parser = argparse.ArgumentParser(
|
216 |
+
description="Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'."
|
217 |
+
)
|
218 |
+
|
219 |
+
parser.add_argument(
|
220 |
+
"--data_dir", default=None, type=str, help="directory to dataset.")
|
221 |
+
|
222 |
+
parser.add_argument(
|
223 |
+
"--dump_dir",
|
224 |
+
type=str,
|
225 |
+
required=True,
|
226 |
+
help="directory to dump feature files.")
|
227 |
+
|
228 |
+
parser.add_argument(
|
229 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
230 |
+
|
231 |
+
parser.add_argument(
|
232 |
+
'--sr', type=int, default=16000, help='sample rate of model')
|
233 |
+
|
234 |
+
# For LibriLight dataset
|
235 |
+
parser.add_argument(
|
236 |
+
"--sub_dataset",
|
237 |
+
default="small",
|
238 |
+
type=str,
|
239 |
+
help="name of sub dataset of LibriLight",
|
240 |
+
choices=['small', 'medium', 'large', 'duplicate'], )
|
241 |
+
parser.add_argument(
|
242 |
+
"--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
|
243 |
+
parser.add_argument("--nshard", type=int, default=3)
|
244 |
+
parser.add_argument("--rank", type=int, default=0)
|
245 |
+
|
246 |
+
# for BEATs
|
247 |
+
parser.add_argument(
|
248 |
+
"--BEATs_ckpt_path",
|
249 |
+
type=str,
|
250 |
+
default='./pretrained_model/BEATs_iter1_finetuned_on_AS2M_cpt1.pt')
|
251 |
+
|
252 |
+
args = parser.parse_args()
|
253 |
+
|
254 |
+
data_dir = Path(args.data_dir).expanduser()
|
255 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
256 |
+
# use absolute path
|
257 |
+
dump_dir = dump_dir.resolve()
|
258 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
259 |
+
|
260 |
+
assert data_dir.is_dir()
|
261 |
+
|
262 |
+
# sub_dataset here
|
263 |
+
sub_dataset_dir = data_dir / args.sub_dataset
|
264 |
+
# olny spk_id in list, sort by lexicographical order
|
265 |
+
speaker_list = sorted(os.listdir(sub_dataset_dir))
|
266 |
+
start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
|
267 |
+
# speaker_list for this rank
|
268 |
+
speaker_list = speaker_list[start:end]
|
269 |
+
|
270 |
+
all_wav_files = []
|
271 |
+
|
272 |
+
for speaker in speaker_list:
|
273 |
+
wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
|
274 |
+
# filter out ._*.flac
|
275 |
+
wav_files = [
|
276 |
+
file for file in wav_files if not file.name.startswith('._')
|
277 |
+
]
|
278 |
+
all_wav_files += wav_files
|
279 |
+
|
280 |
+
print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
|
281 |
+
# get VAD info
|
282 |
+
VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
|
283 |
+
|
284 |
+
sub_dataset_dump_dir = dump_dir / args.sub_dataset
|
285 |
+
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
|
286 |
+
train_dump_dir = sub_dataset_dump_dir / "train"
|
287 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
288 |
+
dev_dump_dir = sub_dataset_dump_dir / "dev"
|
289 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
290 |
+
test_dump_dir = sub_dataset_dump_dir / "test"
|
291 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
292 |
+
|
293 |
+
BEATs_ckpt = torch.load(args.BEATs_ckpt_path)
|
294 |
+
|
295 |
+
BEATs_cfg = BEATsConfig(BEATs_ckpt['cfg'])
|
296 |
+
BEATs_model = BEATs(BEATs_cfg)
|
297 |
+
BEATs_model.load_state_dict(BEATs_ckpt['model'])
|
298 |
+
BEATs_model.eval()
|
299 |
+
# cpu or cuda
|
300 |
+
device = 'cpu'
|
301 |
+
BEATs_model.to(device)
|
302 |
+
|
303 |
+
BEATs_label_dict = BEATs_ckpt['label_dict']
|
304 |
+
|
305 |
+
# 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
|
306 |
+
if all_wav_files:
|
307 |
+
process_sentences(
|
308 |
+
args=args,
|
309 |
+
fps=all_wav_files,
|
310 |
+
train_dump_dir=train_dump_dir,
|
311 |
+
dev_dump_dir=dev_dump_dir,
|
312 |
+
test_dump_dir=test_dump_dir,
|
313 |
+
VAD_dict=VAD_dict,
|
314 |
+
BEATs_model=BEATs_model,
|
315 |
+
BEATs_label_dict=BEATs_label_dict,
|
316 |
+
device=device,
|
317 |
+
nprocs=args.num_cpu)
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
main()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
1. read text of dataset
|
3 |
+
2. text -> IPA by GruutPhonemizer
|
4 |
+
3. save out a *.npy dict for all text
|
5 |
+
my_dict = {"utt_id1": text1, "utt_id2": text2}
|
6 |
+
np.save(output_filename, my_dict)
|
7 |
+
my_dict = np.load(output_filename, allow_pickle=True).item()
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
from operator import itemgetter
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import List
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import tqdm
|
18 |
+
from AR.text_processing.phonemizer import GruutPhonemizer
|
19 |
+
|
20 |
+
|
21 |
+
def read_txt(txt_file):
|
22 |
+
utt_name = txt_file.stem
|
23 |
+
utt_id = utt_name.split('.')[0]
|
24 |
+
try:
|
25 |
+
with open(txt_file, 'r') as file:
|
26 |
+
txt = file.readline()
|
27 |
+
record = {"utt_id": utt_id, "txt": txt}
|
28 |
+
except Exception:
|
29 |
+
print("occur Exception")
|
30 |
+
traceback.print_exc()
|
31 |
+
return None
|
32 |
+
return record
|
33 |
+
|
34 |
+
|
35 |
+
def read_txts(txt_files: List[Path], nprocs: int=1):
|
36 |
+
if nprocs == 1:
|
37 |
+
results = []
|
38 |
+
for txt_file in tqdm.tqdm(txt_files, total=len(txt_files)):
|
39 |
+
record = read_txt(txt_file=txt_file)
|
40 |
+
if record:
|
41 |
+
results.append(record)
|
42 |
+
else:
|
43 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
44 |
+
futures = []
|
45 |
+
with tqdm.tqdm(total=len(txt_files)) as progress:
|
46 |
+
for txt_file in txt_files:
|
47 |
+
future = pool.submit(read_txt, txt_file)
|
48 |
+
future.add_done_callback(lambda p: progress.update())
|
49 |
+
futures.append(future)
|
50 |
+
|
51 |
+
results = []
|
52 |
+
for ft in futures:
|
53 |
+
record = ft.result()
|
54 |
+
if record:
|
55 |
+
results.append(record)
|
56 |
+
|
57 |
+
results.sort(key=itemgetter("utt_id"))
|
58 |
+
return_list = []
|
59 |
+
for item in results:
|
60 |
+
return_list.append((item["utt_id"], item["txt"]))
|
61 |
+
return return_list
|
62 |
+
|
63 |
+
|
64 |
+
def process_sentence(item, phonemizer):
|
65 |
+
utt_id, text = item
|
66 |
+
try:
|
67 |
+
phonemes = phonemizer.phonemize(text, espeak=False)
|
68 |
+
record = {"utt_id": utt_id, "phonemes": phonemes}
|
69 |
+
except Exception:
|
70 |
+
print("occur Exception")
|
71 |
+
traceback.print_exc()
|
72 |
+
return None
|
73 |
+
return record
|
74 |
+
|
75 |
+
|
76 |
+
def process_sentences(items, phonemizer, output_dir, nprocs: int=1):
|
77 |
+
if nprocs == 1:
|
78 |
+
results = []
|
79 |
+
for item in tqdm.tqdm(items, total=len(items)):
|
80 |
+
record = process_sentence(item=item, phonemizer=phonemizer)
|
81 |
+
if record:
|
82 |
+
results.append(record)
|
83 |
+
else:
|
84 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
85 |
+
futures = []
|
86 |
+
with tqdm.tqdm(total=len(items)) as progress:
|
87 |
+
for item in items:
|
88 |
+
future = pool.submit(process_sentence, item, phonemizer)
|
89 |
+
future.add_done_callback(lambda p: progress.update())
|
90 |
+
futures.append(future)
|
91 |
+
|
92 |
+
results = []
|
93 |
+
for ft in futures:
|
94 |
+
record = ft.result()
|
95 |
+
if record:
|
96 |
+
results.append(record)
|
97 |
+
results.sort(key=itemgetter("utt_id"))
|
98 |
+
npy_dict = {}
|
99 |
+
for item in results:
|
100 |
+
utt_id = item["utt_id"]
|
101 |
+
phonemes = item["phonemes"]
|
102 |
+
npy_dict[utt_id] = phonemes
|
103 |
+
filename = output_dir / 'phonemes.npy'
|
104 |
+
np.save(filename, npy_dict)
|
105 |
+
print(f"npy file '{filename}' write down")
|
106 |
+
|
107 |
+
|
108 |
+
def main():
|
109 |
+
# parse config and args
|
110 |
+
parser = argparse.ArgumentParser(description="Get phones for datasets")
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
"--dataset",
|
114 |
+
default="ljspeech",
|
115 |
+
type=str,
|
116 |
+
help="name of dataset, should in {ljspeech, libritts} now")
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--data_dir", default=None, type=str, help="directory to dataset.")
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--dump_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="directory to dump feature files.")
|
126 |
+
parser.add_argument(
|
127 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
128 |
+
|
129 |
+
args = parser.parse_args()
|
130 |
+
|
131 |
+
data_dir = Path(args.data_dir).expanduser()
|
132 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
133 |
+
# use absolute path
|
134 |
+
dump_dir = dump_dir.resolve()
|
135 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
136 |
+
|
137 |
+
assert data_dir.is_dir()
|
138 |
+
|
139 |
+
if args.dataset == "ljspeech":
|
140 |
+
data_dict = {}
|
141 |
+
text_path = data_dir / 'metadata.csv'
|
142 |
+
with open(text_path, 'r') as rf:
|
143 |
+
for line in rf:
|
144 |
+
line_list = line.strip().split('|')
|
145 |
+
utt_id = line_list[0]
|
146 |
+
raw_text = line_list[-1]
|
147 |
+
data_dict[utt_id] = raw_text
|
148 |
+
|
149 |
+
sorted_dict = sorted(data_dict.items())
|
150 |
+
|
151 |
+
num_train = 12900
|
152 |
+
num_dev = 100
|
153 |
+
# (utt_id, txt)
|
154 |
+
train_txts = sorted_dict[:num_train]
|
155 |
+
dev_txts = sorted_dict[num_train:num_train + num_dev]
|
156 |
+
test_txts = sorted_dict[num_train + num_dev:]
|
157 |
+
|
158 |
+
elif args.dataset == "libritts":
|
159 |
+
'''
|
160 |
+
we use train-clean-100、train-clean-360、train-other-500 here
|
161 |
+
and split dev and test from them, don't use test-* and dev-* cause the speakers are disjoint
|
162 |
+
the file structure is LibriTTS_R/train-clean-100/spkid/*/*.wav
|
163 |
+
there are about 2311 in these subsets, we split 1 dev and 1 test wav out from each speaker
|
164 |
+
'''
|
165 |
+
txt_files = []
|
166 |
+
train_txt_files = []
|
167 |
+
dev_txt_files = []
|
168 |
+
test_txt_files = []
|
169 |
+
sub_num_dev = 1
|
170 |
+
for sub_dataset_name in {
|
171 |
+
"train-clean-100", "train-clean-360", "train-other-500"
|
172 |
+
}:
|
173 |
+
sub_dataset_dir = data_dir / sub_dataset_name
|
174 |
+
# filter out hidden files
|
175 |
+
speaker_list = [
|
176 |
+
file for file in os.listdir(sub_dataset_dir)
|
177 |
+
if not file.startswith('.')
|
178 |
+
]
|
179 |
+
for speaker in speaker_list:
|
180 |
+
txt_files = sorted(
|
181 |
+
list((sub_dataset_dir / speaker).rglob(
|
182 |
+
"*/*.normalized.txt")))
|
183 |
+
# filter out ._*.wav
|
184 |
+
txt_files = [
|
185 |
+
file for file in txt_files if not file.name.startswith('._')
|
186 |
+
]
|
187 |
+
train_txt_files += txt_files[:-sub_num_dev * 2]
|
188 |
+
dev_txt_files += txt_files[-sub_num_dev * 2:-sub_num_dev]
|
189 |
+
test_txt_files += txt_files[-sub_num_dev:]
|
190 |
+
print("len(train_txt_files):", len(train_txt_files))
|
191 |
+
print("len(dev_txt_files):", len(dev_txt_files))
|
192 |
+
print("len(test_txt_files):", len(test_txt_files))
|
193 |
+
|
194 |
+
train_txts = read_txts(train_txt_files)
|
195 |
+
dev_txts = read_txts(dev_txt_files)
|
196 |
+
test_txts = read_txts(test_txt_files)
|
197 |
+
|
198 |
+
else:
|
199 |
+
print("dataset should in {ljspeech, libritts} now!")
|
200 |
+
|
201 |
+
train_dump_dir = dump_dir / "train"
|
202 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
203 |
+
dev_dump_dir = dump_dir / "dev"
|
204 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
205 |
+
test_dump_dir = dump_dir / "test"
|
206 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
207 |
+
|
208 |
+
phonemizer = GruutPhonemizer(language='en-us')
|
209 |
+
|
210 |
+
# process for the 3 sections
|
211 |
+
if train_txts:
|
212 |
+
process_sentences(
|
213 |
+
items=train_txts,
|
214 |
+
output_dir=train_dump_dir,
|
215 |
+
phonemizer=phonemizer,
|
216 |
+
nprocs=args.num_cpu)
|
217 |
+
if dev_txts:
|
218 |
+
process_sentences(
|
219 |
+
items=dev_txts,
|
220 |
+
output_dir=dev_dump_dir,
|
221 |
+
phonemizer=phonemizer,
|
222 |
+
nprocs=args.num_cpu)
|
223 |
+
if test_txts:
|
224 |
+
process_sentences(
|
225 |
+
items=test_txts,
|
226 |
+
output_dir=test_dump_dir,
|
227 |
+
phonemizer=phonemizer,
|
228 |
+
nprocs=args.num_cpu)
|
229 |
+
|
230 |
+
|
231 |
+
if __name__ == "__main__":
|
232 |
+
main()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones_librilight.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
1. read text of dataset, for LibriLight read txt_*.npy -> 需要整理成 list(utt_id, txt) 的形式
|
3 |
+
2. text -> IPA by GruutPhonemizer
|
4 |
+
3. save out a *.npy dict for all text
|
5 |
+
4. LibriLight 每个 split 分开处理
|
6 |
+
my_dict = {"utt_id1": text1, "utt_id2": text2}
|
7 |
+
np.save(output_filename, my_dict)
|
8 |
+
my_dict = np.load(output_filename, allow_pickle=True).item()
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import traceback
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
from operator import itemgetter
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import tqdm
|
20 |
+
from AR.text_processing.phonemizer import GruutPhonemizer
|
21 |
+
from soundstorm.utils import check_txt_file
|
22 |
+
|
23 |
+
|
24 |
+
def read_txts(txt_file: Path, nprocs: int=1):
|
25 |
+
'''
|
26 |
+
txt_file: path of npy dict, {"utt_id1": text1, "utt_id2": text2}
|
27 |
+
'''
|
28 |
+
txt_dict = np.load(txt_file, allow_pickle=True).item()
|
29 |
+
#[(utt_id, txt), ...]
|
30 |
+
return_list = list(txt_dict.items())
|
31 |
+
return return_list
|
32 |
+
|
33 |
+
|
34 |
+
def process_sentence(item, phonemizer, output_dir):
|
35 |
+
utt_id, text = item
|
36 |
+
phonemes_dir = output_dir / "phonemes"
|
37 |
+
phonemes_dir.mkdir(parents=True, exist_ok=True)
|
38 |
+
phonemes_path = phonemes_dir / (utt_id + ".txt")
|
39 |
+
try:
|
40 |
+
if os.path.exists(phonemes_path) and check_txt_file(phonemes_path):
|
41 |
+
# print(phonemes_path, 'exits!')
|
42 |
+
pass
|
43 |
+
else:
|
44 |
+
phonemes = phonemizer.phonemize(text, espeak=False)
|
45 |
+
with open(phonemes_path, 'w') as f:
|
46 |
+
f.write(phonemes)
|
47 |
+
record = {"utt_id": utt_id, "phonemes_path": phonemes_path}
|
48 |
+
except Exception:
|
49 |
+
print("occur Exception")
|
50 |
+
traceback.print_exc()
|
51 |
+
return None
|
52 |
+
return record
|
53 |
+
|
54 |
+
|
55 |
+
def process_sentences(args, items, phonemizer, output_dir, nprocs: int=1):
|
56 |
+
print("nprocs:", nprocs)
|
57 |
+
if nprocs == 1:
|
58 |
+
results = []
|
59 |
+
for item in tqdm.tqdm(items, total=len(items)):
|
60 |
+
record = process_sentence(
|
61 |
+
item=item, phonemizer=phonemizer, output_dir=output_dir)
|
62 |
+
if record:
|
63 |
+
results.append(record)
|
64 |
+
else:
|
65 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
66 |
+
futures = []
|
67 |
+
with tqdm.tqdm(total=len(items)) as progress:
|
68 |
+
for item in items:
|
69 |
+
future = pool.submit(process_sentence, item, phonemizer,
|
70 |
+
output_dir)
|
71 |
+
future.add_done_callback(lambda p: progress.update())
|
72 |
+
futures.append(future)
|
73 |
+
|
74 |
+
results = []
|
75 |
+
for ft in futures:
|
76 |
+
record = ft.result()
|
77 |
+
if record:
|
78 |
+
results.append(record)
|
79 |
+
|
80 |
+
results.sort(key=itemgetter("utt_id"))
|
81 |
+
|
82 |
+
npy_dict = {}
|
83 |
+
print(f"start to save {args.rank}_{args.nshard}.npy ...")
|
84 |
+
save_start_time = time.time()
|
85 |
+
for item in tqdm.tqdm(results, total=len(results), colour='green'):
|
86 |
+
# 这里加 try, 因为 txt 文件可能损坏
|
87 |
+
try:
|
88 |
+
utt_id = item["utt_id"]
|
89 |
+
phonemes = check_txt_file(item["phonemes_path"])
|
90 |
+
if phonemes is not False:
|
91 |
+
npy_dict[utt_id] = phonemes
|
92 |
+
else:
|
93 |
+
print(f'phonemes of {utt_id} is False')
|
94 |
+
except Exception:
|
95 |
+
print(f"{utt_id} occur Exception")
|
96 |
+
traceback.print_exc()
|
97 |
+
continue
|
98 |
+
|
99 |
+
filename = output_dir / f'phonemes_{args.rank}_{args.nshard}.npy'
|
100 |
+
np.save(filename, npy_dict)
|
101 |
+
print(f"npy file '{filename}' write down")
|
102 |
+
print('time of save stage:', time.time() - save_start_time)
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
# parse config and args
|
107 |
+
parser = argparse.ArgumentParser(
|
108 |
+
description="Get phones for LibriLight dataset from txt_*.npy")
|
109 |
+
|
110 |
+
parser.add_argument(
|
111 |
+
"--dump_dir",
|
112 |
+
type=str,
|
113 |
+
required=True,
|
114 |
+
help="directory to dump feature files.")
|
115 |
+
parser.add_argument(
|
116 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
'--train_txt_dir',
|
120 |
+
type=str,
|
121 |
+
default='dump/small/train/',
|
122 |
+
help='dir of train txt files')
|
123 |
+
parser.add_argument(
|
124 |
+
'--dev_txt_dir',
|
125 |
+
type=str,
|
126 |
+
default='dump/small/dev/',
|
127 |
+
help='dir of dev txt files')
|
128 |
+
parser.add_argument(
|
129 |
+
'--test_txt_dir',
|
130 |
+
type=str,
|
131 |
+
default='dump/small/test/',
|
132 |
+
help='dir of test txt files')
|
133 |
+
|
134 |
+
parser.add_argument(
|
135 |
+
"--sub_dataset",
|
136 |
+
default="small",
|
137 |
+
type=str,
|
138 |
+
help="name of sub dataset of LibriLight",
|
139 |
+
choices=['small', 'medium', 'large', 'duplicate'], )
|
140 |
+
parser.add_argument("--nshard", type=int, default=3)
|
141 |
+
parser.add_argument("--rank", type=int, default=0)
|
142 |
+
|
143 |
+
args = parser.parse_args()
|
144 |
+
print(f"nshard: {args.nshard}, rank: {args.rank}")
|
145 |
+
|
146 |
+
train_txt_dir = Path(args.train_txt_dir)
|
147 |
+
dev_txt_dir = Path(args.dev_txt_dir)
|
148 |
+
test_txt_dir = Path(args.test_txt_dir)
|
149 |
+
|
150 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
151 |
+
# use absolute path
|
152 |
+
dump_dir = dump_dir.resolve()
|
153 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
154 |
+
|
155 |
+
train_txt_file = train_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
156 |
+
dev_txt_file = dev_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
157 |
+
test_txt_file = test_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
158 |
+
|
159 |
+
train_txts = read_txts(train_txt_file)
|
160 |
+
dev_txts = read_txts(dev_txt_file)
|
161 |
+
test_txts = read_txts(test_txt_file)
|
162 |
+
|
163 |
+
sub_dataset_dump_dir = dump_dir / args.sub_dataset
|
164 |
+
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
|
165 |
+
train_dump_dir = sub_dataset_dump_dir / "train"
|
166 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
167 |
+
dev_dump_dir = sub_dataset_dump_dir / "dev"
|
168 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
169 |
+
test_dump_dir = sub_dataset_dump_dir / "test"
|
170 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
171 |
+
phonemizer = GruutPhonemizer(language='en-us')
|
172 |
+
|
173 |
+
# process for the 3 sections
|
174 |
+
if train_txts:
|
175 |
+
process_sentences(
|
176 |
+
args=args,
|
177 |
+
items=train_txts,
|
178 |
+
output_dir=train_dump_dir,
|
179 |
+
phonemizer=phonemizer,
|
180 |
+
nprocs=args.num_cpu)
|
181 |
+
if dev_txts:
|
182 |
+
process_sentences(
|
183 |
+
args=args,
|
184 |
+
items=dev_txts,
|
185 |
+
output_dir=dev_dump_dir,
|
186 |
+
phonemizer=phonemizer,
|
187 |
+
nprocs=args.num_cpu)
|
188 |
+
if test_txts:
|
189 |
+
process_sentences(
|
190 |
+
args=args,
|
191 |
+
items=test_txts,
|
192 |
+
output_dir=test_dump_dir,
|
193 |
+
phonemizer=phonemizer,
|
194 |
+
nprocs=args.num_cpu)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
main()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_txt_librilight.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import tqdm
|
11 |
+
import whisper
|
12 |
+
from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
|
13 |
+
from soundstorm.utils import check_txt_file
|
14 |
+
|
15 |
+
|
16 |
+
def process_sentence(args,
|
17 |
+
fp: Path,
|
18 |
+
train_dump_dir: Path,
|
19 |
+
dev_dump_dir: Path,
|
20 |
+
test_dump_dir: Path,
|
21 |
+
VAD_dict):
|
22 |
+
asr_model = whisper.load_model("tiny.en")
|
23 |
+
utt_id = fp.stem
|
24 |
+
sr = args.sr
|
25 |
+
record = []
|
26 |
+
train_txt_dir = train_dump_dir / "txt"
|
27 |
+
train_txt_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
dev_txt_dir = dev_dump_dir / "txt"
|
30 |
+
dev_txt_dir.mkdir(parents=True, exist_ok=True)
|
31 |
+
|
32 |
+
test_txt_dir = test_dump_dir / "txt"
|
33 |
+
test_txt_dir.mkdir(parents=True, exist_ok=True)
|
34 |
+
|
35 |
+
try:
|
36 |
+
# get info for path
|
37 |
+
wav_path_list = str(fp).strip().split('/')
|
38 |
+
sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
|
39 |
+
-3], wav_path_list[-2]
|
40 |
+
wav_name = wav_path_list[-1][:-5]
|
41 |
+
assert wav_name == utt_id
|
42 |
+
# key_name for big wav
|
43 |
+
key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
|
44 |
+
# 判断 VAD 字典中不存在该条音频信息的情况
|
45 |
+
if key_name not in VAD_dict.keys():
|
46 |
+
print(key_name, 'not in VAD_dict !')
|
47 |
+
return record
|
48 |
+
wav = None
|
49 |
+
sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
|
50 |
+
len_dict = len(sorted_split_VAD_dict)
|
51 |
+
for index, item in enumerate(sorted_split_VAD_dict):
|
52 |
+
split_name, value = item
|
53 |
+
start, end = value
|
54 |
+
# train | dev | test
|
55 |
+
if index == len_dict - 1:
|
56 |
+
subset = 'test'
|
57 |
+
txt_path = test_txt_dir / (split_name + ".txt")
|
58 |
+
elif index == len_dict - 2:
|
59 |
+
subset = 'dev'
|
60 |
+
txt_path = dev_txt_dir / (split_name + ".txt")
|
61 |
+
else:
|
62 |
+
subset = 'train'
|
63 |
+
txt_path = train_txt_dir / (split_name + ".txt")
|
64 |
+
|
65 |
+
if os.path.exists(txt_path) and check_txt_file(txt_path):
|
66 |
+
# print(txt_path, 'exits!')
|
67 |
+
pass
|
68 |
+
else:
|
69 |
+
# 这里加判断保证在 sub wav 的循环中只 load 一次
|
70 |
+
if wav is None:
|
71 |
+
# load big wav
|
72 |
+
# 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
|
73 |
+
wav, _ = librosa.load(str(fp), sr=sr)
|
74 |
+
sub_wav = wav[int(start * sr):int(end * sr)]
|
75 |
+
asr_result = asr_model.transcribe(sub_wav)["text"]
|
76 |
+
with open(txt_path, 'w') as f:
|
77 |
+
f.write(asr_result)
|
78 |
+
|
79 |
+
sub_record = {
|
80 |
+
"utt_id": split_name,
|
81 |
+
"txt_path": txt_path,
|
82 |
+
"subset": subset
|
83 |
+
}
|
84 |
+
# recodrd 变成 List of Dict
|
85 |
+
record.append(sub_record)
|
86 |
+
except Exception:
|
87 |
+
print("occur Exception")
|
88 |
+
traceback.print_exc()
|
89 |
+
# record 有可能是一个不完整的 List
|
90 |
+
return record
|
91 |
+
return record
|
92 |
+
|
93 |
+
|
94 |
+
def process_sentences(args,
|
95 |
+
fps: Path,
|
96 |
+
train_dump_dir: Path,
|
97 |
+
dev_dump_dir: Path,
|
98 |
+
test_dump_dir: Path,
|
99 |
+
VAD_dict,
|
100 |
+
nprocs: int=1):
|
101 |
+
print("nprocs:", nprocs)
|
102 |
+
if nprocs == 1:
|
103 |
+
results = []
|
104 |
+
for fp in tqdm.tqdm(fps, total=len(fps)):
|
105 |
+
record = process_sentence(
|
106 |
+
args=args,
|
107 |
+
fp=fp,
|
108 |
+
train_dump_dir=train_dump_dir,
|
109 |
+
dev_dump_dir=dev_dump_dir,
|
110 |
+
test_dump_dir=test_dump_dir,
|
111 |
+
VAD_dict=VAD_dict)
|
112 |
+
if record:
|
113 |
+
results.append(record)
|
114 |
+
else:
|
115 |
+
with ThreadPoolExecutor(nprocs) as pool:
|
116 |
+
futures = []
|
117 |
+
with tqdm.tqdm(total=len(fps)) as progress:
|
118 |
+
for fp in fps:
|
119 |
+
future = pool.submit(process_sentence, args, fp,
|
120 |
+
train_dump_dir, dev_dump_dir,
|
121 |
+
test_dump_dir, VAD_dict)
|
122 |
+
future.add_done_callback(lambda p: progress.update())
|
123 |
+
futures.append(future)
|
124 |
+
|
125 |
+
results = []
|
126 |
+
for ft in futures:
|
127 |
+
record = ft.result()
|
128 |
+
if record:
|
129 |
+
results.append(record)
|
130 |
+
|
131 |
+
# torch.save() to a large `.pth` file
|
132 |
+
txt_dict = dict()
|
133 |
+
txt_dict['train'] = {}
|
134 |
+
txt_dict['dev'] = {}
|
135 |
+
txt_dict['test'] = {}
|
136 |
+
# record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
|
137 |
+
print(f"start to save {args.rank}_{args.nshard}.npy ...")
|
138 |
+
save_start_time = time.time()
|
139 |
+
for record in tqdm.tqdm(results, total=len(results), colour='green'):
|
140 |
+
for sub_record in record:
|
141 |
+
# 这里加 try, 因为 txt 文件可能损坏
|
142 |
+
try:
|
143 |
+
utt_id = sub_record["utt_id"]
|
144 |
+
subset = sub_record["subset"]
|
145 |
+
asr_result = check_txt_file(sub_record["txt_path"])
|
146 |
+
if asr_result is not False:
|
147 |
+
txt_dict[subset][utt_id] = asr_result
|
148 |
+
else:
|
149 |
+
print(f'asr result of {utt_id} is False')
|
150 |
+
except Exception:
|
151 |
+
print(f"{utt_id} occur Exception")
|
152 |
+
traceback.print_exc()
|
153 |
+
continue
|
154 |
+
|
155 |
+
train_filename = train_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
156 |
+
dev_filename = dev_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
157 |
+
test_filename = test_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
|
158 |
+
np.save(train_filename, txt_dict['train'])
|
159 |
+
print(f"npy file '{train_filename}' write down")
|
160 |
+
|
161 |
+
np.save(dev_filename, txt_dict['dev'])
|
162 |
+
print(f"npy file '{dev_filename}' write down")
|
163 |
+
|
164 |
+
np.save(test_filename, txt_dict['test'])
|
165 |
+
print(f"npy file '{test_filename}' write down")
|
166 |
+
print('time of save stage:', time.time() - save_start_time)
|
167 |
+
|
168 |
+
|
169 |
+
def main():
|
170 |
+
# parse config and args
|
171 |
+
parser = argparse.ArgumentParser(
|
172 |
+
description="Preprocess audio and then extract features for LibriLight.")
|
173 |
+
|
174 |
+
parser.add_argument(
|
175 |
+
"--data_dir", default=None, type=str, help="directory to dataset.")
|
176 |
+
|
177 |
+
parser.add_argument(
|
178 |
+
"--dump_dir",
|
179 |
+
type=str,
|
180 |
+
required=True,
|
181 |
+
help="directory to dump feature files.")
|
182 |
+
|
183 |
+
parser.add_argument(
|
184 |
+
"--num-cpu", type=int, default=1, help="number of process.")
|
185 |
+
|
186 |
+
parser.add_argument(
|
187 |
+
'--sr', type=int, default=16000, help='sample rate of model')
|
188 |
+
|
189 |
+
# For LibriLight dataset
|
190 |
+
parser.add_argument(
|
191 |
+
"--sub_dataset",
|
192 |
+
default="small",
|
193 |
+
type=str,
|
194 |
+
help="name of sub dataset of LibriLight",
|
195 |
+
choices=['small', 'medium', 'large', 'duplicate'], )
|
196 |
+
parser.add_argument(
|
197 |
+
"--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
|
198 |
+
parser.add_argument("--nshard", type=int, default=3)
|
199 |
+
parser.add_argument("--rank", type=int, default=0)
|
200 |
+
|
201 |
+
args = parser.parse_args()
|
202 |
+
|
203 |
+
data_dir = Path(args.data_dir).expanduser()
|
204 |
+
dump_dir = Path(args.dump_dir).expanduser()
|
205 |
+
# use absolute path
|
206 |
+
dump_dir = dump_dir.resolve()
|
207 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
208 |
+
|
209 |
+
assert data_dir.is_dir()
|
210 |
+
|
211 |
+
# sub_dataset here
|
212 |
+
sub_dataset_dir = data_dir / args.sub_dataset
|
213 |
+
# olny spk_id in list, sort by lexicographical order
|
214 |
+
speaker_list = sorted(os.listdir(sub_dataset_dir))
|
215 |
+
start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
|
216 |
+
# speaker_list for this rank
|
217 |
+
speaker_list = speaker_list[start:end]
|
218 |
+
|
219 |
+
all_wav_files = []
|
220 |
+
|
221 |
+
for speaker in speaker_list:
|
222 |
+
wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
|
223 |
+
# filter out ._*.flac
|
224 |
+
wav_files = [
|
225 |
+
file for file in wav_files if not file.name.startswith('._')
|
226 |
+
]
|
227 |
+
all_wav_files += wav_files
|
228 |
+
|
229 |
+
print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
|
230 |
+
# get VAD info
|
231 |
+
VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
|
232 |
+
|
233 |
+
sub_dataset_dump_dir = dump_dir / args.sub_dataset
|
234 |
+
sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
|
235 |
+
train_dump_dir = sub_dataset_dump_dir / "train"
|
236 |
+
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
237 |
+
dev_dump_dir = sub_dataset_dump_dir / "dev"
|
238 |
+
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
239 |
+
test_dump_dir = sub_dataset_dump_dir / "test"
|
240 |
+
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
241 |
+
|
242 |
+
# 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
|
243 |
+
if all_wav_files:
|
244 |
+
process_sentences(
|
245 |
+
args=args,
|
246 |
+
fps=all_wav_files,
|
247 |
+
train_dump_dir=train_dump_dir,
|
248 |
+
dev_dump_dir=dev_dump_dir,
|
249 |
+
test_dump_dir=test_dump_dir,
|
250 |
+
VAD_dict=VAD_dict,
|
251 |
+
nprocs=args.num_cpu)
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
main()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/split_train_val.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import pandas
|
3 |
+
|
4 |
+
semantic_path = 'dump/semantic.tsv'
|
5 |
+
phoneme_path = 'dump/phoneme.npy'
|
6 |
+
train_semantic_path = 'dump/semantic_train.tsv'
|
7 |
+
train_phoneme_path = 'dump/phoneme_train.npy'
|
8 |
+
dev_semantic_path = 'dump/semantic_dev.tsv'
|
9 |
+
dev_phoneme_path = 'dump/phoneme_dev.npy'
|
10 |
+
|
11 |
+
# 读取dump/semantic.tsv
|
12 |
+
semantic_df = pandas.read_csv(semantic_path, sep='\t')
|
13 |
+
# pd.DataFrame(columns=["item_name", "semantic_audio"])
|
14 |
+
# # 读取dump/phoneme.npy
|
15 |
+
phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item()
|
16 |
+
|
17 |
+
dev_num = 20
|
18 |
+
# 随机从semantic_df中选取dev_num个
|
19 |
+
dev_df = semantic_df.sample(n=dev_num)
|
20 |
+
# 剩下的是train
|
21 |
+
train_df = semantic_df.drop(dev_df.index)
|
22 |
+
# 保存
|
23 |
+
dev_df.to_csv(dev_semantic_path, sep='\t', index=False)
|
24 |
+
train_df.to_csv(train_semantic_path, sep='\t', index=False)
|
25 |
+
|
26 |
+
# 将dev_df中的item_name取出来 作为dev_phoneme_dict的key
|
27 |
+
dev_item_names = dev_df['item_name'].tolist()
|
28 |
+
dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict}
|
29 |
+
train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names}
|
30 |
+
|
31 |
+
numpy.save(dev_phoneme_path, dev_phoneme_dict)
|
32 |
+
numpy.save(train_phoneme_path, train_phoneme_dict)
|
33 |
+
|
34 |
+
|
35 |
+
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/t2s.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# text to semantic
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import whisper
|
12 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
13 |
+
from AR.text_processing.phonemizer import GruutPhonemizer
|
14 |
+
from AR.utils.io import load_yaml_config
|
15 |
+
|
16 |
+
|
17 |
+
def get_batch(text, phonemizer):
|
18 |
+
# phoneme_ids 和 phoneme_ids_len 是需要的
|
19 |
+
phoneme = phonemizer.phonemize(text, espeak=False)
|
20 |
+
phoneme_ids = phonemizer.transform(phoneme)
|
21 |
+
phoneme_ids_len = len(phoneme_ids)
|
22 |
+
phoneme_ids = np.array(phoneme_ids)
|
23 |
+
# add batch axis here
|
24 |
+
phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0)
|
25 |
+
phoneme_ids_len = torch.tensor([phoneme_ids_len])
|
26 |
+
print("phoneme:", phoneme)
|
27 |
+
batch = {
|
28 |
+
# torch.Tensor (B, max_phoneme_length)
|
29 |
+
"phoneme_ids": phoneme_ids,
|
30 |
+
# torch.Tensor (B)
|
31 |
+
"phoneme_ids_len": phoneme_ids_len
|
32 |
+
}
|
33 |
+
return batch
|
34 |
+
|
35 |
+
|
36 |
+
def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer):
|
37 |
+
sample_rate = 16000
|
38 |
+
# to get prompt
|
39 |
+
prompt_name = os.path.basename(prompt_wav_path).split('.')[0]
|
40 |
+
wav, _ = librosa.load(prompt_wav_path, sr=sample_rate)
|
41 |
+
# 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止
|
42 |
+
wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)]
|
43 |
+
# wav 需要挪出末尾的静音否则也可能提前停住
|
44 |
+
prompt_text = asr_model.transcribe(wav)["text"]
|
45 |
+
# 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿
|
46 |
+
prompt_text = prompt_text.replace(".", "")
|
47 |
+
prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False)
|
48 |
+
prompt_phoneme_ids = phonemizer.transform(prompt_phoneme)
|
49 |
+
prompt_phoneme_ids_len = len(prompt_phoneme_ids)
|
50 |
+
# get prompt_semantic
|
51 |
+
# (T) -> (1, T)
|
52 |
+
wav = torch.tensor(wav).unsqueeze(0)
|
53 |
+
wav = wav.cuda()
|
54 |
+
# (1, T)
|
55 |
+
prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32)
|
56 |
+
prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0)
|
57 |
+
prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len])
|
58 |
+
|
59 |
+
result = {
|
60 |
+
'prompt_name': prompt_name,
|
61 |
+
'prompt_phoneme_ids': prompt_phoneme_ids,
|
62 |
+
'prompt_semantic_tokens': prompt_semantic_tokens,
|
63 |
+
'prompt_phoneme_ids_len': prompt_phoneme_ids_len
|
64 |
+
}
|
65 |
+
|
66 |
+
return result
|
67 |
+
|
68 |
+
|
69 |
+
def parse_args():
|
70 |
+
# parse args and config
|
71 |
+
parser = argparse.ArgumentParser(
|
72 |
+
description="Run SoundStorm AR S1 model for input text file")
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
'--config_file',
|
76 |
+
type=str,
|
77 |
+
default='conf/default.yaml',
|
78 |
+
help='path of config file')
|
79 |
+
|
80 |
+
parser.add_argument(
|
81 |
+
"--text_file",
|
82 |
+
type=str,
|
83 |
+
help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line."
|
84 |
+
)
|
85 |
+
|
86 |
+
parser.add_argument(
|
87 |
+
'--ckpt_path',
|
88 |
+
type=str,
|
89 |
+
default='exp/default/ckpt/epoch=99-step=49000.ckpt',
|
90 |
+
help='Checkpoint file of SoundStorm AR S1 model.')
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
'--prompt_wav_path',
|
94 |
+
type=str,
|
95 |
+
default=None,
|
96 |
+
help='extract prompt semantic and prompt phonemes from prompt wav')
|
97 |
+
|
98 |
+
# to get semantic tokens from prompt_wav
|
99 |
+
parser.add_argument("--hubert_path", type=str, default=None)
|
100 |
+
parser.add_argument("--quantizer_path", type=str, default=None)
|
101 |
+
|
102 |
+
parser.add_argument("--output_dir", type=str, help="output dir.")
|
103 |
+
|
104 |
+
args = parser.parse_args()
|
105 |
+
return args
|
106 |
+
|
107 |
+
|
108 |
+
def main():
|
109 |
+
args = parse_args()
|
110 |
+
config = load_yaml_config(args.config_file)
|
111 |
+
|
112 |
+
output_dir = Path(args.output_dir)
|
113 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
114 |
+
|
115 |
+
hz = 50
|
116 |
+
max_sec = config['data']['max_sec']
|
117 |
+
|
118 |
+
# get models
|
119 |
+
t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
|
120 |
+
checkpoint_path=args.ckpt_path, config=config)
|
121 |
+
t2s_model.cuda()
|
122 |
+
t2s_model.eval()
|
123 |
+
|
124 |
+
phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us')
|
125 |
+
|
126 |
+
# models for prompt
|
127 |
+
asr_model = whisper.load_model("tiny.en")
|
128 |
+
|
129 |
+
semantic_tokenizer = SemanticTokenizer(
|
130 |
+
hubert_path=args.hubert_path,
|
131 |
+
quantizer_path=args.quantizer_path,
|
132 |
+
duplicate=True)
|
133 |
+
|
134 |
+
prompt_result = get_prompt(
|
135 |
+
prompt_wav_path=args.prompt_wav_path,
|
136 |
+
asr_model=asr_model,
|
137 |
+
phonemizer=phonemizer,
|
138 |
+
semantic_tokenizer=semantic_tokenizer)
|
139 |
+
|
140 |
+
# zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的
|
141 |
+
# (B, 1)
|
142 |
+
# prompt = torch.ones(
|
143 |
+
# batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0
|
144 |
+
|
145 |
+
prompt = prompt_result['prompt_semantic_tokens']
|
146 |
+
prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len']
|
147 |
+
prompt_phoneme_ids = prompt_result['prompt_phoneme_ids']
|
148 |
+
|
149 |
+
sentences = []
|
150 |
+
with open(args.text_file, 'rt', encoding='utf-8') as f:
|
151 |
+
for line in f:
|
152 |
+
if line.strip() != "":
|
153 |
+
items = re.split(r"\s+", line.strip(), 1)
|
154 |
+
utt_id = items[0]
|
155 |
+
sentence = " ".join(items[1:])
|
156 |
+
sentences.append((utt_id, sentence))
|
157 |
+
semantic_data = [['item_name', 'semantic_audio']]
|
158 |
+
for utt_id, sentence in sentences[1:]:
|
159 |
+
# 需要自己构造伪 batch 输入给模型
|
160 |
+
batch = get_batch(sentence, phonemizer)
|
161 |
+
# prompt 和真正的输入拼接
|
162 |
+
all_phoneme_ids = torch.cat(
|
163 |
+
[prompt_phoneme_ids, batch['phoneme_ids']], dim=1)
|
164 |
+
# 或者可以直接求 all_phoneme_ids 的 shape[-1]
|
165 |
+
all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len']
|
166 |
+
st = time.time()
|
167 |
+
with torch.no_grad():
|
168 |
+
pred_semantic = t2s_model.model.infer(
|
169 |
+
all_phoneme_ids.cuda(),
|
170 |
+
all_phoneme_len.cuda(),
|
171 |
+
prompt.cuda(),
|
172 |
+
top_k=config['inference']['top_k'],
|
173 |
+
early_stop_num=hz * max_sec)
|
174 |
+
print(f'{time.time() - st} sec used in T2S')
|
175 |
+
|
176 |
+
# 删除 prompt 对应的部分
|
177 |
+
prompt_len = prompt.shape[-1]
|
178 |
+
pred_semantic = pred_semantic[:, prompt_len:]
|
179 |
+
|
180 |
+
# bs = 1
|
181 |
+
pred_semantic = pred_semantic[0]
|
182 |
+
semantic_token = pred_semantic.detach().cpu().numpy().tolist()
|
183 |
+
semantic_token_str = ' '.join(str(x) for x in semantic_token)
|
184 |
+
semantic_data.append([utt_id, semantic_token_str])
|
185 |
+
|
186 |
+
delimiter = '\t'
|
187 |
+
filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv'
|
188 |
+
with open(filename, 'w', encoding='utf-8') as writer:
|
189 |
+
for row in semantic_data:
|
190 |
+
line = delimiter.join(row)
|
191 |
+
writer.write(line + '\n')
|
192 |
+
# clean semantic token for next setence
|
193 |
+
semantic_data = [['item_name', 'semantic_audio']]
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
main()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/test.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# test from dump file
|
2 |
+
import argparse
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from AR.data.dataset import Text2SemanticDataset
|
9 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
10 |
+
from AR.utils.io import load_yaml_config
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
# parse args and config
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description="Run SoundStorm AR S1 model for test set.")
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
'--config_file',
|
21 |
+
type=str,
|
22 |
+
default='conf/default.yaml',
|
23 |
+
help='path of config file')
|
24 |
+
|
25 |
+
# args for dataset
|
26 |
+
parser.add_argument(
|
27 |
+
'--test_semantic_path',
|
28 |
+
type=str,
|
29 |
+
default='dump/test/semantic_token.tsv')
|
30 |
+
parser.add_argument(
|
31 |
+
'--test_phoneme_path', type=str, default='dump/test/phonemes.npy')
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
'--ckpt_path',
|
35 |
+
type=str,
|
36 |
+
default='exp/default/ckpt/epoch=99-step=49000.ckpt',
|
37 |
+
help='Checkpoint file of SoundStorm AR S1 model.')
|
38 |
+
|
39 |
+
parser.add_argument("--output_dir", type=str, help="output dir.")
|
40 |
+
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
args = parse_args()
|
47 |
+
|
48 |
+
config = load_yaml_config(args.config_file)
|
49 |
+
|
50 |
+
output_dir = Path(args.output_dir)
|
51 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
52 |
+
|
53 |
+
batch_size = 1
|
54 |
+
hz = 50
|
55 |
+
max_sec = config['data']['max_sec']
|
56 |
+
|
57 |
+
# get dataset
|
58 |
+
test_dataset = Text2SemanticDataset(
|
59 |
+
phoneme_path=args.test_phoneme_path,
|
60 |
+
semantic_path=args.test_semantic_path,
|
61 |
+
# max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
|
62 |
+
# 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
|
63 |
+
max_sec=100,
|
64 |
+
max_sample=8,
|
65 |
+
pad_val=config['data']['pad_val'])
|
66 |
+
# get model
|
67 |
+
t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
|
68 |
+
checkpoint_path=args.ckpt_path, config=config)
|
69 |
+
t2s_model.cuda()
|
70 |
+
t2s_model.eval()
|
71 |
+
|
72 |
+
# 获取 batch_size 条
|
73 |
+
# 创建 DataLoader,并指定 collate_fn 函数
|
74 |
+
dataloader = DataLoader(
|
75 |
+
test_dataset,
|
76 |
+
batch_size=batch_size,
|
77 |
+
shuffle=False,
|
78 |
+
collate_fn=test_dataset.collate)
|
79 |
+
|
80 |
+
item_names = test_dataset.__get_item_names__()
|
81 |
+
|
82 |
+
# 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
|
83 |
+
semantic_data = [['item_name', 'semantic_audio']]
|
84 |
+
for i, batch in enumerate(dataloader):
|
85 |
+
# 要保证 bs = 1
|
86 |
+
utt_id = item_names[i]
|
87 |
+
if i == 0:
|
88 |
+
print("utt_id:", utt_id)
|
89 |
+
# bs > 1 时会补零
|
90 |
+
# 与 validation_step() 保持一致
|
91 |
+
semantic_len = batch['semantic_ids'].size(1)
|
92 |
+
# 以 batch['semantic_ids'] 的前 150 个为 prompt
|
93 |
+
# 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
|
94 |
+
prompt_len = min(int(semantic_len * 0.5), 150)
|
95 |
+
# 输入纯文本时 prompt 该输入什么?=> see t2s.py
|
96 |
+
prompt = batch['semantic_ids'][:, :prompt_len]
|
97 |
+
# # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
|
98 |
+
# 证明 semantic token 中还是包含了音色信息
|
99 |
+
# prompt = torch.ones(
|
100 |
+
# batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
|
101 |
+
# print("prompt:", prompt)
|
102 |
+
# print("prompt.shape:", prompt.shape)
|
103 |
+
np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())
|
104 |
+
|
105 |
+
st = time.time()
|
106 |
+
with torch.no_grad():
|
107 |
+
# calculate acc for test
|
108 |
+
loss, acc = t2s_model.model.forward(
|
109 |
+
batch['phoneme_ids'].cuda(),
|
110 |
+
batch['phoneme_ids_len'].cuda(),
|
111 |
+
batch['semantic_ids'].cuda(),
|
112 |
+
batch['semantic_ids_len'].cuda())
|
113 |
+
print("top_3_acc of this batch:", acc)
|
114 |
+
pred_semantic = t2s_model.model.infer(
|
115 |
+
batch['phoneme_ids'].cuda(),
|
116 |
+
batch['phoneme_ids_len'].cuda(),
|
117 |
+
prompt.cuda(),
|
118 |
+
top_k=config['inference']['top_k'],
|
119 |
+
# hz * max_sec in train dataloader
|
120 |
+
# 生成的长度是 1002 应该是有一些 pad
|
121 |
+
early_stop_num=hz * max_sec)
|
122 |
+
# bs = 1
|
123 |
+
pred_semantic = pred_semantic[0]
|
124 |
+
print(f'{time.time() - st} sec used in T2S')
|
125 |
+
semantic_token = pred_semantic.detach().cpu().numpy().tolist()
|
126 |
+
semantic_token_str = ' '.join(str(x) for x in semantic_token)
|
127 |
+
semantic_data.append([utt_id, semantic_token_str])
|
128 |
+
else:
|
129 |
+
break
|
130 |
+
delimiter = '\t'
|
131 |
+
filename = output_dir / "semantic_token.tsv"
|
132 |
+
with open(filename, 'w', encoding='utf-8') as writer:
|
133 |
+
for row in semantic_data:
|
134 |
+
line = delimiter.join(row)
|
135 |
+
writer.write(line + '\n')
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/text.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
001 Life was like a box of chocolates, you never know what you're gonna get.
|
2 |
+
002 With great power there must come great responsibility.
|
3 |
+
003 To be or not to be, that’s a question.
|
4 |
+
004 A man can be destroyed but not defeated
|
5 |
+
005 Do not, for one repulse, give up the purpose that you resolved to effort.
|
6 |
+
006 Death is just a part of life, something we're all destined to do.
|
7 |
+
007 I think it's hard winning a war with words.
|
8 |
+
008 Don’t argue with the people of strong determination, because they may change the fact!
|
9 |
+
009 Love you three thousand times.
|
10 |
+
010 tidy tiger tied a tie tighter to tidy her tiny tall.
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from pytorch_lightning import Trainer
|
10 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
11 |
+
from pytorch_lightning.loggers import WandbLogger
|
12 |
+
from pytorch_lightning.strategies import DDPStrategy
|
13 |
+
from AR.data.data_module import Text2SemanticDataModule
|
14 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
15 |
+
from soundstorm.utils.io import load_yaml_config
|
16 |
+
logging.getLogger('numba').setLevel(logging.WARNING)
|
17 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
18 |
+
torch.set_float32_matmul_precision('high')
|
19 |
+
from soundstorm.utils import get_newest_ckpt
|
20 |
+
|
21 |
+
|
22 |
+
def main(args):
|
23 |
+
output_dir = Path(args.output_dir)
|
24 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
25 |
+
|
26 |
+
ckpt_dir = output_dir / 'ckpt'
|
27 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
config = load_yaml_config(args.config_file)
|
30 |
+
|
31 |
+
seed_everything(config["train"]["seed"], workers=True)
|
32 |
+
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
|
33 |
+
save_top_k=-1,
|
34 |
+
save_on_train_epoch_end=False,
|
35 |
+
every_n_epochs=config["train"]["save_every_n_epoch"],
|
36 |
+
dirpath=ckpt_dir)
|
37 |
+
logger = WandbLogger(
|
38 |
+
project="AR_S1",
|
39 |
+
name=output_dir.stem,
|
40 |
+
save_dir=output_dir,
|
41 |
+
# resume the loss curve
|
42 |
+
resume=True,
|
43 |
+
# id='k19kvsq8'
|
44 |
+
)
|
45 |
+
trainer: Trainer = Trainer(
|
46 |
+
max_epochs=config["train"]["epochs"],
|
47 |
+
accelerator='gpu',
|
48 |
+
devices=-1,
|
49 |
+
benchmark=False,
|
50 |
+
fast_dev_run=False,
|
51 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
52 |
+
precision=config["train"]["precision"],
|
53 |
+
logger=logger,
|
54 |
+
callbacks=[ckpt_callback])
|
55 |
+
|
56 |
+
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
57 |
+
config, output_dir)
|
58 |
+
|
59 |
+
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
60 |
+
config,
|
61 |
+
train_semantic_path=args.train_semantic_path,
|
62 |
+
train_phoneme_path=args.train_phoneme_path,
|
63 |
+
dev_semantic_path=args.dev_semantic_path,
|
64 |
+
dev_phoneme_path=args.dev_phoneme_path)
|
65 |
+
|
66 |
+
try:
|
67 |
+
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
|
68 |
+
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
|
69 |
+
ckpt_path = ckpt_dir / newest_ckpt_name
|
70 |
+
except Exception:
|
71 |
+
ckpt_path = None
|
72 |
+
print("ckpt_path:", ckpt_path)
|
73 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
74 |
+
|
75 |
+
|
76 |
+
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
77 |
+
if __name__ == '__main__':
|
78 |
+
parser = argparse.ArgumentParser()
|
79 |
+
parser.add_argument(
|
80 |
+
'--config_file',
|
81 |
+
type=str,
|
82 |
+
default='conf/default.yaml',
|
83 |
+
help='path of config file')
|
84 |
+
# args for dataset
|
85 |
+
parser.add_argument(
|
86 |
+
'--train_semantic_path',
|
87 |
+
type=str,
|
88 |
+
default='dump/train/semantic_token.tsv')
|
89 |
+
parser.add_argument(
|
90 |
+
'--train_phoneme_path', type=str, default='dump/train/phonemes.npy')
|
91 |
+
parser.add_argument(
|
92 |
+
'--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv')
|
93 |
+
parser.add_argument(
|
94 |
+
'--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy')
|
95 |
+
parser.add_argument(
|
96 |
+
'--output_dir',
|
97 |
+
type=str,
|
98 |
+
default='exp/default',
|
99 |
+
help='directory to save the results')
|
100 |
+
|
101 |
+
args = parser.parse_args()
|
102 |
+
logging.info(str(args))
|
103 |
+
main(args)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train_librilight_6k.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from pytorch_lightning import Trainer
|
10 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
11 |
+
from pytorch_lightning.loggers import WandbLogger
|
12 |
+
from pytorch_lightning.strategies import DDPStrategy
|
13 |
+
from AR.data.data_module_librilight_6k import Text2SemanticDataModule
|
14 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
15 |
+
from soundstorm.utils import get_newest_ckpt
|
16 |
+
from soundstorm.utils.io import load_yaml_config
|
17 |
+
|
18 |
+
logging.getLogger('numba').setLevel(logging.WARNING)
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
torch.set_float32_matmul_precision('high')
|
21 |
+
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
output_dir = Path(args.output_dir)
|
25 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
26 |
+
|
27 |
+
ckpt_dir = output_dir / 'ckpt'
|
28 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
config = load_yaml_config(args.config_file)
|
31 |
+
|
32 |
+
seed_everything(config["train"]["seed"], workers=True)
|
33 |
+
|
34 |
+
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
|
35 |
+
save_top_k=-1,
|
36 |
+
save_on_train_epoch_end=False,
|
37 |
+
every_n_train_steps=config["train"]["every_n_train_steps"],
|
38 |
+
dirpath=ckpt_dir)
|
39 |
+
logger = WandbLogger(
|
40 |
+
project="AR_S1_LibriLight",
|
41 |
+
name=output_dir.stem,
|
42 |
+
save_dir=output_dir,
|
43 |
+
# resume the loss curve
|
44 |
+
resume=True,
|
45 |
+
# id='k19kvsq8'
|
46 |
+
)
|
47 |
+
trainer: Trainer = Trainer(
|
48 |
+
max_epochs=config["train"]["epochs"],
|
49 |
+
accelerator='gpu',
|
50 |
+
devices=-1,
|
51 |
+
benchmark=False,
|
52 |
+
fast_dev_run=False,
|
53 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
54 |
+
precision=config["train"]["precision"],
|
55 |
+
logger=logger,
|
56 |
+
callbacks=[ckpt_callback])
|
57 |
+
|
58 |
+
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
59 |
+
config, output_dir)
|
60 |
+
|
61 |
+
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
62 |
+
config,
|
63 |
+
train_semantic_dirs=args.train_semantic_dirs,
|
64 |
+
train_phoneme_dirs=args.train_phoneme_dirs,
|
65 |
+
dev_semantic_dirs=args.dev_semantic_dirs,
|
66 |
+
dev_phoneme_dirs=args.dev_phoneme_dirs,
|
67 |
+
train_non_speech_dirs=args.train_non_speech_dirs,
|
68 |
+
dev_non_speech_dirs=args.dev_non_speech_dirs)
|
69 |
+
try:
|
70 |
+
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
|
71 |
+
ckpt_path = ckpt_dir / newest_ckpt_name
|
72 |
+
except Exception:
|
73 |
+
ckpt_path = None
|
74 |
+
|
75 |
+
print("ckpt_path:", ckpt_path)
|
76 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
77 |
+
|
78 |
+
|
79 |
+
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
80 |
+
if __name__ == '__main__':
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument(
|
83 |
+
'--config_file',
|
84 |
+
type=str,
|
85 |
+
default='conf/default.yaml',
|
86 |
+
help='path of config file')
|
87 |
+
# args for dataset
|
88 |
+
parser.add_argument(
|
89 |
+
'--train_semantic_dirs',
|
90 |
+
type=list,
|
91 |
+
nargs='+',
|
92 |
+
default=["dump/small/train/"],
|
93 |
+
help='dirs of train semantic')
|
94 |
+
parser.add_argument(
|
95 |
+
'--train_phoneme_dirs',
|
96 |
+
type=list,
|
97 |
+
nargs='+',
|
98 |
+
default=["dump/small/train/"],
|
99 |
+
help='dirs of train phoneme')
|
100 |
+
parser.add_argument(
|
101 |
+
'--dev_semantic_dirs',
|
102 |
+
type=list,
|
103 |
+
nargs='+',
|
104 |
+
default=["dump/small/dev/"],
|
105 |
+
help='dirs of dev semantic')
|
106 |
+
parser.add_argument(
|
107 |
+
'--dev_phoneme_dirs',
|
108 |
+
type=list,
|
109 |
+
nargs='+',
|
110 |
+
default=["dump/small/dev/"],
|
111 |
+
help='dirs of dev phoneme')
|
112 |
+
parser.add_argument(
|
113 |
+
'--output_dir',
|
114 |
+
type=str,
|
115 |
+
default='exp/default',
|
116 |
+
help='directory to save the results')
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
'--train_non_speech_dirs',
|
120 |
+
type=list,
|
121 |
+
nargs='+',
|
122 |
+
default=None,
|
123 |
+
help='dirs of train non_speech data')
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
'--dev_non_speech_dirs',
|
127 |
+
type=list,
|
128 |
+
nargs='+',
|
129 |
+
default=None,
|
130 |
+
help='dirs of dev non_speech data')
|
131 |
+
|
132 |
+
args = parser.parse_args()
|
133 |
+
|
134 |
+
new_train_semantic_dirs = []
|
135 |
+
new_train_phoneme_dirs = []
|
136 |
+
new_dev_semantic_dirs = []
|
137 |
+
new_dev_phoneme_dirs = []
|
138 |
+
|
139 |
+
new_train_non_speech_dirs = []
|
140 |
+
new_dev_non_speech_dirs = []
|
141 |
+
|
142 |
+
# format dataset dirs
|
143 |
+
for item in args.train_semantic_dirs:
|
144 |
+
new_train_semantic_dirs.append(''.join(item))
|
145 |
+
args.train_semantic_dirs = new_train_semantic_dirs
|
146 |
+
|
147 |
+
for item in args.train_phoneme_dirs:
|
148 |
+
new_train_phoneme_dirs.append(''.join(item))
|
149 |
+
args.train_phoneme_dirs = new_train_phoneme_dirs
|
150 |
+
|
151 |
+
for item in args.dev_semantic_dirs:
|
152 |
+
new_dev_semantic_dirs.append(''.join(item))
|
153 |
+
args.dev_semantic_dirs = new_dev_semantic_dirs
|
154 |
+
|
155 |
+
for item in args.dev_phoneme_dirs:
|
156 |
+
new_dev_phoneme_dirs.append(''.join(item))
|
157 |
+
args.dev_phoneme_dirs = new_dev_phoneme_dirs
|
158 |
+
|
159 |
+
if args.train_non_speech_dirs is not None:
|
160 |
+
for item in args.train_non_speech_dirs:
|
161 |
+
new_train_non_speech_dirs.append(''.join(item))
|
162 |
+
args.train_non_speech_dirs = new_train_non_speech_dirs
|
163 |
+
|
164 |
+
if args.dev_non_speech_dirs is not None:
|
165 |
+
for item in args.dev_non_speech_dirs:
|
166 |
+
new_dev_non_speech_dirs.append(''.join(item))
|
167 |
+
args.dev_non_speech_dirs = new_dev_non_speech_dirs
|
168 |
+
|
169 |
+
logging.info(str(args))
|
170 |
+
main(args)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/__init__.py
ADDED
File without changes
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_lightning_module.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
|
2 |
+
import os,sys
|
3 |
+
now_dir = os.getcwd()
|
4 |
+
sys.path.append(now_dir)
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning import LightningModule
|
9 |
+
from AR.models.t2s_model import Text2SemanticDecoder
|
10 |
+
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
11 |
+
from AR.modules.optim import ScaledAdam
|
12 |
+
|
13 |
+
|
14 |
+
class Text2SemanticLightningModule(LightningModule):
|
15 |
+
def __init__(self, config, output_dir,is_train=True):
|
16 |
+
super().__init__()
|
17 |
+
self.config = config
|
18 |
+
self.top_k = 3
|
19 |
+
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
20 |
+
pretrained_s1=config.get("pretrained_s1")
|
21 |
+
if(pretrained_s1 and is_train):
|
22 |
+
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
23 |
+
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
|
24 |
+
if is_train:
|
25 |
+
self.automatic_optimization = False
|
26 |
+
self.save_hyperparameters()
|
27 |
+
self.eval_dir = output_dir / 'eval'
|
28 |
+
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
def training_step(self, batch: Dict, batch_idx: int):
|
31 |
+
|
32 |
+
opt = self.optimizers()
|
33 |
+
scheduler = self.lr_schedulers()
|
34 |
+
loss, acc = self.model.forward(
|
35 |
+
batch['phoneme_ids'], batch['phoneme_ids_len'],
|
36 |
+
batch['semantic_ids'], batch['semantic_ids_len'],
|
37 |
+
batch['bert_feature'])
|
38 |
+
self.manual_backward(loss)
|
39 |
+
if batch_idx > 0 and batch_idx % 4 == 0:
|
40 |
+
opt.step()
|
41 |
+
opt.zero_grad()
|
42 |
+
scheduler.step()
|
43 |
+
|
44 |
+
self.log(
|
45 |
+
"total_loss",
|
46 |
+
loss,
|
47 |
+
on_step=True,
|
48 |
+
on_epoch=True,
|
49 |
+
prog_bar=True,
|
50 |
+
sync_dist=True)
|
51 |
+
self.log(
|
52 |
+
"lr",
|
53 |
+
scheduler.get_last_lr()[0],
|
54 |
+
on_epoch=True,
|
55 |
+
prog_bar=True,
|
56 |
+
sync_dist=True)
|
57 |
+
self.log(
|
58 |
+
f"top_{self.top_k}_acc",
|
59 |
+
acc,
|
60 |
+
on_step=True,
|
61 |
+
on_epoch=True,
|
62 |
+
prog_bar=True,
|
63 |
+
sync_dist=True)
|
64 |
+
|
65 |
+
def validation_step(self, batch: Dict, batch_idx: int):return
|
66 |
+
# # get loss
|
67 |
+
# loss, acc = self.model.forward(
|
68 |
+
# batch['phoneme_ids'], batch['phoneme_ids_len'],
|
69 |
+
# batch['semantic_ids'], batch['semantic_ids_len'],
|
70 |
+
# batch['bert_feature']
|
71 |
+
# )
|
72 |
+
#
|
73 |
+
# self.log(
|
74 |
+
# "val_total_loss",
|
75 |
+
# loss,
|
76 |
+
# on_step=True,
|
77 |
+
# on_epoch=True,
|
78 |
+
# prog_bar=True,
|
79 |
+
# sync_dist=True)
|
80 |
+
# self.log(
|
81 |
+
# f"val_top_{self.top_k}_acc",
|
82 |
+
# acc,
|
83 |
+
# on_step=True,
|
84 |
+
# on_epoch=True,
|
85 |
+
# prog_bar=True,
|
86 |
+
# sync_dist=True)
|
87 |
+
#
|
88 |
+
# # get infer output
|
89 |
+
# semantic_len = batch['semantic_ids'].size(1)
|
90 |
+
# prompt_len = min(int(semantic_len * 0.5), 150)
|
91 |
+
# prompt = batch['semantic_ids'][:, :prompt_len]
|
92 |
+
# pred_semantic = self.model.infer(batch['phoneme_ids'],
|
93 |
+
# batch['phoneme_ids_len'], prompt,
|
94 |
+
# batch['bert_feature']
|
95 |
+
# )
|
96 |
+
# save_name = f'semantic_toks_{batch_idx}.pt'
|
97 |
+
# save_path = os.path.join(self.eval_dir, save_name)
|
98 |
+
# torch.save(pred_semantic.detach().cpu(), save_path)
|
99 |
+
|
100 |
+
def configure_optimizers(self):
|
101 |
+
model_parameters = self.model.parameters()
|
102 |
+
parameters_names = []
|
103 |
+
parameters_names.append([
|
104 |
+
name_param_pair[0]
|
105 |
+
for name_param_pair in self.model.named_parameters()
|
106 |
+
])
|
107 |
+
lm_opt = ScaledAdam(
|
108 |
+
model_parameters,
|
109 |
+
lr=0.01,
|
110 |
+
betas=(0.9, 0.95),
|
111 |
+
clipping_scale=2.0,
|
112 |
+
parameters_names=parameters_names,
|
113 |
+
show_dominant_parameters=False,
|
114 |
+
clipping_update_period=1000, )
|
115 |
+
|
116 |
+
return {
|
117 |
+
"optimizer": lm_opt,
|
118 |
+
"lr_scheduler": {
|
119 |
+
"scheduler":
|
120 |
+
WarmupCosineLRSchedule(
|
121 |
+
lm_opt,
|
122 |
+
init_lr=self.config['optimizer']['lr_init'],
|
123 |
+
peak_lr=self.config['optimizer']['lr'],
|
124 |
+
end_lr=self.config['optimizer']['lr_end'],
|
125 |
+
warmup_steps=self.config['optimizer']['warmup_steps'],
|
126 |
+
total_steps=self.config['optimizer']['decay_steps'])
|
127 |
+
}
|
128 |
+
}
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from AR.models.utils import make_pad_mask
|
6 |
+
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
|
7 |
+
from AR.modules.embedding import SinePositionalEmbedding
|
8 |
+
from AR.modules.embedding import TokenEmbedding
|
9 |
+
from AR.modules.transformer import LayerNorm
|
10 |
+
from AR.modules.transformer import TransformerEncoder
|
11 |
+
from AR.modules.transformer import TransformerEncoderLayer
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torchmetrics.classification import MulticlassAccuracy
|
15 |
+
|
16 |
+
default_config = {
|
17 |
+
"embedding_dim": 512,
|
18 |
+
"hidden_dim": 512,
|
19 |
+
"num_head": 8,
|
20 |
+
"num_layers": 12,
|
21 |
+
"num_codebook": 8,
|
22 |
+
"p_dropout": 0.0,
|
23 |
+
"vocab_size": 1024 + 1,
|
24 |
+
"phoneme_vocab_size": 512,
|
25 |
+
"EOS": 1024
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class Text2SemanticDecoder(nn.Module):
|
30 |
+
def __init__(self, config, norm_first=False, top_k=3):
|
31 |
+
super(Text2SemanticDecoder, self).__init__()
|
32 |
+
self.model_dim = config['model']["hidden_dim"]
|
33 |
+
self.embedding_dim = config['model']["embedding_dim"]
|
34 |
+
self.num_head = config['model']["head"]
|
35 |
+
self.num_layers = config['model']["n_layer"]
|
36 |
+
self.norm_first = norm_first
|
37 |
+
self.vocab_size = config['model']["vocab_size"]
|
38 |
+
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
|
39 |
+
self.p_dropout = config['model']["dropout"]
|
40 |
+
self.EOS = config['model']["EOS"]
|
41 |
+
self.norm_first = norm_first
|
42 |
+
assert self.EOS == self.vocab_size - 1
|
43 |
+
# should be same as num of kmeans bin
|
44 |
+
# assert self.EOS == 1024
|
45 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
46 |
+
self.ar_text_embedding = TokenEmbedding(
|
47 |
+
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
|
48 |
+
self.ar_text_position = SinePositionalEmbedding(
|
49 |
+
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
50 |
+
self.ar_audio_embedding = TokenEmbedding(
|
51 |
+
self.embedding_dim, self.vocab_size, self.p_dropout)
|
52 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
53 |
+
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
54 |
+
|
55 |
+
self.h = TransformerEncoder(
|
56 |
+
TransformerEncoderLayer(
|
57 |
+
d_model=self.model_dim,
|
58 |
+
nhead=self.num_head,
|
59 |
+
dim_feedforward=self.model_dim * 4,
|
60 |
+
dropout=0.1,
|
61 |
+
batch_first=True,
|
62 |
+
norm_first=norm_first, ),
|
63 |
+
num_layers=self.num_layers,
|
64 |
+
norm=LayerNorm(self.model_dim) if norm_first else None, )
|
65 |
+
|
66 |
+
self.ar_predict_layer = nn.Linear(
|
67 |
+
self.model_dim, self.vocab_size, bias=False)
|
68 |
+
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
|
69 |
+
|
70 |
+
self.ar_accuracy_metric = MulticlassAccuracy(
|
71 |
+
self.vocab_size,
|
72 |
+
top_k=top_k,
|
73 |
+
average="micro",
|
74 |
+
multidim_average="global",
|
75 |
+
ignore_index=self.EOS, )
|
76 |
+
|
77 |
+
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
78 |
+
'''
|
79 |
+
x: phoneme_ids
|
80 |
+
y: semantic_ids
|
81 |
+
'''
|
82 |
+
x = self.ar_text_embedding(x)
|
83 |
+
x = x + self.bert_proj(bert_feature.transpose(1,2))
|
84 |
+
x = self.ar_text_position(x)
|
85 |
+
x_mask = make_pad_mask(x_lens)
|
86 |
+
|
87 |
+
y_mask = make_pad_mask(y_lens)
|
88 |
+
y_mask_int = y_mask.type(torch.int64)
|
89 |
+
codes = y.type(torch.int64) * (1 - y_mask_int)
|
90 |
+
|
91 |
+
# Training
|
92 |
+
# AR Decoder
|
93 |
+
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
|
94 |
+
x_len = x_lens.max()
|
95 |
+
y_len = y_lens.max()
|
96 |
+
y_emb = self.ar_audio_embedding(y)
|
97 |
+
y_pos = self.ar_audio_position(y_emb)
|
98 |
+
|
99 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
100 |
+
ar_xy_padding_mask = xy_padding_mask
|
101 |
+
|
102 |
+
x_attn_mask = F.pad(
|
103 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
104 |
+
(0, y_len),
|
105 |
+
value=True, )
|
106 |
+
y_attn_mask = F.pad(
|
107 |
+
torch.triu(
|
108 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
109 |
+
diagonal=1, ),
|
110 |
+
(x_len, 0),
|
111 |
+
value=False, )
|
112 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
113 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
114 |
+
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
115 |
+
.expand(-1, self.num_head, -1, -1)
|
116 |
+
.reshape(bsz * self.num_head, 1, src_len))
|
117 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
118 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
119 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
120 |
+
xy_attn_mask = new_attn_mask
|
121 |
+
# x 和完整的 y 一次性输入模型
|
122 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
123 |
+
xy_dec, _ = self.h(
|
124 |
+
(xy_pos, None),
|
125 |
+
mask=xy_attn_mask, )
|
126 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
127 |
+
# loss
|
128 |
+
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
129 |
+
loss = F.cross_entropy(logits, targets, reduction='sum')
|
130 |
+
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
|
131 |
+
return loss, acc
|
132 |
+
|
133 |
+
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
134 |
+
def infer(self,
|
135 |
+
x,
|
136 |
+
x_lens,
|
137 |
+
prompts,
|
138 |
+
bert_feature,
|
139 |
+
top_k: int=-100,
|
140 |
+
early_stop_num: int=-1,
|
141 |
+
temperature: float=1.0):
|
142 |
+
|
143 |
+
x = self.ar_text_embedding(x)
|
144 |
+
x = x + self.bert_proj(bert_feature.transpose(1,2))
|
145 |
+
x = self.ar_text_position(x)
|
146 |
+
|
147 |
+
# AR Decoder
|
148 |
+
y = prompts
|
149 |
+
prefix_len = y.shape[1]
|
150 |
+
x_len = x.shape[1]
|
151 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
152 |
+
stop = False
|
153 |
+
for _ in tqdm(range(1500)):
|
154 |
+
y_emb = self.ar_audio_embedding(y)
|
155 |
+
y_pos = self.ar_audio_position(y_emb)
|
156 |
+
# x 和逐渐增长的 y 一起输入给模型
|
157 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
158 |
+
y_len = y.shape[1]
|
159 |
+
x_attn_mask_pad = F.pad(
|
160 |
+
x_attn_mask,
|
161 |
+
(0, y_len),
|
162 |
+
value=True, )
|
163 |
+
y_attn_mask = F.pad(
|
164 |
+
torch.triu(
|
165 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
166 |
+
(x_len, 0),
|
167 |
+
value=False, )
|
168 |
+
xy_attn_mask = torch.concat(
|
169 |
+
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
170 |
+
|
171 |
+
xy_dec, _ = self.h(
|
172 |
+
(xy_pos, None),
|
173 |
+
mask=xy_attn_mask, )
|
174 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
175 |
+
samples = topk_sampling(
|
176 |
+
logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
177 |
+
|
178 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len
|
179 |
+
) > early_stop_num:
|
180 |
+
print("use early stop num:", early_stop_num)
|
181 |
+
stop = True
|
182 |
+
|
183 |
+
if torch.argmax(
|
184 |
+
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
185 |
+
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
186 |
+
stop = True
|
187 |
+
if stop:
|
188 |
+
if prompts.shape[1] == y.shape[1]:
|
189 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
190 |
+
print('bad zero prediction')
|
191 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
192 |
+
break
|
193 |
+
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
194 |
+
# print(samples.shape)#[1,1]#第一个1是bs
|
195 |
+
# import os
|
196 |
+
# os._exit(2333)
|
197 |
+
y = torch.concat([y, samples], dim=1)
|
198 |
+
return y
|
199 |
+
|
200 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
201 |
+
targets = F.pad(
|
202 |
+
y, (0, 1), value=0) + eos_id * F.pad(
|
203 |
+
y_mask_int, (0, 1), value=1)
|
204 |
+
# 错位
|
205 |
+
return targets[:, :-1], targets[:, 1:]
|
206 |
+
|
207 |
+
def infer_panel(self,
|
208 |
+
x,#####全部文本token
|
209 |
+
x_lens,
|
210 |
+
prompts,####参考音频token
|
211 |
+
bert_feature,
|
212 |
+
top_k: int=-100,
|
213 |
+
early_stop_num: int=-1,
|
214 |
+
temperature: float=1.0):
|
215 |
+
|
216 |
+
x = self.ar_text_embedding(x)
|
217 |
+
x = x + self.bert_proj(bert_feature.transpose(1,2))
|
218 |
+
x = self.ar_text_position(x)
|
219 |
+
|
220 |
+
# AR Decoder
|
221 |
+
y = prompts
|
222 |
+
prefix_len = y.shape[1]
|
223 |
+
x_len = x.shape[1]
|
224 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
225 |
+
stop = False
|
226 |
+
# print(1111111,self.num_layers)
|
227 |
+
cache={
|
228 |
+
"all_stage":self.num_layers,
|
229 |
+
"k":[None]*self.num_layers,###根据配置自己手写
|
230 |
+
"v":[None]*self.num_layers,
|
231 |
+
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
|
232 |
+
"y_emb":None,##只需要对最新的samples求emb,再拼历史的就行
|
233 |
+
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
|
234 |
+
# "xy_dec":None,###不需要,本来只需要最后一个做logits
|
235 |
+
"first_infer":1,
|
236 |
+
"stage":0
|
237 |
+
}
|
238 |
+
for idx in tqdm(range(1500)):
|
239 |
+
if(cache["first_infer"]==1):
|
240 |
+
y_emb = self.ar_audio_embedding(y)
|
241 |
+
else:
|
242 |
+
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
|
243 |
+
cache["y_emb"]=y_emb
|
244 |
+
y_pos = self.ar_audio_position(y_emb)
|
245 |
+
# x 和逐渐增长的 y 一起输入给模型
|
246 |
+
if(cache["first_infer"]==1):
|
247 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
248 |
+
else:
|
249 |
+
xy_pos=y_pos[:,-1:]
|
250 |
+
y_len = y_pos.shape[1]
|
251 |
+
###以下3个不做缓存
|
252 |
+
if (cache["first_infer"] == 1):
|
253 |
+
x_attn_mask_pad = F.pad(
|
254 |
+
x_attn_mask,
|
255 |
+
(0, y_len),###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
256 |
+
value=True, )
|
257 |
+
y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
|
258 |
+
torch.triu(
|
259 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
260 |
+
(x_len, 0),
|
261 |
+
value=False, )
|
262 |
+
xy_attn_mask = torch.concat(
|
263 |
+
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
264 |
+
else:
|
265 |
+
###最右边一列(是错的)
|
266 |
+
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
|
267 |
+
# xy_attn_mask[:,-1]=False
|
268 |
+
###最下面一行(是对的)
|
269 |
+
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
|
270 |
+
# pdb.set_trace()
|
271 |
+
###缓存重头戏
|
272 |
+
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
|
273 |
+
xy_dec, _ = self.h(
|
274 |
+
(xy_pos, None),
|
275 |
+
mask=xy_attn_mask,cache=cache )
|
276 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
|
277 |
+
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
278 |
+
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
279 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len
|
280 |
+
) > early_stop_num:
|
281 |
+
print("use early stop num:", early_stop_num)
|
282 |
+
stop = True
|
283 |
+
|
284 |
+
if torch.argmax(
|
285 |
+
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
286 |
+
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
287 |
+
stop = True
|
288 |
+
if stop:
|
289 |
+
if prompts.shape[1] == y.shape[1]:
|
290 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
291 |
+
print('bad zero prediction')
|
292 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
293 |
+
break
|
294 |
+
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
295 |
+
# print(samples.shape)#[1,1]#第一个1是bs
|
296 |
+
y = torch.concat([y, samples], dim=1)
|
297 |
+
cache["first_infer"]=0
|
298 |
+
return y,idx
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/utils.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
|
7 |
+
def sequence_mask(length, max_length=None):
|
8 |
+
if max_length is None:
|
9 |
+
max_length = length.max()
|
10 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
11 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
12 |
+
|
13 |
+
|
14 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
lengths:
|
18 |
+
A 1-D tensor containing sentence lengths.
|
19 |
+
max_len:
|
20 |
+
The length of masks.
|
21 |
+
Returns:
|
22 |
+
Return a 2-D bool tensor, where masked positions
|
23 |
+
are filled with `True` and non-masked positions are
|
24 |
+
filled with `False`.
|
25 |
+
|
26 |
+
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
27 |
+
#>>> make_pad_mask(lengths)
|
28 |
+
tensor([[False, True, True, True, True],
|
29 |
+
[False, False, False, True, True],
|
30 |
+
[False, False, True, True, True],
|
31 |
+
[False, False, False, False, False]])
|
32 |
+
"""
|
33 |
+
assert lengths.ndim == 1, lengths.ndim
|
34 |
+
max_len = max(max_len, lengths.max())
|
35 |
+
n = lengths.size(0)
|
36 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
37 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
38 |
+
|
39 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
40 |
+
|
41 |
+
|
42 |
+
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
43 |
+
def top_k_top_p_filtering(logits,
|
44 |
+
top_k=0,
|
45 |
+
top_p=1.0,
|
46 |
+
filter_value=-float("Inf"),
|
47 |
+
min_tokens_to_keep=1):
|
48 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
49 |
+
Args:
|
50 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
51 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
52 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
53 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
54 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
55 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
56 |
+
"""
|
57 |
+
if top_k > 0:
|
58 |
+
top_k = min(max(top_k, min_tokens_to_keep),
|
59 |
+
logits.size(-1)) # Safety check
|
60 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
61 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
62 |
+
logits[indices_to_remove] = filter_value
|
63 |
+
|
64 |
+
if top_p < 1.0:
|
65 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
66 |
+
cumulative_probs = torch.cumsum(
|
67 |
+
F.softmax(sorted_logits, dim=-1), dim=-1)
|
68 |
+
|
69 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
70 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
71 |
+
if min_tokens_to_keep > 1:
|
72 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
73 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
74 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
75 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
76 |
+
..., :-1].clone()
|
77 |
+
sorted_indices_to_remove[..., 0] = 0
|
78 |
+
|
79 |
+
# scatter sorted tensors to original indexing
|
80 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
81 |
+
1, sorted_indices, sorted_indices_to_remove)
|
82 |
+
logits[indices_to_remove] = filter_value
|
83 |
+
return logits
|
84 |
+
|
85 |
+
|
86 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
87 |
+
# temperature: (`optional`) float
|
88 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
89 |
+
# top_k: (`optional`) int
|
90 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
91 |
+
# top_p: (`optional`) float
|
92 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
93 |
+
|
94 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
95 |
+
if temperature != 1.0:
|
96 |
+
logits = logits / temperature
|
97 |
+
# Top-p/top-k filtering
|
98 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
99 |
+
# Sample
|
100 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
101 |
+
return token
|
102 |
+
|
103 |
+
|
104 |
+
from typing import Optional, Tuple
|
105 |
+
def multinomial_sample_one_no_sync(
|
106 |
+
probs_sort,
|
107 |
+
): # Does multinomial sampling without a cuda synchronization
|
108 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
109 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
110 |
+
|
111 |
+
|
112 |
+
def logits_to_probs(
|
113 |
+
logits,
|
114 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
115 |
+
temperature: float = 1.0,
|
116 |
+
top_k: Optional[int] = None,
|
117 |
+
top_p: Optional[int] = None,
|
118 |
+
repetition_penalty: float = 1.0,
|
119 |
+
):
|
120 |
+
previous_tokens=previous_tokens.squeeze()
|
121 |
+
# print(logits.shape,previous_tokens.shape)
|
122 |
+
# pdb.set_trace()
|
123 |
+
if previous_tokens is not None and repetition_penalty != 1.0:
|
124 |
+
previous_tokens = previous_tokens.long()
|
125 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
126 |
+
score = torch.where(
|
127 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
128 |
+
)
|
129 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
130 |
+
|
131 |
+
if top_p is not None and top_p < 1.0:
|
132 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
133 |
+
cum_probs = torch.cumsum(
|
134 |
+
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
135 |
+
)
|
136 |
+
sorted_indices_to_remove = cum_probs > top_p
|
137 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
138 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
139 |
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
140 |
+
)
|
141 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
142 |
+
|
143 |
+
logits = logits / max(temperature, 1e-5)
|
144 |
+
|
145 |
+
if top_k is not None:
|
146 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
147 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
148 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
149 |
+
|
150 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
151 |
+
return probs
|
152 |
+
|
153 |
+
|
154 |
+
def sample(
|
155 |
+
logits,
|
156 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
157 |
+
**sampling_kwargs,
|
158 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
159 |
+
probs = logits_to_probs(
|
160 |
+
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
161 |
+
)
|
162 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
163 |
+
return idx_next, probs
|
164 |
+
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/__init__.py
ADDED
File without changes
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/activation.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
2 |
+
from typing import Optional
|
3 |
+
from typing import Tuple
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear
|
7 |
+
from torch.nn import Module
|
8 |
+
from torch.nn.init import constant_
|
9 |
+
from torch.nn.init import xavier_normal_
|
10 |
+
from torch.nn.init import xavier_uniform_
|
11 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
12 |
+
from torch.nn.parameter import Parameter
|
13 |
+
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
16 |
+
F.multi_head_attention_forward=multi_head_attention_forward_patched
|
17 |
+
|
18 |
+
class MultiheadAttention(Module):
|
19 |
+
r"""Allows the model to jointly attend to information
|
20 |
+
from different representation subspaces as described in the paper:
|
21 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
22 |
+
|
23 |
+
Multi-Head Attention is defined as:
|
24 |
+
|
25 |
+
.. math::
|
26 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
27 |
+
|
28 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
29 |
+
|
30 |
+
``forward()`` will use a special optimized implementation if all of the following
|
31 |
+
conditions are met:
|
32 |
+
|
33 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
34 |
+
restriction will be loosened in the future.)
|
35 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
36 |
+
- training is disabled (using ``.eval()``)
|
37 |
+
- dropout is 0
|
38 |
+
- ``add_bias_kv`` is ``False``
|
39 |
+
- ``add_zero_attn`` is ``False``
|
40 |
+
- ``batch_first`` is ``True`` and the input is batched
|
41 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
42 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
43 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
44 |
+
nor ``attn_mask`` is passed
|
45 |
+
|
46 |
+
If the optimized implementation is in use, a
|
47 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
48 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
49 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
50 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
51 |
+
that is padding can be expected.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
embed_dim: Total dimension of the model.
|
55 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
56 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
57 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
58 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
59 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
60 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
61 |
+
Default: ``False``.
|
62 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
63 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
64 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
65 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
66 |
+
|
67 |
+
Examples::
|
68 |
+
|
69 |
+
>>> # xdoctest: +SKIP
|
70 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
71 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
72 |
+
|
73 |
+
"""
|
74 |
+
__constants__ = ["batch_first"]
|
75 |
+
bias_k: Optional[torch.Tensor]
|
76 |
+
bias_v: Optional[torch.Tensor]
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
embed_dim,
|
81 |
+
num_heads,
|
82 |
+
dropout=0.0,
|
83 |
+
bias=True,
|
84 |
+
add_bias_kv=False,
|
85 |
+
add_zero_attn=False,
|
86 |
+
kdim=None,
|
87 |
+
vdim=None,
|
88 |
+
batch_first=False,
|
89 |
+
linear1_cls=Linear,
|
90 |
+
linear2_cls=Linear,
|
91 |
+
device=None,
|
92 |
+
dtype=None, ) -> None:
|
93 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
94 |
+
super(MultiheadAttention, self).__init__()
|
95 |
+
self.embed_dim = embed_dim
|
96 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
97 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
98 |
+
self._qkv_same_embed_dim = (self.kdim == embed_dim and
|
99 |
+
self.vdim == embed_dim)
|
100 |
+
|
101 |
+
self.num_heads = num_heads
|
102 |
+
self.dropout = dropout
|
103 |
+
self.batch_first = batch_first
|
104 |
+
self.head_dim = embed_dim // num_heads
|
105 |
+
assert (self.head_dim * num_heads == self.embed_dim
|
106 |
+
), "embed_dim must be divisible by num_heads"
|
107 |
+
|
108 |
+
if add_bias_kv:
|
109 |
+
self.bias_k = Parameter(
|
110 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs))
|
111 |
+
self.bias_v = Parameter(
|
112 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs))
|
113 |
+
else:
|
114 |
+
self.bias_k = self.bias_v = None
|
115 |
+
|
116 |
+
if linear1_cls == Linear:
|
117 |
+
if not self._qkv_same_embed_dim:
|
118 |
+
self.q_proj_weight = Parameter(
|
119 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs))
|
120 |
+
self.k_proj_weight = Parameter(
|
121 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs))
|
122 |
+
self.v_proj_weight = Parameter(
|
123 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs))
|
124 |
+
self.register_parameter("in_proj_weight", None)
|
125 |
+
else:
|
126 |
+
self.in_proj_weight = Parameter(
|
127 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
|
128 |
+
self.register_parameter("q_proj_weight", None)
|
129 |
+
self.register_parameter("k_proj_weight", None)
|
130 |
+
self.register_parameter("v_proj_weight", None)
|
131 |
+
|
132 |
+
if bias:
|
133 |
+
self.in_proj_bias = Parameter(
|
134 |
+
torch.empty(3 * embed_dim, **factory_kwargs))
|
135 |
+
else:
|
136 |
+
self.register_parameter("in_proj_bias", None)
|
137 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
138 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
139 |
+
|
140 |
+
self._reset_parameters()
|
141 |
+
else:
|
142 |
+
if not self._qkv_same_embed_dim:
|
143 |
+
raise NotImplementedError
|
144 |
+
else:
|
145 |
+
self.in_proj_linear = linear1_cls(
|
146 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
147 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
148 |
+
|
149 |
+
self.register_parameter("q_proj_weight", None)
|
150 |
+
self.register_parameter("k_proj_weight", None)
|
151 |
+
self.register_parameter("v_proj_weight", None)
|
152 |
+
|
153 |
+
if bias:
|
154 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
155 |
+
else:
|
156 |
+
self.register_parameter("in_proj_bias", None)
|
157 |
+
|
158 |
+
self.out_proj = linear2_cls(
|
159 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
160 |
+
|
161 |
+
if self.bias_k is not None:
|
162 |
+
xavier_normal_(self.bias_k)
|
163 |
+
if self.bias_v is not None:
|
164 |
+
xavier_normal_(self.bias_v)
|
165 |
+
|
166 |
+
self.add_zero_attn = add_zero_attn
|
167 |
+
|
168 |
+
def _reset_parameters(self):
|
169 |
+
if self._qkv_same_embed_dim:
|
170 |
+
xavier_uniform_(self.in_proj_weight)
|
171 |
+
else:
|
172 |
+
xavier_uniform_(self.q_proj_weight)
|
173 |
+
xavier_uniform_(self.k_proj_weight)
|
174 |
+
xavier_uniform_(self.v_proj_weight)
|
175 |
+
|
176 |
+
if self.in_proj_bias is not None:
|
177 |
+
constant_(self.in_proj_bias, 0.0)
|
178 |
+
constant_(self.out_proj.bias, 0.0)
|
179 |
+
|
180 |
+
if self.bias_k is not None:
|
181 |
+
xavier_normal_(self.bias_k)
|
182 |
+
if self.bias_v is not None:
|
183 |
+
xavier_normal_(self.bias_v)
|
184 |
+
|
185 |
+
def __setstate__(self, state):
|
186 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
187 |
+
if "_qkv_same_embed_dim" not in state:
|
188 |
+
state["_qkv_same_embed_dim"] = True
|
189 |
+
|
190 |
+
super(MultiheadAttention, self).__setstate__(state)
|
191 |
+
|
192 |
+
def forward(
|
193 |
+
self,
|
194 |
+
query: Tensor,
|
195 |
+
key: Tensor,
|
196 |
+
value: Tensor,
|
197 |
+
key_padding_mask: Optional[Tensor]=None,
|
198 |
+
need_weights: bool=True,
|
199 |
+
attn_mask: Optional[Tensor]=None,
|
200 |
+
average_attn_weights: bool=True,cache=None
|
201 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
202 |
+
r"""
|
203 |
+
Args:
|
204 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
205 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
206 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
207 |
+
Queries are compared against key-value pairs to produce the output.
|
208 |
+
See "Attention Is All You Need" for more details.
|
209 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
210 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
211 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
212 |
+
See "Attention Is All You Need" for more details.
|
213 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
214 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
215 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
216 |
+
See "Attention Is All You Need" for more details.
|
217 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
218 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
219 |
+
Binary and byte masks are supported.
|
220 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
221 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
222 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
223 |
+
Default: ``True``.
|
224 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
225 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
226 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
227 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
228 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
229 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
230 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
231 |
+
the attention weight.
|
232 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
233 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
234 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
235 |
+
|
236 |
+
Outputs:
|
237 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
238 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
239 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
240 |
+
embedding dimension ``embed_dim``.
|
241 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
242 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
243 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
244 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
245 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
246 |
+
|
247 |
+
.. note::
|
248 |
+
`batch_first` argument is ignored for unbatched inputs.
|
249 |
+
"""
|
250 |
+
is_batched = query.dim() == 3
|
251 |
+
if key_padding_mask is not None:
|
252 |
+
_kpm_dtype = key_padding_mask.dtype
|
253 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
254 |
+
key_padding_mask):
|
255 |
+
raise AssertionError(
|
256 |
+
"only bool and floating types of key_padding_mask are supported"
|
257 |
+
)
|
258 |
+
why_not_fast_path = ""
|
259 |
+
if not is_batched:
|
260 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
261 |
+
elif query is not key or key is not value:
|
262 |
+
# When lifting this restriction, don't forget to either
|
263 |
+
# enforce that the dtypes all match or test cases where
|
264 |
+
# they don't!
|
265 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
266 |
+
elif (self.in_proj_bias is not None and
|
267 |
+
query.dtype != self.in_proj_bias.dtype):
|
268 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
269 |
+
elif (self.in_proj_weight is not None and
|
270 |
+
query.dtype != self.in_proj_weight.dtype):
|
271 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
272 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
273 |
+
elif self.training:
|
274 |
+
why_not_fast_path = "training is enabled"
|
275 |
+
elif not self.batch_first:
|
276 |
+
why_not_fast_path = "batch_first was not True"
|
277 |
+
elif self.bias_k is not None:
|
278 |
+
why_not_fast_path = "self.bias_k was not None"
|
279 |
+
elif self.bias_v is not None:
|
280 |
+
why_not_fast_path = "self.bias_v was not None"
|
281 |
+
elif self.dropout:
|
282 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
283 |
+
elif self.add_zero_attn:
|
284 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
285 |
+
elif not self._qkv_same_embed_dim:
|
286 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
287 |
+
elif attn_mask is not None:
|
288 |
+
why_not_fast_path = "attn_mask was not None"
|
289 |
+
elif query.is_nested and key_padding_mask is not None:
|
290 |
+
why_not_fast_path = (
|
291 |
+
"key_padding_mask is not supported with NestedTensor input")
|
292 |
+
elif self.num_heads % 2 == 1:
|
293 |
+
why_not_fast_path = "num_heads is odd"
|
294 |
+
elif torch.is_autocast_enabled():
|
295 |
+
why_not_fast_path = "autocast is enabled"
|
296 |
+
|
297 |
+
if not why_not_fast_path:
|
298 |
+
tensor_args = (query, key, value, self.in_proj_weight,
|
299 |
+
self.in_proj_bias, self.out_proj.weight,
|
300 |
+
self.out_proj.bias, )
|
301 |
+
# We have to use list comprehensions below because TorchScript does not support
|
302 |
+
# generator expressions.
|
303 |
+
if torch.overrides.has_torch_function(tensor_args):
|
304 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
305 |
+
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
|
306 |
+
for x in tensor_args]):
|
307 |
+
why_not_fast_path = (
|
308 |
+
"some Tensor argument is neither CUDA nor CPU")
|
309 |
+
elif torch.is_grad_enabled() and any(
|
310 |
+
[x is not None and x.requires_grad for x in tensor_args]):
|
311 |
+
why_not_fast_path = (
|
312 |
+
"grad is enabled and at least one of query or the "
|
313 |
+
"input/output projection weights or biases requires_grad")
|
314 |
+
if not why_not_fast_path:
|
315 |
+
return torch._native_multi_head_attention(
|
316 |
+
query,
|
317 |
+
key,
|
318 |
+
value,
|
319 |
+
self.embed_dim,
|
320 |
+
self.num_heads,
|
321 |
+
self.in_proj_weight,
|
322 |
+
self.in_proj_bias,
|
323 |
+
self.out_proj.weight,
|
324 |
+
self.out_proj.bias,
|
325 |
+
key_padding_mask
|
326 |
+
if key_padding_mask is not None else attn_mask,
|
327 |
+
need_weights,
|
328 |
+
average_attn_weights,
|
329 |
+
1 if key_padding_mask is not None else 0
|
330 |
+
if attn_mask is not None else None, )
|
331 |
+
|
332 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
333 |
+
assert not any_nested, (
|
334 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
335 |
+
+ f"The fast path was not hit because {why_not_fast_path}")
|
336 |
+
|
337 |
+
if self.batch_first and is_batched:
|
338 |
+
# make sure that the transpose op does not affect the "is" property
|
339 |
+
if key is value:
|
340 |
+
if query is key:
|
341 |
+
query = key = value = query.transpose(1, 0)
|
342 |
+
else:
|
343 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
344 |
+
value = key
|
345 |
+
else:
|
346 |
+
query, key, value = [
|
347 |
+
x.transpose(1, 0) for x in (query, key, value)
|
348 |
+
]
|
349 |
+
|
350 |
+
if not self._qkv_same_embed_dim:
|
351 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
352 |
+
query,
|
353 |
+
key,
|
354 |
+
value,
|
355 |
+
self.embed_dim,
|
356 |
+
self.num_heads,
|
357 |
+
self.in_proj_weight,
|
358 |
+
self.in_proj_bias,
|
359 |
+
self.bias_k,
|
360 |
+
self.bias_v,
|
361 |
+
self.add_zero_attn,
|
362 |
+
self.dropout,
|
363 |
+
self.out_proj.weight,
|
364 |
+
self.out_proj.bias,
|
365 |
+
training=self.training,
|
366 |
+
key_padding_mask=key_padding_mask,
|
367 |
+
need_weights=need_weights,
|
368 |
+
attn_mask=attn_mask,
|
369 |
+
use_separate_proj_weight=True,
|
370 |
+
q_proj_weight=self.q_proj_weight,
|
371 |
+
k_proj_weight=self.k_proj_weight,
|
372 |
+
v_proj_weight=self.v_proj_weight,
|
373 |
+
average_attn_weights=average_attn_weights,cache=cache )
|
374 |
+
else:
|
375 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
376 |
+
query,
|
377 |
+
key,
|
378 |
+
value,
|
379 |
+
self.embed_dim,
|
380 |
+
self.num_heads,
|
381 |
+
self.in_proj_weight,
|
382 |
+
self.in_proj_bias,
|
383 |
+
self.bias_k,
|
384 |
+
self.bias_v,
|
385 |
+
self.add_zero_attn,
|
386 |
+
self.dropout,
|
387 |
+
self.out_proj.weight,
|
388 |
+
self.out_proj.bias,
|
389 |
+
training=self.training,
|
390 |
+
key_padding_mask=key_padding_mask,
|
391 |
+
need_weights=need_weights,
|
392 |
+
attn_mask=attn_mask,
|
393 |
+
average_attn_weights=average_attn_weights,cache=cache )
|
394 |
+
if self.batch_first and is_batched:
|
395 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
396 |
+
else:
|
397 |
+
return attn_output, attn_output_weights
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/embedding.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class TokenEmbedding(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
embedding_dim: int,
|
12 |
+
vocab_size: int,
|
13 |
+
dropout: float=0.0, ):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.vocab_size = vocab_size
|
17 |
+
self.embedding_dim = embedding_dim
|
18 |
+
|
19 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
20 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
21 |
+
|
22 |
+
@property
|
23 |
+
def weight(self) -> torch.Tensor:
|
24 |
+
return self.word_embeddings.weight
|
25 |
+
|
26 |
+
def embedding(self, index: int) -> torch.Tensor:
|
27 |
+
return self.word_embeddings.weight[index:index + 1]
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor):
|
30 |
+
x = self.word_embeddings(x)
|
31 |
+
x = self.dropout(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class SinePositionalEmbedding(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
embedding_dim: int,
|
39 |
+
dropout: float=0.0,
|
40 |
+
scale: bool=False,
|
41 |
+
alpha: bool=False, ):
|
42 |
+
super().__init__()
|
43 |
+
self.embedding_dim = embedding_dim
|
44 |
+
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
45 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
46 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
47 |
+
|
48 |
+
self.reverse = False
|
49 |
+
self.pe = None
|
50 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
51 |
+
|
52 |
+
def extend_pe(self, x):
|
53 |
+
"""Reset the positional encodings."""
|
54 |
+
if self.pe is not None:
|
55 |
+
if self.pe.size(1) >= x.size(1):
|
56 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
57 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
58 |
+
return
|
59 |
+
pe = torch.zeros(x.size(1), self.embedding_dim)
|
60 |
+
if self.reverse:
|
61 |
+
position = torch.arange(
|
62 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
63 |
+
else:
|
64 |
+
position = torch.arange(
|
65 |
+
0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
66 |
+
div_term = torch.exp(
|
67 |
+
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
|
68 |
+
-(math.log(10000.0) / self.embedding_dim))
|
69 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
70 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
71 |
+
pe = pe.unsqueeze(0)
|
72 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
75 |
+
self.extend_pe(x)
|
76 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
77 |
+
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
|
78 |
+
return self.dropout(output)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/lr_schedulers.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
from torch import nn
|
7 |
+
from torch.optim import Adam
|
8 |
+
|
9 |
+
|
10 |
+
class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
11 |
+
"""
|
12 |
+
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
optimizer,
|
17 |
+
init_lr,
|
18 |
+
peak_lr,
|
19 |
+
end_lr,
|
20 |
+
warmup_steps=10000,
|
21 |
+
total_steps=400000,
|
22 |
+
current_step=0):
|
23 |
+
self.init_lr = init_lr
|
24 |
+
self.peak_lr = peak_lr
|
25 |
+
self.end_lr = end_lr
|
26 |
+
self.optimizer = optimizer
|
27 |
+
self._warmup_rate = (peak_lr - init_lr) / warmup_steps
|
28 |
+
self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
|
29 |
+
self._current_step = current_step
|
30 |
+
self.lr = init_lr
|
31 |
+
self.warmup_steps = warmup_steps
|
32 |
+
self.total_steps = total_steps
|
33 |
+
self._last_lr = [self.lr]
|
34 |
+
|
35 |
+
def set_lr(self, lr):
|
36 |
+
self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
|
37 |
+
for g in self.optimizer.param_groups:
|
38 |
+
# g['lr'] = lr
|
39 |
+
g['lr'] = self.end_lr###锁定用线性
|
40 |
+
|
41 |
+
def step(self):
|
42 |
+
if self._current_step < self.warmup_steps:
|
43 |
+
lr = self.init_lr + self._warmup_rate * self._current_step
|
44 |
+
|
45 |
+
elif self._current_step > self.total_steps:
|
46 |
+
lr = self.end_lr
|
47 |
+
|
48 |
+
else:
|
49 |
+
decay_ratio = (self._current_step - self.warmup_steps) / (
|
50 |
+
self.total_steps - self.warmup_steps)
|
51 |
+
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
52 |
+
raise RuntimeError(
|
53 |
+
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
54 |
+
)
|
55 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
56 |
+
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
57 |
+
|
58 |
+
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
|
59 |
+
self.set_lr(lr)
|
60 |
+
self.lr = lr
|
61 |
+
self._current_step += 1
|
62 |
+
return self.lr
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == '__main__':
|
67 |
+
m = nn.Linear(10, 10)
|
68 |
+
opt = Adam(m.parameters(), lr=1e-4)
|
69 |
+
s = WarmupCosineLRSchedule(
|
70 |
+
opt,
|
71 |
+
1e-6,
|
72 |
+
2e-4,
|
73 |
+
1e-6,
|
74 |
+
warmup_steps=2000,
|
75 |
+
total_steps=20000,
|
76 |
+
current_step=0)
|
77 |
+
lrs = []
|
78 |
+
for i in range(25000):
|
79 |
+
s.step()
|
80 |
+
lrs.append(s.lr)
|
81 |
+
print(s.lr)
|
82 |
+
|
83 |
+
plt.plot(lrs)
|
84 |
+
plt.plot(range(0, 25000), lrs)
|
85 |
+
plt.show()
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/optim.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import contextlib
|
17 |
+
import logging
|
18 |
+
from collections import defaultdict
|
19 |
+
from typing import List
|
20 |
+
from typing import Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import Tensor
|
24 |
+
from torch.optim import Optimizer
|
25 |
+
|
26 |
+
|
27 |
+
class BatchedOptimizer(Optimizer):
|
28 |
+
"""
|
29 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
30 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
31 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
32 |
+
as it reduces the number of kernels launched in the optimizer.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
params:
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, params, defaults):
|
39 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
40 |
+
|
41 |
+
@contextlib.contextmanager
|
42 |
+
def batched_params(self, param_group, group_params_names):
|
43 |
+
"""
|
44 |
+
This function returns (technically, yields) a list of
|
45 |
+
of tuples (p, state), where
|
46 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
47 |
+
that share the same shape, and its gradient is also stacked;
|
48 |
+
`state` is the state corresponding to this batch of parameters
|
49 |
+
(it will be physically located in the "state" for one of the real
|
50 |
+
parameters, the last one that has any particular shape and dtype).
|
51 |
+
|
52 |
+
This function is decorated as a context manager so that it can
|
53 |
+
write parameters back to their "real" locations.
|
54 |
+
|
55 |
+
The idea is, instead of doing:
|
56 |
+
<code>
|
57 |
+
for p in group["params"]:
|
58 |
+
state = self.state[p]
|
59 |
+
...
|
60 |
+
</code>
|
61 |
+
you can do:
|
62 |
+
<code>
|
63 |
+
with self.batched_params(group["params"]) as batches:
|
64 |
+
for p, state, p_names in batches:
|
65 |
+
...
|
66 |
+
</code>
|
67 |
+
|
68 |
+
Args:
|
69 |
+
group: a parameter group, which is a list of parameters; should be
|
70 |
+
one of self.param_groups.
|
71 |
+
group_params_names: name for each parameter in group,
|
72 |
+
which is List[str].
|
73 |
+
"""
|
74 |
+
batches = defaultdict(
|
75 |
+
list
|
76 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
77 |
+
batches_names = defaultdict(
|
78 |
+
list
|
79 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
80 |
+
|
81 |
+
assert len(param_group) == len(group_params_names)
|
82 |
+
for p, named_p in zip(param_group, group_params_names):
|
83 |
+
key = (str(p.dtype), *p.shape)
|
84 |
+
batches[key].append(p)
|
85 |
+
batches_names[key].append(named_p)
|
86 |
+
|
87 |
+
batches_names_keys = list(batches_names.keys())
|
88 |
+
sorted_idx = sorted(
|
89 |
+
range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
90 |
+
batches_names = [
|
91 |
+
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
92 |
+
]
|
93 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
94 |
+
|
95 |
+
stacked_params_dict = dict()
|
96 |
+
|
97 |
+
# turn batches into a list, in deterministic order.
|
98 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
99 |
+
# one for each batch in `batches`.
|
100 |
+
tuples = []
|
101 |
+
|
102 |
+
for batch, batch_names in zip(batches, batches_names):
|
103 |
+
p = batch[0]
|
104 |
+
# we arbitrarily store the state in the
|
105 |
+
# state corresponding to the 1st parameter in the
|
106 |
+
# group. class Optimizer will take care of saving/loading state.
|
107 |
+
state = self.state[p]
|
108 |
+
p_stacked = torch.stack(batch)
|
109 |
+
grad = torch.stack([
|
110 |
+
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
|
111 |
+
])
|
112 |
+
p_stacked.grad = grad
|
113 |
+
stacked_params_dict[key] = p_stacked
|
114 |
+
tuples.append((p_stacked, state, batch_names))
|
115 |
+
|
116 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
117 |
+
|
118 |
+
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
119 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
120 |
+
p.copy_(stacked_params[i])
|
121 |
+
|
122 |
+
|
123 |
+
class ScaledAdam(BatchedOptimizer):
|
124 |
+
"""
|
125 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
126 |
+
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
127 |
+
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
128 |
+
param = underlying_param * log_scale.exp())
|
129 |
+
|
130 |
+
|
131 |
+
Args:
|
132 |
+
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
133 |
+
lr: The learning rate. We will typically use a learning rate schedule that starts
|
134 |
+
at 0.03 and decreases over time, i.e. much higher than other common
|
135 |
+
optimizers.
|
136 |
+
clipping_scale: (e.g. 2.0)
|
137 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
138 |
+
over the whole model will be clipped to have 2-norm equal to
|
139 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
140 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
141 |
+
we mean after multiplying by the rms parameter value for this tensor
|
142 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
143 |
+
by this quantity.
|
144 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
145 |
+
Must satisfy 0 < beta <= beta2 < 1.
|
146 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
147 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
148 |
+
If each parameter were decomposed
|
149 |
+
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
150 |
+
would be a the scaling factor on the learning rate of p_scale.
|
151 |
+
eps: A general-purpose epsilon to prevent division by zero
|
152 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
153 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
154 |
+
parameter tensor to be >= this value)
|
155 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
156 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
157 |
+
parameter tensor to be <= this value)
|
158 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
159 |
+
model has any parameters with numel() == 1).
|
160 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
161 |
+
of the parameter tensor. This is provided to save a little time
|
162 |
+
in the update.
|
163 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
params,
|
169 |
+
lr=3e-02,
|
170 |
+
clipping_scale=None,
|
171 |
+
betas=(0.9, 0.98),
|
172 |
+
scalar_lr_scale=0.1,
|
173 |
+
eps=1.0e-08,
|
174 |
+
param_min_rms=1.0e-05,
|
175 |
+
param_max_rms=3.0,
|
176 |
+
scalar_max=10.0,
|
177 |
+
size_update_period=4,
|
178 |
+
clipping_update_period=100,
|
179 |
+
parameters_names=None,
|
180 |
+
show_dominant_parameters=True, ):
|
181 |
+
|
182 |
+
assert parameters_names is not None, (
|
183 |
+
"Please prepare parameters_names,"
|
184 |
+
"which is a List[List[str]]. Each List[str] is for a group"
|
185 |
+
"and each str is for a parameter")
|
186 |
+
defaults = dict(
|
187 |
+
lr=lr,
|
188 |
+
clipping_scale=clipping_scale,
|
189 |
+
betas=betas,
|
190 |
+
scalar_lr_scale=scalar_lr_scale,
|
191 |
+
eps=eps,
|
192 |
+
param_min_rms=param_min_rms,
|
193 |
+
param_max_rms=param_max_rms,
|
194 |
+
scalar_max=scalar_max,
|
195 |
+
size_update_period=size_update_period,
|
196 |
+
clipping_update_period=clipping_update_period, )
|
197 |
+
|
198 |
+
super(ScaledAdam, self).__init__(params, defaults)
|
199 |
+
assert len(self.param_groups) == len(parameters_names)
|
200 |
+
self.parameters_names = parameters_names
|
201 |
+
self.show_dominant_parameters = show_dominant_parameters
|
202 |
+
|
203 |
+
def __setstate__(self, state):
|
204 |
+
super(ScaledAdam, self).__setstate__(state)
|
205 |
+
|
206 |
+
@torch.no_grad()
|
207 |
+
def step(self, closure=None):
|
208 |
+
"""Performs a single optimization step.
|
209 |
+
|
210 |
+
Arguments:
|
211 |
+
closure (callable, optional): A closure that reevaluates the model
|
212 |
+
and returns the loss.
|
213 |
+
"""
|
214 |
+
loss = None
|
215 |
+
if closure is not None:
|
216 |
+
with torch.enable_grad():
|
217 |
+
loss = closure()
|
218 |
+
|
219 |
+
batch = True
|
220 |
+
|
221 |
+
for group, group_params_names in zip(self.param_groups,
|
222 |
+
self.parameters_names):
|
223 |
+
|
224 |
+
with self.batched_params(group["params"],
|
225 |
+
group_params_names) as batches:
|
226 |
+
|
227 |
+
# batches is list of pairs (stacked_param, state). stacked_param is like
|
228 |
+
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
229 |
+
# a stacking dim, it is not a real dim.
|
230 |
+
|
231 |
+
if (len(batches[0][1]) ==
|
232 |
+
0): # if len(first state) == 0: not yet initialized
|
233 |
+
clipping_scale = 1
|
234 |
+
else:
|
235 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
236 |
+
|
237 |
+
for p, state, _ in batches:
|
238 |
+
# Perform optimization step.
|
239 |
+
# grad is not going to be None, we handled that when creating the batches.
|
240 |
+
grad = p.grad
|
241 |
+
if grad.is_sparse:
|
242 |
+
raise RuntimeError(
|
243 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
244 |
+
)
|
245 |
+
# State initialization
|
246 |
+
if len(state) == 0:
|
247 |
+
self._init_state(group, p, state)
|
248 |
+
|
249 |
+
self._step_one_batch(group, p, state, clipping_scale)
|
250 |
+
|
251 |
+
return loss
|
252 |
+
|
253 |
+
def _init_state(self, group: dict, p: Tensor, state: dict):
|
254 |
+
"""
|
255 |
+
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
256 |
+
is actually the batch dimension, corresponding to batched-together
|
257 |
+
parameters of a given shape.
|
258 |
+
|
259 |
+
|
260 |
+
Args:
|
261 |
+
group: Dict to look up configuration values.
|
262 |
+
p: The parameter that we are initializing the state for
|
263 |
+
state: Dict from string to whatever state we are initializing
|
264 |
+
"""
|
265 |
+
size_update_period = group["size_update_period"]
|
266 |
+
|
267 |
+
state["step"] = 0
|
268 |
+
|
269 |
+
kwargs = {"device": p.device, "dtype": p.dtype}
|
270 |
+
|
271 |
+
# 'delta' implements conventional momentum. There are
|
272 |
+
# several different kinds of update going on, so rather than
|
273 |
+
# compute "exp_avg" like in Adam, we store and decay a
|
274 |
+
# parameter-change "delta", which combines all forms of
|
275 |
+
# update. this is equivalent to how it's done in Adam,
|
276 |
+
# except for the first few steps.
|
277 |
+
state["delta"] = torch.zeros_like(
|
278 |
+
p, memory_format=torch.preserve_format)
|
279 |
+
|
280 |
+
batch_size = p.shape[0]
|
281 |
+
numel = p.numel() // batch_size
|
282 |
+
numel = p.numel()
|
283 |
+
|
284 |
+
if numel > 1:
|
285 |
+
# "param_rms" just periodically records the scalar root-mean-square value of
|
286 |
+
# the parameter tensor.
|
287 |
+
# it has a shape like (batch_size, 1, 1, 1, 1)
|
288 |
+
param_rms = (
|
289 |
+
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
290 |
+
state["param_rms"] = param_rms
|
291 |
+
|
292 |
+
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
293 |
+
state["scale_grads"] = torch.zeros(size_update_period,
|
294 |
+
*param_rms.shape, **kwargs)
|
295 |
+
|
296 |
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
297 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
298 |
+
p, memory_format=torch.preserve_format)
|
299 |
+
|
300 |
+
def _get_clipping_scale(self,
|
301 |
+
group: dict,
|
302 |
+
tuples: List[Tuple[Tensor, dict, List[str]]]
|
303 |
+
) -> float:
|
304 |
+
"""
|
305 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
306 |
+
by this amount before applying the rest of the update.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
group: the parameter group, an item in self.param_groups
|
310 |
+
tuples: a list of tuples of (param, state, param_names)
|
311 |
+
where param is a batched set of parameters,
|
312 |
+
with a .grad (1st dim is batch dim)
|
313 |
+
and state is the state-dict where optimization parameters are kept.
|
314 |
+
param_names is a List[str] while each str is name for a parameter
|
315 |
+
in batched set of parameters "param".
|
316 |
+
"""
|
317 |
+
assert len(tuples) >= 1
|
318 |
+
clipping_scale = group["clipping_scale"]
|
319 |
+
(first_p, first_state, _) = tuples[0]
|
320 |
+
step = first_state["step"]
|
321 |
+
if clipping_scale is None or step == 0:
|
322 |
+
# no clipping. return early on step == 0 because the other
|
323 |
+
# parameters' state won't have been initialized yet.
|
324 |
+
return 1.0
|
325 |
+
clipping_update_period = group["clipping_update_period"]
|
326 |
+
|
327 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
328 |
+
for (p, state, param_names) in tuples:
|
329 |
+
grad = p.grad
|
330 |
+
if grad.is_sparse:
|
331 |
+
raise RuntimeError(
|
332 |
+
"ScaledAdam optimizer does not support sparse gradients")
|
333 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
334 |
+
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
335 |
+
else:
|
336 |
+
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
337 |
+
|
338 |
+
tot_norm = tot_sumsq.sqrt()
|
339 |
+
if "model_norms" not in first_state:
|
340 |
+
first_state["model_norms"] = torch.zeros(
|
341 |
+
clipping_update_period, device=p.device)
|
342 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
343 |
+
|
344 |
+
if step % clipping_update_period == 0:
|
345 |
+
# Print some stats.
|
346 |
+
# We don't reach here if step == 0 because we would have returned
|
347 |
+
# above.
|
348 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
349 |
+
quartiles = []
|
350 |
+
for n in range(0, 5):
|
351 |
+
index = min(
|
352 |
+
clipping_update_period - 1,
|
353 |
+
(clipping_update_period // 4) * n, )
|
354 |
+
quartiles.append(sorted_norms[index].item())
|
355 |
+
|
356 |
+
median = quartiles[2]
|
357 |
+
threshold = clipping_scale * median
|
358 |
+
first_state["model_norm_threshold"] = threshold
|
359 |
+
percent_clipped = (first_state["num_clipped"] * 100.0 /
|
360 |
+
clipping_update_period
|
361 |
+
if "num_clipped" in first_state else 0.0)
|
362 |
+
first_state["num_clipped"] = 0
|
363 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
364 |
+
logging.info(
|
365 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
366 |
+
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
367 |
+
)
|
368 |
+
|
369 |
+
if step < clipping_update_period:
|
370 |
+
return 1.0 # We have not yet estimated a norm to clip to.
|
371 |
+
else:
|
372 |
+
try:
|
373 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
374 |
+
except KeyError:
|
375 |
+
logging.info(
|
376 |
+
"Warning: model_norm_threshold not in state: possibly "
|
377 |
+
"you changed config when restarting, adding clipping_scale option?"
|
378 |
+
)
|
379 |
+
return 1.0
|
380 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
381 |
+
if ans < 1.0:
|
382 |
+
first_state["num_clipped"] += 1
|
383 |
+
if ans < 0.1:
|
384 |
+
logging.warn(
|
385 |
+
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
386 |
+
)
|
387 |
+
if self.show_dominant_parameters:
|
388 |
+
assert p.shape[0] == len(param_names)
|
389 |
+
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
390 |
+
return ans
|
391 |
+
|
392 |
+
def _show_gradient_dominating_parameter(
|
393 |
+
self, tuples: List[Tuple[Tensor, dict, List[str]]],
|
394 |
+
tot_sumsq: Tensor):
|
395 |
+
"""
|
396 |
+
Show information of parameter wihch dominanting tot_sumsq.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
tuples: a list of tuples of (param, state, param_names)
|
400 |
+
where param is a batched set of parameters,
|
401 |
+
with a .grad (1st dim is batch dim)
|
402 |
+
and state is the state-dict where optimization parameters are kept.
|
403 |
+
param_names is a List[str] while each str is name for a parameter
|
404 |
+
in batched set of parameters "param".
|
405 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
406 |
+
from tuples, we still pass it to save some time.
|
407 |
+
"""
|
408 |
+
all_sumsq_orig = {}
|
409 |
+
for (p, state, batch_param_names) in tuples:
|
410 |
+
# p is a stacked batch parameters.
|
411 |
+
batch_grad = p.grad
|
412 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
413 |
+
batch_sumsq_orig = batch_grad**2
|
414 |
+
# Dummpy values used by following `zip` statement.
|
415 |
+
batch_rms_orig = torch.ones(p.shape[0])
|
416 |
+
else:
|
417 |
+
batch_rms_orig = state["param_rms"]
|
418 |
+
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
|
419 |
+
dim=list(range(1, batch_grad.ndim)))
|
420 |
+
|
421 |
+
for name, sumsq_orig, rms, grad in zip(batch_param_names,
|
422 |
+
batch_sumsq_orig,
|
423 |
+
batch_rms_orig, batch_grad):
|
424 |
+
|
425 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
426 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
427 |
+
|
428 |
+
assert torch.isclose(
|
429 |
+
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
430 |
+
torch.tensor(1.0), )
|
431 |
+
sorted_by_proportion = {
|
432 |
+
k: v
|
433 |
+
for k, v in sorted(
|
434 |
+
all_sumsq_orig.items(),
|
435 |
+
key=lambda item: item[1][0],
|
436 |
+
reverse=True, )
|
437 |
+
}
|
438 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
439 |
+
(dominant_proportion, dominant_sumsq, dominant_rms,
|
440 |
+
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
|
441 |
+
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
442 |
+
f" with proportion {dominant_proportion:.2f},"
|
443 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
444 |
+
f"={dominant_sumsq:.3e},"
|
445 |
+
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
446 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
|
447 |
+
|
448 |
+
def _step_one_batch(self,
|
449 |
+
group: dict,
|
450 |
+
p: Tensor,
|
451 |
+
state: dict,
|
452 |
+
clipping_scale: float):
|
453 |
+
"""
|
454 |
+
Do the step for one parameter, which is actually going to be a batch of
|
455 |
+
`real` parameters, with dim 0 as the batch dim.
|
456 |
+
Args:
|
457 |
+
group: dict to look up configuration values
|
458 |
+
p: parameter to update (actually multiple parameters stacked together
|
459 |
+
as a batch)
|
460 |
+
state: state-dict for p, to look up the optimizer state
|
461 |
+
"""
|
462 |
+
lr = group["lr"]
|
463 |
+
size_update_period = group["size_update_period"]
|
464 |
+
beta1 = group["betas"][0]
|
465 |
+
|
466 |
+
grad = p.grad
|
467 |
+
if clipping_scale != 1.0:
|
468 |
+
grad = grad * clipping_scale
|
469 |
+
step = state["step"]
|
470 |
+
delta = state["delta"]
|
471 |
+
|
472 |
+
delta.mul_(beta1)
|
473 |
+
batch_size = p.shape[0]
|
474 |
+
numel = p.numel() // batch_size
|
475 |
+
if numel > 1:
|
476 |
+
# Update the size/scale of p, and set param_rms
|
477 |
+
scale_grads = state["scale_grads"]
|
478 |
+
scale_grads[step % size_update_period] = (p * grad).sum(
|
479 |
+
dim=list(range(1, p.ndim)), keepdim=True)
|
480 |
+
if step % size_update_period == size_update_period - 1:
|
481 |
+
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
482 |
+
param_rms.copy_((p**2)
|
483 |
+
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
484 |
+
.sqrt())
|
485 |
+
if step > 0:
|
486 |
+
# self._size_update() learns the overall scale on the
|
487 |
+
# parameter, by shrinking or expanding it.
|
488 |
+
self._size_update(group, scale_grads, p, state)
|
489 |
+
|
490 |
+
if numel == 1:
|
491 |
+
# For parameters with 1 element we just use regular Adam.
|
492 |
+
# Updates delta.
|
493 |
+
self._step_scalar(group, p, state)
|
494 |
+
else:
|
495 |
+
self._step(group, p, state)
|
496 |
+
|
497 |
+
state["step"] = step + 1
|
498 |
+
|
499 |
+
def _size_update(self,
|
500 |
+
group: dict,
|
501 |
+
scale_grads: Tensor,
|
502 |
+
p: Tensor,
|
503 |
+
state: dict) -> None:
|
504 |
+
"""
|
505 |
+
Called only where p.numel() > 1, this updates the scale of the parameter.
|
506 |
+
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
507 |
+
gradient descent on underlying param and on scale, this function does the update
|
508 |
+
on `scale`.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
group: dict to look up configuration values
|
512 |
+
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
513 |
+
grads w.r.t. the scales.
|
514 |
+
p: The parameter to update
|
515 |
+
state: The state-dict of p
|
516 |
+
"""
|
517 |
+
|
518 |
+
param_rms = state["param_rms"]
|
519 |
+
beta1, beta2 = group["betas"]
|
520 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
521 |
+
param_min_rms = group["param_min_rms"]
|
522 |
+
param_max_rms = group["param_max_rms"]
|
523 |
+
eps = group["eps"]
|
524 |
+
step = state["step"]
|
525 |
+
batch_size = p.shape[0]
|
526 |
+
|
527 |
+
size_update_period = scale_grads.shape[0]
|
528 |
+
# correct beta2 for the size update period: we will have
|
529 |
+
# faster decay at this level.
|
530 |
+
beta2_corr = beta2**size_update_period
|
531 |
+
|
532 |
+
scale_exp_avg_sq = state[
|
533 |
+
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
534 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
535 |
+
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
536 |
+
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
|
537 |
+
|
538 |
+
# The 1st time we reach here is when size_step == 1.
|
539 |
+
size_step = (step + 1) // size_update_period
|
540 |
+
bias_correction2 = 1 - beta2_corr**size_step
|
541 |
+
# we don't bother with bias_correction1; this will help prevent divergence
|
542 |
+
# at the start of training.
|
543 |
+
|
544 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
545 |
+
|
546 |
+
scale_step = (-size_lr * (bias_correction2**0.5) *
|
547 |
+
scale_grads.sum(dim=0) / denom)
|
548 |
+
|
549 |
+
is_too_small = param_rms < param_min_rms
|
550 |
+
is_too_large = param_rms > param_max_rms
|
551 |
+
|
552 |
+
# when the param gets too small, just don't shrink it any further.
|
553 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
554 |
+
# when it gets too large, stop it from getting any larger.
|
555 |
+
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
556 |
+
delta = state["delta"]
|
557 |
+
# the factor of (1-beta1) relates to momentum.
|
558 |
+
delta.add_(p * scale_step, alpha=(1 - beta1))
|
559 |
+
|
560 |
+
def _step(self, group: dict, p: Tensor, state: dict):
|
561 |
+
"""
|
562 |
+
This function does the core update of self.step(), in the case where the members of
|
563 |
+
the batch have more than 1 element.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
group: A dict which will be used to look up configuration values
|
567 |
+
p: The parameter to be updated
|
568 |
+
grad: The grad of p
|
569 |
+
state: The state-dict corresponding to parameter p
|
570 |
+
|
571 |
+
This function modifies p.
|
572 |
+
"""
|
573 |
+
grad = p.grad
|
574 |
+
lr = group["lr"]
|
575 |
+
beta1, beta2 = group["betas"]
|
576 |
+
eps = group["eps"]
|
577 |
+
param_min_rms = group["param_min_rms"]
|
578 |
+
step = state["step"]
|
579 |
+
|
580 |
+
exp_avg_sq = state["exp_avg_sq"]
|
581 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
582 |
+
|
583 |
+
this_step = state["step"] - (state["zero_step"]
|
584 |
+
if "zero_step" in state else 0)
|
585 |
+
bias_correction2 = 1 - beta2**(this_step + 1)
|
586 |
+
if bias_correction2 < 0.99:
|
587 |
+
# note: not in-place.
|
588 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
589 |
+
|
590 |
+
denom = exp_avg_sq.sqrt()
|
591 |
+
denom += eps
|
592 |
+
grad = grad / denom
|
593 |
+
|
594 |
+
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
595 |
+
|
596 |
+
delta = state["delta"]
|
597 |
+
delta.add_(grad * alpha)
|
598 |
+
p.add_(delta)
|
599 |
+
|
600 |
+
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
601 |
+
"""
|
602 |
+
A simplified form of the core update for scalar tensors, where we cannot get a good
|
603 |
+
estimate of the parameter rms.
|
604 |
+
"""
|
605 |
+
beta1, beta2 = group["betas"]
|
606 |
+
scalar_max = group["scalar_max"]
|
607 |
+
eps = group["eps"]
|
608 |
+
lr = group["lr"] * group["scalar_lr_scale"]
|
609 |
+
grad = p.grad
|
610 |
+
|
611 |
+
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
612 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
613 |
+
|
614 |
+
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
615 |
+
# slower update at the start will help stability anyway.
|
616 |
+
bias_correction2 = 1 - beta2**(state["step"] + 1)
|
617 |
+
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
618 |
+
|
619 |
+
delta = state["delta"]
|
620 |
+
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
621 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
622 |
+
p.add_(delta)
|
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn.functional import *
|
2 |
+
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
|
3 |
+
# import torch
|
4 |
+
# Tensor = torch.Tensor
|
5 |
+
# from typing import Callable, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
def multi_head_attention_forward_patched(
|
8 |
+
query: Tensor,
|
9 |
+
key: Tensor,
|
10 |
+
value: Tensor,
|
11 |
+
embed_dim_to_check: int,
|
12 |
+
num_heads: int,
|
13 |
+
in_proj_weight: Optional[Tensor],
|
14 |
+
in_proj_bias: Optional[Tensor],
|
15 |
+
bias_k: Optional[Tensor],
|
16 |
+
bias_v: Optional[Tensor],
|
17 |
+
add_zero_attn: bool,
|
18 |
+
dropout_p: float,
|
19 |
+
out_proj_weight: Tensor,
|
20 |
+
out_proj_bias: Optional[Tensor],
|
21 |
+
training: bool = True,
|
22 |
+
key_padding_mask: Optional[Tensor] = None,
|
23 |
+
need_weights: bool = True,
|
24 |
+
attn_mask: Optional[Tensor] = None,
|
25 |
+
use_separate_proj_weight: bool = False,
|
26 |
+
q_proj_weight: Optional[Tensor] = None,
|
27 |
+
k_proj_weight: Optional[Tensor] = None,
|
28 |
+
v_proj_weight: Optional[Tensor] = None,
|
29 |
+
static_k: Optional[Tensor] = None,
|
30 |
+
static_v: Optional[Tensor] = None,
|
31 |
+
average_attn_weights: bool = True,
|
32 |
+
is_causal: bool = False,cache=None
|
33 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
34 |
+
r"""
|
35 |
+
Args:
|
36 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
37 |
+
See "Attention Is All You Need" for more details.
|
38 |
+
embed_dim_to_check: total dimension of the model.
|
39 |
+
num_heads: parallel attention heads.
|
40 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
41 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
42 |
+
add_zero_attn: add a new batch of zeros to the key and
|
43 |
+
value sequences at dim=1.
|
44 |
+
dropout_p: probability of an element to be zeroed.
|
45 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
46 |
+
training: apply dropout if is ``True``.
|
47 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
48 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
49 |
+
the corresponding value on the attention layer will be filled with -inf.
|
50 |
+
need_weights: output attn_output_weights.
|
51 |
+
Default: `True`
|
52 |
+
Note: `needs_weight` defaults to `True`, but should be set to `False`
|
53 |
+
For best performance when attention weights are not nedeeded.
|
54 |
+
*Setting needs_weights to `True`
|
55 |
+
leads to a significant performance degradation.*
|
56 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
57 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
58 |
+
is_causal: If specified, applies a causal mask as attention mask, and ignores
|
59 |
+
attn_mask for computing scaled dot product attention.
|
60 |
+
Default: ``False``.
|
61 |
+
.. warning::
|
62 |
+
is_causal is provides a hint that the attn_mask is the
|
63 |
+
causal mask.Providing incorrect hints can result in
|
64 |
+
incorrect execution, including forward and backward
|
65 |
+
compatibility.
|
66 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
67 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
68 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
69 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
70 |
+
static_k, static_v: static key and value used for attention operators.
|
71 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
|
72 |
+
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
|
73 |
+
when ``need_weights=True.``. Default: True
|
74 |
+
|
75 |
+
|
76 |
+
Shape:
|
77 |
+
Inputs:
|
78 |
+
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
79 |
+
the embedding dimension.
|
80 |
+
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
81 |
+
the embedding dimension.
|
82 |
+
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
83 |
+
the embedding dimension.
|
84 |
+
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
85 |
+
If a FloatTensor is provided, it will be directly added to the value.
|
86 |
+
If a BoolTensor is provided, the positions with the
|
87 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
88 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
89 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
90 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
91 |
+
positions. If a BoolTensor is provided, positions with ``True``
|
92 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
93 |
+
is provided, it will be added to the attention weight.
|
94 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
95 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
96 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
97 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
98 |
+
|
99 |
+
Outputs:
|
100 |
+
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
101 |
+
E is the embedding dimension.
|
102 |
+
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
|
103 |
+
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
104 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
105 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
106 |
+
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
107 |
+
"""
|
108 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
109 |
+
if has_torch_function(tens_ops):
|
110 |
+
return handle_torch_function(
|
111 |
+
multi_head_attention_forward,
|
112 |
+
tens_ops,
|
113 |
+
query,
|
114 |
+
key,
|
115 |
+
value,
|
116 |
+
embed_dim_to_check,
|
117 |
+
num_heads,
|
118 |
+
in_proj_weight,
|
119 |
+
in_proj_bias,
|
120 |
+
bias_k,
|
121 |
+
bias_v,
|
122 |
+
add_zero_attn,
|
123 |
+
dropout_p,
|
124 |
+
out_proj_weight,
|
125 |
+
out_proj_bias,
|
126 |
+
training=training,
|
127 |
+
key_padding_mask=key_padding_mask,
|
128 |
+
need_weights=need_weights,
|
129 |
+
attn_mask=attn_mask,
|
130 |
+
is_causal=is_causal,
|
131 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
132 |
+
q_proj_weight=q_proj_weight,
|
133 |
+
k_proj_weight=k_proj_weight,
|
134 |
+
v_proj_weight=v_proj_weight,
|
135 |
+
static_k=static_k,
|
136 |
+
static_v=static_v,
|
137 |
+
average_attn_weights=average_attn_weights,cache=cache
|
138 |
+
)
|
139 |
+
|
140 |
+
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
141 |
+
|
142 |
+
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
143 |
+
# is batched, run the computation and before returning squeeze the
|
144 |
+
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
145 |
+
if not is_batched:
|
146 |
+
# unsqueeze if the input is unbatched
|
147 |
+
query = query.unsqueeze(1)
|
148 |
+
key = key.unsqueeze(1)
|
149 |
+
value = value.unsqueeze(1)
|
150 |
+
if key_padding_mask is not None:
|
151 |
+
key_padding_mask = key_padding_mask.unsqueeze(0)
|
152 |
+
|
153 |
+
# set up shape vars
|
154 |
+
tgt_len, bsz, embed_dim = query.shape
|
155 |
+
src_len, _, _ = key.shape
|
156 |
+
|
157 |
+
key_padding_mask = _canonical_mask(
|
158 |
+
mask=key_padding_mask,
|
159 |
+
mask_name="key_padding_mask",
|
160 |
+
other_type=_none_or_dtype(attn_mask),
|
161 |
+
other_name="attn_mask",
|
162 |
+
target_type=query.dtype
|
163 |
+
)
|
164 |
+
|
165 |
+
if is_causal and attn_mask is None:
|
166 |
+
raise RuntimeError(
|
167 |
+
"Need attn_mask if specifying the is_causal hint. "
|
168 |
+
"You may use the Transformer module method "
|
169 |
+
"`generate_square_subsequent_mask` to create this mask."
|
170 |
+
)
|
171 |
+
|
172 |
+
if is_causal and key_padding_mask is None and not need_weights:
|
173 |
+
# when we have a kpm or need weights, we need attn_mask
|
174 |
+
# Otherwise, we use the is_causal hint go as is_causal
|
175 |
+
# indicator to SDPA.
|
176 |
+
attn_mask = None
|
177 |
+
else:
|
178 |
+
attn_mask = _canonical_mask(
|
179 |
+
mask=attn_mask,
|
180 |
+
mask_name="attn_mask",
|
181 |
+
other_type=None,
|
182 |
+
other_name="",
|
183 |
+
target_type=query.dtype,
|
184 |
+
check_other=False,
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
if key_padding_mask is not None:
|
189 |
+
# We have the attn_mask, and use that to merge kpm into it.
|
190 |
+
# Turn off use of is_causal hint, as the merged mask is no
|
191 |
+
# longer causal.
|
192 |
+
is_causal = False
|
193 |
+
|
194 |
+
assert embed_dim == embed_dim_to_check, \
|
195 |
+
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
196 |
+
if isinstance(embed_dim, torch.Tensor):
|
197 |
+
# embed_dim can be a tensor when JIT tracing
|
198 |
+
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
|
199 |
+
else:
|
200 |
+
head_dim = embed_dim // num_heads
|
201 |
+
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
202 |
+
if use_separate_proj_weight:
|
203 |
+
# allow MHA to have different embedding dimensions when separate projection weights are used
|
204 |
+
assert key.shape[:2] == value.shape[:2], \
|
205 |
+
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
206 |
+
else:
|
207 |
+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
208 |
+
|
209 |
+
#
|
210 |
+
# compute in-projection
|
211 |
+
#
|
212 |
+
if not use_separate_proj_weight:
|
213 |
+
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
214 |
+
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
215 |
+
else:
|
216 |
+
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
217 |
+
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
218 |
+
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
219 |
+
if in_proj_bias is None:
|
220 |
+
b_q = b_k = b_v = None
|
221 |
+
else:
|
222 |
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
223 |
+
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
|
224 |
+
if(cache!=None):
|
225 |
+
if(cache["first_infer"]==1):
|
226 |
+
cache["k"][cache["stage"]]=k
|
227 |
+
# print(0,cache["k"].shape)
|
228 |
+
cache["v"][cache["stage"]]=v
|
229 |
+
else:###12个layer每个都要留自己的cache_kv
|
230 |
+
# print(1,cache["k"].shape)
|
231 |
+
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
|
232 |
+
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
|
233 |
+
# print(2, cache["k"].shape)
|
234 |
+
src_len = cache["k"][cache["stage"]].shape[0]
|
235 |
+
k=cache["k"][cache["stage"]]
|
236 |
+
v=cache["v"][cache["stage"]]
|
237 |
+
# if attn_mask is not None:
|
238 |
+
# attn_mask=attn_mask[-1:,]
|
239 |
+
# print(attn_mask.shape,attn_mask)
|
240 |
+
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
241 |
+
# print(2333,cache)
|
242 |
+
# prep attention mask
|
243 |
+
|
244 |
+
attn_mask = _canonical_mask(
|
245 |
+
mask=attn_mask,
|
246 |
+
mask_name="attn_mask",
|
247 |
+
other_type=None,
|
248 |
+
other_name="",
|
249 |
+
target_type=q.dtype,
|
250 |
+
check_other=False,
|
251 |
+
)
|
252 |
+
|
253 |
+
if attn_mask is not None:
|
254 |
+
# ensure attn_mask's dim is 3
|
255 |
+
if attn_mask.dim() == 2:
|
256 |
+
correct_2d_size = (tgt_len, src_len)
|
257 |
+
if attn_mask.shape != correct_2d_size:
|
258 |
+
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
259 |
+
attn_mask = attn_mask.unsqueeze(0)
|
260 |
+
elif attn_mask.dim() == 3:
|
261 |
+
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
262 |
+
if attn_mask.shape != correct_3d_size:
|
263 |
+
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
264 |
+
else:
|
265 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
266 |
+
|
267 |
+
# add bias along batch dimension (currently second)
|
268 |
+
if bias_k is not None and bias_v is not None:
|
269 |
+
assert static_k is None, "bias cannot be added to static key."
|
270 |
+
assert static_v is None, "bias cannot be added to static value."
|
271 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
272 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
273 |
+
if attn_mask is not None:
|
274 |
+
attn_mask = pad(attn_mask, (0, 1))
|
275 |
+
if key_padding_mask is not None:
|
276 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
277 |
+
else:
|
278 |
+
assert bias_k is None
|
279 |
+
assert bias_v is None
|
280 |
+
|
281 |
+
#
|
282 |
+
# reshape q, k, v for multihead attention and make em batch first
|
283 |
+
#
|
284 |
+
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
285 |
+
if static_k is None:
|
286 |
+
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
287 |
+
else:
|
288 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
289 |
+
assert static_k.size(0) == bsz * num_heads, \
|
290 |
+
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
291 |
+
assert static_k.size(2) == head_dim, \
|
292 |
+
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
293 |
+
k = static_k
|
294 |
+
if static_v is None:
|
295 |
+
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
296 |
+
else:
|
297 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
298 |
+
assert static_v.size(0) == bsz * num_heads, \
|
299 |
+
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
300 |
+
assert static_v.size(2) == head_dim, \
|
301 |
+
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
302 |
+
v = static_v
|
303 |
+
|
304 |
+
# add zero attention along batch dimension (now first)
|
305 |
+
if add_zero_attn:
|
306 |
+
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
307 |
+
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
308 |
+
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
309 |
+
if attn_mask is not None:
|
310 |
+
attn_mask = pad(attn_mask, (0, 1))
|
311 |
+
if key_padding_mask is not None:
|
312 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
313 |
+
|
314 |
+
# update source sequence length after adjustments
|
315 |
+
src_len = k.size(1)
|
316 |
+
|
317 |
+
# merge key padding and attention masks
|
318 |
+
if key_padding_mask is not None:
|
319 |
+
assert key_padding_mask.shape == (bsz, src_len), \
|
320 |
+
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
321 |
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
|
322 |
+
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
323 |
+
if attn_mask is None:
|
324 |
+
attn_mask = key_padding_mask
|
325 |
+
else:
|
326 |
+
attn_mask = attn_mask + key_padding_mask
|
327 |
+
|
328 |
+
# adjust dropout probability
|
329 |
+
if not training:
|
330 |
+
dropout_p = 0.0
|
331 |
+
|
332 |
+
#
|
333 |
+
# (deep breath) calculate attention and out projection
|
334 |
+
#
|
335 |
+
|
336 |
+
if need_weights:
|
337 |
+
B, Nt, E = q.shape
|
338 |
+
q_scaled = q / math.sqrt(E)
|
339 |
+
|
340 |
+
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
341 |
+
|
342 |
+
if attn_mask is not None:
|
343 |
+
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
344 |
+
else:
|
345 |
+
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
346 |
+
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
347 |
+
if dropout_p > 0.0:
|
348 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
349 |
+
|
350 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
351 |
+
|
352 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
353 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
354 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
355 |
+
|
356 |
+
# optionally average attention weights over heads
|
357 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
358 |
+
if average_attn_weights:
|
359 |
+
attn_output_weights = attn_output_weights.mean(dim=1)
|
360 |
+
|
361 |
+
if not is_batched:
|
362 |
+
# squeeze the output if input was unbatched
|
363 |
+
attn_output = attn_output.squeeze(1)
|
364 |
+
attn_output_weights = attn_output_weights.squeeze(0)
|
365 |
+
return attn_output, attn_output_weights
|
366 |
+
else:
|
367 |
+
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
368 |
+
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
369 |
+
# in order to match the input for SDPA of (N, num_heads, L, S)
|
370 |
+
if attn_mask is not None:
|
371 |
+
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
372 |
+
attn_mask = attn_mask.unsqueeze(0)
|
373 |
+
else:
|
374 |
+
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
375 |
+
|
376 |
+
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
377 |
+
k = k.view(bsz, num_heads, src_len, head_dim)
|
378 |
+
v = v.view(bsz, num_heads, src_len, head_dim)
|
379 |
+
|
380 |
+
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
381 |
+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
382 |
+
|
383 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
384 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
385 |
+
if not is_batched:
|
386 |
+
# squeeze the output if input was unbatched
|
387 |
+
attn_output = attn_output.squeeze(1)
|
388 |
+
return attn_output, None
|