kevinwang676 commited on
Commit
2d6ed53
1 Parent(s): e74f0aa

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +8 -0
  2. .gitattributes +1 -0
  3. .gitignore +14 -0
  4. Changelog_CN.md +143 -0
  5. Docker/damo.sha256 +3 -0
  6. Docker/download.py +5 -0
  7. Docker/download.sh +11 -0
  8. Docker/links.sha256 +12 -0
  9. Docker/links.txt +34 -0
  10. Dockerfile +45 -0
  11. GPT-SoVITS-models/.gitattributes +44 -0
  12. GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/del-checkpoint.sh +12 -0
  13. GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/webui-checkpoint.py +719 -0
  14. GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/启动webui-checkpoint.sh +2 -0
  15. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/.ipynb_checkpoints/inference_webui-checkpoint.py +270 -0
  16. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/__init__.py +0 -0
  17. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/__init__.py +0 -0
  18. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/bucket_sampler.py +157 -0
  19. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/data_module.py +66 -0
  20. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/dataset.py +302 -0
  21. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/__init__.py +0 -0
  22. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/BEATs.py +179 -0
  23. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/README.md +127 -0
  24. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/Tokenizers.py +172 -0
  25. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/__init__.py +2 -0
  26. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/backbone.py +791 -0
  27. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/config.py +19 -0
  28. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/modules.py +220 -0
  29. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/ontology.json +0 -0
  30. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/quantizer.py +235 -0
  31. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_beats_librilight.py +321 -0
  32. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones.py +232 -0
  33. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones_librilight.py +198 -0
  34. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_txt_librilight.py +255 -0
  35. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/split_train_val.py +35 -0
  36. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/t2s.py +197 -0
  37. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/test.py +139 -0
  38. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/text.txt +10 -0
  39. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train.py +103 -0
  40. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train_librilight_6k.py +170 -0
  41. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/__init__.py +0 -0
  42. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_lightning_module.py +128 -0
  43. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model.py +298 -0
  44. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/utils.py +164 -0
  45. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/__init__.py +0 -0
  46. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/activation.py +397 -0
  47. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/embedding.py +78 -0
  48. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/lr_schedulers.py +85 -0
  49. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/optim.py +622 -0
  50. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache.py +388 -0
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ docs
2
+ logs
3
+ output
4
+ reference
5
+ SoVITS_weights
6
+ GPT_weights
7
+ TEMP
8
+ .git
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tools/damo_asr/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__
3
+ *.pyc
4
+ env
5
+ runtime
6
+ .idea
7
+ output
8
+ logs
9
+ reference
10
+ GPT_weights
11
+ SoVITS_weights
12
+ TEMP
13
+
14
+
Changelog_CN.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 20240121更新
2
+
3
+ 1-config添加is_share,诸如colab等场景可以将此改为True,来使得webui映射到公网
4
+
5
+ 2-WebUI添加英文系统英文翻译适配
6
+
7
+ 3-cmd-asr自动判断是否已自带damo模型,如不在默认目录上将从modelscope自带下载
8
+
9
+ 4-[SoVITS训练报错ZeroDivisionError](https://github.com/RVC-Boss/GPT-SoVITS/issues/79) 尝试修复(过滤长度0的样本等)
10
+
11
+ 5-清理TEMP文件夹缓存音频等文件
12
+
13
+ 6-大幅削弱合成音频包含参考音频结尾的问题
14
+
15
+ ### 20240122更新
16
+
17
+ 1-修复过短输出文件返回重复参考音频的问题。
18
+
19
+ 2-经测试,英文日文训练原生支持(日文训练需要根目录不含非英文等特殊字符)。
20
+
21
+ 3-音频路径检查。如果尝试读取输入错的路径报错路径不存在,而非ffmpeg错误。
22
+
23
+ ### 20240123更新
24
+
25
+ 1-解决hubert提取nan导致SoVITS/GPT训练报错ZeroDivisionError的问题
26
+
27
+ 2-支持推理界面快速切换模型
28
+
29
+ 3-优化模型文件排序逻辑
30
+
31
+ 4-中文分词使用jieba_fast代替jieba
32
+
33
+ ### 20240126更新
34
+
35
+ 1-支持输出文本中英混合、日英混合
36
+
37
+ 2-输出可选切分模式
38
+
39
+ 3-修复uvr5读取到目录自动跳出的问题
40
+
41
+ 4-修复多个换行导致推理报错
42
+
43
+ 5-去除推理界面大量冗余log
44
+
45
+ 6-支持mac训练推理
46
+
47
+ 7-自动识别不支持半精度的卡强制单精度。cpu推理下强制单精度。
48
+
49
+ ### 20240128更新
50
+
51
+ 1-修复数字转汉字念法问题
52
+
53
+ 2-修复句首少量字容易吞字的问题
54
+
55
+ 3-通过限制排除不合理的参考音频长度
56
+
57
+ 4-修复GPT训练不保存ckpt的问题
58
+
59
+ 5-完善Dockerfile的下载模型流程
60
+
61
+ ### 20240129更新
62
+
63
+ 1-16系等半精度训练有问题的显卡把训练配置改为单精度训练
64
+
65
+ 2-测试更新可用的colab版本
66
+
67
+ 3-修复git clone modelscope funasr仓库+老版本funasr导致接口不对齐报错的问题
68
+
69
+
70
+ ### 20240130更新
71
+
72
+ 1-所有涉及路径的地方双引号自动去除,小白复制路径带双引号不会报错
73
+
74
+ 2-修复中英文标点切割问题和句首句尾补标点的问题
75
+
76
+ 3-增加按标点符号切分
77
+
78
+ ### 20240201更新
79
+
80
+ 1-修复uvr5读取格式错误导致分离失败的问题
81
+
82
+ 2-支持中日英混合多种文本自动切分识别语种
83
+
84
+ ### 20240202更新
85
+
86
+ 1-修复asr路径尾缀带/保存文件名报错
87
+
88
+ 2-引入paddlespeech的Normalizer https://github.com/RVC-Boss/GPT-SoVITS/pull/377 修复一些问题,例如:xx.xx%(带百分号类),元/吨 会读成 元吨 而不是元每吨,下划线不再会报错
89
+
90
+ ### 20240207更新
91
+
92
+ 1-修正语种传参混乱导致中文推理效果下降 https://github.com/RVC-Boss/GPT-SoVITS/issues/391
93
+
94
+ 2-uvr5适配高版本librosa https://github.com/RVC-Boss/GPT-SoVITS/pull/403
95
+
96
+ 3-修复uvr5 inf everywhere报错的问题(is_half传参未转换bool导致恒定半精度推理,16系显卡会inf) https://github.com/RVC-Boss/GPT-SoVITS/commit/14a285109a521679f8846589c22da8f656a46ad8
97
+
98
+ 4-优化英文文本前端
99
+
100
+ 5-修复gradio依赖
101
+
102
+ 6-支持三连根目录留空自动读取.list全路径
103
+
104
+ 7-集成faster whisper ASR日文英文
105
+
106
+ ### 20240208更新
107
+
108
+ 1-GPT训练卡死(win10 1909)和https://github.com/RVC-Boss/GPT-SoVITS/issues/232 (系统语言繁体)GPT训练报错,[尝试修复](https://github.com/RVC-Boss/GPT-SoVITS/commit/59f35adad85815df27e9c6b33d420f5ebfd8376b)。
109
+
110
+ ### 20240212更新
111
+
112
+ 1-faster whisper和funasr逻辑优化。faster whisper转镜像站下载,规避huggingface连不上的问题。
113
+
114
+ 2-DPO Loss实验性训练选项开启,通过构造负样本训练缓解GPT重复漏字问题。推理界面公开几个推理参数。 https://github.com/RVC-Boss/GPT-SoVITS/pull/457
115
+
116
+ ### 20240214更新
117
+
118
+ 1-训练支持中文实验名(原来会报错)
119
+
120
+ 2-DPO训练改为可勾选选项而非必须。如勾选batch size自动减半。修复推理界面新参数不传参的问题。
121
+
122
+ ### 20240216更新
123
+
124
+ 1-支持无参考文本输入
125
+
126
+ 2-修复中文文本前端bug https://github.com/RVC-Boss/GPT-SoVITS/issues/475
127
+
128
+ ### 20240221更新
129
+
130
+ 1-数据处理添加语音降噪选项
131
+
132
+ 2-中文日文前端处理优化 https://github.com/RVC-Boss/GPT-SoVITS/pull/559 https://github.com/RVC-Boss/GPT-SoVITS/pull/556 https://github.com/RVC-Boss/GPT-SoVITS/pull/532 https://github.com/RVC-Boss/GPT-SoVITS/pull/507 https://github.com/RVC-Boss/GPT-SoVITS/pull/509
133
+
134
+ 3-mac CPU推理更快因此把推理设备从mps改到CPU
135
+
136
+ 4-colab修复不开启公网url
137
+
138
+ todolist:
139
+
140
+ 1-中文多音字推理优化
141
+
142
+
143
+
Docker/damo.sha256 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 5bba782a5e9196166233b9ab12ba04cadff9ef9212b4ff6153ed9290ff679025 /workspace/tools/damo_asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pb
2
+ b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5 /workspace/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pb
3
+ a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916 /workspace/tools/damo_asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pb
Docker/download.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Download moda ASR related models
2
+ from modelscope import snapshot_download
3
+ model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
4
+ model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
5
+ model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
Docker/download.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -Eeuo pipefail
4
+
5
+ echo "Downloading models..."
6
+
7
+ aria2c --disable-ipv6 --input-file /workspace/Docker/links.txt --dir /workspace --continue
8
+
9
+ echo "Checking SHA256..."
10
+
11
+ parallel --will-cite -a /workspace/Docker/links.sha256 "echo -n {} | sha256sum -c"
Docker/links.sha256 ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1 /workspace/GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
2
+ fc579c1db3c1e21b721001cf99d7a584214280df19b002e200b630a34fa06eb8 /workspace/GPT_SoVITS/pretrained_models/s2D488k.pth
3
+ 020a014e1e01e550e510f2f61fae5e5f5b6aab40f15c22f1f12f724df507e835 /workspace/GPT_SoVITS/pretrained_models/s2G488k.pth
4
+ 24164f129c66499d1346e2aa55f183250c223161ec2770c0da3d3b08cf432d3c /workspace/GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
5
+ e53a693acc59ace251d143d068096ae0d7b79e4b1b503fa84c9dcf576448c1d8 /workspace/GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
6
+ 39796caa5db18d7f9382d8ac997ac967bfd85f7761014bb807d2543cc844ef05 /workspace/tools/uvr5/uvr5_weights/HP2_all_vocals.pth
7
+ 45e6b65199e781b4a6542002699be9f19cd3d1cb7d1558bc2bfbcd84674dfe28 /workspace/tools/uvr5/uvr5_weights/HP3_all_vocals.pth
8
+ 5908891829634926119720241e8573d97cbeb8277110a7512bdb0bd7563258ee /workspace/tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
9
+ 8c8fd1582f9aabc363e47af62ddb88df6cae7e064cae75bbf041a067a5e0aee2 /workspace/tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
10
+ 01376dd2a571bf3cb9cced680732726d2d732609d09216a610b0d110f133febe /workspace/tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
11
+ 56aba59db3bcdd14a14464e62f3129698ecdea62eee0f003b9360923eb3ac79e /workspace/tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
12
+ 233bb5c6aaa365e568659a0a81211746fa881f8f47f82d9e864fce1f7692db80 /workspace/tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
Docker/links.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT-SoVITS models
2
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s1bert25hz-2kh-longer-epoch%3D68e-step%3D50232.ckpt
3
+ out=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
4
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2D488k.pth
5
+ out=GPT_SoVITS/pretrained_models/s2D488k.pth
6
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2G488k.pth
7
+ out=GPT_SoVITS/pretrained_models/s2G488k.pth
8
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/config.json
9
+ out=GPT_SoVITS/pretrained_models/chinese-hubert-base/config.json
10
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/preprocessor_config.json
11
+ out=GPT_SoVITS/pretrained_models/chinese-hubert-base/preprocessor_config.json
12
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/pytorch_model.bin
13
+ out=GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
14
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/config.json
15
+ out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/config.json
16
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/pytorch_model.bin
17
+ out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
18
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/tokenizer.json
19
+ out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json
20
+ # UVR5
21
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2_all_vocals.pth
22
+ out=tools/uvr5/uvr5_weights/HP2_all_vocals.pth
23
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP3_all_vocals.pth
24
+ out=tools/uvr5/uvr5_weights/HP3_all_vocals.pth
25
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5_only_main_vocal.pth
26
+ out=tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
27
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoAggressive.pth
28
+ out=tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
29
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoDeReverb.pth
30
+ out=tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
31
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoNormal.pth
32
+ out=tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
33
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
34
+ out=tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base CUDA image
2
+ FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
3
+
4
+ LABEL maintainer="[email protected]"
5
+ LABEL version="dev-20240209"
6
+ LABEL description="Docker image for GPT-SoVITS"
7
+
8
+
9
+ # Install 3rd party apps
10
+ ENV DEBIAN_FRONTEND=noninteractive
11
+ ENV TZ=Etc/UTC
12
+ RUN apt-get update && \
13
+ apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
14
+ git lfs install && \
15
+ rm -rf /var/lib/apt/lists/*
16
+
17
+ # Copy only requirements.txt initially to leverage Docker cache
18
+ WORKDIR /workspace
19
+ COPY requirements.txt /workspace/
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Define a build-time argument for image type
23
+ ARG IMAGE_TYPE=full
24
+
25
+ # Conditional logic based on the IMAGE_TYPE argument
26
+ # Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
27
+ COPY ./Docker /workspace/Docker
28
+ # elite 类型的镜像里面不包含额外的模型
29
+ RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
30
+ chmod +x /workspace/Docker/download.sh && \
31
+ /workspace/Docker/download.sh && \
32
+ python /workspace/Docker/download.py && \
33
+ python -m nltk.downloader averaged_perceptron_tagger cmudict; \
34
+ fi
35
+
36
+
37
+ # Copy the rest of the application
38
+ COPY . /workspace
39
+
40
+ # Copy the rest of the application
41
+ COPY . /workspace
42
+
43
+ EXPOSE 9871 9872 9873 9874 9880
44
+
45
+ CMD ["python", "webui.py"]
GPT-SoVITS-models/.gitattributes ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ GPT-SoVITS/TEMP/gradio/2bbf387613664982acc3847e4b4970fc6bf09120/audio.wav filter=lfs diff=lfs merge=lfs -text
37
+ GPT-SoVITS/TEMP/gradio/4e25df8e5470697bd435cc94559e1c34f09bab16/audio.wav filter=lfs diff=lfs merge=lfs -text
38
+ GPT-SoVITS/TEMP/gradio/6b579cccde8715941d9b9b06a1b9787ce0fdb4db/audio.wav filter=lfs diff=lfs merge=lfs -text
39
+ GPT-SoVITS/TEMP/gradio/873c1f03462a87c00222fd2422a8b328244f45da/audio.wav filter=lfs diff=lfs merge=lfs -text
40
+ GPT-SoVITS/TEMP/gradio/d2c38e2d7f131cfc51fe07c541177b0f5a061cc3/audio.wav filter=lfs diff=lfs merge=lfs -text
41
+ GPT-SoVITS/TEMP/gradio/e6f05e0d768171ac3b7355d968cb1badf9d84864/wyxy_101-0-100.wav filter=lfs diff=lfs merge=lfs -text
42
+ GPT-SoVITS/TEMP/gradio/e6f05e0d768171ac3b7355d968cb1badf9d84864/wyxy_101.wav filter=lfs diff=lfs merge=lfs -text
43
+ GPT-SoVITS/TEMP/jieba.cache filter=lfs diff=lfs merge=lfs -text
44
+ GPT-SoVITS/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav filter=lfs diff=lfs merge=lfs -text
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/del-checkpoint.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ cd /root/autodl-tmp/workdir/GPT-SoVITS
3
+ rm -rf GPT_weights/*
4
+ rm -rf SoVITS_weights/*
5
+
6
+ rm -rf input/*
7
+ rm -rf output/asr_opt/*
8
+ rm -rf output/slicer_opt/*
9
+ rm -rf output/uvr5_opt/*
10
+ rm -rf logs/*
11
+
12
+ echo 初始化完成
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/webui-checkpoint.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json,yaml,warnings,torch
2
+ warnings.filterwarnings("ignore")
3
+ torch.manual_seed(233333)
4
+ import os,pdb,sys
5
+ now_dir = os.getcwd()
6
+ tmp = os.path.join(now_dir, "TEMP")
7
+ os.makedirs(tmp, exist_ok=True)
8
+ os.environ["TEMP"] = tmp
9
+ import site
10
+ site_packages_root="%s/root/miniconda3/lib/python3.10/site-packages"%now_dir
11
+ for path in site.getsitepackages():
12
+ if("site-packages"in path):site_packages_root=path
13
+ os.environ["OPENBLAS_NUM_THREADS"] = "4"
14
+ os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
15
+ with open("%s/users.pth"%(site_packages_root),"w")as f:
16
+ f.write("%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"%(now_dir,now_dir,now_dir,now_dir,now_dir))
17
+ import traceback
18
+ sys.path.append(now_dir)
19
+ import shutil
20
+ import pdb
21
+ import gradio as gr
22
+ from subprocess import Popen
23
+ import signal
24
+ from config import python_exec,infer_device,is_half,exp_root
25
+ from i18n.i18n import I18nAuto
26
+ i18n = I18nAuto()
27
+ from scipy.io import wavfile
28
+ from tools.my_utils import load_audio
29
+ from multiprocessing import cpu_count
30
+ n_cpu=cpu_count()
31
+
32
+ # 判断是否有能用来训练和加速推理的N卡
33
+ ngpu = torch.cuda.device_count()
34
+ gpu_infos = []
35
+ mem = []
36
+ if_gpu_ok = False
37
+
38
+ if torch.cuda.is_available() or ngpu != 0:
39
+ for i in range(ngpu):
40
+ gpu_name = torch.cuda.get_device_name(i)
41
+ if any(value in gpu_name.upper()for value in ["10","16","20","30","40","A2","A3","A4","P4","A50","500","A60","70","80","90","M4","T4","TITAN","L"]):
42
+ # A10#A100#V100#A40#P40#M40#K80#A4500
43
+ if_gpu_ok = True # 至少有一张能用的N卡
44
+ gpu_infos.append("%s\t%s" % (i, gpu_name))
45
+ mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4))
46
+ if if_gpu_ok and len(gpu_infos) > 0:
47
+ gpu_info = "\n".join(gpu_infos)
48
+ default_batch_size = min(mem) // 2
49
+ else:
50
+ gpu_info = i18n("很遗憾您这没有能用的显卡来支持您训练")
51
+ default_batch_size = 1
52
+ gpus = "-".join([i[0] for i in gpu_infos])
53
+
54
+ pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"
55
+ pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
56
+ def get_weights_names():
57
+ SoVITS_names = [pretrained_sovits_name]
58
+ for name in os.listdir(SoVITS_weight_root):
59
+ if name.endswith(".pth"):SoVITS_names.append(name)
60
+ GPT_names = [pretrained_gpt_name]
61
+ for name in os.listdir(GPT_weight_root):
62
+ if name.endswith(".ckpt"): GPT_names.append(name)
63
+ return SoVITS_names,GPT_names
64
+ SoVITS_weight_root="SoVITS_weights"
65
+ GPT_weight_root="GPT_weights"
66
+ SoVITS_names,GPT_names = get_weights_names()
67
+
68
+ def change_choices():
69
+ SoVITS_names, GPT_names = get_weights_names()
70
+ return {"choices": sorted(SoVITS_names), "__type__": "update"}, {"choices": sorted(GPT_names), "__type__": "update"}
71
+
72
+ p_label=None
73
+ p_uvr5=None
74
+ p_asr=None
75
+ p_tts_inference=None
76
+
77
+ def kill_process(pid):
78
+ os.system("taskkill /t /f /pid %s" % pid) # todo:识别linux用kill -9
79
+ # os.kill(p_label.pid,19)#主进程#控制台进程#python子进程###不好使,连主进程的webui一起关了,辣鸡
80
+
81
+ def change_label(if_label,path_list):
82
+ global p_label
83
+ if(if_label==True and p_label==None):
84
+ cmd = '"%s" tools/subfix_webui.py --load_list "%s"'%(python_exec,path_list)
85
+ yield "打标工具WebUI已开启"
86
+ print(cmd)
87
+ p_label = Popen(cmd, shell=True)
88
+ elif(if_label==False and p_label!=None):
89
+ kill_process(p_label.pid)
90
+ p_label=None
91
+ yield "打标工具WebUI已关闭"
92
+
93
+ def change_uvr5(if_uvr5):
94
+ global p_uvr5
95
+ if(if_uvr5==True and p_uvr5==None):
96
+ cmd = '"%s" tools/uvr5/webui.py "%s" %s'%(python_exec,infer_device,is_half)
97
+ yield "UVR5已开启"
98
+ print(cmd)
99
+ p_uvr5 = Popen(cmd, shell=True)
100
+ elif(if_uvr5==False and p_uvr5!=None):
101
+ kill_process(p_uvr5.pid)
102
+ p_uvr5=None
103
+ yield "UVR5已关闭"
104
+
105
+ def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path):
106
+ global p_tts_inference
107
+ if(if_tts==True and p_tts_inference==None):
108
+ os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path)
109
+ os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path)
110
+ os.environ["cnhubert_base_path"]=cnhubert_base_path
111
+ os.environ["bert_path"]=bert_path
112
+ os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_number
113
+ os.environ["is_half"]=str(is_half)
114
+ cmd = '"%s" GPT_SoVITS/inference_webui.py'%(python_exec)
115
+ yield "TTS推理进程已开启"
116
+ print(cmd)
117
+ p_tts_inference = Popen(cmd, shell=True)
118
+ elif(if_tts==False and p_tts_inference!=None):
119
+ kill_process(p_tts_inference.pid)
120
+ p_tts_inference=None
121
+ yield "TTS推理进程已关闭"
122
+
123
+
124
+ def open_asr(asr_inp_dir):
125
+ global p_asr
126
+ if(p_asr==None):
127
+ cmd = '"%s" tools/damo_asr/cmd-asr.py "%s"'%(python_exec,asr_inp_dir)
128
+ yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
129
+ print(cmd)
130
+ p_asr = Popen(cmd, shell=True)
131
+ p_asr.wait()
132
+ p_asr=None
133
+ yield "ASR任务完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
134
+ else:
135
+ yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
136
+
137
+ def close_asr():
138
+ global p_asr
139
+ if(p_asr!=None):
140
+ kill_process(p_asr.pid)
141
+ p_asr=None
142
+ return "已终止ASR进程",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
143
+
144
+ '''
145
+ button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Bb,button1Ba_open,button1Ba_close])
146
+ button1Ba_close.click(close1Ba, [], [info1Bb,button1Ba_open,button1Ba_close])
147
+ '''
148
+ p_train_SoVITS=None
149
+ def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D):
150
+ global p_train_SoVITS
151
+ if(p_train_SoVITS==None):
152
+ with open("GPT_SoVITS/configs/s2.json")as f:
153
+ data=f.read()
154
+ data=json.loads(data)
155
+ s2_dir="%s/%s"%(exp_root,exp_name)
156
+ os.makedirs("%s/logs_s2"%(s2_dir),exist_ok=True)
157
+ data["train"]["batch_size"]=batch_size
158
+ data["train"]["epochs"]=total_epoch
159
+ data["train"]["text_low_lr_rate"]=text_low_lr_rate
160
+ data["train"]["pretrained_s2G"]=pretrained_s2G
161
+ data["train"]["pretrained_s2D"]=pretrained_s2D
162
+ data["train"]["if_save_latest"]=if_save_latest
163
+ data["train"]["if_save_every_weights"]=if_save_every_weights
164
+ data["train"]["save_every_epoch"]=save_every_epoch
165
+ data["train"]["gpu_numbers"]=gpu_numbers1Ba
166
+ data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir
167
+ data["save_weight_dir"]=SoVITS_weight_root
168
+ data["name"]=exp_name
169
+ tmp_config_path="TEMP/tmp_s2.json"
170
+ with open(tmp_config_path,"w")as f:f.write(json.dumps(data))
171
+
172
+ cmd = '"%s" GPT_SoVITS/s2_train.py --config "%s"'%(python_exec,tmp_config_path)
173
+ yield "SoVITS训练开始:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
174
+ print(cmd)
175
+ p_train_SoVITS = Popen(cmd, shell=True)
176
+ p_train_SoVITS.wait()
177
+ p_train_SoVITS=None
178
+ yield "SoVITS训练完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
179
+ else:
180
+ yield "已有正在进行的SoVITS训练任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
181
+
182
+ def close1Ba():
183
+ global p_train_SoVITS
184
+ if(p_train_SoVITS!=None):
185
+ kill_process(p_train_SoVITS.pid)
186
+ p_train_SoVITS=None
187
+ return "已终止SoVITS训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
188
+
189
+ p_train_GPT=None
190
+ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers,pretrained_s1):
191
+ global p_train_GPT
192
+ if(p_train_GPT==None):
193
+ with open("GPT_SoVITS/configs/s1longer.yaml")as f:
194
+ data=f.read()
195
+ data=yaml.load(data, Loader=yaml.FullLoader)
196
+ s1_dir="%s/%s"%(exp_root,exp_name)
197
+ os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
198
+ data["train"]["batch_size"]=batch_size
199
+ data["train"]["epochs"]=total_epoch
200
+ data["pretrained_s1"]=pretrained_s1
201
+ data["train"]["save_every_n_epoch"]=save_every_epoch
202
+ data["train"]["if_save_every_weights"]=if_save_every_weights
203
+ data["train"]["if_save_latest"]=if_save_latest
204
+ data["train"]["half_weights_save_dir"]=GPT_weight_root
205
+ data["train"]["exp_name"]=exp_name
206
+ data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
207
+ data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
208
+ data["output_dir"]="%s/logs_s1"%s1_dir
209
+
210
+ os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
211
+ os.environ["hz"]="25hz"
212
+ tmp_config_path="TEMP/tmp_s1.yaml"
213
+ with open(tmp_config_path, "w") as f:f.write(yaml.dump(data, default_flow_style=False))
214
+ # cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" --train_semantic_path "%s/6-name2semantic.tsv" --train_phoneme_path "%s/2-name2text.txt" --output_dir "%s/logs_s1"'%(python_exec,tmp_config_path,s1_dir,s1_dir,s1_dir)
215
+ cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" '%(python_exec,tmp_config_path)
216
+ yield "GPT训练开始:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
217
+ print(cmd)
218
+ p_train_GPT = Popen(cmd, shell=True)
219
+ p_train_GPT.wait()
220
+ p_train_GPT=None
221
+ yield "GPT训练完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
222
+ else:
223
+ yield "已有正在进行的GPT训练任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
224
+
225
+ def close1Bb():
226
+ global p_train_GPT
227
+ if(p_train_GPT!=None):
228
+ kill_process(p_train_GPT.pid)
229
+ p_train_GPT=None
230
+ return "已终止GPT训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
231
+
232
+ ps_slice=[]
233
+ def open_slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_parts):
234
+ global ps_slice
235
+ if(os.path.exists(inp)==False):
236
+ yield "输入路径不存在",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
237
+ return
238
+ if os.path.isfile(inp):n_parts=1
239
+ elif os.path.isdir(inp):pass
240
+ else:
241
+ yield "输入路径存在但既不是文件也不是文件夹",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
242
+ return
243
+ if (ps_slice == []):
244
+ for i_part in range(n_parts):
245
+ cmd = '"%s" tools/slice_audio.py "%s" "%s" %s %s %s %s %s %s %s %s %s''' % (python_exec,inp, opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, i_part, n_parts)
246
+ print(cmd)
247
+ p = Popen(cmd, shell=True)
248
+ ps_slice.append(p)
249
+ yield "切割执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
250
+ for p in ps_slice:
251
+ p.wait()
252
+ ps_slice=[]
253
+ yield "切割结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
254
+ else:
255
+ yield "已有正在进行的切割任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
256
+
257
+ def close_slice():
258
+ global ps_slice
259
+ if (ps_slice != []):
260
+ for p_slice in ps_slice:
261
+ try:
262
+ kill_process(p_slice.pid)
263
+ except:
264
+ traceback.print_exc()
265
+ ps_slice=[]
266
+ return "已终止所有切割进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
267
+
268
+ '''
269
+ inp_text= os.environ.get("inp_text")
270
+ inp_wav_dir= os.environ.get("inp_wav_dir")
271
+ exp_name= os.environ.get("exp_name")
272
+ i_part= os.environ.get("i_part")
273
+ all_parts= os.environ.get("all_parts")
274
+ os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
275
+ opt_dir= os.environ.get("opt_dir")#"/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
276
+ bert_pretrained_dir= os.environ.get("bert_pretrained_dir")#"/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
277
+ '''
278
+ ps1a=[]
279
+ def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
280
+ global ps1a
281
+ if (ps1a == []):
282
+ config={
283
+ "inp_text":inp_text,
284
+ "inp_wav_dir":inp_wav_dir,
285
+ "exp_name":exp_name,
286
+ "opt_dir":"%s/%s"%(exp_root,exp_name),
287
+ "bert_pretrained_dir":bert_pretrained_dir,
288
+ }
289
+ gpu_names=gpu_numbers.split("-")
290
+ all_parts=len(gpu_names)
291
+ for i_part in range(all_parts):
292
+ config.update(
293
+ {
294
+ "i_part": str(i_part),
295
+ "all_parts": str(all_parts),
296
+ "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
297
+ "is_half": str(is_half)
298
+ }
299
+ )
300
+ os.environ.update(config)
301
+ cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
302
+ print(cmd)
303
+ p = Popen(cmd, shell=True)
304
+ ps1a.append(p)
305
+ yield "文本进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
306
+ for p in ps1a:
307
+ p.wait()
308
+ ps1a=[]
309
+ yield "文本进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
310
+ else:
311
+ yield "已有正在进行的文本任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
312
+
313
+ def close1a():
314
+ global ps1a
315
+ if (ps1a != []):
316
+ for p1a in ps1a:
317
+ try:
318
+ kill_process(p1a.pid)
319
+ except:
320
+ traceback.print_exc()
321
+ ps1a=[]
322
+ return "已终止所有1a进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
323
+ '''
324
+ inp_text= os.environ.get("inp_text")
325
+ inp_wav_dir= os.environ.get("inp_wav_dir")
326
+ exp_name= os.environ.get("exp_name")
327
+ i_part= os.environ.get("i_part")
328
+ all_parts= os.environ.get("all_parts")
329
+ os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
330
+ opt_dir= os.environ.get("opt_dir")
331
+ cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
332
+ '''
333
+ ps1b=[]
334
+ def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
335
+ global ps1b
336
+ if (ps1b == []):
337
+ config={
338
+ "inp_text":inp_text,
339
+ "inp_wav_dir":inp_wav_dir,
340
+ "exp_name":exp_name,
341
+ "opt_dir":"%s/%s"%(exp_root,exp_name),
342
+ "cnhubert_base_dir":ssl_pretrained_dir,
343
+ "is_half": str(is_half)
344
+ }
345
+ gpu_names=gpu_numbers.split("-")
346
+ all_parts=len(gpu_names)
347
+ for i_part in range(all_parts):
348
+ config.update(
349
+ {
350
+ "i_part": str(i_part),
351
+ "all_parts": str(all_parts),
352
+ "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
353
+ }
354
+ )
355
+ os.environ.update(config)
356
+ cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
357
+ print(cmd)
358
+ p = Popen(cmd, shell=True)
359
+ ps1b.append(p)
360
+ yield "SSL提取进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
361
+ for p in ps1b:
362
+ p.wait()
363
+ ps1b=[]
364
+ yield "SSL提取进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
365
+ else:
366
+ yield "已有正在进行的SSL提取任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
367
+
368
+ def close1b():
369
+ global ps1b
370
+ if (ps1b != []):
371
+ for p1b in ps1b:
372
+ try:
373
+ kill_process(p1b.pid)
374
+ except:
375
+ traceback.print_exc()
376
+ ps1b=[]
377
+ return "已终止所有1b进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
378
+ '''
379
+ inp_text= os.environ.get("inp_text")
380
+ exp_name= os.environ.get("exp_name")
381
+ i_part= os.environ.get("i_part")
382
+ all_parts= os.environ.get("all_parts")
383
+ os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
384
+ opt_dir= os.environ.get("opt_dir")
385
+ pretrained_s2G= os.environ.get("pretrained_s2G")
386
+ '''
387
+ ps1c=[]
388
+ def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
389
+ global ps1c
390
+ if (ps1c == []):
391
+ config={
392
+ "inp_text":inp_text,
393
+ "exp_name":exp_name,
394
+ "opt_dir":"%s/%s"%(exp_root,exp_name),
395
+ "pretrained_s2G":pretrained_s2G_path,
396
+ "s2config_path":"GPT_SoVITS/configs/s2.json",
397
+ "is_half": str(is_half)
398
+ }
399
+ gpu_names=gpu_numbers.split("-")
400
+ all_parts=len(gpu_names)
401
+ for i_part in range(all_parts):
402
+ config.update(
403
+ {
404
+ "i_part": str(i_part),
405
+ "all_parts": str(all_parts),
406
+ "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
407
+ }
408
+ )
409
+ os.environ.update(config)
410
+ cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
411
+ print(cmd)
412
+ p = Popen(cmd, shell=True)
413
+ ps1c.append(p)
414
+ yield "语义token提取进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
415
+ for p in ps1c:
416
+ p.wait()
417
+ ps1c=[]
418
+ yield "语义token提取进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
419
+ else:
420
+ yield "已有正在进行的语义token提取任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
421
+
422
+ def close1c():
423
+ global ps1c
424
+ if (ps1c != []):
425
+ for p1c in ps1c:
426
+ try:
427
+ kill_process(p1c.pid)
428
+ except:
429
+ traceback.print_exc()
430
+ ps1c=[]
431
+ return "已终止所有语义token进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
432
+ #####inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G
433
+ ps1abc=[]
434
+ def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
435
+ global ps1abc
436
+ if (ps1abc == []):
437
+ opt_dir="%s/%s"%(exp_root,exp_name)
438
+ try:
439
+ #############################1a
440
+ path_text="%s/2-name2text.txt" % opt_dir
441
+ if(os.path.exists(path_text)==False):
442
+ config={
443
+ "inp_text":inp_text,
444
+ "inp_wav_dir":inp_wav_dir,
445
+ "exp_name":exp_name,
446
+ "opt_dir":opt_dir,
447
+ "bert_pretrained_dir":bert_pretrained_dir,
448
+ "is_half": str(is_half)
449
+ }
450
+ gpu_names=gpu_numbers1a.split("-")
451
+ all_parts=len(gpu_names)
452
+ for i_part in range(all_parts):
453
+ config.update(
454
+ {
455
+ "i_part": str(i_part),
456
+ "all_parts": str(all_parts),
457
+ "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
458
+ }
459
+ )
460
+ os.environ.update(config)
461
+ cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
462
+ print(cmd)
463
+ p = Popen(cmd, shell=True)
464
+ ps1abc.append(p)
465
+ yield "进度:1a-ing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
466
+ for p in ps1abc:p.wait()
467
+
468
+ opt = []
469
+ for i_part in range(all_parts):#txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
470
+ txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
471
+ with open(txt_path, "r",encoding="utf8") as f:
472
+ opt += f.read().strip("\n").split("\n")
473
+ os.remove(txt_path)
474
+ with open(path_text, "w",encoding="utf8") as f:
475
+ f.write("\n".join(opt) + "\n")
476
+
477
+ yield "进度:1a-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
478
+ ps1abc=[]
479
+ #############################1b
480
+ config={
481
+ "inp_text":inp_text,
482
+ "inp_wav_dir":inp_wav_dir,
483
+ "exp_name":exp_name,
484
+ "opt_dir":opt_dir,
485
+ "cnhubert_base_dir":ssl_pretrained_dir,
486
+ }
487
+ gpu_names=gpu_numbers1Ba.split("-")
488
+ all_parts=len(gpu_names)
489
+ for i_part in range(all_parts):
490
+ config.update(
491
+ {
492
+ "i_part": str(i_part),
493
+ "all_parts": str(all_parts),
494
+ "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
495
+ }
496
+ )
497
+ os.environ.update(config)
498
+ cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
499
+ print(cmd)
500
+ p = Popen(cmd, shell=True)
501
+ ps1abc.append(p)
502
+ yield "进度:1a-done, 1b-ing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
503
+ for p in ps1abc:p.wait()
504
+ yield "进度:1a1b-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
505
+ ps1abc=[]
506
+ #############################1c
507
+ path_semantic = "%s/6-name2semantic.tsv" % opt_dir
508
+ if(os.path.exists(path_semantic)==False):
509
+ config={
510
+ "inp_text":inp_text,
511
+ "exp_name":exp_name,
512
+ "opt_dir":opt_dir,
513
+ "pretrained_s2G":pretrained_s2G_path,
514
+ "s2config_path":"GPT_SoVITS/configs/s2.json",
515
+ }
516
+ gpu_names=gpu_numbers1c.split("-")
517
+ all_parts=len(gpu_names)
518
+ for i_part in range(all_parts):
519
+ config.update(
520
+ {
521
+ "i_part": str(i_part),
522
+ "all_parts": str(all_parts),
523
+ "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
524
+ }
525
+ )
526
+ os.environ.update(config)
527
+ cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
528
+ print(cmd)
529
+ p = Popen(cmd, shell=True)
530
+ ps1abc.append(p)
531
+ yield "进度:1a1b-done, 1cing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
532
+ for p in ps1abc:p.wait()
533
+
534
+ opt = ["item_name semantic_audio"]
535
+ for i_part in range(all_parts):
536
+ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
537
+ with open(semantic_path, "r",encoding="utf8") as f:
538
+ opt += f.read().strip("\n").split("\n")
539
+ os.remove(semantic_path)
540
+ with open(path_semantic, "w",encoding="utf8") as f:
541
+ f.write("\n".join(opt) + "\n")
542
+ yield "进度:all-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
543
+ ps1abc = []
544
+ yield "一键三连进程结束", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
545
+ except:
546
+ traceback.print_exc()
547
+ close1abc()
548
+ yield "一键三连中途报错", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
549
+ else:
550
+ yield "已有正在进行的一键三连任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
551
+
552
+ def close1abc():
553
+ global ps1abc
554
+ if (ps1abc != []):
555
+ for p1abc in ps1abc:
556
+ try:
557
+ kill_process(p1abc.pid)
558
+ except:
559
+ traceback.print_exc()
560
+ ps1abc=[]
561
+ return "已终止所有一键三连进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
562
+
563
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
564
+ gr.Markdown(
565
+ value=
566
+ "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
567
+ )
568
+ with gr.Tabs():
569
+ with gr.TabItem("0-前置数据集获取工具"):#提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标
570
+ gr.Markdown(value="0a-UVR5人声伴奏分离&去混响去延迟工具")
571
+ with gr.Row():
572
+ if_uvr5 = gr.Checkbox(label="是否开启UVR5-WebUI",show_label=True)
573
+ uvr5_info = gr.Textbox(label="UVR5进程输出信息")
574
+ gr.Markdown(value="0b-语音切分工具")
575
+ with gr.Row():
576
+ with gr.Row():
577
+ slice_inp_path=gr.Textbox(label="音频自动切分输入路径,可文件可文件夹",value="")
578
+ slice_opt_root=gr.Textbox(label="切分后的子音频的输出根目录",value="output/slicer_opt")
579
+ threshold=gr.Textbox(label="threshold:音量小于这个值视作静音的备选切割点",value="-34")
580
+ min_length=gr.Textbox(label="min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值",value="4000")
581
+ min_interval=gr.Textbox(label="min_interval:最短切割间隔",value="300")
582
+ hop_size=gr.Textbox(label="hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)",value="10")
583
+ max_sil_kept=gr.Textbox(label="max_sil_kept:切完后静音最多留多长",value="500")
584
+ with gr.Row():
585
+ open_slicer_button=gr.Button("开启语音切割", variant="primary",visible=True)
586
+ close_slicer_button=gr.Button("终止语音切割", variant="primary",visible=False)
587
+ _max=gr.Slider(minimum=0,maximum=1,step=0.05,label="max:归一化后最大值多少",value=0.9,interactive=True)
588
+ alpha=gr.Slider(minimum=0,maximum=1,step=0.05,label="alpha_mix:混多少比例归一化后音频进来",value=0.25,interactive=True)
589
+ n_process=gr.Slider(minimum=1,maximum=n_cpu,step=1,label="切割使用的进程数",value=4,interactive=True)
590
+ slicer_info = gr.Textbox(label="语音切割进程输出信息")
591
+ gr.Markdown(value="0c-中文批量离线ASR工具")
592
+ with gr.Row():
593
+ open_asr_button = gr.Button("开启离线批量ASR", variant="primary",visible=True)
594
+ close_asr_button = gr.Button("终止ASR进程", variant="primary",visible=False)
595
+ asr_inp_dir = gr.Textbox(
596
+ label="批量ASR(中文only)输入文件夹路径",
597
+ value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx",
598
+ interactive=True,
599
+ )
600
+ asr_info = gr.Textbox(label="ASR进程输出信息")
601
+ gr.Markdown(value="0d-语音文本校对标注工具")
602
+ with gr.Row():
603
+ if_label = gr.Checkbox(label="是否开启打标WebUI",show_label=True)
604
+ path_list = gr.Textbox(
605
+ label="打标数据标注文件路径",
606
+ value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list",
607
+ interactive=True,
608
+ )
609
+ label_info = gr.Textbox(label="打标工具进程输出信息")
610
+ if_label.change(change_label, [if_label,path_list], [label_info])
611
+ if_uvr5.change(change_uvr5, [if_uvr5], [uvr5_info])
612
+ open_asr_button.click(open_asr, [asr_inp_dir], [asr_info,open_asr_button,close_asr_button])
613
+ close_asr_button.click(close_asr, [], [asr_info,open_asr_button,close_asr_button])
614
+ open_slicer_button.click(open_slice, [slice_inp_path,slice_opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_process], [slicer_info,open_slicer_button,close_slicer_button])
615
+ close_slicer_button.click(close_slice, [], [slicer_info,open_slicer_button,close_slicer_button])
616
+ with gr.TabItem("1-GPT-SoVITS-TTS"):
617
+ with gr.Row():
618
+ exp_name = gr.Textbox(label="*实验/模型名", value="xxx", interactive=True)
619
+ gpu_info = gr.Textbox(label="显卡信息", value=gpu_info, visible=True, interactive=False)
620
+ pretrained_s2G = gr.Textbox(label="预训练的SoVITS-G模型路径", value="GPT_SoVITS/pretrained_models/s2G488k.pth", interactive=True)
621
+ pretrained_s2D = gr.Textbox(label="预训练的SoVITS-D模型路径", value="GPT_SoVITS/pretrained_models/s2D488k.pth", interactive=True)
622
+ pretrained_s1 = gr.Textbox(label="预训练的GPT模型路径", value="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", interactive=True)
623
+ with gr.TabItem("1A-训练集格式化工具"):
624
+ gr.Markdown(value="输出logs/实验名目录下应有23456开头的文件和文件夹")
625
+ with gr.Row():
626
+ inp_text = gr.Textbox(label="*文本标注文件",value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list",interactive=True)
627
+ inp_wav_dir = gr.Textbox(label="*训练集音频文件目录",value=r"D:\RVC1006\GPT-SoVITS\raw\xxx",interactive=True)
628
+ gr.Markdown(value="1Aa-文本内容")
629
+ with gr.Row():
630
+ gpu_numbers1a = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
631
+ bert_pretrained_dir = gr.Textbox(label="预训练的中文BERT模型路径",value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",interactive=False)
632
+ button1a_open = gr.Button("开启文本获取", variant="primary",visible=True)
633
+ button1a_close = gr.Button("终止文本获取进程", variant="primary",visible=False)
634
+ info1a=gr.Textbox(label="文本进程输出信息")
635
+ gr.Markdown(value="1Ab-SSL自监督特征提取")
636
+ with gr.Row():
637
+ gpu_numbers1Ba = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
638
+ cnhubert_base_dir = gr.Textbox(label="预训练的SSL模型路径",value="GPT_SoVITS/pretrained_models/chinese-hubert-base",interactive=False)
639
+ button1b_open = gr.Button("开启SSL提取", variant="primary",visible=True)
640
+ button1b_close = gr.Button("终止SSL提取进程", variant="primary",visible=False)
641
+ info1b=gr.Textbox(label="SSL进程输出信息")
642
+ gr.Markdown(value="1Ac-语义token提取")
643
+ with gr.Row():
644
+ gpu_numbers1c = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
645
+ button1c_open = gr.Button("开启语义token提取", variant="primary",visible=True)
646
+ button1c_close = gr.Button("终止语义token提取进程", variant="primary",visible=False)
647
+ info1c=gr.Textbox(label="语义token提取进程输出信息")
648
+ gr.Markdown(value="1Aabc-训练集格式化一键三连")
649
+ with gr.Row():
650
+ button1abc_open = gr.Button("开启一键三连", variant="primary",visible=True)
651
+ button1abc_close = gr.Button("终止一键三连", variant="primary",visible=False)
652
+ info1abc=gr.Textbox(label="一键三连进程输出信息")
653
+ button1a_open.click(open1a, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,bert_pretrained_dir], [info1a,button1a_open,button1a_close])
654
+ button1a_close.click(close1a, [], [info1a,button1a_open,button1a_close])
655
+ button1b_open.click(open1b, [inp_text,inp_wav_dir,exp_name,gpu_numbers1Ba,cnhubert_base_dir], [info1b,button1b_open,button1b_close])
656
+ button1b_close.click(close1b, [], [info1b,button1b_open,button1b_close])
657
+ button1c_open.click(open1c, [inp_text,exp_name,gpu_numbers1c,pretrained_s2G], [info1c,button1c_open,button1c_close])
658
+ button1c_close.click(close1c, [], [info1c,button1c_open,button1c_close])
659
+ button1abc_open.click(open1abc, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G], [info1abc,button1abc_open,button1abc_close])
660
+ button1abc_close.click(close1abc, [], [info1abc,button1abc_open,button1abc_close])
661
+ with gr.TabItem("1B-微调训练"):
662
+ gr.Markdown(value="1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。")
663
+ with gr.Row():
664
+ batch_size = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
665
+ total_epoch = gr.Slider(minimum=2,maximum=100,step=1,label=i18n("总训练轮数total_epoch,不建议太高"),value=10,interactive=True)
666
+ text_low_lr_rate = gr.Slider(minimum=0.2,maximum=0.6,step=0.05,label="文本模块学习率权重",value=0.4,interactive=True)
667
+ save_every_epoch = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
668
+ if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
669
+ if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
670
+ gpu_numbers1Ba = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程", value="%s" % (gpus), interactive=True)
671
+ with gr.Row():
672
+ button1Ba_open = gr.Button("开启SoVITS训练", variant="primary",visible=True)
673
+ button1Ba_close = gr.Button("终止SoVITS训练", variant="primary",visible=False)
674
+ info1Ba=gr.Textbox(label="SoVITS训练进程输出信息")
675
+ gr.Markdown(value="1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。")
676
+ with gr.Row():
677
+ batch_size1Bb = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
678
+ total_epoch1Bb = gr.Slider(minimum=2,maximum=200,step=1,label=i18n("总训练轮数total_epoch"),value=15,interactive=True)
679
+ if_save_latest1Bb = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
680
+ if_save_every_weights1Bb = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
681
+ save_every_epoch1Bb = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
682
+ gpu_numbers1Bb = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程", value="%s" % (gpus), interactive=True)
683
+ with gr.Row():
684
+ button1Bb_open = gr.Button("开启GPT训练", variant="primary",visible=True)
685
+ button1Bb_close = gr.Button("终止GPT训练", variant="primary",visible=False)
686
+ info1Bb=gr.Textbox(label="GPT训练进程输出信息")
687
+ button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close])
688
+ button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
689
+ button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
690
+ button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
691
+ with gr.TabItem("1C-推理"):
692
+ gr.Markdown(value="选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。")
693
+ with gr.Row():
694
+ GPT_dropdown = gr.Dropdown(label="*GPT模型列表", choices=sorted(GPT_names),value=pretrained_gpt_name)
695
+ SoVITS_dropdown = gr.Dropdown(label="*SoVITS模型列表", choices=sorted(SoVITS_names),value=pretrained_sovits_name)
696
+ gpu_number_1C=gr.Textbox(label="GPU卡号,只能填1个整数", value=gpus, interactive=True)
697
+ refresh_button = gr.Button("刷新模型路径", variant="primary")
698
+ refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown])
699
+ with gr.Row():
700
+ if_tts = gr.Checkbox(label="是否开启TTS推理WebUI", show_label=True)
701
+ tts_info = gr.Textbox(label="TTS推理WebUI进程输出信息")
702
+ if_tts.change(change_tts_inference, [if_tts,bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info])
703
+ with gr.TabItem("2-GPT-SoVITS-变声"):gr.Markdown(value="施工中,请静候佳音")
704
+
705
+ '''
706
+ os.environ["gpt_path"]=gpt_path
707
+ os.environ["sovits_path"]=sovits_path#bert_pretrained_dir
708
+ os.environ["cnhubert_base_path"]=cnhubert_base_path#cnhubert_base_dir
709
+ os.environ["bert_path"]=bert_path
710
+ os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_number
711
+ '''
712
+
713
+ app.queue(concurrency_count=511, max_size=1022).launch(
714
+ share=True,
715
+ server_name="0.0.0.0",
716
+ inbrowser=True,
717
+ server_port=7890,
718
+ quiet=True,
719
+ )
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/启动webui-checkpoint.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ python /root/autodl-tmp/workdir/GPT-SoVITS/webui.py
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/.ipynb_checkpoints/inference_webui-checkpoint.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
3
+ sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
4
+ cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
5
+ bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
6
+ if("_CUDA_VISIBLE_DEVICES"in os.environ):
7
+ os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
8
+ is_half=eval(os.environ.get("is_half","True"))
9
+ import gradio as gr
10
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
11
+ import sys,torch,numpy as np
12
+ from pathlib import Path
13
+ import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
14
+ # torch.backends.cuda.sdp_kernel("flash")
15
+ # torch.backends.cuda.enable_flash_sdp(True)
16
+ # torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
17
+ # torch.backends.cuda.enable_math_sdp(True)
18
+ from random import shuffle
19
+ from AR.utils import get_newest_ckpt
20
+ from glob import glob
21
+ from tqdm import tqdm
22
+ from feature_extractor import cnhubert
23
+ cnhubert.cnhubert_base_path=cnhubert_base_path
24
+ from io import BytesIO
25
+ from module.models import SynthesizerTrn
26
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
27
+ from AR.utils.io import load_yaml_config
28
+ from text import cleaned_text_to_sequence
29
+ from text.cleaner import text_to_sequence, clean_text
30
+ from time import time as ttime
31
+ from module.mel_processing import spectrogram_torch
32
+ from my_utils import load_audio
33
+
34
+ device="cuda"
35
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
36
+ bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
37
+ if(is_half==True):bert_model=bert_model.half().to(device)
38
+ else:bert_model=bert_model.to(device)
39
+ # bert_model=bert_model.to(device)
40
+ def get_bert_feature(text, word2ph):
41
+ with torch.no_grad():
42
+ inputs = tokenizer(text, return_tensors="pt")
43
+ for i in inputs:
44
+ inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
45
+ res = bert_model(**inputs, output_hidden_states=True)
46
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
47
+ assert len(word2ph) == len(text)
48
+ phone_level_feature = []
49
+ for i in range(len(word2ph)):
50
+ repeat_feature = res[i].repeat(word2ph[i], 1)
51
+ phone_level_feature.append(repeat_feature)
52
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
53
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
54
+ return phone_level_feature.T
55
+
56
+ n_semantic = 1024
57
+ dict_s2=torch.load(sovits_path,map_location="cpu")
58
+ hps=dict_s2["config"]
59
+ class DictToAttrRecursive:
60
+ def __init__(self, input_dict):
61
+ for key, value in input_dict.items():
62
+ if isinstance(value, dict):
63
+ # 如果值是字典,递归调用构造函数
64
+ setattr(self, key, DictToAttrRecursive(value))
65
+ else:
66
+ setattr(self, key, value)
67
+
68
+ hps = DictToAttrRecursive(hps)
69
+ hps.model.semantic_frame_rate="25hz"
70
+ dict_s1=torch.load(gpt_path,map_location="cpu")
71
+ config=dict_s1["config"]
72
+ ssl_model=cnhubert.get_model()
73
+ if(is_half==True):ssl_model=ssl_model.half().to(device)
74
+ else:ssl_model=ssl_model.to(device)
75
+
76
+ vq_model = SynthesizerTrn(
77
+ hps.data.filter_length // 2 + 1,
78
+ hps.train.segment_size // hps.data.hop_length,
79
+ n_speakers=hps.data.n_speakers,
80
+ **hps.model)
81
+ if(is_half==True):vq_model=vq_model.half().to(device)
82
+ else:vq_model=vq_model.to(device)
83
+ vq_model.eval()
84
+ print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
85
+ hz = 50
86
+ max_sec = config['data']['max_sec']
87
+ # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
88
+ t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
89
+ t2s_model.load_state_dict(dict_s1["weight"])
90
+ if(is_half==True):t2s_model=t2s_model.half()
91
+ t2s_model=t2s_model.to(device)
92
+ t2s_model.eval()
93
+ total = sum([param.nelement() for param in t2s_model.parameters()])
94
+ print("Number of parameter: %.2fM" % (total / 1e6))
95
+ def get_spepc(hps, filename):
96
+ audio=load_audio(filename,int(hps.data.sampling_rate))
97
+ audio=torch.FloatTensor(audio)
98
+ audio_norm = audio
99
+ audio_norm = audio_norm.unsqueeze(0)
100
+ spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
101
+ return spec
102
+
103
+ dict_language={
104
+ "中文":"zh",
105
+ "英文":"en",
106
+ "日文":"ja"
107
+ }
108
+ def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
109
+ t0 = ttime()
110
+ prompt_text=prompt_text.strip("\n")
111
+ prompt_language,text=prompt_language,text.strip("\n")
112
+ with torch.no_grad():
113
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
114
+ wav16k = torch.from_numpy(wav16k)
115
+ if(is_half==True):wav16k=wav16k.half().to(device)
116
+ else:wav16k=wav16k.to(device)
117
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
118
+ codes = vq_model.extract_latent(ssl_content)
119
+ prompt_semantic = codes[0, 0]
120
+ t1 = ttime()
121
+ prompt_language=dict_language[prompt_language]
122
+ text_language=dict_language[text_language]
123
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
124
+ phones1=cleaned_text_to_sequence(phones1)
125
+ texts=text.split("\n")
126
+ audio_opt = []
127
+ zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
128
+ for text in texts:
129
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
130
+ phones2 = cleaned_text_to_sequence(phones2)
131
+ if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1)
132
+ else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
133
+ if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2)
134
+ else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
135
+ bert = torch.cat([bert1, bert2], 1)
136
+
137
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
138
+ bert = bert.to(device).unsqueeze(0)
139
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
140
+ prompt = prompt_semantic.unsqueeze(0).to(device)
141
+ t2 = ttime()
142
+ with torch.no_grad():
143
+ # pred_semantic = t2s_model.model.infer(
144
+ pred_semantic,idx = t2s_model.model.infer_panel(
145
+ all_phoneme_ids,
146
+ all_phoneme_len,
147
+ prompt,
148
+ bert,
149
+ # prompt_phone_len=ph_offset,
150
+ top_k=config['inference']['top_k'],
151
+ early_stop_num=hz * max_sec)
152
+ t3 = ttime()
153
+ # print(pred_semantic.shape,idx)
154
+ pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
155
+ refer = get_spepc(hps, ref_wav_path)#.to(device)
156
+ if(is_half==True):refer=refer.half().to(device)
157
+ else:refer=refer.to(device)
158
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
159
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
160
+ audio_opt.append(audio)
161
+ audio_opt.append(zero_wav)
162
+ t4 = ttime()
163
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
164
+ yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
165
+
166
+
167
+ splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
168
+ def split(todo_text):
169
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
170
+ if (todo_text[-1] not in splits): todo_text += "。"
171
+ i_split_head = i_split_tail = 0
172
+ len_text = len(todo_text)
173
+ todo_texts = []
174
+ while (1):
175
+ if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
176
+ if (todo_text[i_split_head] in splits):
177
+ i_split_head += 1
178
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
179
+ i_split_tail = i_split_head
180
+ else:
181
+ i_split_head += 1
182
+ return todo_texts
183
+ def cut1(inp):
184
+ inp=inp.strip("\n")
185
+ inps=split(inp)
186
+ split_idx=list(range(0,len(inps),5))
187
+ split_idx[-1]=None
188
+ if(len(split_idx)>1):
189
+ opts=[]
190
+ for idx in range(len(split_idx)-1):
191
+ opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
192
+ else:
193
+ opts=[inp]
194
+ return "\n".join(opts)
195
+
196
+ def cut2(inp):
197
+ inp=inp.strip("\n")
198
+ inps=split(inp)
199
+ if(len(inps)<2):return [inp]
200
+ opts=[]
201
+ summ=0
202
+ tmp_str=""
203
+ for i in range(len(inps)):
204
+ summ+=len(inps[i])
205
+ tmp_str+=inps[i]
206
+ if(summ>50):
207
+ summ=0
208
+ opts.append(tmp_str)
209
+ tmp_str=""
210
+ if(tmp_str!=""):opts.append(tmp_str)
211
+ if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
212
+ opts[-2]=opts[-2]+opts[-1]
213
+ opts=opts[:-1]
214
+ return "\n".join(opts)
215
+
216
+ def cut3(inp):
217
+ inp=inp.strip("\n")
218
+ return "\n".join(["%s。"%item for item in inp.strip("。").split("。")])
219
+
220
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
221
+ gr.Markdown(
222
+ value=
223
+ "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
224
+ )
225
+ # with gr.Tabs():
226
+ # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
227
+ with gr.Group():
228
+ gr.Markdown(
229
+ value=
230
+ "*请上传并填写参考信息"
231
+ )
232
+ with gr.Row():
233
+ inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
234
+ prompt_text= gr.Textbox(label="参考音频的文本",value="")
235
+ prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"])
236
+ gr.Markdown(
237
+ value=
238
+ "*请填写需要合成的目标文本"
239
+ )
240
+ with gr.Row():
241
+ text=gr.Textbox(label="需要合成的文本",value="")
242
+ text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"])
243
+ inference_button=gr.Button("合成语音", variant="primary")
244
+ output = gr.Audio(label="输出的语音")
245
+ inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
246
+
247
+ gr.Markdown(
248
+ value=
249
+ "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
250
+ )
251
+ with gr.Row():
252
+ text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
253
+ button1 = gr.Button("凑五句一切", variant="primary")
254
+ button2 = gr.Button("凑50字一切", variant="primary")
255
+ button3 = gr.Button("按中文句号。切", variant="primary")
256
+ text_opt = gr.Textbox(label="切分后文本", value="")
257
+ button1.click(cut1,[text_inp],[text_opt])
258
+ button2.click(cut2,[text_inp],[text_opt])
259
+ button3.click(cut3,[text_inp],[text_opt])
260
+ gr.Markdown(
261
+ value=
262
+ "后续将支持混合语种编码文本输入。"
263
+ )
264
+
265
+ app.queue(concurrency_count=511, max_size=1022).launch(
266
+ server_name="0.0.0.0",
267
+ inbrowser=True,
268
+ server_port=6006,
269
+ quiet=True,
270
+ )
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/__init__.py ADDED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/__init__.py ADDED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
2
+ import itertools
3
+ import math
4
+ import random
5
+ from random import shuffle
6
+ from typing import Iterator
7
+ from typing import Optional
8
+ from typing import TypeVar
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch.utils.data import Dataset
13
+ from torch.utils.data import Sampler
14
+
15
+ __all__ = [
16
+ "DistributedBucketSampler",
17
+ ]
18
+
19
+ T_co = TypeVar('T_co', covariant=True)
20
+
21
+
22
+ class DistributedBucketSampler(Sampler[T_co]):
23
+ r"""
24
+ sort the dataset wrt. input length
25
+ divide samples into buckets
26
+ sort within buckets
27
+ divide buckets into batches
28
+ sort batches
29
+ """
30
+
31
+ def __init__(self,
32
+ dataset: Dataset,
33
+ num_replicas: Optional[int]=None,
34
+ rank: Optional[int]=None,
35
+ shuffle: bool=True,
36
+ seed: int=0,
37
+ drop_last: bool=False,
38
+ batch_size: int=32) -> None:
39
+ if num_replicas is None:
40
+ if not dist.is_available():
41
+ raise RuntimeError(
42
+ "Requires distributed package to be available")
43
+ num_replicas = dist.get_world_size()
44
+ if rank is None:
45
+ if not dist.is_available():
46
+ raise RuntimeError(
47
+ "Requires distributed package to be available")
48
+ rank = dist.get_rank()
49
+ torch.cuda.set_device(rank)
50
+ if rank >= num_replicas or rank < 0:
51
+ raise ValueError("Invalid rank {}, rank should be in the interval"
52
+ " [0, {}]".format(rank, num_replicas - 1))
53
+ self.dataset = dataset
54
+ self.num_replicas = num_replicas
55
+ self.rank = rank
56
+ self.epoch = 0
57
+ self.drop_last = drop_last
58
+ # If the dataset length is evenly divisible by # of replicas, then there
59
+ # is no need to drop any data, since the dataset will be split equally.
60
+ if self.drop_last and len(
61
+ self.
62
+ dataset) % self.num_replicas != 0: # type: ignore[arg-type]
63
+ # Split to nearest available length that is evenly divisible.
64
+ # This is to ensure each rank receives the same amount of data when
65
+ # using this Sampler.
66
+ self.num_samples = math.ceil(
67
+ (len(self.dataset) - self.num_replicas) /
68
+ self.num_replicas # type: ignore[arg-type]
69
+ )
70
+ else:
71
+ self.num_samples = math.ceil(
72
+ len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
73
+ self.total_size = self.num_samples * self.num_replicas
74
+ self.shuffle = shuffle
75
+ self.seed = seed
76
+ self.batch_size = batch_size
77
+ self.id_with_length = self._get_sample_lengths()
78
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
79
+
80
+ def _get_sample_lengths(self):
81
+ id_with_lengths = []
82
+ for i in range(len(self.dataset)):
83
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
84
+ id_with_lengths.sort(key=lambda x: x[1])
85
+ return id_with_lengths
86
+
87
+ def make_buckets(self, bucket_width: float=2.0):
88
+ buckets = []
89
+ cur = []
90
+ max_sec = bucket_width
91
+ for id, sec in self.id_with_length:
92
+ if sec < max_sec:
93
+ cur.append(id)
94
+ else:
95
+ buckets.append(cur)
96
+ cur = [id]
97
+ max_sec += bucket_width
98
+ if len(cur) > 0:
99
+ buckets.append(cur)
100
+ return buckets
101
+
102
+ def __iter__(self) -> Iterator[T_co]:
103
+ if self.shuffle:
104
+ # deterministically shuffle based on epoch and seed
105
+ g = torch.Generator()
106
+ g.manual_seed(self.seed + self.epoch)
107
+ random.seed(self.epoch + self.seed)
108
+ shuffled_bucket = []
109
+ for buc in self.id_buckets:
110
+ buc_copy = buc.copy()
111
+ shuffle(buc_copy)
112
+ shuffled_bucket.append(buc_copy)
113
+ grouped_batch_size = self.batch_size * self.num_replicas
114
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
115
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
116
+ batches = [
117
+ shuffled_bucket[b * grouped_batch_size:(b + 1) *
118
+ grouped_batch_size] for b in range(n_batch)
119
+ ]
120
+ shuffle(batches)
121
+ indices = list(itertools.chain(*batches))
122
+ else:
123
+ # type: ignore[arg-type]
124
+ indices = list(range(len(self.dataset)))
125
+
126
+ if not self.drop_last:
127
+ # add extra samples to make it evenly divisible
128
+ padding_size = self.total_size - len(indices)
129
+ if padding_size <= len(indices):
130
+ indices += indices[:padding_size]
131
+ else:
132
+ indices += (indices * math.ceil(padding_size /
133
+ len(indices)))[:padding_size]
134
+ else:
135
+ # remove tail of data to make it evenly divisible.
136
+ indices = indices[:self.total_size]
137
+ assert len(indices) == self.total_size
138
+
139
+ # subsample
140
+ indices = indices[self.rank:self.total_size:self.num_replicas]
141
+ assert len(indices) == self.num_samples
142
+
143
+ return iter(indices)
144
+
145
+ def __len__(self) -> int:
146
+ return self.num_samples
147
+
148
+ def set_epoch(self, epoch: int) -> None:
149
+ r"""
150
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
151
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
152
+ sampler will yield the same ordering.
153
+
154
+ Args:
155
+ epoch (int): Epoch number.
156
+ """
157
+ self.epoch = epoch
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/data_module.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
2
+ from pytorch_lightning import LightningDataModule
3
+ from AR.data.bucket_sampler import DistributedBucketSampler
4
+ from AR.data.dataset import Text2SemanticDataset
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ class Text2SemanticDataModule(LightningDataModule):
9
+ def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
10
+ super().__init__()
11
+ self.config = config
12
+ self.train_semantic_path = train_semantic_path
13
+ self.train_phoneme_path = train_phoneme_path
14
+ self.dev_semantic_path = dev_semantic_path
15
+ self.dev_phoneme_path = dev_phoneme_path
16
+ self.num_workers = self.config['data']['num_workers']
17
+
18
+ def prepare_data(self):
19
+ pass
20
+
21
+ def setup(self, stage=None, output_logs=False):
22
+ self._train_dataset = Text2SemanticDataset(
23
+ phoneme_path=self.train_phoneme_path,
24
+ semantic_path=self.train_semantic_path,
25
+ max_sec=self.config['data']['max_sec'],
26
+ pad_val=self.config['data']['pad_val'])
27
+ self._dev_dataset = self._train_dataset
28
+ # self._dev_dataset = Text2SemanticDataset(
29
+ # phoneme_path=self.dev_phoneme_path,
30
+ # semantic_path=self.dev_semantic_path,
31
+ # max_sample=self.config['data']['max_eval_sample'],
32
+ # max_sec=self.config['data']['max_sec'],
33
+ # pad_val=self.config['data']['pad_val'])
34
+
35
+ def train_dataloader(self):
36
+ batch_size = self.config['train']['batch_size']
37
+ sampler = DistributedBucketSampler(
38
+ self._train_dataset, batch_size=batch_size)
39
+ return DataLoader(
40
+ self._train_dataset,
41
+ batch_size=batch_size,
42
+ sampler=sampler,
43
+ collate_fn=self._train_dataset.collate,
44
+ num_workers=self.num_workers,
45
+ persistent_workers=True,
46
+ prefetch_factor=16
47
+ )
48
+
49
+ def val_dataloader(self):
50
+ return DataLoader(
51
+ self._dev_dataset,
52
+ batch_size=1,
53
+ shuffle=False,
54
+ collate_fn=self._train_dataset.collate,
55
+ num_workers=max(self.num_workers,12),
56
+ persistent_workers=True,
57
+ prefetch_factor=16
58
+ )
59
+
60
+ # 这个会使用到嘛?
61
+ def test_dataloader(self):
62
+ return DataLoader(
63
+ self._dev_dataset,
64
+ batch_size=1,
65
+ shuffle=False,
66
+ collate_fn=self._train_dataset.collate)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/dataset.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
2
+ import pdb
3
+ import sys
4
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
5
+ import traceback,os
6
+ from typing import Dict
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch,json
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.data import Dataset
14
+ from transformers import AutoTokenizer
15
+
16
+ from text import cleaned_text_to_sequence
17
+ # from config import exp_dir
18
+
19
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
20
+ seq = sequences[0]
21
+ ndim = seq.ndim
22
+ if axis < 0:
23
+ axis += ndim
24
+ dtype = seq.dtype
25
+ pad_value = dtype.type(pad_value)
26
+ seq_lengths = [seq.shape[axis] for seq in sequences]
27
+ max_length = np.max(seq_lengths)
28
+
29
+ padded_sequences = []
30
+ for seq, length in zip(sequences, seq_lengths):
31
+ padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
32
+ ndim - axis - 1)
33
+ padded_seq = np.pad(
34
+ seq, padding, mode='constant', constant_values=pad_value)
35
+ padded_sequences.append(padded_seq)
36
+ batch = np.stack(padded_sequences)
37
+ return batch
38
+
39
+ class Text2SemanticDataset(Dataset):
40
+ """dataset class for text tokens to semantic model training."""
41
+
42
+ def __init__(self,
43
+ phoneme_path: str,
44
+ semantic_path: str,
45
+ max_sample: int = None,
46
+ max_sec: int = 100,
47
+ pad_val: int = 1024,
48
+ # min value of phoneme/sec
49
+ min_ps_ratio: int = 3,
50
+ # max value of phoneme/sec
51
+ max_ps_ratio: int = 25) -> None:
52
+ super().__init__()
53
+
54
+ self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
55
+ # get dict
56
+ self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
57
+ self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
58
+ self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
59
+ assert os.path.exists(self.path2)
60
+ assert os.path.exists(self.path6)
61
+ self.phoneme_data={}
62
+ with open(self.path2,"r",encoding="utf8")as f:
63
+ lines=f.read().strip("\n").split("\n")
64
+
65
+ for line in lines:
66
+ tmp=line.split("\t")
67
+ if(len(tmp)!=4):continue
68
+ self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
69
+
70
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
71
+ # pad for semantic tokens
72
+ self.PAD: int = pad_val
73
+ # self.hz = 25
74
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
75
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
76
+ # self.hz=int(data[:-2])#
77
+ self.hz=int(os.environ.get("hz","25hz")[:-2])
78
+
79
+ # max seconds of semantic token
80
+ self.max_sec = max_sec
81
+ self.min_ps_ratio = min_ps_ratio
82
+ self.max_ps_ratio = max_ps_ratio
83
+
84
+ if max_sample is not None:
85
+ self.semantic_data = self.semantic_data[:max_sample]
86
+
87
+ # {idx: (semantic, phoneme)}
88
+ # semantic list, phoneme list
89
+ self.semantic_phoneme = []
90
+ self.item_names = []
91
+
92
+ self.inited = False
93
+
94
+ if not self.inited:
95
+ # 调用初始化函数
96
+ self.init_batch()
97
+ self.inited = True
98
+ del self.semantic_data
99
+ del self.phoneme_data
100
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
101
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
102
+
103
+
104
+ def init_batch(self):
105
+ semantic_data_len = len(self.semantic_data)
106
+ phoneme_data_len = len(self.phoneme_data.keys())
107
+ print("semantic_data_len:", semantic_data_len)
108
+ print("phoneme_data_len:", phoneme_data_len)
109
+ idx = 0
110
+ num_not_in = 0
111
+ num_deleted_bigger = 0
112
+ num_deleted_ps = 0
113
+ for i in range(semantic_data_len):
114
+ # 先依次遍历
115
+ # get str
116
+ item_name = self.semantic_data['item_name'][i]
117
+ # print(self.phoneme_data)
118
+ try:
119
+ phoneme, word2ph, text = self.phoneme_data[item_name]
120
+ except Exception:
121
+ traceback.print_exc()
122
+ # print(f"{item_name} not in self.phoneme_data !")
123
+ num_not_in += 1
124
+ continue
125
+
126
+ semantic_str = self.semantic_data['semantic_audio'][i]
127
+ # get token list
128
+ semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
129
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
130
+ # 过滤掉太长的样本
131
+ if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token���数推测总时长过滤时长60s(config里)#40*25=1k
132
+ num_deleted_bigger += 1
133
+ continue
134
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
135
+ phoneme = phoneme.split(' ')
136
+
137
+ try:
138
+ phoneme_ids = cleaned_text_to_sequence(phoneme)
139
+ except:
140
+ traceback.print_exc()
141
+ # print(f"{item_name} not in self.phoneme_data !")
142
+ num_not_in += 1
143
+ continue
144
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
145
+ if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行
146
+ num_deleted_ps += 1
147
+ continue
148
+ # if len(semantic_ids) > 1000:###########3
149
+ # num_deleted_bigger += 1
150
+ # continue
151
+
152
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
153
+
154
+ if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
155
+ num_deleted_ps += 1
156
+ # print(item_name)
157
+ continue
158
+
159
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
160
+ idx += 1
161
+ self.item_names.append(item_name)
162
+
163
+ min_num=100#20直接不补#30补了也不存ckpt
164
+ leng =len(self.semantic_phoneme)
165
+ if(leng<min_num):
166
+ tmp1=self.semantic_phoneme
167
+ tmp2=self.item_names
168
+ self.semantic_phoneme=[]
169
+ self.item_names=[]
170
+ for _ in range(max(2,int(min_num/leng))):
171
+ self.semantic_phoneme+=tmp1
172
+ self.item_names+=tmp2
173
+ if num_not_in > 0:
174
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
175
+ if num_deleted_bigger > 0:
176
+ print(
177
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
178
+ )
179
+ if num_deleted_ps > 0:
180
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
181
+ print(
182
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
183
+ )
184
+ '''
185
+ there are 31 semantic datas not in phoneme datas
186
+ deleted 34 audios who's duration are bigger than 54 seconds
187
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
188
+ dataset.__len__(): 366463
189
+
190
+ '''
191
+ # 345410 for LibriTTS
192
+ print("dataset.__len__():", self.__len__())
193
+
194
+ def __get_item_names__(self) -> List[str]:
195
+ return self.item_names
196
+
197
+ def __len__(self) -> int:
198
+ return len(self.semantic_phoneme)
199
+
200
+ def __getitem__(self, idx: int) -> Dict:
201
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
202
+ item_name = self.item_names[idx]
203
+ phoneme_ids_len = len(phoneme_ids)
204
+ # semantic tokens target
205
+ semantic_ids_len = len(semantic_ids)
206
+
207
+ flag=0
208
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
209
+ if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
210
+ else:flag=1
211
+ if(flag==1):
212
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
213
+ bert_feature=None
214
+ else:
215
+ assert bert_feature.shape[-1] == len(phoneme_ids)
216
+ return {
217
+ 'idx': idx,
218
+ 'phoneme_ids': phoneme_ids,
219
+ 'phoneme_ids_len': phoneme_ids_len,
220
+ 'semantic_ids': semantic_ids,
221
+ 'semantic_ids_len': semantic_ids_len,
222
+ 'bert_feature': bert_feature,
223
+ }
224
+
225
+ def get_sample_length(self, idx: int):
226
+ semantic_ids = self.semantic_phoneme[idx][0]
227
+ sec = 1.0 * len(semantic_ids) / self.hz
228
+ return sec
229
+
230
+ def collate(self, examples: List[Dict]) -> Dict:
231
+ sample_index: List[int] = []
232
+ phoneme_ids: List[torch.Tensor] = []
233
+ phoneme_ids_lens: List[int] = []
234
+ semantic_ids: List[torch.Tensor] = []
235
+ semantic_ids_lens: List[int] = []
236
+ # return
237
+
238
+
239
+ for item in examples:
240
+ sample_index.append(item["idx"])
241
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
242
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
243
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
244
+ semantic_ids_lens.append(item["semantic_ids_len"])
245
+
246
+ # pad 0
247
+ phoneme_ids = batch_sequences(phoneme_ids)
248
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
249
+
250
+ # # convert each batch to torch.tensor
251
+ phoneme_ids = torch.tensor(phoneme_ids)
252
+ semantic_ids = torch.tensor(semantic_ids)
253
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
254
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
255
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
256
+ bert_padded.zero_()
257
+
258
+ for idx, item in enumerate(examples):
259
+ bert = item['bert_feature']
260
+ if(bert!=None):
261
+ bert_padded[idx, :, :bert.shape[-1]] = bert
262
+
263
+ return {
264
+ # List[int]
265
+ "ids": sample_index,
266
+ # torch.Tensor (B, max_phoneme_length)
267
+ "phoneme_ids": phoneme_ids,
268
+ # torch.Tensor (B)
269
+ "phoneme_ids_len": phoneme_ids_lens,
270
+ # torch.Tensor (B, max_semantic_ids_length)
271
+ "semantic_ids": semantic_ids,
272
+ # torch.Tensor (B)
273
+ "semantic_ids_len": semantic_ids_lens,
274
+ # torch.Tensor (B, 1024, max_phoneme_length)
275
+ "bert_feature": bert_padded,
276
+ }
277
+
278
+
279
+ if __name__ == '__main__':
280
+ root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
281
+ dataset = Text2SemanticDataset(
282
+ phoneme_path=root_dir + 'phoneme_train.npy',
283
+ semantic_path=root_dir + 'semantic_train.tsv')
284
+
285
+ batch_size = 12
286
+ dataloader = DataLoader(
287
+ dataset,
288
+ batch_size=batch_size,
289
+ collate_fn=dataset.collate,
290
+ shuffle=False)
291
+ for i, batch in enumerate(dataloader):
292
+ if(i%1000==0):print(i)
293
+ # if i == 0:
294
+ # print('batch["ids"]:', batch["ids"])
295
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
296
+ # batch["phoneme_ids"].shape)
297
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
298
+ # batch["phoneme_ids_len"].shape)
299
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
300
+ # batch["semantic_ids"].shape)
301
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
302
+ # batch["semantic_ids_len"].shape)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/__init__.py ADDED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/BEATs.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import logging
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+ from torch.nn import LayerNorm
16
+
17
+ from .backbone import TransformerEncoder
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class BEATsConfig:
23
+ def __init__(self, cfg=None):
24
+ self.input_patch_size: int = -1 # path size of patch embedding
25
+ self.embed_dim: int = 512 # patch embedding dimension
26
+ self.conv_bias: bool = False # include bias in conv encoder
27
+
28
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
29
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
30
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
31
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
32
+ self.activation_fn: str = "gelu" # activation function to use
33
+
34
+ self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
35
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
36
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
37
+
38
+ # dropouts
39
+ self.dropout: float = 0.1 # dropout probability for the transformer
40
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
41
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
42
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
43
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
44
+
45
+ # positional embeddings
46
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
47
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
48
+
49
+ # relative position embedding
50
+ self.relative_position_embedding: bool = False # apply relative position embedding
51
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
52
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
53
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
54
+
55
+ # label predictor
56
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
57
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
58
+ self.predictor_class: int = 527 # target class number for the predictor
59
+
60
+ if cfg is not None:
61
+ self.update(cfg)
62
+
63
+ def update(self, cfg: dict):
64
+ self.__dict__.update(cfg)
65
+
66
+
67
+ class BEATs(nn.Module):
68
+ def __init__(
69
+ self,
70
+ cfg: BEATsConfig, ) -> None:
71
+ super().__init__()
72
+ logger.info(f"BEATs Config: {cfg.__dict__}")
73
+
74
+ self.cfg = cfg
75
+
76
+ self.embed = cfg.embed_dim
77
+ self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
78
+ if self.embed != cfg.encoder_embed_dim else
79
+ None)
80
+
81
+ self.input_patch_size = cfg.input_patch_size
82
+ self.patch_embedding = nn.Conv2d(
83
+ 1,
84
+ self.embed,
85
+ kernel_size=self.input_patch_size,
86
+ stride=self.input_patch_size,
87
+ bias=cfg.conv_bias)
88
+
89
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
90
+
91
+ assert not cfg.deep_norm or not cfg.layer_norm_first
92
+ self.encoder = TransformerEncoder(cfg)
93
+ self.layer_norm = LayerNorm(self.embed)
94
+
95
+ if cfg.finetuned_model:
96
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
97
+ self.predictor = nn.Linear(cfg.encoder_embed_dim,
98
+ cfg.predictor_class)
99
+ else:
100
+ self.predictor = None
101
+
102
+ def forward_padding_mask(
103
+ self,
104
+ features: torch.Tensor,
105
+ padding_mask: torch.Tensor, ) -> torch.Tensor:
106
+ extra = padding_mask.size(1) % features.size(1)
107
+ if extra > 0:
108
+ padding_mask = padding_mask[:, :-extra]
109
+ padding_mask = padding_mask.view(
110
+ padding_mask.size(0), features.size(1), -1)
111
+ padding_mask = padding_mask.all(-1)
112
+ return padding_mask
113
+
114
+ def preprocess(
115
+ self,
116
+ source: torch.Tensor,
117
+ fbank_mean: float=15.41663,
118
+ fbank_std: float=6.55582, ) -> torch.Tensor:
119
+ fbanks = []
120
+ for waveform in source:
121
+ waveform = waveform.unsqueeze(0) * 2**15
122
+ fbank = ta_kaldi.fbank(
123
+ waveform,
124
+ num_mel_bins=128,
125
+ sample_frequency=16000,
126
+ frame_length=25,
127
+ frame_shift=10)
128
+ fbanks.append(fbank)
129
+ fbank = torch.stack(fbanks, dim=0)
130
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
131
+ return fbank
132
+
133
+ def extract_features(
134
+ self,
135
+ source: torch.Tensor,
136
+ padding_mask: Optional[torch.Tensor]=None,
137
+ fbank_mean: float=15.41663,
138
+ fbank_std: float=6.55582, ):
139
+ fbank = self.preprocess(
140
+ source, fbank_mean=fbank_mean, fbank_std=fbank_std)
141
+
142
+ if padding_mask is not None:
143
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
144
+
145
+ fbank = fbank.unsqueeze(1)
146
+ features = self.patch_embedding(fbank)
147
+ features = features.reshape(features.shape[0], features.shape[1], -1)
148
+ features = features.transpose(1, 2)
149
+ features = self.layer_norm(features)
150
+
151
+ if padding_mask is not None:
152
+ padding_mask = self.forward_padding_mask(features, padding_mask)
153
+
154
+ if self.post_extract_proj is not None:
155
+ features = self.post_extract_proj(features)
156
+
157
+ x = self.dropout_input(features)
158
+
159
+ x, layer_results = self.encoder(
160
+ x,
161
+ padding_mask=padding_mask, )
162
+
163
+ if self.predictor is not None:
164
+ x = self.predictor_dropout(x)
165
+ logits = self.predictor(x)
166
+
167
+ if padding_mask is not None and padding_mask.any():
168
+ logits[padding_mask] = 0
169
+ logits = logits.sum(dim=1)
170
+ logits = logits / (~padding_mask).sum(
171
+ dim=1).unsqueeze(-1).expand_as(logits)
172
+ else:
173
+ logits = logits.mean(dim=1)
174
+
175
+ lprobs = torch.sigmoid(logits)
176
+
177
+ return lprobs, padding_mask
178
+ else:
179
+ return x, padding_mask
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # BEATs
3
+
4
+ [**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers**
5
+
6
+ Official PyTorch implementation and pretrained models of BEATs
7
+
8
+ ## Pre-Trained and Fine-Tuned Tokenizers and Models
9
+ Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2
10
+ |---|---|---|---|---
11
+ Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
12
+ Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
13
+ Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
14
+ Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
15
+ Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
16
+
17
+
18
+ ### Load Tokenizers
19
+
20
+ ```python
21
+ import torch
22
+ from Tokenizers import TokenizersConfig, Tokenizers
23
+
24
+ # load the pre-trained checkpoints
25
+ checkpoint = torch.load('/path/to/tokenizer.pt')
26
+
27
+ cfg = TokenizersConfig(checkpoint['cfg'])
28
+ BEATs_tokenizer = Tokenizers(cfg)
29
+ BEATs_tokenizer.load_state_dict(checkpoint['model'])
30
+ BEATs_tokenizer.eval()
31
+
32
+ # tokenize the audio and generate the labels
33
+ audio_input_16khz = torch.randn(1, 10000)
34
+ padding_mask = torch.zeros(1, 10000).bool()
35
+
36
+ labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
37
+ ```
38
+
39
+
40
+ ### Load Pre-Trained Models
41
+
42
+ ```python
43
+ import torch
44
+ from BEATs import BEATs, BEATsConfig
45
+
46
+ # load the pre-trained checkpoints
47
+ checkpoint = torch.load('/path/to/model.pt')
48
+
49
+ cfg = BEATsConfig(checkpoint['cfg'])
50
+ BEATs_model = BEATs(cfg)
51
+ BEATs_model.load_state_dict(checkpoint['model'])
52
+ BEATs_model.eval()
53
+
54
+ # extract the the audio representation
55
+ audio_input_16khz = torch.randn(1, 10000)
56
+ padding_mask = torch.zeros(1, 10000).bool()
57
+
58
+ representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
59
+ ```
60
+
61
+
62
+ ### Load Fine-tuned Models
63
+
64
+ ```python
65
+ import torch
66
+ from BEATs import BEATs, BEATsConfig
67
+
68
+ # load the fine-tuned checkpoints
69
+ checkpoint = torch.load('/path/to/model.pt')
70
+
71
+ cfg = BEATsConfig(checkpoint['cfg'])
72
+ BEATs_model = BEATs(cfg)
73
+ BEATs_model.load_state_dict(checkpoint['model'])
74
+ BEATs_model.eval()
75
+
76
+ # predict the classification probability of each class
77
+ audio_input_16khz = torch.randn(3, 10000)
78
+ padding_mask = torch.zeros(3, 10000).bool()
79
+
80
+ probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
81
+
82
+ for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
83
+ top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
84
+ print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
85
+ ```
86
+
87
+ ## Evaluation Results
88
+
89
+ ### Comparing with the SOTA Single Models
90
+ ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png)
91
+
92
+
93
+ ### Comparing with the SOTA Ensemble Models
94
+ ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png)
95
+
96
+
97
+ ### Comparing Different BEATS Tokenizers
98
+ ![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png)
99
+
100
+
101
+ ### Comparing Different Pre-Training Targets
102
+ ![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png)
103
+
104
+
105
+ ## License
106
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
107
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project.
108
+
109
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
110
+
111
+
112
+ ### Reference
113
+ If you find our work is useful in your research, please cite the following paper:
114
+ ``` latex
115
+ @article{Chen2022beats,
116
+ title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
117
+ author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
118
+ eprint={2212.09058},
119
+ archivePrefix={arXiv},
120
+ year={2022}
121
+ }
122
+ ```
123
+ ### Contact Information
124
+
125
+ For help or issues using BEATs models, please submit a GitHub issue.
126
+
127
+ For other communications related to BEATs, please contact Yu Wu (`[email protected]`).
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import logging
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+ from backbone import (
16
+ TransformerEncoder, )
17
+ from quantizer import (
18
+ NormEMAVectorQuantizer, )
19
+ from torch.nn import LayerNorm
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class TokenizersConfig:
25
+ def __init__(self, cfg=None):
26
+ self.input_patch_size: int = -1 # path size of patch embedding
27
+ self.embed_dim: int = 512 # patch embedding dimension
28
+ self.conv_bias: bool = False # include bias in conv encoder
29
+
30
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
31
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
32
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
33
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
34
+ self.activation_fn: str = "gelu" # activation function to use
35
+
36
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
37
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
38
+
39
+ # dropouts
40
+ self.dropout: float = 0.1 # dropout probability for the transformer
41
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
42
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
43
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
44
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
45
+
46
+ # positional embeddings
47
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
48
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
49
+
50
+ # relative position embedding
51
+ self.relative_position_embedding: bool = False # apply relative position embedding
52
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
53
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
54
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
55
+
56
+ # quantizer
57
+ self.quant_n: int = 1024 # codebook number in quantizer
58
+ self.quant_dim: int = 256 # codebook dimension in quantizer
59
+
60
+ if cfg is not None:
61
+ self.update(cfg)
62
+
63
+ def update(self, cfg: dict):
64
+ self.__dict__.update(cfg)
65
+
66
+
67
+ class Tokenizers(nn.Module):
68
+ def __init__(
69
+ self,
70
+ cfg: TokenizersConfig, ) -> None:
71
+ super().__init__()
72
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
73
+
74
+ self.cfg = cfg
75
+
76
+ self.embed = cfg.embed_dim
77
+ self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
78
+ if self.embed != cfg.encoder_embed_dim else
79
+ None)
80
+
81
+ self.input_patch_size = cfg.input_patch_size
82
+ self.patch_embedding = nn.Conv2d(
83
+ 1,
84
+ self.embed,
85
+ kernel_size=self.input_patch_size,
86
+ stride=self.input_patch_size,
87
+ bias=cfg.conv_bias)
88
+
89
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
90
+
91
+ assert not cfg.deep_norm or not cfg.layer_norm_first
92
+ self.encoder = TransformerEncoder(cfg)
93
+ self.layer_norm = LayerNorm(self.embed)
94
+
95
+ self.quantize = NormEMAVectorQuantizer(
96
+ n_embed=cfg.quant_n,
97
+ embedding_dim=cfg.quant_dim,
98
+ beta=1.0,
99
+ kmeans_init=True,
100
+ decay=0.99, )
101
+ self.quant_n = cfg.quant_n
102
+ self.quantize_layer = nn.Sequential(
103
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
104
+ nn.Tanh(),
105
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
106
+ )
107
+
108
+ def forward_padding_mask(
109
+ self,
110
+ features: torch.Tensor,
111
+ padding_mask: torch.Tensor, ) -> torch.Tensor:
112
+ extra = padding_mask.size(1) % features.size(1)
113
+ if extra > 0:
114
+ padding_mask = padding_mask[:, :-extra]
115
+ padding_mask = padding_mask.view(
116
+ padding_mask.size(0), features.size(1), -1)
117
+ padding_mask = padding_mask.all(-1)
118
+ return padding_mask
119
+
120
+ def preprocess(
121
+ self,
122
+ source: torch.Tensor,
123
+ fbank_mean: float=15.41663,
124
+ fbank_std: float=6.55582, ) -> torch.Tensor:
125
+ fbanks = []
126
+ for waveform in source:
127
+ waveform = waveform.unsqueeze(0) * 2**15
128
+ fbank = ta_kaldi.fbank(
129
+ waveform,
130
+ num_mel_bins=128,
131
+ sample_frequency=16000,
132
+ frame_length=25,
133
+ frame_shift=10)
134
+ fbanks.append(fbank)
135
+ fbank = torch.stack(fbanks, dim=0)
136
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
137
+ return fbank
138
+
139
+ def extract_labels(
140
+ self,
141
+ source: torch.Tensor,
142
+ padding_mask: Optional[torch.Tensor]=None,
143
+ fbank_mean: float=15.41663,
144
+ fbank_std: float=6.55582, ):
145
+ fbank = self.preprocess(
146
+ source, fbank_mean=fbank_mean, fbank_std=fbank_std)
147
+
148
+ if padding_mask is not None:
149
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
150
+
151
+ fbank = fbank.unsqueeze(1)
152
+ features = self.patch_embedding(fbank)
153
+ features = features.reshape(features.shape[0], features.shape[1], -1)
154
+ features = features.transpose(1, 2)
155
+ features = self.layer_norm(features)
156
+
157
+ if padding_mask is not None:
158
+ padding_mask = self.forward_padding_mask(features, padding_mask)
159
+
160
+ if self.post_extract_proj is not None:
161
+ features = self.post_extract_proj(features)
162
+
163
+ x = self.dropout_input(features)
164
+
165
+ x, layer_results = self.encoder(
166
+ x,
167
+ padding_mask=padding_mask, )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # this folder is modified from https://github.com/microsoft/unilm/tree/master/beats
2
+ # ontology.json is from https://github.com/audioset/ontology/
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/backbone.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import math
10
+ from typing import Dict
11
+ from typing import Optional
12
+ from typing import Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from torch import Tensor
19
+ from torch.nn import LayerNorm
20
+ from torch.nn import Parameter
21
+
22
+ from .modules import get_activation_fn
23
+ from .modules import GLU_Linear
24
+ from .modules import GradMultiply
25
+ from .modules import quant_noise
26
+ from .modules import SamePad
27
+
28
+
29
+ class TransformerEncoder(nn.Module):
30
+ def __init__(self, args):
31
+ super().__init__()
32
+
33
+ self.dropout = args.dropout
34
+ self.embedding_dim = args.encoder_embed_dim
35
+
36
+ self.pos_conv = nn.Conv1d(
37
+ self.embedding_dim,
38
+ self.embedding_dim,
39
+ kernel_size=args.conv_pos,
40
+ padding=args.conv_pos // 2,
41
+ groups=args.conv_pos_groups, )
42
+ dropout = 0
43
+ std = math.sqrt(
44
+ (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
45
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
46
+ nn.init.constant_(self.pos_conv.bias, 0)
47
+
48
+ self.pos_conv = nn.utils.weight_norm(
49
+ self.pos_conv, name="weight", dim=2)
50
+ self.pos_conv = nn.Sequential(self.pos_conv,
51
+ SamePad(args.conv_pos), nn.GELU())
52
+
53
+ if hasattr(args, "relative_position_embedding"):
54
+ self.relative_position_embedding = args.relative_position_embedding
55
+ self.num_buckets = args.num_buckets
56
+ self.max_distance = args.max_distance
57
+ else:
58
+ self.relative_position_embedding = False
59
+ self.num_buckets = 0
60
+ self.max_distance = 0
61
+
62
+ self.layers = nn.ModuleList([
63
+ TransformerSentenceEncoderLayer(
64
+ embedding_dim=self.embedding_dim,
65
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
66
+ num_attention_heads=args.encoder_attention_heads,
67
+ dropout=self.dropout,
68
+ attention_dropout=args.attention_dropout,
69
+ activation_dropout=args.activation_dropout,
70
+ activation_fn=args.activation_fn,
71
+ layer_norm_first=args.layer_norm_first,
72
+ deep_norm=args.deep_norm,
73
+ has_relative_attention_bias=self.relative_position_embedding,
74
+ num_buckets=self.num_buckets,
75
+ max_distance=self.max_distance,
76
+ gru_rel_pos=args.gru_rel_pos,
77
+ encoder_layers=args.encoder_layers, )
78
+ for i in range(args.encoder_layers)
79
+ ])
80
+ if self.relative_position_embedding:
81
+ for i in range(1, args.encoder_layers):
82
+ del self.layers[i].self_attn.relative_attention_bias
83
+ self.layers[i].self_attn.relative_attention_bias = self.layers[
84
+ 0].self_attn.relative_attention_bias
85
+
86
+ self.layer_norm_first = args.layer_norm_first
87
+ self.layer_norm = LayerNorm(self.embedding_dim)
88
+ self.layerdrop = args.encoder_layerdrop
89
+
90
+ self.apply(init_bert_params)
91
+
92
+ if args.deep_norm:
93
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
94
+ for i in range(args.encoder_layers):
95
+ nn.init.xavier_normal_(
96
+ self.layers[i].self_attn.k_proj.weight, gain=1)
97
+ nn.init.xavier_normal_(
98
+ self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
99
+ nn.init.xavier_normal_(
100
+ self.layers[i].self_attn.q_proj.weight, gain=1)
101
+ nn.init.xavier_normal_(
102
+ self.layers[i].self_attn.out_proj.weight,
103
+ gain=deep_norm_beta)
104
+ nn.init.xavier_normal_(
105
+ self.layers[i].fc1.weight, gain=deep_norm_beta)
106
+ nn.init.xavier_normal_(
107
+ self.layers[i].fc2.weight, gain=deep_norm_beta)
108
+
109
+ self.layer_wise_gradient_decay_ratio = getattr(
110
+ args, "layer_wise_gradient_decay_ratio", 1)
111
+
112
+ def forward(self, x, padding_mask=None, layer=None):
113
+ x, layer_results = self.extract_features(x, padding_mask, layer)
114
+
115
+ if self.layer_norm_first and layer is None:
116
+ x = self.layer_norm(x)
117
+
118
+ return x, layer_results
119
+
120
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
121
+
122
+ if padding_mask is not None:
123
+ x[padding_mask] = 0
124
+
125
+ x_conv = self.pos_conv(x.transpose(1, 2))
126
+ x_conv = x_conv.transpose(1, 2)
127
+ x = x + x_conv
128
+
129
+ if not self.layer_norm_first:
130
+ x = self.layer_norm(x)
131
+
132
+ x = F.dropout(x, p=self.dropout, training=self.training)
133
+
134
+ # B x T x C -> T x B x C
135
+ x = x.transpose(0, 1)
136
+
137
+ layer_results = []
138
+ z = None
139
+ if tgt_layer is not None:
140
+ layer_results.append((x, z))
141
+ r = None
142
+ pos_bias = None
143
+ for i, layer in enumerate(self.layers):
144
+ if self.layer_wise_gradient_decay_ratio != 1.0:
145
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
146
+ dropout_probability = np.random.random()
147
+ if not self.training or (dropout_probability > self.layerdrop):
148
+ x, z, pos_bias = layer(
149
+ x,
150
+ self_attn_padding_mask=padding_mask,
151
+ need_weights=False,
152
+ pos_bias=pos_bias)
153
+ if tgt_layer is not None:
154
+ layer_results.append((x, z))
155
+ if i == tgt_layer:
156
+ r = x
157
+ break
158
+
159
+ if r is not None:
160
+ x = r
161
+
162
+ # T x B x C -> B x T x C
163
+ x = x.transpose(0, 1)
164
+
165
+ return x, layer_results
166
+
167
+
168
+ class TransformerSentenceEncoderLayer(nn.Module):
169
+ def __init__(
170
+ self,
171
+ embedding_dim: float=768,
172
+ ffn_embedding_dim: float=3072,
173
+ num_attention_heads: float=8,
174
+ dropout: float=0.1,
175
+ attention_dropout: float=0.1,
176
+ activation_dropout: float=0.1,
177
+ activation_fn: str="relu",
178
+ layer_norm_first: bool=False,
179
+ deep_norm: bool=False,
180
+ has_relative_attention_bias: bool=False,
181
+ num_buckets: int=0,
182
+ max_distance: int=0,
183
+ rescale_init: bool=False,
184
+ gru_rel_pos: bool=False,
185
+ encoder_layers: int=0, ) -> None:
186
+
187
+ super().__init__()
188
+ self.embedding_dim = embedding_dim
189
+ self.dropout = dropout
190
+ self.activation_dropout = activation_dropout
191
+
192
+ self.activation_name = activation_fn
193
+ self.activation_fn = get_activation_fn(activation_fn)
194
+ self.self_attn = MultiheadAttention(
195
+ self.embedding_dim,
196
+ num_attention_heads,
197
+ dropout=attention_dropout,
198
+ self_attention=True,
199
+ has_relative_attention_bias=has_relative_attention_bias,
200
+ num_buckets=num_buckets,
201
+ max_distance=max_distance,
202
+ rescale_init=rescale_init,
203
+ gru_rel_pos=gru_rel_pos, )
204
+
205
+ self.dropout1 = nn.Dropout(dropout)
206
+ self.dropout2 = nn.Dropout(self.activation_dropout)
207
+ self.dropout3 = nn.Dropout(dropout)
208
+
209
+ self.layer_norm_first = layer_norm_first
210
+
211
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
212
+
213
+ if self.activation_name == "glu":
214
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim,
215
+ "swish")
216
+ else:
217
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
218
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
219
+
220
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
221
+
222
+ self.deep_norm = deep_norm
223
+ if self.deep_norm:
224
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
225
+ else:
226
+ self.deep_norm_alpha = 1
227
+
228
+ def forward(self,
229
+ x: torch.Tensor,
230
+ self_attn_mask: torch.Tensor=None,
231
+ self_attn_padding_mask: torch.Tensor=None,
232
+ need_weights: bool=False,
233
+ pos_bias=None):
234
+ residual = x
235
+
236
+ if self.layer_norm_first:
237
+ x = self.self_attn_layer_norm(x)
238
+ x, attn, pos_bias = self.self_attn(
239
+ query=x,
240
+ key=x,
241
+ value=x,
242
+ key_padding_mask=self_attn_padding_mask,
243
+ need_weights=False,
244
+ attn_mask=self_attn_mask,
245
+ position_bias=pos_bias)
246
+ x = self.dropout1(x)
247
+ x = residual + x
248
+
249
+ residual = x
250
+ x = self.final_layer_norm(x)
251
+ if self.activation_name == "glu":
252
+ x = self.fc1(x)
253
+ else:
254
+ x = self.activation_fn(self.fc1(x))
255
+ x = self.dropout2(x)
256
+ x = self.fc2(x)
257
+ x = self.dropout3(x)
258
+ x = residual + x
259
+ else:
260
+ x, attn, pos_bias = self.self_attn(
261
+ query=x,
262
+ key=x,
263
+ value=x,
264
+ key_padding_mask=self_attn_padding_mask,
265
+ need_weights=need_weights,
266
+ attn_mask=self_attn_mask,
267
+ position_bias=pos_bias)
268
+
269
+ x = self.dropout1(x)
270
+ x = residual * self.deep_norm_alpha + x
271
+
272
+ x = self.self_attn_layer_norm(x)
273
+
274
+ residual = x
275
+ if self.activation_name == "glu":
276
+ x = self.fc1(x)
277
+ else:
278
+ x = self.activation_fn(self.fc1(x))
279
+ x = self.dropout2(x)
280
+ x = self.fc2(x)
281
+ x = self.dropout3(x)
282
+ x = residual * self.deep_norm_alpha + x
283
+ x = self.final_layer_norm(x)
284
+
285
+ return x, attn, pos_bias
286
+
287
+
288
+ class MultiheadAttention(nn.Module):
289
+ """Multi-headed attention.
290
+
291
+ See "Attention Is All You Need" for more details.
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ embed_dim,
297
+ num_heads,
298
+ kdim=None,
299
+ vdim=None,
300
+ dropout=0.0,
301
+ bias=True,
302
+ add_bias_kv=False,
303
+ add_zero_attn=False,
304
+ self_attention=False,
305
+ encoder_decoder_attention=False,
306
+ q_noise=0.0,
307
+ qn_block_size=8,
308
+ has_relative_attention_bias=False,
309
+ num_buckets=32,
310
+ max_distance=128,
311
+ gru_rel_pos=False,
312
+ rescale_init=False, ):
313
+ super().__init__()
314
+ self.embed_dim = embed_dim
315
+ self.kdim = kdim if kdim is not None else embed_dim
316
+ self.vdim = vdim if vdim is not None else embed_dim
317
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
318
+
319
+ self.num_heads = num_heads
320
+ self.dropout_module = nn.Dropout(dropout)
321
+
322
+ self.has_relative_attention_bias = has_relative_attention_bias
323
+ self.num_buckets = num_buckets
324
+ self.max_distance = max_distance
325
+ if self.has_relative_attention_bias:
326
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
327
+
328
+ self.head_dim = embed_dim // num_heads
329
+ self.q_head_dim = self.head_dim
330
+ self.k_head_dim = self.head_dim
331
+ assert (self.head_dim * num_heads == self.embed_dim
332
+ ), "embed_dim must be divisible by num_heads"
333
+ self.scaling = self.head_dim**-0.5
334
+
335
+ self.self_attention = self_attention
336
+ self.encoder_decoder_attention = encoder_decoder_attention
337
+
338
+ assert not self.self_attention or self.qkv_same_dim, (
339
+ "Self-attention requires query, key and "
340
+ "value to be of the same size")
341
+
342
+ k_bias = True
343
+ if rescale_init:
344
+ k_bias = False
345
+
346
+ k_embed_dim = embed_dim
347
+ q_embed_dim = embed_dim
348
+
349
+ self.k_proj = quant_noise(
350
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise,
351
+ qn_block_size)
352
+ self.v_proj = quant_noise(
353
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
354
+ self.q_proj = quant_noise(
355
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise,
356
+ qn_block_size)
357
+
358
+ self.out_proj = quant_noise(
359
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
360
+
361
+ if add_bias_kv:
362
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
363
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
364
+ else:
365
+ self.bias_k = self.bias_v = None
366
+
367
+ self.add_zero_attn = add_zero_attn
368
+
369
+ self.gru_rel_pos = gru_rel_pos
370
+ if self.gru_rel_pos:
371
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
372
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
373
+
374
+ self.reset_parameters()
375
+
376
+ def reset_parameters(self):
377
+ if self.qkv_same_dim:
378
+ # Empirically observed the convergence to be much better with
379
+ # the scaled initialization
380
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
381
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
382
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
383
+ else:
384
+ nn.init.xavier_uniform_(self.k_proj.weight)
385
+ nn.init.xavier_uniform_(self.v_proj.weight)
386
+ nn.init.xavier_uniform_(self.q_proj.weight)
387
+
388
+ nn.init.xavier_uniform_(self.out_proj.weight)
389
+ if self.out_proj.bias is not None:
390
+ nn.init.constant_(self.out_proj.bias, 0.0)
391
+ if self.bias_k is not None:
392
+ nn.init.xavier_normal_(self.bias_k)
393
+ if self.bias_v is not None:
394
+ nn.init.xavier_normal_(self.bias_v)
395
+ if self.has_relative_attention_bias:
396
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
397
+
398
+ def _relative_positions_bucket(self, relative_positions,
399
+ bidirectional=True):
400
+ num_buckets = self.num_buckets
401
+ max_distance = self.max_distance
402
+ relative_buckets = 0
403
+
404
+ if bidirectional:
405
+ num_buckets = num_buckets // 2
406
+ relative_buckets += (
407
+ relative_positions > 0).to(torch.long) * num_buckets
408
+ relative_positions = torch.abs(relative_positions)
409
+ else:
410
+ relative_positions = -torch.min(
411
+ relative_positions, torch.zeros_like(relative_positions))
412
+
413
+ max_exact = num_buckets // 2
414
+ is_small = relative_positions < max_exact
415
+
416
+ relative_postion_if_large = max_exact + (
417
+ torch.log(relative_positions.float() / max_exact) / math.log(
418
+ max_distance / max_exact) *
419
+ (num_buckets - max_exact)).to(torch.long)
420
+ relative_postion_if_large = torch.min(
421
+ relative_postion_if_large,
422
+ torch.full_like(relative_postion_if_large, num_buckets - 1))
423
+
424
+ relative_buckets += torch.where(is_small, relative_positions,
425
+ relative_postion_if_large)
426
+ return relative_buckets
427
+
428
+ def compute_bias(self, query_length, key_length):
429
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
430
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
431
+ relative_position = memory_position - context_position
432
+ relative_position_bucket = self._relative_positions_bucket(
433
+ relative_position, bidirectional=True)
434
+ relative_position_bucket = relative_position_bucket.to(
435
+ self.relative_attention_bias.weight.device)
436
+ values = self.relative_attention_bias(relative_position_bucket)
437
+ values = values.permute([2, 0, 1])
438
+ return values
439
+
440
+ def forward(self,
441
+ query,
442
+ key: Optional[Tensor],
443
+ value: Optional[Tensor],
444
+ key_padding_mask: Optional[Tensor]=None,
445
+ incremental_state: Optional[Dict[str, Dict[str, Optional[
446
+ Tensor]]]]=None,
447
+ need_weights: bool=True,
448
+ static_kv: bool=False,
449
+ attn_mask: Optional[Tensor]=None,
450
+ before_softmax: bool=False,
451
+ need_head_weights: bool=False,
452
+ position_bias: Optional[Tensor]=None
453
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
454
+ """Input shape: Time x Batch x Channel
455
+
456
+ Args:
457
+ key_padding_mask (ByteTensor, optional): mask to exclude
458
+ keys that are pads, of shape `(batch, src_len)`, where
459
+ padding elements are indicated by 1s.
460
+ need_weights (bool, optional): return the attention weights,
461
+ averaged over heads (default: False).
462
+ attn_mask (ByteTensor, optional): typically used to
463
+ implement causal attention, where the mask prevents the
464
+ attention from looking forward in time (default: None).
465
+ before_softmax (bool, optional): return the raw attention
466
+ weights and values before the attention softmax.
467
+ need_head_weights (bool, optional): return the attention
468
+ weights for each head. Implies *need_weights*. Default:
469
+ return the average attention weights over all heads.
470
+ """
471
+ if need_head_weights:
472
+ need_weights = True
473
+
474
+ is_tpu = query.device.type == "xla"
475
+
476
+ tgt_len, bsz, embed_dim = query.size()
477
+ src_len = tgt_len
478
+ assert embed_dim == self.embed_dim
479
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
480
+ if key is not None:
481
+ src_len, key_bsz, _ = key.size()
482
+ if not torch.jit.is_scripting():
483
+ assert key_bsz == bsz
484
+ assert value is not None
485
+ assert src_len, bsz == value.shape[:2]
486
+
487
+ if self.has_relative_attention_bias and position_bias is None:
488
+ position_bias = self.compute_bias(tgt_len, src_len)
489
+ position_bias = position_bias.unsqueeze(0).repeat(
490
+ bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
491
+
492
+ if incremental_state is not None:
493
+ saved_state = self._get_input_buffer(incremental_state)
494
+ if saved_state is not None and "prev_key" in saved_state:
495
+ # previous time steps are cached - no need to recompute
496
+ # key and value if they are static
497
+ if static_kv:
498
+ assert self.encoder_decoder_attention and not self.self_attention
499
+ key = value = None
500
+ else:
501
+ saved_state = None
502
+
503
+ if self.self_attention:
504
+ q = self.q_proj(query)
505
+ k = self.k_proj(query)
506
+ v = self.v_proj(query)
507
+ elif self.encoder_decoder_attention:
508
+ # encoder-decoder attention
509
+ q = self.q_proj(query)
510
+ if key is None:
511
+ assert value is None
512
+ k = v = None
513
+ else:
514
+ k = self.k_proj(key)
515
+ v = self.v_proj(key)
516
+
517
+ else:
518
+ assert key is not None and value is not None
519
+ q = self.q_proj(query)
520
+ k = self.k_proj(key)
521
+ v = self.v_proj(value)
522
+ q *= self.scaling
523
+ alpha = 32
524
+ q *= 1 / alpha
525
+
526
+ if self.bias_k is not None:
527
+ assert self.bias_v is not None
528
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
529
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
530
+ if attn_mask is not None:
531
+ attn_mask = torch.cat(
532
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
533
+ dim=1)
534
+ if key_padding_mask is not None:
535
+ key_padding_mask = torch.cat(
536
+ [
537
+ key_padding_mask,
538
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
539
+ ],
540
+ dim=1, )
541
+
542
+ q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim)
543
+ .transpose(0, 1))
544
+ if k is not None:
545
+ k = (k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim)
546
+ .transpose(0, 1))
547
+ if v is not None:
548
+ v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim)
549
+ .transpose(0, 1))
550
+
551
+ if saved_state is not None:
552
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
553
+ if "prev_key" in saved_state:
554
+ _prev_key = saved_state["prev_key"]
555
+ assert _prev_key is not None
556
+ prev_key = _prev_key.view(bsz * self.num_heads, -1,
557
+ self.head_dim)
558
+ if static_kv:
559
+ k = prev_key
560
+ else:
561
+ assert k is not None
562
+ k = torch.cat([prev_key, k], dim=1)
563
+ src_len = k.size(1)
564
+ if "prev_value" in saved_state:
565
+ _prev_value = saved_state["prev_value"]
566
+ assert _prev_value is not None
567
+ prev_value = _prev_value.view(bsz * self.num_heads, -1,
568
+ self.head_dim)
569
+ if static_kv:
570
+ v = prev_value
571
+ else:
572
+ assert v is not None
573
+ v = torch.cat([prev_value, v], dim=1)
574
+ prev_key_padding_mask: Optional[Tensor] = None
575
+ if "prev_key_padding_mask" in saved_state:
576
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
577
+ assert k is not None and v is not None
578
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
579
+ key_padding_mask=key_padding_mask,
580
+ prev_key_padding_mask=prev_key_padding_mask,
581
+ batch_size=bsz,
582
+ src_len=k.size(1),
583
+ static_kv=static_kv, )
584
+
585
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
586
+ self.head_dim)
587
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
588
+ self.head_dim)
589
+ saved_state["prev_key_padding_mask"] = key_padding_mask
590
+ # In this branch incremental_state is never None
591
+ assert incremental_state is not None
592
+ incremental_state = self._set_input_buffer(incremental_state,
593
+ saved_state)
594
+ assert k is not None
595
+ assert k.size(1) == src_len
596
+
597
+ # This is part of a workaround to get around fork/join parallelism
598
+ # not supporting Optional types.
599
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
600
+ key_padding_mask = None
601
+
602
+ if key_padding_mask is not None:
603
+ assert key_padding_mask.size(0) == bsz
604
+ assert key_padding_mask.size(1) == src_len
605
+
606
+ if self.add_zero_attn:
607
+ assert v is not None
608
+ src_len += 1
609
+ k = torch.cat(
610
+ [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
611
+ v = torch.cat(
612
+ [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
613
+ if attn_mask is not None:
614
+ attn_mask = torch.cat(
615
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
616
+ dim=1)
617
+ if key_padding_mask is not None:
618
+ key_padding_mask = torch.cat(
619
+ [
620
+ key_padding_mask,
621
+ torch.zeros(key_padding_mask.size(0),
622
+ 1).type_as(key_padding_mask),
623
+ ],
624
+ dim=1, )
625
+
626
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
627
+ attn_weights = (
628
+ attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
629
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
630
+ bsz)
631
+
632
+ assert list(
633
+ attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
634
+
635
+ if attn_mask is not None:
636
+ attn_mask = attn_mask.unsqueeze(0)
637
+ attn_weights += attn_mask
638
+
639
+ if key_padding_mask is not None:
640
+ # don't attend to padding symbols
641
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
642
+ src_len)
643
+ if not is_tpu:
644
+ attn_weights = attn_weights.masked_fill(
645
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
646
+ float("-inf"), )
647
+ else:
648
+ attn_weights = attn_weights.transpose(0, 2)
649
+ attn_weights = attn_weights.masked_fill(key_padding_mask,
650
+ float("-inf"))
651
+ attn_weights = attn_weights.transpose(0, 2)
652
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
653
+ src_len)
654
+
655
+ if before_softmax:
656
+ return attn_weights, v, position_bias
657
+
658
+ if position_bias is not None:
659
+ attn_mask_rel_pos = position_bias
660
+ if self.gru_rel_pos == 1:
661
+ query_layer = q.view(bsz, self.num_heads, tgt_len,
662
+ self.q_head_dim) * alpha / self.scaling
663
+ _B, _H, _L, __ = query_layer.size()
664
+ gate_a, gate_b = torch.sigmoid(
665
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(
666
+ -1, keepdim=False)).chunk(
667
+ 2, dim=-1)
668
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
669
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len,
670
+ 1) * position_bias
671
+
672
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
673
+
674
+ attn_weights = attn_weights + attn_mask_rel_pos
675
+
676
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
677
+ attn_weights = attn_weights_float.type_as(attn_weights)
678
+ attn_probs = self.dropout_module(attn_weights)
679
+
680
+ assert v is not None
681
+ attn = torch.bmm(attn_probs, v)
682
+ assert list(
683
+ attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
684
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
685
+ attn = self.out_proj(attn)
686
+ attn_weights: Optional[Tensor] = None
687
+ if need_weights:
688
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len,
689
+ src_len).transpose(1, 0)
690
+ if not need_head_weights:
691
+ # average attention weights over heads
692
+ attn_weights = attn_weights.mean(dim=0)
693
+
694
+ return attn, attn_weights, position_bias
695
+
696
+ @staticmethod
697
+ def _append_prev_key_padding_mask(
698
+ key_padding_mask: Optional[Tensor],
699
+ prev_key_padding_mask: Optional[Tensor],
700
+ batch_size: int,
701
+ src_len: int,
702
+ static_kv: bool, ) -> Optional[Tensor]:
703
+ # saved key padding masks have shape (bsz, seq_len)
704
+ if prev_key_padding_mask is not None and static_kv:
705
+ new_key_padding_mask = prev_key_padding_mask
706
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
707
+ new_key_padding_mask = torch.cat(
708
+ [prev_key_padding_mask.float(), key_padding_mask.float()],
709
+ dim=1)
710
+ # During incremental decoding, as the padding token enters and
711
+ # leaves the frame, there will be a time when prev or current
712
+ # is None
713
+ elif prev_key_padding_mask is not None:
714
+ if src_len > prev_key_padding_mask.size(1):
715
+ filler = torch.zeros(
716
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
717
+ device=prev_key_padding_mask.device, )
718
+ new_key_padding_mask = torch.cat(
719
+ [prev_key_padding_mask.float(), filler.float()], dim=1)
720
+ else:
721
+ new_key_padding_mask = prev_key_padding_mask.float()
722
+ elif key_padding_mask is not None:
723
+ if src_len > key_padding_mask.size(1):
724
+ filler = torch.zeros(
725
+ (batch_size, src_len - key_padding_mask.size(1)),
726
+ device=key_padding_mask.device, )
727
+ new_key_padding_mask = torch.cat(
728
+ [filler.float(), key_padding_mask.float()], dim=1)
729
+ else:
730
+ new_key_padding_mask = key_padding_mask.float()
731
+ else:
732
+ new_key_padding_mask = prev_key_padding_mask
733
+ return new_key_padding_mask
734
+
735
+ def _get_input_buffer(
736
+ self,
737
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
738
+ ) -> Dict[str, Optional[Tensor]]:
739
+ result = self.get_incremental_state(incremental_state, "attn_state")
740
+ if result is not None:
741
+ return result
742
+ else:
743
+ empty_result: Dict[str, Optional[Tensor]] = {}
744
+ return empty_result
745
+
746
+ def _set_input_buffer(
747
+ self,
748
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
749
+ buffer: Dict[str, Optional[Tensor]], ):
750
+ return self.set_incremental_state(incremental_state, "attn_state",
751
+ buffer)
752
+
753
+ def apply_sparse_mask(self,
754
+ attn_weights,
755
+ tgt_len: int,
756
+ src_len: int,
757
+ bsz: int):
758
+ return attn_weights
759
+
760
+
761
+ def init_bert_params(module):
762
+ """
763
+ Initialize the weights specific to the BERT Model.
764
+ This overrides the default initializations depending on the specified arguments.
765
+ 1. If normal_init_linear_weights is set then weights of linear
766
+ layer will be initialized using the normal distribution and
767
+ bais will be set to the specified value.
768
+ 2. If normal_init_embed_weights is set then weights of embedding
769
+ layer will be initialized using the normal distribution.
770
+ 3. If normal_init_proj_weights is set then weights of
771
+ in_project_weight for MultiHeadAttention initialized using
772
+ the normal distribution (to be validated).
773
+ """
774
+
775
+ def normal_(data):
776
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
777
+ # so that the RNG is consistent with and without FSDP
778
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
779
+
780
+ if isinstance(module, nn.Linear):
781
+ normal_(module.weight.data)
782
+ if module.bias is not None:
783
+ module.bias.data.zero_()
784
+ if isinstance(module, nn.Embedding):
785
+ normal_(module.weight.data)
786
+ if module.padding_idx is not None:
787
+ module.weight.data[module.padding_idx].zero_()
788
+ if isinstance(module, MultiheadAttention):
789
+ normal_(module.q_proj.weight.data)
790
+ normal_(module.k_proj.weight.data)
791
+ normal_(module.v_proj.weight.data)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ # 获取当前脚本的所在目录
5
+ script_dir = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ # JSON 文件的文件名
8
+ json_filename = "ontology.json"
9
+
10
+ # 构建 JSON 文件的完整路径
11
+ json_path = os.path.join(script_dir, json_filename)
12
+
13
+ id_name_dict = {}
14
+
15
+ with open(json_path, 'r') as f:
16
+ json_items = json.load(f)
17
+ # '/m/0dgw9r' -> 'Human sounds' and etc.
18
+ for item in json_items:
19
+ id_name_dict[item['id']] = item['name']
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/modules.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, :-self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self,
54
+ input_dim,
55
+ output_dim,
56
+ glu_type="sigmoid",
57
+ bias_in_glu=True):
58
+ super(GLU_Linear, self).__init__()
59
+
60
+ self.glu_type = glu_type
61
+ self.output_dim = output_dim
62
+
63
+ if glu_type == "sigmoid":
64
+ self.glu_act = torch.nn.Sigmoid()
65
+ elif glu_type == "swish":
66
+ self.glu_act = Swish()
67
+ elif glu_type == "relu":
68
+ self.glu_act = torch.nn.ReLU()
69
+ elif glu_type == "gelu":
70
+ self.glu_act = torch.nn.GELU()
71
+
72
+ if bias_in_glu:
73
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
74
+ else:
75
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
76
+
77
+ def forward(self, x):
78
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
79
+ x = self.linear(x)
80
+
81
+ if self.glu_type == "bilinear":
82
+ x = (x[:, :, 0:self.output_dim] *
83
+ x[:, :, self.output_dim:self.output_dim * 2])
84
+ else:
85
+ x = (x[:, :, 0:self.output_dim] *
86
+ self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
87
+
88
+ return x
89
+
90
+
91
+ def gelu_accurate(x):
92
+ if not hasattr(gelu_accurate, "_a"):
93
+ gelu_accurate._a = math.sqrt(2 / math.pi)
94
+ return (0.5 * x * (1 + torch.tanh(gelu_accurate._a *
95
+ (x + 0.044715 * torch.pow(x, 3)))))
96
+
97
+
98
+ def gelu(x: torch.Tensor) -> torch.Tensor:
99
+ return torch.nn.functional.gelu(x.float()).type_as(x)
100
+
101
+
102
+ def get_activation_fn(activation: str):
103
+ """Returns the activation function corresponding to `activation`"""
104
+
105
+ if activation == "relu":
106
+ return F.relu
107
+ elif activation == "gelu":
108
+ return gelu
109
+ elif activation == "gelu_fast":
110
+ warnings.warn(
111
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate")
112
+ return gelu_accurate
113
+ elif activation == "gelu_accurate":
114
+ return gelu_accurate
115
+ elif activation == "tanh":
116
+ return torch.tanh
117
+ elif activation == "linear":
118
+ return lambda x: x
119
+ elif activation == "glu":
120
+ return lambda x: x
121
+ else:
122
+ raise RuntimeError(
123
+ "--activation-fn {} not supported".format(activation))
124
+
125
+
126
+ def quant_noise(module, p, block_size):
127
+ """
128
+ Wraps modules and applies quantization noise to the weights for
129
+ subsequent quantization with Iterative Product Quantization as
130
+ described in "Training with Quantization Noise for Extreme Model Compression"
131
+
132
+ Args:
133
+ - module: nn.Module
134
+ - p: amount of Quantization Noise
135
+ - block_size: size of the blocks for subsequent quantization with iPQ
136
+
137
+ Remarks:
138
+ - Module weights must have the right sizes wrt the block size
139
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
140
+ - For more detail on how to quantize by blocks with convolutional weights,
141
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
142
+ - We implement the simplest form of noise here as stated in the paper
143
+ which consists in randomly dropping blocks
144
+ """
145
+
146
+ # if no quantization noise, don't register hook
147
+ if p <= 0:
148
+ return module
149
+
150
+ # supported modules
151
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
152
+
153
+ # test whether module.weight has the right sizes wrt block_size
154
+ is_conv = module.weight.ndim == 4
155
+
156
+ # 2D matrix
157
+ if not is_conv:
158
+ assert (
159
+ module.weight.size(1) %
160
+ block_size == 0), "Input features must be a multiple of block sizes"
161
+
162
+ # 4D matrix
163
+ else:
164
+ # 1x1 convolutions
165
+ if module.kernel_size == (1, 1):
166
+ assert (module.in_channels % block_size == 0
167
+ ), "Input channels must be a multiple of block sizes"
168
+ # regular convolutions
169
+ else:
170
+ k = module.kernel_size[0] * module.kernel_size[1]
171
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
172
+
173
+ def _forward_pre_hook(mod, input):
174
+ # no noise for evaluation
175
+ if mod.training:
176
+ if not is_conv:
177
+ # gather weight and sizes
178
+ weight = mod.weight
179
+ in_features = weight.size(1)
180
+ out_features = weight.size(0)
181
+
182
+ # split weight matrix into blocks and randomly drop selected blocks
183
+ mask = torch.zeros(
184
+ in_features // block_size * out_features,
185
+ device=weight.device)
186
+ mask.bernoulli_(p)
187
+ mask = mask.repeat_interleave(block_size, -1).view(-1,
188
+ in_features)
189
+
190
+ else:
191
+ # gather weight and sizes
192
+ weight = mod.weight
193
+ in_channels = mod.in_channels
194
+ out_channels = mod.out_channels
195
+
196
+ # split weight matrix into blocks and randomly drop selected blocks
197
+ if mod.kernel_size == (1, 1):
198
+ mask = torch.zeros(
199
+ int(in_channels // block_size * out_channels),
200
+ device=weight.device, )
201
+ mask.bernoulli_(p)
202
+ mask = mask.repeat_interleave(block_size, -1).view(
203
+ -1, in_channels)
204
+ else:
205
+ mask = torch.zeros(
206
+ weight.size(0), weight.size(1), device=weight.device)
207
+ mask.bernoulli_(p)
208
+ mask = (
209
+ mask.unsqueeze(2).unsqueeze(3)
210
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
211
+
212
+ # scale weights and apply mask
213
+ mask = mask.to(
214
+ torch.
215
+ bool) # x.bool() is not currently supported in TorchScript
216
+ s = 1 / (1 - p)
217
+ mod.weight.data = s * weight.masked_fill(mask, 0)
218
+
219
+ module.register_forward_pre_hook(_forward_pre_hook)
220
+ return module
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/ontology.json ADDED
The diff for this file is too large to render. See raw diff
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/quantizer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+ import torch
10
+ import torch.distributed as distributed
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from einops import rearrange, repeat
16
+ except ImportError:
17
+ pass
18
+
19
+
20
+ def l2norm(t):
21
+ return F.normalize(t, p=2, dim=-1)
22
+
23
+
24
+ def ema_inplace(moving_avg, new, decay):
25
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
26
+
27
+
28
+ def sample_vectors(samples, num):
29
+ num_samples, device = samples.shape[0], samples.device
30
+
31
+ if num_samples >= num:
32
+ indices = torch.randperm(num_samples, device=device)[:num]
33
+ else:
34
+ indices = torch.randint(0, num_samples, (num, ), device=device)
35
+
36
+ return samples[indices]
37
+
38
+
39
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
40
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
41
+
42
+ means = sample_vectors(samples, num_clusters)
43
+
44
+ for _ in range(num_iters):
45
+ if use_cosine_sim:
46
+ dists = samples @ means.t()
47
+ else:
48
+ diffs = rearrange(samples, 'n d -> n () d') \
49
+ - rearrange(means, 'c d -> () c d')
50
+ dists = -(diffs**2).sum(dim=-1)
51
+
52
+ buckets = dists.max(dim=-1).indices
53
+ bins = torch.bincount(buckets, minlength=num_clusters)
54
+ zero_mask = bins == 0
55
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
56
+
57
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
58
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
59
+ new_means = new_means / bins_min_clamped[..., None]
60
+
61
+ if use_cosine_sim:
62
+ new_means = l2norm(new_means)
63
+
64
+ means = torch.where(zero_mask[..., None], means, new_means)
65
+
66
+ return means, bins
67
+
68
+
69
+ class EmbeddingEMA(nn.Module):
70
+ def __init__(self,
71
+ num_tokens,
72
+ codebook_dim,
73
+ decay=0.99,
74
+ eps=1e-5,
75
+ kmeans_init=True,
76
+ codebook_init_path=''):
77
+ super().__init__()
78
+ self.num_tokens = num_tokens
79
+ self.codebook_dim = codebook_dim
80
+ self.decay = decay
81
+ self.eps = eps
82
+ if codebook_init_path == '':
83
+ if not kmeans_init:
84
+ weight = torch.randn(num_tokens, codebook_dim)
85
+ weight = l2norm(weight)
86
+ else:
87
+ weight = torch.zeros(num_tokens, codebook_dim)
88
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
89
+ else:
90
+ print(f"load init codebook weight from {codebook_init_path}")
91
+ codebook_ckpt_weight = torch.load(
92
+ codebook_init_path, map_location='cpu')
93
+ weight = codebook_ckpt_weight.clone()
94
+ self.register_buffer('initted', torch.Tensor([True]))
95
+
96
+ self.weight = nn.Parameter(weight, requires_grad=False)
97
+ self.cluster_size = nn.Parameter(
98
+ torch.zeros(num_tokens), requires_grad=False)
99
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
100
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
101
+ self.update = True
102
+
103
+ @torch.jit.ignore
104
+ def init_embed_(self, data):
105
+ if self.initted:
106
+ return
107
+ print("Performing Kemans init for codebook")
108
+ embed, cluster_size = kmeans(
109
+ data, self.num_tokens, 10, use_cosine_sim=True)
110
+ self.weight.data.copy_(embed)
111
+ self.cluster_size.data.copy_(cluster_size)
112
+ self.initted.data.copy_(torch.Tensor([True]))
113
+
114
+ def forward(self, embed_id):
115
+ return F.embedding(embed_id, self.weight)
116
+
117
+ def cluster_size_ema_update(self, new_cluster_size):
118
+ self.cluster_size.data.mul_(self.decay).add_(
119
+ new_cluster_size, alpha=1 - self.decay)
120
+
121
+ def embed_avg_ema_update(self, new_embed_avg):
122
+ self.embed_avg.data.mul_(self.decay).add_(
123
+ new_embed_avg, alpha=1 - self.decay)
124
+
125
+ def weight_update(self, num_tokens):
126
+ n = self.cluster_size.sum()
127
+ smoothed_cluster_size = (
128
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n)
129
+ # normalize embedding average with smoothed cluster size
130
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
131
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
132
+ self.weight.data.copy_(embed_normalized)
133
+
134
+
135
+ def norm_ema_inplace(moving_avg, new, decay):
136
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
137
+ moving_avg.data.copy_(l2norm(moving_avg.data))
138
+
139
+
140
+ class NormEMAVectorQuantizer(nn.Module):
141
+ def __init__(self,
142
+ n_embed,
143
+ embedding_dim,
144
+ beta,
145
+ decay=0.99,
146
+ eps=1e-5,
147
+ statistic_code_usage=True,
148
+ kmeans_init=False,
149
+ codebook_init_path=''):
150
+ super().__init__()
151
+ self.codebook_dim = embedding_dim
152
+ self.num_tokens = n_embed
153
+ self.beta = beta
154
+ self.decay = decay
155
+
156
+ # learnable = True if orthogonal_reg_weight > 0 else False
157
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay,
158
+ eps, kmeans_init, codebook_init_path)
159
+
160
+ self.statistic_code_usage = statistic_code_usage
161
+ if statistic_code_usage:
162
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
163
+ if distributed.is_available() and distributed.is_initialized():
164
+ print(
165
+ "ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!"
166
+ )
167
+ self.all_reduce_fn = distributed.all_reduce
168
+ else:
169
+ self.all_reduce_fn = nn.Identity()
170
+
171
+ def reset_cluster_size(self, device):
172
+ if self.statistic_code_usage:
173
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
174
+ self.cluster_size = self.cluster_size.to(device)
175
+
176
+ def forward(self, z):
177
+ # reshape z -> (batch, height, width, channel) and flatten
178
+ # z, 'b c h w -> b h w c'
179
+ # z = rearrange(z, 'b c h w -> b h w c')
180
+ # z = z.transpose(1, 2)
181
+ z = l2norm(z)
182
+ z_flattened = z.reshape(-1, self.codebook_dim)
183
+
184
+ self.embedding.init_embed_(z_flattened)
185
+
186
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
187
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
188
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
189
+
190
+ encoding_indices = torch.argmin(d, dim=1)
191
+
192
+ z_q = self.embedding(encoding_indices).view(z.shape)
193
+
194
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
195
+
196
+ if not self.training:
197
+ with torch.no_grad():
198
+ cluster_size = encodings.sum(0)
199
+ self.all_reduce_fn(cluster_size)
200
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
201
+
202
+ if self.training and self.embedding.update:
203
+ # EMA cluster size
204
+
205
+ bins = encodings.sum(0)
206
+ self.all_reduce_fn(bins)
207
+
208
+ # self.embedding.cluster_size_ema_update(bins)
209
+ ema_inplace(self.cluster_size, bins, self.decay)
210
+
211
+ zero_mask = (bins == 0)
212
+ bins = bins.masked_fill(zero_mask, 1.)
213
+
214
+ embed_sum = z_flattened.t() @ encodings
215
+ self.all_reduce_fn(embed_sum)
216
+
217
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
218
+ embed_normalized = l2norm(embed_normalized)
219
+
220
+ embed_normalized = torch.where(
221
+ zero_mask[..., None], self.embedding.weight, embed_normalized)
222
+ norm_ema_inplace(self.embedding.weight, embed_normalized,
223
+ self.decay)
224
+
225
+ # compute loss for embedding
226
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
227
+
228
+ # preserve gradients
229
+ z_q = z + (z_q - z).detach()
230
+
231
+ # reshape back to match original input shape
232
+ # z_q, 'b h w c -> b c h w'
233
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
234
+ # z_q = z_q.transpose(1, 2)
235
+ return z_q, loss, encoding_indices
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_beats_librilight.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'
2
+ # non_speech.npy, 存储一个 python dict 表示非 speech 类型的音频的 tag, 更小,加载和搜索速度更快
3
+ # audio_tag 目录存储 {utt_id}.txt, 第一行是小写的 top1 tag
4
+ import argparse
5
+ import os
6
+ import time
7
+ import traceback
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from pathlib import Path
10
+
11
+ import librosa
12
+ import numpy as np
13
+ import torch
14
+ import tqdm
15
+ from AR.exps.beats.BEATs import BEATs
16
+ from AR.exps.beats.BEATs import BEATsConfig
17
+ from AR.exps.beats.config import id_name_dict
18
+ from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
19
+ from soundstorm.utils import check_txt_file
20
+
21
+
22
+ def get_BEATs_top1(wav,
23
+ BEATs_model,
24
+ BEATs_label_dict,
25
+ device: str='cpu',
26
+ topk: int=1):
27
+ wav = torch.tensor(wav).unsqueeze(0).to(device)
28
+ padding_mask = torch.zeros(wav.shape).bool().to(device)
29
+ probs = BEATs_model.extract_features(wav, padding_mask=padding_mask)[0]
30
+ # 单条推理
31
+ probs = probs[0]
32
+ topk_label_prob, topk_label_idx = probs.topk(k=topk)
33
+ topk_label = [
34
+ BEATs_label_dict[label_idx.item()] for label_idx in topk_label_idx
35
+ ]
36
+ topk_label_name = [id_name_dict[label] for label in topk_label]
37
+ top1_label = topk_label_name[0]
38
+ return top1_label
39
+
40
+
41
+ def process_sentence(args,
42
+ fp: Path,
43
+ train_dump_dir: Path,
44
+ dev_dump_dir: Path,
45
+ test_dump_dir: Path,
46
+ VAD_dict,
47
+ BEATs_model,
48
+ BEATs_label_dict,
49
+ device: str='cpu'):
50
+ utt_id = fp.stem
51
+ sr = args.sr
52
+ record = []
53
+ train_audio_tag_dir = train_dump_dir / "audio_tag"
54
+ train_audio_tag_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ dev_audio_tag_dir = dev_dump_dir / "audio_tag"
57
+ dev_audio_tag_dir.mkdir(parents=True, exist_ok=True)
58
+
59
+ test_audio_tag_dir = test_dump_dir / "audio_tag"
60
+ test_audio_tag_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ try:
63
+ # get info for path
64
+ wav_path_list = str(fp).strip().split('/')
65
+ sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
66
+ -3], wav_path_list[-2]
67
+ wav_name = wav_path_list[-1][:-5]
68
+ assert wav_name == utt_id
69
+ # key_name for big wav
70
+ key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
71
+ # 判断 VAD 字典中不存在该条音频信息的情况
72
+ if key_name not in VAD_dict.keys():
73
+ print(key_name, 'not in VAD_dict !')
74
+ return record
75
+ wav = None
76
+ sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
77
+ len_dict = len(sorted_split_VAD_dict)
78
+ for index, item in enumerate(sorted_split_VAD_dict):
79
+ split_name, value = item
80
+ start, end = value
81
+ # train | dev | test
82
+ if index == len_dict - 1:
83
+ subset = 'test'
84
+ audio_tag_path = test_audio_tag_dir / (split_name + ".txt")
85
+ elif index == len_dict - 2:
86
+ subset = 'dev'
87
+ audio_tag_path = dev_audio_tag_dir / (split_name + ".txt")
88
+ else:
89
+ subset = 'train'
90
+ audio_tag_path = train_audio_tag_dir / (split_name + ".txt")
91
+
92
+ if os.path.exists(audio_tag_path) and check_txt_file(
93
+ audio_tag_path):
94
+ # print(audio_tag_path, 'exits!')
95
+ pass
96
+ else:
97
+ # 这里加判断保证在 sub wav 的循环中只 load 一次
98
+ if wav is None:
99
+ # load big wav
100
+ # 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
101
+ wav, _ = librosa.load(str(fp), sr=sr)
102
+ sub_wav = wav[int(start * sr):int(end * sr)]
103
+ audio_tag_top1 = get_BEATs_top1(
104
+ wav=sub_wav,
105
+ BEATs_model=BEATs_model,
106
+ BEATs_label_dict=BEATs_label_dict,
107
+ device=device)
108
+
109
+ with open(audio_tag_path, 'w') as f:
110
+ f.write(audio_tag_top1)
111
+
112
+ sub_record = {
113
+ "utt_id": split_name,
114
+ "audio_tag_path": audio_tag_path,
115
+ "subset": subset
116
+ }
117
+ # recodrd 变成 List of Dict
118
+ record.append(sub_record)
119
+ except Exception:
120
+ print("occur Exception")
121
+ traceback.print_exc()
122
+ # record 有可能是一个不完整的 List
123
+ return record
124
+ return record
125
+
126
+
127
+ def process_sentences(args,
128
+ fps: Path,
129
+ train_dump_dir: Path,
130
+ dev_dump_dir: Path,
131
+ test_dump_dir: Path,
132
+ VAD_dict,
133
+ BEATs_model,
134
+ BEATs_label_dict,
135
+ device: str='cpu',
136
+ nprocs: int=1):
137
+ print("nprocs:", nprocs)
138
+ if nprocs == 1:
139
+ results = []
140
+ for fp in tqdm.tqdm(fps, total=len(fps)):
141
+ record = process_sentence(
142
+ args=args,
143
+ fp=fp,
144
+ train_dump_dir=train_dump_dir,
145
+ dev_dump_dir=dev_dump_dir,
146
+ test_dump_dir=test_dump_dir,
147
+ VAD_dict=VAD_dict,
148
+ BEATs_model=BEATs_model,
149
+ BEATs_label_dict=BEATs_label_dict,
150
+ device=device)
151
+ if record:
152
+ results.append(record)
153
+ else:
154
+ with ThreadPoolExecutor(nprocs) as pool:
155
+ futures = []
156
+ with tqdm.tqdm(total=len(fps)) as progress:
157
+ for fp in fps:
158
+ future = pool.submit(process_sentence, args, fp,
159
+ train_dump_dir, dev_dump_dir,
160
+ test_dump_dir, VAD_dict, BEATs_model,
161
+ BEATs_label_dict, device)
162
+ future.add_done_callback(lambda p: progress.update())
163
+ futures.append(future)
164
+
165
+ results = []
166
+ for ft in futures:
167
+ record = ft.result()
168
+ if record:
169
+ results.append(record)
170
+
171
+ # torch.save() to a large `.pth` file
172
+ non_speech_dict = dict()
173
+ non_speech_dict['train'] = {}
174
+ non_speech_dict['dev'] = {}
175
+ non_speech_dict['test'] = {}
176
+ # record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
177
+ print(f"start to save {args.rank}_{args.nshard}.npy ...")
178
+ save_start_time = time.time()
179
+ for record in tqdm.tqdm(results, total=len(results), colour='green'):
180
+ for sub_record in record:
181
+ # 这里加 try, 因为 txt 文件可能损坏
182
+ try:
183
+ utt_id = sub_record["utt_id"]
184
+ subset = sub_record["subset"]
185
+ audio_tag_top1 = check_txt_file(sub_record["audio_tag_path"])
186
+ if audio_tag_top1 is not False:
187
+ if 'speech' not in audio_tag_top1.lower():
188
+ non_speech_dict[subset][utt_id] = audio_tag_top1
189
+ else:
190
+ # print(f'audio tag result of {utt_id} is speech')
191
+ pass
192
+ else:
193
+ print(f'audio tag result of {utt_id} is False')
194
+ except Exception:
195
+ print(f"{utt_id} occur Exception")
196
+ traceback.print_exc()
197
+ continue
198
+
199
+ train_filename = train_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
200
+ dev_filename = dev_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
201
+ test_filename = test_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
202
+ np.save(train_filename, non_speech_dict['train'])
203
+ print(f"npy file '{train_filename}' write down")
204
+
205
+ np.save(dev_filename, non_speech_dict['dev'])
206
+ print(f"npy file '{dev_filename}' write down")
207
+
208
+ np.save(test_filename, non_speech_dict['test'])
209
+ print(f"npy file '{test_filename}' write down")
210
+ print('time of save stage:', time.time() - save_start_time)
211
+
212
+
213
+ def main():
214
+ # parse config and args
215
+ parser = argparse.ArgumentParser(
216
+ description="Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'."
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--data_dir", default=None, type=str, help="directory to dataset.")
221
+
222
+ parser.add_argument(
223
+ "--dump_dir",
224
+ type=str,
225
+ required=True,
226
+ help="directory to dump feature files.")
227
+
228
+ parser.add_argument(
229
+ "--num-cpu", type=int, default=1, help="number of process.")
230
+
231
+ parser.add_argument(
232
+ '--sr', type=int, default=16000, help='sample rate of model')
233
+
234
+ # For LibriLight dataset
235
+ parser.add_argument(
236
+ "--sub_dataset",
237
+ default="small",
238
+ type=str,
239
+ help="name of sub dataset of LibriLight",
240
+ choices=['small', 'medium', 'large', 'duplicate'], )
241
+ parser.add_argument(
242
+ "--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
243
+ parser.add_argument("--nshard", type=int, default=3)
244
+ parser.add_argument("--rank", type=int, default=0)
245
+
246
+ # for BEATs
247
+ parser.add_argument(
248
+ "--BEATs_ckpt_path",
249
+ type=str,
250
+ default='./pretrained_model/BEATs_iter1_finetuned_on_AS2M_cpt1.pt')
251
+
252
+ args = parser.parse_args()
253
+
254
+ data_dir = Path(args.data_dir).expanduser()
255
+ dump_dir = Path(args.dump_dir).expanduser()
256
+ # use absolute path
257
+ dump_dir = dump_dir.resolve()
258
+ dump_dir.mkdir(parents=True, exist_ok=True)
259
+
260
+ assert data_dir.is_dir()
261
+
262
+ # sub_dataset here
263
+ sub_dataset_dir = data_dir / args.sub_dataset
264
+ # olny spk_id in list, sort by lexicographical order
265
+ speaker_list = sorted(os.listdir(sub_dataset_dir))
266
+ start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
267
+ # speaker_list for this rank
268
+ speaker_list = speaker_list[start:end]
269
+
270
+ all_wav_files = []
271
+
272
+ for speaker in speaker_list:
273
+ wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
274
+ # filter out ._*.flac
275
+ wav_files = [
276
+ file for file in wav_files if not file.name.startswith('._')
277
+ ]
278
+ all_wav_files += wav_files
279
+
280
+ print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
281
+ # get VAD info
282
+ VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
283
+
284
+ sub_dataset_dump_dir = dump_dir / args.sub_dataset
285
+ sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
286
+ train_dump_dir = sub_dataset_dump_dir / "train"
287
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
288
+ dev_dump_dir = sub_dataset_dump_dir / "dev"
289
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
290
+ test_dump_dir = sub_dataset_dump_dir / "test"
291
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
292
+
293
+ BEATs_ckpt = torch.load(args.BEATs_ckpt_path)
294
+
295
+ BEATs_cfg = BEATsConfig(BEATs_ckpt['cfg'])
296
+ BEATs_model = BEATs(BEATs_cfg)
297
+ BEATs_model.load_state_dict(BEATs_ckpt['model'])
298
+ BEATs_model.eval()
299
+ # cpu or cuda
300
+ device = 'cpu'
301
+ BEATs_model.to(device)
302
+
303
+ BEATs_label_dict = BEATs_ckpt['label_dict']
304
+
305
+ # 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
306
+ if all_wav_files:
307
+ process_sentences(
308
+ args=args,
309
+ fps=all_wav_files,
310
+ train_dump_dir=train_dump_dir,
311
+ dev_dump_dir=dev_dump_dir,
312
+ test_dump_dir=test_dump_dir,
313
+ VAD_dict=VAD_dict,
314
+ BEATs_model=BEATs_model,
315
+ BEATs_label_dict=BEATs_label_dict,
316
+ device=device,
317
+ nprocs=args.num_cpu)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1. read text of dataset
3
+ 2. text -> IPA by GruutPhonemizer
4
+ 3. save out a *.npy dict for all text
5
+ my_dict = {"utt_id1": text1, "utt_id2": text2}
6
+ np.save(output_filename, my_dict)
7
+ my_dict = np.load(output_filename, allow_pickle=True).item()
8
+ """
9
+ import argparse
10
+ import os
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from operator import itemgetter
13
+ from pathlib import Path
14
+ from typing import List
15
+
16
+ import numpy as np
17
+ import tqdm
18
+ from AR.text_processing.phonemizer import GruutPhonemizer
19
+
20
+
21
+ def read_txt(txt_file):
22
+ utt_name = txt_file.stem
23
+ utt_id = utt_name.split('.')[0]
24
+ try:
25
+ with open(txt_file, 'r') as file:
26
+ txt = file.readline()
27
+ record = {"utt_id": utt_id, "txt": txt}
28
+ except Exception:
29
+ print("occur Exception")
30
+ traceback.print_exc()
31
+ return None
32
+ return record
33
+
34
+
35
+ def read_txts(txt_files: List[Path], nprocs: int=1):
36
+ if nprocs == 1:
37
+ results = []
38
+ for txt_file in tqdm.tqdm(txt_files, total=len(txt_files)):
39
+ record = read_txt(txt_file=txt_file)
40
+ if record:
41
+ results.append(record)
42
+ else:
43
+ with ThreadPoolExecutor(nprocs) as pool:
44
+ futures = []
45
+ with tqdm.tqdm(total=len(txt_files)) as progress:
46
+ for txt_file in txt_files:
47
+ future = pool.submit(read_txt, txt_file)
48
+ future.add_done_callback(lambda p: progress.update())
49
+ futures.append(future)
50
+
51
+ results = []
52
+ for ft in futures:
53
+ record = ft.result()
54
+ if record:
55
+ results.append(record)
56
+
57
+ results.sort(key=itemgetter("utt_id"))
58
+ return_list = []
59
+ for item in results:
60
+ return_list.append((item["utt_id"], item["txt"]))
61
+ return return_list
62
+
63
+
64
+ def process_sentence(item, phonemizer):
65
+ utt_id, text = item
66
+ try:
67
+ phonemes = phonemizer.phonemize(text, espeak=False)
68
+ record = {"utt_id": utt_id, "phonemes": phonemes}
69
+ except Exception:
70
+ print("occur Exception")
71
+ traceback.print_exc()
72
+ return None
73
+ return record
74
+
75
+
76
+ def process_sentences(items, phonemizer, output_dir, nprocs: int=1):
77
+ if nprocs == 1:
78
+ results = []
79
+ for item in tqdm.tqdm(items, total=len(items)):
80
+ record = process_sentence(item=item, phonemizer=phonemizer)
81
+ if record:
82
+ results.append(record)
83
+ else:
84
+ with ThreadPoolExecutor(nprocs) as pool:
85
+ futures = []
86
+ with tqdm.tqdm(total=len(items)) as progress:
87
+ for item in items:
88
+ future = pool.submit(process_sentence, item, phonemizer)
89
+ future.add_done_callback(lambda p: progress.update())
90
+ futures.append(future)
91
+
92
+ results = []
93
+ for ft in futures:
94
+ record = ft.result()
95
+ if record:
96
+ results.append(record)
97
+ results.sort(key=itemgetter("utt_id"))
98
+ npy_dict = {}
99
+ for item in results:
100
+ utt_id = item["utt_id"]
101
+ phonemes = item["phonemes"]
102
+ npy_dict[utt_id] = phonemes
103
+ filename = output_dir / 'phonemes.npy'
104
+ np.save(filename, npy_dict)
105
+ print(f"npy file '{filename}' write down")
106
+
107
+
108
+ def main():
109
+ # parse config and args
110
+ parser = argparse.ArgumentParser(description="Get phones for datasets")
111
+
112
+ parser.add_argument(
113
+ "--dataset",
114
+ default="ljspeech",
115
+ type=str,
116
+ help="name of dataset, should in {ljspeech, libritts} now")
117
+
118
+ parser.add_argument(
119
+ "--data_dir", default=None, type=str, help="directory to dataset.")
120
+
121
+ parser.add_argument(
122
+ "--dump_dir",
123
+ type=str,
124
+ required=True,
125
+ help="directory to dump feature files.")
126
+ parser.add_argument(
127
+ "--num-cpu", type=int, default=1, help="number of process.")
128
+
129
+ args = parser.parse_args()
130
+
131
+ data_dir = Path(args.data_dir).expanduser()
132
+ dump_dir = Path(args.dump_dir).expanduser()
133
+ # use absolute path
134
+ dump_dir = dump_dir.resolve()
135
+ dump_dir.mkdir(parents=True, exist_ok=True)
136
+
137
+ assert data_dir.is_dir()
138
+
139
+ if args.dataset == "ljspeech":
140
+ data_dict = {}
141
+ text_path = data_dir / 'metadata.csv'
142
+ with open(text_path, 'r') as rf:
143
+ for line in rf:
144
+ line_list = line.strip().split('|')
145
+ utt_id = line_list[0]
146
+ raw_text = line_list[-1]
147
+ data_dict[utt_id] = raw_text
148
+
149
+ sorted_dict = sorted(data_dict.items())
150
+
151
+ num_train = 12900
152
+ num_dev = 100
153
+ # (utt_id, txt)
154
+ train_txts = sorted_dict[:num_train]
155
+ dev_txts = sorted_dict[num_train:num_train + num_dev]
156
+ test_txts = sorted_dict[num_train + num_dev:]
157
+
158
+ elif args.dataset == "libritts":
159
+ '''
160
+ we use train-clean-100、train-clean-360、train-other-500 here
161
+ and split dev and test from them, don't use test-* and dev-* cause the speakers are disjoint
162
+ the file structure is LibriTTS_R/train-clean-100/spkid/*/*.wav
163
+ there are about 2311 in these subsets, we split 1 dev and 1 test wav out from each speaker
164
+ '''
165
+ txt_files = []
166
+ train_txt_files = []
167
+ dev_txt_files = []
168
+ test_txt_files = []
169
+ sub_num_dev = 1
170
+ for sub_dataset_name in {
171
+ "train-clean-100", "train-clean-360", "train-other-500"
172
+ }:
173
+ sub_dataset_dir = data_dir / sub_dataset_name
174
+ # filter out hidden files
175
+ speaker_list = [
176
+ file for file in os.listdir(sub_dataset_dir)
177
+ if not file.startswith('.')
178
+ ]
179
+ for speaker in speaker_list:
180
+ txt_files = sorted(
181
+ list((sub_dataset_dir / speaker).rglob(
182
+ "*/*.normalized.txt")))
183
+ # filter out ._*.wav
184
+ txt_files = [
185
+ file for file in txt_files if not file.name.startswith('._')
186
+ ]
187
+ train_txt_files += txt_files[:-sub_num_dev * 2]
188
+ dev_txt_files += txt_files[-sub_num_dev * 2:-sub_num_dev]
189
+ test_txt_files += txt_files[-sub_num_dev:]
190
+ print("len(train_txt_files):", len(train_txt_files))
191
+ print("len(dev_txt_files):", len(dev_txt_files))
192
+ print("len(test_txt_files):", len(test_txt_files))
193
+
194
+ train_txts = read_txts(train_txt_files)
195
+ dev_txts = read_txts(dev_txt_files)
196
+ test_txts = read_txts(test_txt_files)
197
+
198
+ else:
199
+ print("dataset should in {ljspeech, libritts} now!")
200
+
201
+ train_dump_dir = dump_dir / "train"
202
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
203
+ dev_dump_dir = dump_dir / "dev"
204
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
205
+ test_dump_dir = dump_dir / "test"
206
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
207
+
208
+ phonemizer = GruutPhonemizer(language='en-us')
209
+
210
+ # process for the 3 sections
211
+ if train_txts:
212
+ process_sentences(
213
+ items=train_txts,
214
+ output_dir=train_dump_dir,
215
+ phonemizer=phonemizer,
216
+ nprocs=args.num_cpu)
217
+ if dev_txts:
218
+ process_sentences(
219
+ items=dev_txts,
220
+ output_dir=dev_dump_dir,
221
+ phonemizer=phonemizer,
222
+ nprocs=args.num_cpu)
223
+ if test_txts:
224
+ process_sentences(
225
+ items=test_txts,
226
+ output_dir=test_dump_dir,
227
+ phonemizer=phonemizer,
228
+ nprocs=args.num_cpu)
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones_librilight.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1. read text of dataset, for LibriLight read txt_*.npy -> 需要整理成 list(utt_id, txt) 的形式
3
+ 2. text -> IPA by GruutPhonemizer
4
+ 3. save out a *.npy dict for all text
5
+ 4. LibriLight 每个 split 分开处理
6
+ my_dict = {"utt_id1": text1, "utt_id2": text2}
7
+ np.save(output_filename, my_dict)
8
+ my_dict = np.load(output_filename, allow_pickle=True).item()
9
+ """
10
+ import argparse
11
+ import os
12
+ import time
13
+ import traceback
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from operator import itemgetter
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import tqdm
20
+ from AR.text_processing.phonemizer import GruutPhonemizer
21
+ from soundstorm.utils import check_txt_file
22
+
23
+
24
+ def read_txts(txt_file: Path, nprocs: int=1):
25
+ '''
26
+ txt_file: path of npy dict, {"utt_id1": text1, "utt_id2": text2}
27
+ '''
28
+ txt_dict = np.load(txt_file, allow_pickle=True).item()
29
+ #[(utt_id, txt), ...]
30
+ return_list = list(txt_dict.items())
31
+ return return_list
32
+
33
+
34
+ def process_sentence(item, phonemizer, output_dir):
35
+ utt_id, text = item
36
+ phonemes_dir = output_dir / "phonemes"
37
+ phonemes_dir.mkdir(parents=True, exist_ok=True)
38
+ phonemes_path = phonemes_dir / (utt_id + ".txt")
39
+ try:
40
+ if os.path.exists(phonemes_path) and check_txt_file(phonemes_path):
41
+ # print(phonemes_path, 'exits!')
42
+ pass
43
+ else:
44
+ phonemes = phonemizer.phonemize(text, espeak=False)
45
+ with open(phonemes_path, 'w') as f:
46
+ f.write(phonemes)
47
+ record = {"utt_id": utt_id, "phonemes_path": phonemes_path}
48
+ except Exception:
49
+ print("occur Exception")
50
+ traceback.print_exc()
51
+ return None
52
+ return record
53
+
54
+
55
+ def process_sentences(args, items, phonemizer, output_dir, nprocs: int=1):
56
+ print("nprocs:", nprocs)
57
+ if nprocs == 1:
58
+ results = []
59
+ for item in tqdm.tqdm(items, total=len(items)):
60
+ record = process_sentence(
61
+ item=item, phonemizer=phonemizer, output_dir=output_dir)
62
+ if record:
63
+ results.append(record)
64
+ else:
65
+ with ThreadPoolExecutor(nprocs) as pool:
66
+ futures = []
67
+ with tqdm.tqdm(total=len(items)) as progress:
68
+ for item in items:
69
+ future = pool.submit(process_sentence, item, phonemizer,
70
+ output_dir)
71
+ future.add_done_callback(lambda p: progress.update())
72
+ futures.append(future)
73
+
74
+ results = []
75
+ for ft in futures:
76
+ record = ft.result()
77
+ if record:
78
+ results.append(record)
79
+
80
+ results.sort(key=itemgetter("utt_id"))
81
+
82
+ npy_dict = {}
83
+ print(f"start to save {args.rank}_{args.nshard}.npy ...")
84
+ save_start_time = time.time()
85
+ for item in tqdm.tqdm(results, total=len(results), colour='green'):
86
+ # 这里加 try, 因为 txt 文件可能损坏
87
+ try:
88
+ utt_id = item["utt_id"]
89
+ phonemes = check_txt_file(item["phonemes_path"])
90
+ if phonemes is not False:
91
+ npy_dict[utt_id] = phonemes
92
+ else:
93
+ print(f'phonemes of {utt_id} is False')
94
+ except Exception:
95
+ print(f"{utt_id} occur Exception")
96
+ traceback.print_exc()
97
+ continue
98
+
99
+ filename = output_dir / f'phonemes_{args.rank}_{args.nshard}.npy'
100
+ np.save(filename, npy_dict)
101
+ print(f"npy file '{filename}' write down")
102
+ print('time of save stage:', time.time() - save_start_time)
103
+
104
+
105
+ def main():
106
+ # parse config and args
107
+ parser = argparse.ArgumentParser(
108
+ description="Get phones for LibriLight dataset from txt_*.npy")
109
+
110
+ parser.add_argument(
111
+ "--dump_dir",
112
+ type=str,
113
+ required=True,
114
+ help="directory to dump feature files.")
115
+ parser.add_argument(
116
+ "--num-cpu", type=int, default=1, help="number of process.")
117
+
118
+ parser.add_argument(
119
+ '--train_txt_dir',
120
+ type=str,
121
+ default='dump/small/train/',
122
+ help='dir of train txt files')
123
+ parser.add_argument(
124
+ '--dev_txt_dir',
125
+ type=str,
126
+ default='dump/small/dev/',
127
+ help='dir of dev txt files')
128
+ parser.add_argument(
129
+ '--test_txt_dir',
130
+ type=str,
131
+ default='dump/small/test/',
132
+ help='dir of test txt files')
133
+
134
+ parser.add_argument(
135
+ "--sub_dataset",
136
+ default="small",
137
+ type=str,
138
+ help="name of sub dataset of LibriLight",
139
+ choices=['small', 'medium', 'large', 'duplicate'], )
140
+ parser.add_argument("--nshard", type=int, default=3)
141
+ parser.add_argument("--rank", type=int, default=0)
142
+
143
+ args = parser.parse_args()
144
+ print(f"nshard: {args.nshard}, rank: {args.rank}")
145
+
146
+ train_txt_dir = Path(args.train_txt_dir)
147
+ dev_txt_dir = Path(args.dev_txt_dir)
148
+ test_txt_dir = Path(args.test_txt_dir)
149
+
150
+ dump_dir = Path(args.dump_dir).expanduser()
151
+ # use absolute path
152
+ dump_dir = dump_dir.resolve()
153
+ dump_dir.mkdir(parents=True, exist_ok=True)
154
+
155
+ train_txt_file = train_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
156
+ dev_txt_file = dev_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
157
+ test_txt_file = test_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
158
+
159
+ train_txts = read_txts(train_txt_file)
160
+ dev_txts = read_txts(dev_txt_file)
161
+ test_txts = read_txts(test_txt_file)
162
+
163
+ sub_dataset_dump_dir = dump_dir / args.sub_dataset
164
+ sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
165
+ train_dump_dir = sub_dataset_dump_dir / "train"
166
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
167
+ dev_dump_dir = sub_dataset_dump_dir / "dev"
168
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
169
+ test_dump_dir = sub_dataset_dump_dir / "test"
170
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
171
+ phonemizer = GruutPhonemizer(language='en-us')
172
+
173
+ # process for the 3 sections
174
+ if train_txts:
175
+ process_sentences(
176
+ args=args,
177
+ items=train_txts,
178
+ output_dir=train_dump_dir,
179
+ phonemizer=phonemizer,
180
+ nprocs=args.num_cpu)
181
+ if dev_txts:
182
+ process_sentences(
183
+ args=args,
184
+ items=dev_txts,
185
+ output_dir=dev_dump_dir,
186
+ phonemizer=phonemizer,
187
+ nprocs=args.num_cpu)
188
+ if test_txts:
189
+ process_sentences(
190
+ args=args,
191
+ items=test_txts,
192
+ output_dir=test_dump_dir,
193
+ phonemizer=phonemizer,
194
+ nprocs=args.num_cpu)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_txt_librilight.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import traceback
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import tqdm
11
+ import whisper
12
+ from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
13
+ from soundstorm.utils import check_txt_file
14
+
15
+
16
+ def process_sentence(args,
17
+ fp: Path,
18
+ train_dump_dir: Path,
19
+ dev_dump_dir: Path,
20
+ test_dump_dir: Path,
21
+ VAD_dict):
22
+ asr_model = whisper.load_model("tiny.en")
23
+ utt_id = fp.stem
24
+ sr = args.sr
25
+ record = []
26
+ train_txt_dir = train_dump_dir / "txt"
27
+ train_txt_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ dev_txt_dir = dev_dump_dir / "txt"
30
+ dev_txt_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+ test_txt_dir = test_dump_dir / "txt"
33
+ test_txt_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ try:
36
+ # get info for path
37
+ wav_path_list = str(fp).strip().split('/')
38
+ sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
39
+ -3], wav_path_list[-2]
40
+ wav_name = wav_path_list[-1][:-5]
41
+ assert wav_name == utt_id
42
+ # key_name for big wav
43
+ key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
44
+ # 判断 VAD 字典中不存在该条音频信息的情况
45
+ if key_name not in VAD_dict.keys():
46
+ print(key_name, 'not in VAD_dict !')
47
+ return record
48
+ wav = None
49
+ sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
50
+ len_dict = len(sorted_split_VAD_dict)
51
+ for index, item in enumerate(sorted_split_VAD_dict):
52
+ split_name, value = item
53
+ start, end = value
54
+ # train | dev | test
55
+ if index == len_dict - 1:
56
+ subset = 'test'
57
+ txt_path = test_txt_dir / (split_name + ".txt")
58
+ elif index == len_dict - 2:
59
+ subset = 'dev'
60
+ txt_path = dev_txt_dir / (split_name + ".txt")
61
+ else:
62
+ subset = 'train'
63
+ txt_path = train_txt_dir / (split_name + ".txt")
64
+
65
+ if os.path.exists(txt_path) and check_txt_file(txt_path):
66
+ # print(txt_path, 'exits!')
67
+ pass
68
+ else:
69
+ # 这里加判断保证在 sub wav 的循环中只 load 一次
70
+ if wav is None:
71
+ # load big wav
72
+ # 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
73
+ wav, _ = librosa.load(str(fp), sr=sr)
74
+ sub_wav = wav[int(start * sr):int(end * sr)]
75
+ asr_result = asr_model.transcribe(sub_wav)["text"]
76
+ with open(txt_path, 'w') as f:
77
+ f.write(asr_result)
78
+
79
+ sub_record = {
80
+ "utt_id": split_name,
81
+ "txt_path": txt_path,
82
+ "subset": subset
83
+ }
84
+ # recodrd 变成 List of Dict
85
+ record.append(sub_record)
86
+ except Exception:
87
+ print("occur Exception")
88
+ traceback.print_exc()
89
+ # record 有可能是一个不完整的 List
90
+ return record
91
+ return record
92
+
93
+
94
+ def process_sentences(args,
95
+ fps: Path,
96
+ train_dump_dir: Path,
97
+ dev_dump_dir: Path,
98
+ test_dump_dir: Path,
99
+ VAD_dict,
100
+ nprocs: int=1):
101
+ print("nprocs:", nprocs)
102
+ if nprocs == 1:
103
+ results = []
104
+ for fp in tqdm.tqdm(fps, total=len(fps)):
105
+ record = process_sentence(
106
+ args=args,
107
+ fp=fp,
108
+ train_dump_dir=train_dump_dir,
109
+ dev_dump_dir=dev_dump_dir,
110
+ test_dump_dir=test_dump_dir,
111
+ VAD_dict=VAD_dict)
112
+ if record:
113
+ results.append(record)
114
+ else:
115
+ with ThreadPoolExecutor(nprocs) as pool:
116
+ futures = []
117
+ with tqdm.tqdm(total=len(fps)) as progress:
118
+ for fp in fps:
119
+ future = pool.submit(process_sentence, args, fp,
120
+ train_dump_dir, dev_dump_dir,
121
+ test_dump_dir, VAD_dict)
122
+ future.add_done_callback(lambda p: progress.update())
123
+ futures.append(future)
124
+
125
+ results = []
126
+ for ft in futures:
127
+ record = ft.result()
128
+ if record:
129
+ results.append(record)
130
+
131
+ # torch.save() to a large `.pth` file
132
+ txt_dict = dict()
133
+ txt_dict['train'] = {}
134
+ txt_dict['dev'] = {}
135
+ txt_dict['test'] = {}
136
+ # record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
137
+ print(f"start to save {args.rank}_{args.nshard}.npy ...")
138
+ save_start_time = time.time()
139
+ for record in tqdm.tqdm(results, total=len(results), colour='green'):
140
+ for sub_record in record:
141
+ # 这里加 try, 因为 txt 文件可能损坏
142
+ try:
143
+ utt_id = sub_record["utt_id"]
144
+ subset = sub_record["subset"]
145
+ asr_result = check_txt_file(sub_record["txt_path"])
146
+ if asr_result is not False:
147
+ txt_dict[subset][utt_id] = asr_result
148
+ else:
149
+ print(f'asr result of {utt_id} is False')
150
+ except Exception:
151
+ print(f"{utt_id} occur Exception")
152
+ traceback.print_exc()
153
+ continue
154
+
155
+ train_filename = train_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
156
+ dev_filename = dev_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
157
+ test_filename = test_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
158
+ np.save(train_filename, txt_dict['train'])
159
+ print(f"npy file '{train_filename}' write down")
160
+
161
+ np.save(dev_filename, txt_dict['dev'])
162
+ print(f"npy file '{dev_filename}' write down")
163
+
164
+ np.save(test_filename, txt_dict['test'])
165
+ print(f"npy file '{test_filename}' write down")
166
+ print('time of save stage:', time.time() - save_start_time)
167
+
168
+
169
+ def main():
170
+ # parse config and args
171
+ parser = argparse.ArgumentParser(
172
+ description="Preprocess audio and then extract features for LibriLight.")
173
+
174
+ parser.add_argument(
175
+ "--data_dir", default=None, type=str, help="directory to dataset.")
176
+
177
+ parser.add_argument(
178
+ "--dump_dir",
179
+ type=str,
180
+ required=True,
181
+ help="directory to dump feature files.")
182
+
183
+ parser.add_argument(
184
+ "--num-cpu", type=int, default=1, help="number of process.")
185
+
186
+ parser.add_argument(
187
+ '--sr', type=int, default=16000, help='sample rate of model')
188
+
189
+ # For LibriLight dataset
190
+ parser.add_argument(
191
+ "--sub_dataset",
192
+ default="small",
193
+ type=str,
194
+ help="name of sub dataset of LibriLight",
195
+ choices=['small', 'medium', 'large', 'duplicate'], )
196
+ parser.add_argument(
197
+ "--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
198
+ parser.add_argument("--nshard", type=int, default=3)
199
+ parser.add_argument("--rank", type=int, default=0)
200
+
201
+ args = parser.parse_args()
202
+
203
+ data_dir = Path(args.data_dir).expanduser()
204
+ dump_dir = Path(args.dump_dir).expanduser()
205
+ # use absolute path
206
+ dump_dir = dump_dir.resolve()
207
+ dump_dir.mkdir(parents=True, exist_ok=True)
208
+
209
+ assert data_dir.is_dir()
210
+
211
+ # sub_dataset here
212
+ sub_dataset_dir = data_dir / args.sub_dataset
213
+ # olny spk_id in list, sort by lexicographical order
214
+ speaker_list = sorted(os.listdir(sub_dataset_dir))
215
+ start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
216
+ # speaker_list for this rank
217
+ speaker_list = speaker_list[start:end]
218
+
219
+ all_wav_files = []
220
+
221
+ for speaker in speaker_list:
222
+ wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
223
+ # filter out ._*.flac
224
+ wav_files = [
225
+ file for file in wav_files if not file.name.startswith('._')
226
+ ]
227
+ all_wav_files += wav_files
228
+
229
+ print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
230
+ # get VAD info
231
+ VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
232
+
233
+ sub_dataset_dump_dir = dump_dir / args.sub_dataset
234
+ sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
235
+ train_dump_dir = sub_dataset_dump_dir / "train"
236
+ train_dump_dir.mkdir(parents=True, exist_ok=True)
237
+ dev_dump_dir = sub_dataset_dump_dir / "dev"
238
+ dev_dump_dir.mkdir(parents=True, exist_ok=True)
239
+ test_dump_dir = sub_dataset_dump_dir / "test"
240
+ test_dump_dir.mkdir(parents=True, exist_ok=True)
241
+
242
+ # 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
243
+ if all_wav_files:
244
+ process_sentences(
245
+ args=args,
246
+ fps=all_wav_files,
247
+ train_dump_dir=train_dump_dir,
248
+ dev_dump_dir=dev_dump_dir,
249
+ test_dump_dir=test_dump_dir,
250
+ VAD_dict=VAD_dict,
251
+ nprocs=args.num_cpu)
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/split_train_val.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import pandas
3
+
4
+ semantic_path = 'dump/semantic.tsv'
5
+ phoneme_path = 'dump/phoneme.npy'
6
+ train_semantic_path = 'dump/semantic_train.tsv'
7
+ train_phoneme_path = 'dump/phoneme_train.npy'
8
+ dev_semantic_path = 'dump/semantic_dev.tsv'
9
+ dev_phoneme_path = 'dump/phoneme_dev.npy'
10
+
11
+ # 读取dump/semantic.tsv
12
+ semantic_df = pandas.read_csv(semantic_path, sep='\t')
13
+ # pd.DataFrame(columns=["item_name", "semantic_audio"])
14
+ # # 读取dump/phoneme.npy
15
+ phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item()
16
+
17
+ dev_num = 20
18
+ # 随机从semantic_df中选取dev_num个
19
+ dev_df = semantic_df.sample(n=dev_num)
20
+ # 剩下的是train
21
+ train_df = semantic_df.drop(dev_df.index)
22
+ # 保存
23
+ dev_df.to_csv(dev_semantic_path, sep='\t', index=False)
24
+ train_df.to_csv(train_semantic_path, sep='\t', index=False)
25
+
26
+ # 将dev_df中的item_name取出来 作为dev_phoneme_dict的key
27
+ dev_item_names = dev_df['item_name'].tolist()
28
+ dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict}
29
+ train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names}
30
+
31
+ numpy.save(dev_phoneme_path, dev_phoneme_dict)
32
+ numpy.save(train_phoneme_path, train_phoneme_dict)
33
+
34
+
35
+
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/t2s.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # text to semantic
2
+ import argparse
3
+ import os
4
+ import re
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import whisper
12
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
13
+ from AR.text_processing.phonemizer import GruutPhonemizer
14
+ from AR.utils.io import load_yaml_config
15
+
16
+
17
+ def get_batch(text, phonemizer):
18
+ # phoneme_ids 和 phoneme_ids_len 是需要的
19
+ phoneme = phonemizer.phonemize(text, espeak=False)
20
+ phoneme_ids = phonemizer.transform(phoneme)
21
+ phoneme_ids_len = len(phoneme_ids)
22
+ phoneme_ids = np.array(phoneme_ids)
23
+ # add batch axis here
24
+ phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0)
25
+ phoneme_ids_len = torch.tensor([phoneme_ids_len])
26
+ print("phoneme:", phoneme)
27
+ batch = {
28
+ # torch.Tensor (B, max_phoneme_length)
29
+ "phoneme_ids": phoneme_ids,
30
+ # torch.Tensor (B)
31
+ "phoneme_ids_len": phoneme_ids_len
32
+ }
33
+ return batch
34
+
35
+
36
+ def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer):
37
+ sample_rate = 16000
38
+ # to get prompt
39
+ prompt_name = os.path.basename(prompt_wav_path).split('.')[0]
40
+ wav, _ = librosa.load(prompt_wav_path, sr=sample_rate)
41
+ # 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止
42
+ wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)]
43
+ # wav 需要挪出末尾的静音否则也可能提前停住
44
+ prompt_text = asr_model.transcribe(wav)["text"]
45
+ # 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿
46
+ prompt_text = prompt_text.replace(".", "")
47
+ prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False)
48
+ prompt_phoneme_ids = phonemizer.transform(prompt_phoneme)
49
+ prompt_phoneme_ids_len = len(prompt_phoneme_ids)
50
+ # get prompt_semantic
51
+ # (T) -> (1, T)
52
+ wav = torch.tensor(wav).unsqueeze(0)
53
+ wav = wav.cuda()
54
+ # (1, T)
55
+ prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32)
56
+ prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0)
57
+ prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len])
58
+
59
+ result = {
60
+ 'prompt_name': prompt_name,
61
+ 'prompt_phoneme_ids': prompt_phoneme_ids,
62
+ 'prompt_semantic_tokens': prompt_semantic_tokens,
63
+ 'prompt_phoneme_ids_len': prompt_phoneme_ids_len
64
+ }
65
+
66
+ return result
67
+
68
+
69
+ def parse_args():
70
+ # parse args and config
71
+ parser = argparse.ArgumentParser(
72
+ description="Run SoundStorm AR S1 model for input text file")
73
+
74
+ parser.add_argument(
75
+ '--config_file',
76
+ type=str,
77
+ default='conf/default.yaml',
78
+ help='path of config file')
79
+
80
+ parser.add_argument(
81
+ "--text_file",
82
+ type=str,
83
+ help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line."
84
+ )
85
+
86
+ parser.add_argument(
87
+ '--ckpt_path',
88
+ type=str,
89
+ default='exp/default/ckpt/epoch=99-step=49000.ckpt',
90
+ help='Checkpoint file of SoundStorm AR S1 model.')
91
+
92
+ parser.add_argument(
93
+ '--prompt_wav_path',
94
+ type=str,
95
+ default=None,
96
+ help='extract prompt semantic and prompt phonemes from prompt wav')
97
+
98
+ # to get semantic tokens from prompt_wav
99
+ parser.add_argument("--hubert_path", type=str, default=None)
100
+ parser.add_argument("--quantizer_path", type=str, default=None)
101
+
102
+ parser.add_argument("--output_dir", type=str, help="output dir.")
103
+
104
+ args = parser.parse_args()
105
+ return args
106
+
107
+
108
+ def main():
109
+ args = parse_args()
110
+ config = load_yaml_config(args.config_file)
111
+
112
+ output_dir = Path(args.output_dir)
113
+ output_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ hz = 50
116
+ max_sec = config['data']['max_sec']
117
+
118
+ # get models
119
+ t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
120
+ checkpoint_path=args.ckpt_path, config=config)
121
+ t2s_model.cuda()
122
+ t2s_model.eval()
123
+
124
+ phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us')
125
+
126
+ # models for prompt
127
+ asr_model = whisper.load_model("tiny.en")
128
+
129
+ semantic_tokenizer = SemanticTokenizer(
130
+ hubert_path=args.hubert_path,
131
+ quantizer_path=args.quantizer_path,
132
+ duplicate=True)
133
+
134
+ prompt_result = get_prompt(
135
+ prompt_wav_path=args.prompt_wav_path,
136
+ asr_model=asr_model,
137
+ phonemizer=phonemizer,
138
+ semantic_tokenizer=semantic_tokenizer)
139
+
140
+ # zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的
141
+ # (B, 1)
142
+ # prompt = torch.ones(
143
+ # batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0
144
+
145
+ prompt = prompt_result['prompt_semantic_tokens']
146
+ prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len']
147
+ prompt_phoneme_ids = prompt_result['prompt_phoneme_ids']
148
+
149
+ sentences = []
150
+ with open(args.text_file, 'rt', encoding='utf-8') as f:
151
+ for line in f:
152
+ if line.strip() != "":
153
+ items = re.split(r"\s+", line.strip(), 1)
154
+ utt_id = items[0]
155
+ sentence = " ".join(items[1:])
156
+ sentences.append((utt_id, sentence))
157
+ semantic_data = [['item_name', 'semantic_audio']]
158
+ for utt_id, sentence in sentences[1:]:
159
+ # 需要自己构造伪 batch 输入给模型
160
+ batch = get_batch(sentence, phonemizer)
161
+ # prompt 和真正的输入拼接
162
+ all_phoneme_ids = torch.cat(
163
+ [prompt_phoneme_ids, batch['phoneme_ids']], dim=1)
164
+ # 或者可以直接求 all_phoneme_ids 的 shape[-1]
165
+ all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len']
166
+ st = time.time()
167
+ with torch.no_grad():
168
+ pred_semantic = t2s_model.model.infer(
169
+ all_phoneme_ids.cuda(),
170
+ all_phoneme_len.cuda(),
171
+ prompt.cuda(),
172
+ top_k=config['inference']['top_k'],
173
+ early_stop_num=hz * max_sec)
174
+ print(f'{time.time() - st} sec used in T2S')
175
+
176
+ # 删除 prompt 对应的部分
177
+ prompt_len = prompt.shape[-1]
178
+ pred_semantic = pred_semantic[:, prompt_len:]
179
+
180
+ # bs = 1
181
+ pred_semantic = pred_semantic[0]
182
+ semantic_token = pred_semantic.detach().cpu().numpy().tolist()
183
+ semantic_token_str = ' '.join(str(x) for x in semantic_token)
184
+ semantic_data.append([utt_id, semantic_token_str])
185
+
186
+ delimiter = '\t'
187
+ filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv'
188
+ with open(filename, 'w', encoding='utf-8') as writer:
189
+ for row in semantic_data:
190
+ line = delimiter.join(row)
191
+ writer.write(line + '\n')
192
+ # clean semantic token for next setence
193
+ semantic_data = [['item_name', 'semantic_audio']]
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/test.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test from dump file
2
+ import argparse
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from AR.data.dataset import Text2SemanticDataset
9
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
10
+ from AR.utils.io import load_yaml_config
11
+ from torch.utils.data import DataLoader
12
+
13
+
14
+ def parse_args():
15
+ # parse args and config
16
+ parser = argparse.ArgumentParser(
17
+ description="Run SoundStorm AR S1 model for test set.")
18
+
19
+ parser.add_argument(
20
+ '--config_file',
21
+ type=str,
22
+ default='conf/default.yaml',
23
+ help='path of config file')
24
+
25
+ # args for dataset
26
+ parser.add_argument(
27
+ '--test_semantic_path',
28
+ type=str,
29
+ default='dump/test/semantic_token.tsv')
30
+ parser.add_argument(
31
+ '--test_phoneme_path', type=str, default='dump/test/phonemes.npy')
32
+
33
+ parser.add_argument(
34
+ '--ckpt_path',
35
+ type=str,
36
+ default='exp/default/ckpt/epoch=99-step=49000.ckpt',
37
+ help='Checkpoint file of SoundStorm AR S1 model.')
38
+
39
+ parser.add_argument("--output_dir", type=str, help="output dir.")
40
+
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def main():
46
+ args = parse_args()
47
+
48
+ config = load_yaml_config(args.config_file)
49
+
50
+ output_dir = Path(args.output_dir)
51
+ output_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ batch_size = 1
54
+ hz = 50
55
+ max_sec = config['data']['max_sec']
56
+
57
+ # get dataset
58
+ test_dataset = Text2SemanticDataset(
59
+ phoneme_path=args.test_phoneme_path,
60
+ semantic_path=args.test_semantic_path,
61
+ # max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
62
+ # 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
63
+ max_sec=100,
64
+ max_sample=8,
65
+ pad_val=config['data']['pad_val'])
66
+ # get model
67
+ t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
68
+ checkpoint_path=args.ckpt_path, config=config)
69
+ t2s_model.cuda()
70
+ t2s_model.eval()
71
+
72
+ # 获取 batch_size 条
73
+ # 创建 DataLoader,并指定 collate_fn 函数
74
+ dataloader = DataLoader(
75
+ test_dataset,
76
+ batch_size=batch_size,
77
+ shuffle=False,
78
+ collate_fn=test_dataset.collate)
79
+
80
+ item_names = test_dataset.__get_item_names__()
81
+
82
+ # 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
83
+ semantic_data = [['item_name', 'semantic_audio']]
84
+ for i, batch in enumerate(dataloader):
85
+ # 要保证 bs = 1
86
+ utt_id = item_names[i]
87
+ if i == 0:
88
+ print("utt_id:", utt_id)
89
+ # bs > 1 时会补零
90
+ # 与 validation_step() 保持一致
91
+ semantic_len = batch['semantic_ids'].size(1)
92
+ # 以 batch['semantic_ids'] 的前 150 个为 prompt
93
+ # 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
94
+ prompt_len = min(int(semantic_len * 0.5), 150)
95
+ # 输入纯文本时 prompt 该输入什么?=> see t2s.py
96
+ prompt = batch['semantic_ids'][:, :prompt_len]
97
+ # # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
98
+ # 证明 semantic token 中还是包含了音色信息
99
+ # prompt = torch.ones(
100
+ # batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
101
+ # print("prompt:", prompt)
102
+ # print("prompt.shape:", prompt.shape)
103
+ np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())
104
+
105
+ st = time.time()
106
+ with torch.no_grad():
107
+ # calculate acc for test
108
+ loss, acc = t2s_model.model.forward(
109
+ batch['phoneme_ids'].cuda(),
110
+ batch['phoneme_ids_len'].cuda(),
111
+ batch['semantic_ids'].cuda(),
112
+ batch['semantic_ids_len'].cuda())
113
+ print("top_3_acc of this batch:", acc)
114
+ pred_semantic = t2s_model.model.infer(
115
+ batch['phoneme_ids'].cuda(),
116
+ batch['phoneme_ids_len'].cuda(),
117
+ prompt.cuda(),
118
+ top_k=config['inference']['top_k'],
119
+ # hz * max_sec in train dataloader
120
+ # 生成的长度是 1002 应该是有一些 pad
121
+ early_stop_num=hz * max_sec)
122
+ # bs = 1
123
+ pred_semantic = pred_semantic[0]
124
+ print(f'{time.time() - st} sec used in T2S')
125
+ semantic_token = pred_semantic.detach().cpu().numpy().tolist()
126
+ semantic_token_str = ' '.join(str(x) for x in semantic_token)
127
+ semantic_data.append([utt_id, semantic_token_str])
128
+ else:
129
+ break
130
+ delimiter = '\t'
131
+ filename = output_dir / "semantic_token.tsv"
132
+ with open(filename, 'w', encoding='utf-8') as writer:
133
+ for row in semantic_data:
134
+ line = delimiter.join(row)
135
+ writer.write(line + '\n')
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/text.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ 001 Life was like a box of chocolates, you never know what you're gonna get.
2
+ 002 With great power there must come great responsibility.
3
+ 003 To be or not to be, that’s a question.
4
+ 004 A man can be destroyed but not defeated
5
+ 005 Do not, for one repulse, give up the purpose that you resolved to effort.
6
+ 006 Death is just a part of life, something we're all destined to do.
7
+ 007 I think it's hard winning a war with words.
8
+ 008 Don’t argue with the people of strong determination, because they may change the fact!
9
+ 009 Love you three thousand times.
10
+ 010 tidy tiger tied a tie tighter to tidy her tiny tall.
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
+ import argparse
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from pytorch_lightning import seed_everything
9
+ from pytorch_lightning import Trainer
10
+ from pytorch_lightning.callbacks import ModelCheckpoint
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from pytorch_lightning.strategies import DDPStrategy
13
+ from AR.data.data_module import Text2SemanticDataModule
14
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
15
+ from soundstorm.utils.io import load_yaml_config
16
+ logging.getLogger('numba').setLevel(logging.WARNING)
17
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
18
+ torch.set_float32_matmul_precision('high')
19
+ from soundstorm.utils import get_newest_ckpt
20
+
21
+
22
+ def main(args):
23
+ output_dir = Path(args.output_dir)
24
+ output_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ ckpt_dir = output_dir / 'ckpt'
27
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ config = load_yaml_config(args.config_file)
30
+
31
+ seed_everything(config["train"]["seed"], workers=True)
32
+ ckpt_callback: ModelCheckpoint = ModelCheckpoint(
33
+ save_top_k=-1,
34
+ save_on_train_epoch_end=False,
35
+ every_n_epochs=config["train"]["save_every_n_epoch"],
36
+ dirpath=ckpt_dir)
37
+ logger = WandbLogger(
38
+ project="AR_S1",
39
+ name=output_dir.stem,
40
+ save_dir=output_dir,
41
+ # resume the loss curve
42
+ resume=True,
43
+ # id='k19kvsq8'
44
+ )
45
+ trainer: Trainer = Trainer(
46
+ max_epochs=config["train"]["epochs"],
47
+ accelerator='gpu',
48
+ devices=-1,
49
+ benchmark=False,
50
+ fast_dev_run=False,
51
+ strategy=DDPStrategy(find_unused_parameters=True),
52
+ precision=config["train"]["precision"],
53
+ logger=logger,
54
+ callbacks=[ckpt_callback])
55
+
56
+ model: Text2SemanticLightningModule = Text2SemanticLightningModule(
57
+ config, output_dir)
58
+
59
+ data_module: Text2SemanticDataModule = Text2SemanticDataModule(
60
+ config,
61
+ train_semantic_path=args.train_semantic_path,
62
+ train_phoneme_path=args.train_phoneme_path,
63
+ dev_semantic_path=args.dev_semantic_path,
64
+ dev_phoneme_path=args.dev_phoneme_path)
65
+
66
+ try:
67
+ # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
68
+ newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
69
+ ckpt_path = ckpt_dir / newest_ckpt_name
70
+ except Exception:
71
+ ckpt_path = None
72
+ print("ckpt_path:", ckpt_path)
73
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
74
+
75
+
76
+ # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
77
+ if __name__ == '__main__':
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument(
80
+ '--config_file',
81
+ type=str,
82
+ default='conf/default.yaml',
83
+ help='path of config file')
84
+ # args for dataset
85
+ parser.add_argument(
86
+ '--train_semantic_path',
87
+ type=str,
88
+ default='dump/train/semantic_token.tsv')
89
+ parser.add_argument(
90
+ '--train_phoneme_path', type=str, default='dump/train/phonemes.npy')
91
+ parser.add_argument(
92
+ '--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv')
93
+ parser.add_argument(
94
+ '--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy')
95
+ parser.add_argument(
96
+ '--output_dir',
97
+ type=str,
98
+ default='exp/default',
99
+ help='directory to save the results')
100
+
101
+ args = parser.parse_args()
102
+ logging.info(str(args))
103
+ main(args)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train_librilight_6k.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
+ import argparse
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from pytorch_lightning import seed_everything
9
+ from pytorch_lightning import Trainer
10
+ from pytorch_lightning.callbacks import ModelCheckpoint
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from pytorch_lightning.strategies import DDPStrategy
13
+ from AR.data.data_module_librilight_6k import Text2SemanticDataModule
14
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
15
+ from soundstorm.utils import get_newest_ckpt
16
+ from soundstorm.utils.io import load_yaml_config
17
+
18
+ logging.getLogger('numba').setLevel(logging.WARNING)
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ torch.set_float32_matmul_precision('high')
21
+
22
+
23
+ def main(args):
24
+ output_dir = Path(args.output_dir)
25
+ output_dir.mkdir(parents=True, exist_ok=True)
26
+
27
+ ckpt_dir = output_dir / 'ckpt'
28
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ config = load_yaml_config(args.config_file)
31
+
32
+ seed_everything(config["train"]["seed"], workers=True)
33
+
34
+ ckpt_callback: ModelCheckpoint = ModelCheckpoint(
35
+ save_top_k=-1,
36
+ save_on_train_epoch_end=False,
37
+ every_n_train_steps=config["train"]["every_n_train_steps"],
38
+ dirpath=ckpt_dir)
39
+ logger = WandbLogger(
40
+ project="AR_S1_LibriLight",
41
+ name=output_dir.stem,
42
+ save_dir=output_dir,
43
+ # resume the loss curve
44
+ resume=True,
45
+ # id='k19kvsq8'
46
+ )
47
+ trainer: Trainer = Trainer(
48
+ max_epochs=config["train"]["epochs"],
49
+ accelerator='gpu',
50
+ devices=-1,
51
+ benchmark=False,
52
+ fast_dev_run=False,
53
+ strategy=DDPStrategy(find_unused_parameters=True),
54
+ precision=config["train"]["precision"],
55
+ logger=logger,
56
+ callbacks=[ckpt_callback])
57
+
58
+ model: Text2SemanticLightningModule = Text2SemanticLightningModule(
59
+ config, output_dir)
60
+
61
+ data_module: Text2SemanticDataModule = Text2SemanticDataModule(
62
+ config,
63
+ train_semantic_dirs=args.train_semantic_dirs,
64
+ train_phoneme_dirs=args.train_phoneme_dirs,
65
+ dev_semantic_dirs=args.dev_semantic_dirs,
66
+ dev_phoneme_dirs=args.dev_phoneme_dirs,
67
+ train_non_speech_dirs=args.train_non_speech_dirs,
68
+ dev_non_speech_dirs=args.dev_non_speech_dirs)
69
+ try:
70
+ newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
71
+ ckpt_path = ckpt_dir / newest_ckpt_name
72
+ except Exception:
73
+ ckpt_path = None
74
+
75
+ print("ckpt_path:", ckpt_path)
76
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
77
+
78
+
79
+ # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
80
+ if __name__ == '__main__':
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument(
83
+ '--config_file',
84
+ type=str,
85
+ default='conf/default.yaml',
86
+ help='path of config file')
87
+ # args for dataset
88
+ parser.add_argument(
89
+ '--train_semantic_dirs',
90
+ type=list,
91
+ nargs='+',
92
+ default=["dump/small/train/"],
93
+ help='dirs of train semantic')
94
+ parser.add_argument(
95
+ '--train_phoneme_dirs',
96
+ type=list,
97
+ nargs='+',
98
+ default=["dump/small/train/"],
99
+ help='dirs of train phoneme')
100
+ parser.add_argument(
101
+ '--dev_semantic_dirs',
102
+ type=list,
103
+ nargs='+',
104
+ default=["dump/small/dev/"],
105
+ help='dirs of dev semantic')
106
+ parser.add_argument(
107
+ '--dev_phoneme_dirs',
108
+ type=list,
109
+ nargs='+',
110
+ default=["dump/small/dev/"],
111
+ help='dirs of dev phoneme')
112
+ parser.add_argument(
113
+ '--output_dir',
114
+ type=str,
115
+ default='exp/default',
116
+ help='directory to save the results')
117
+
118
+ parser.add_argument(
119
+ '--train_non_speech_dirs',
120
+ type=list,
121
+ nargs='+',
122
+ default=None,
123
+ help='dirs of train non_speech data')
124
+
125
+ parser.add_argument(
126
+ '--dev_non_speech_dirs',
127
+ type=list,
128
+ nargs='+',
129
+ default=None,
130
+ help='dirs of dev non_speech data')
131
+
132
+ args = parser.parse_args()
133
+
134
+ new_train_semantic_dirs = []
135
+ new_train_phoneme_dirs = []
136
+ new_dev_semantic_dirs = []
137
+ new_dev_phoneme_dirs = []
138
+
139
+ new_train_non_speech_dirs = []
140
+ new_dev_non_speech_dirs = []
141
+
142
+ # format dataset dirs
143
+ for item in args.train_semantic_dirs:
144
+ new_train_semantic_dirs.append(''.join(item))
145
+ args.train_semantic_dirs = new_train_semantic_dirs
146
+
147
+ for item in args.train_phoneme_dirs:
148
+ new_train_phoneme_dirs.append(''.join(item))
149
+ args.train_phoneme_dirs = new_train_phoneme_dirs
150
+
151
+ for item in args.dev_semantic_dirs:
152
+ new_dev_semantic_dirs.append(''.join(item))
153
+ args.dev_semantic_dirs = new_dev_semantic_dirs
154
+
155
+ for item in args.dev_phoneme_dirs:
156
+ new_dev_phoneme_dirs.append(''.join(item))
157
+ args.dev_phoneme_dirs = new_dev_phoneme_dirs
158
+
159
+ if args.train_non_speech_dirs is not None:
160
+ for item in args.train_non_speech_dirs:
161
+ new_train_non_speech_dirs.append(''.join(item))
162
+ args.train_non_speech_dirs = new_train_non_speech_dirs
163
+
164
+ if args.dev_non_speech_dirs is not None:
165
+ for item in args.dev_non_speech_dirs:
166
+ new_dev_non_speech_dirs.append(''.join(item))
167
+ args.dev_non_speech_dirs = new_dev_non_speech_dirs
168
+
169
+ logging.info(str(args))
170
+ main(args)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/__init__.py ADDED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
+ import os,sys
3
+ now_dir = os.getcwd()
4
+ sys.path.append(now_dir)
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from pytorch_lightning import LightningModule
9
+ from AR.models.t2s_model import Text2SemanticDecoder
10
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
11
+ from AR.modules.optim import ScaledAdam
12
+
13
+
14
+ class Text2SemanticLightningModule(LightningModule):
15
+ def __init__(self, config, output_dir,is_train=True):
16
+ super().__init__()
17
+ self.config = config
18
+ self.top_k = 3
19
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
20
+ pretrained_s1=config.get("pretrained_s1")
21
+ if(pretrained_s1 and is_train):
22
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
23
+ print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
24
+ if is_train:
25
+ self.automatic_optimization = False
26
+ self.save_hyperparameters()
27
+ self.eval_dir = output_dir / 'eval'
28
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ def training_step(self, batch: Dict, batch_idx: int):
31
+
32
+ opt = self.optimizers()
33
+ scheduler = self.lr_schedulers()
34
+ loss, acc = self.model.forward(
35
+ batch['phoneme_ids'], batch['phoneme_ids_len'],
36
+ batch['semantic_ids'], batch['semantic_ids_len'],
37
+ batch['bert_feature'])
38
+ self.manual_backward(loss)
39
+ if batch_idx > 0 and batch_idx % 4 == 0:
40
+ opt.step()
41
+ opt.zero_grad()
42
+ scheduler.step()
43
+
44
+ self.log(
45
+ "total_loss",
46
+ loss,
47
+ on_step=True,
48
+ on_epoch=True,
49
+ prog_bar=True,
50
+ sync_dist=True)
51
+ self.log(
52
+ "lr",
53
+ scheduler.get_last_lr()[0],
54
+ on_epoch=True,
55
+ prog_bar=True,
56
+ sync_dist=True)
57
+ self.log(
58
+ f"top_{self.top_k}_acc",
59
+ acc,
60
+ on_step=True,
61
+ on_epoch=True,
62
+ prog_bar=True,
63
+ sync_dist=True)
64
+
65
+ def validation_step(self, batch: Dict, batch_idx: int):return
66
+ # # get loss
67
+ # loss, acc = self.model.forward(
68
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
69
+ # batch['semantic_ids'], batch['semantic_ids_len'],
70
+ # batch['bert_feature']
71
+ # )
72
+ #
73
+ # self.log(
74
+ # "val_total_loss",
75
+ # loss,
76
+ # on_step=True,
77
+ # on_epoch=True,
78
+ # prog_bar=True,
79
+ # sync_dist=True)
80
+ # self.log(
81
+ # f"val_top_{self.top_k}_acc",
82
+ # acc,
83
+ # on_step=True,
84
+ # on_epoch=True,
85
+ # prog_bar=True,
86
+ # sync_dist=True)
87
+ #
88
+ # # get infer output
89
+ # semantic_len = batch['semantic_ids'].size(1)
90
+ # prompt_len = min(int(semantic_len * 0.5), 150)
91
+ # prompt = batch['semantic_ids'][:, :prompt_len]
92
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
93
+ # batch['phoneme_ids_len'], prompt,
94
+ # batch['bert_feature']
95
+ # )
96
+ # save_name = f'semantic_toks_{batch_idx}.pt'
97
+ # save_path = os.path.join(self.eval_dir, save_name)
98
+ # torch.save(pred_semantic.detach().cpu(), save_path)
99
+
100
+ def configure_optimizers(self):
101
+ model_parameters = self.model.parameters()
102
+ parameters_names = []
103
+ parameters_names.append([
104
+ name_param_pair[0]
105
+ for name_param_pair in self.model.named_parameters()
106
+ ])
107
+ lm_opt = ScaledAdam(
108
+ model_parameters,
109
+ lr=0.01,
110
+ betas=(0.9, 0.95),
111
+ clipping_scale=2.0,
112
+ parameters_names=parameters_names,
113
+ show_dominant_parameters=False,
114
+ clipping_update_period=1000, )
115
+
116
+ return {
117
+ "optimizer": lm_opt,
118
+ "lr_scheduler": {
119
+ "scheduler":
120
+ WarmupCosineLRSchedule(
121
+ lm_opt,
122
+ init_lr=self.config['optimizer']['lr_init'],
123
+ peak_lr=self.config['optimizer']['lr'],
124
+ end_lr=self.config['optimizer']['lr_end'],
125
+ warmup_steps=self.config['optimizer']['warmup_steps'],
126
+ total_steps=self.config['optimizer']['decay_steps'])
127
+ }
128
+ }
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.models.utils import make_pad_mask
6
+ from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
7
+ from AR.modules.embedding import SinePositionalEmbedding
8
+ from AR.modules.embedding import TokenEmbedding
9
+ from AR.modules.transformer import LayerNorm
10
+ from AR.modules.transformer import TransformerEncoder
11
+ from AR.modules.transformer import TransformerEncoderLayer
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torchmetrics.classification import MulticlassAccuracy
15
+
16
+ default_config = {
17
+ "embedding_dim": 512,
18
+ "hidden_dim": 512,
19
+ "num_head": 8,
20
+ "num_layers": 12,
21
+ "num_codebook": 8,
22
+ "p_dropout": 0.0,
23
+ "vocab_size": 1024 + 1,
24
+ "phoneme_vocab_size": 512,
25
+ "EOS": 1024
26
+ }
27
+
28
+
29
+ class Text2SemanticDecoder(nn.Module):
30
+ def __init__(self, config, norm_first=False, top_k=3):
31
+ super(Text2SemanticDecoder, self).__init__()
32
+ self.model_dim = config['model']["hidden_dim"]
33
+ self.embedding_dim = config['model']["embedding_dim"]
34
+ self.num_head = config['model']["head"]
35
+ self.num_layers = config['model']["n_layer"]
36
+ self.norm_first = norm_first
37
+ self.vocab_size = config['model']["vocab_size"]
38
+ self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
39
+ self.p_dropout = config['model']["dropout"]
40
+ self.EOS = config['model']["EOS"]
41
+ self.norm_first = norm_first
42
+ assert self.EOS == self.vocab_size - 1
43
+ # should be same as num of kmeans bin
44
+ # assert self.EOS == 1024
45
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
46
+ self.ar_text_embedding = TokenEmbedding(
47
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
48
+ self.ar_text_position = SinePositionalEmbedding(
49
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True)
50
+ self.ar_audio_embedding = TokenEmbedding(
51
+ self.embedding_dim, self.vocab_size, self.p_dropout)
52
+ self.ar_audio_position = SinePositionalEmbedding(
53
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True)
54
+
55
+ self.h = TransformerEncoder(
56
+ TransformerEncoderLayer(
57
+ d_model=self.model_dim,
58
+ nhead=self.num_head,
59
+ dim_feedforward=self.model_dim * 4,
60
+ dropout=0.1,
61
+ batch_first=True,
62
+ norm_first=norm_first, ),
63
+ num_layers=self.num_layers,
64
+ norm=LayerNorm(self.model_dim) if norm_first else None, )
65
+
66
+ self.ar_predict_layer = nn.Linear(
67
+ self.model_dim, self.vocab_size, bias=False)
68
+ self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
69
+
70
+ self.ar_accuracy_metric = MulticlassAccuracy(
71
+ self.vocab_size,
72
+ top_k=top_k,
73
+ average="micro",
74
+ multidim_average="global",
75
+ ignore_index=self.EOS, )
76
+
77
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
78
+ '''
79
+ x: phoneme_ids
80
+ y: semantic_ids
81
+ '''
82
+ x = self.ar_text_embedding(x)
83
+ x = x + self.bert_proj(bert_feature.transpose(1,2))
84
+ x = self.ar_text_position(x)
85
+ x_mask = make_pad_mask(x_lens)
86
+
87
+ y_mask = make_pad_mask(y_lens)
88
+ y_mask_int = y_mask.type(torch.int64)
89
+ codes = y.type(torch.int64) * (1 - y_mask_int)
90
+
91
+ # Training
92
+ # AR Decoder
93
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
94
+ x_len = x_lens.max()
95
+ y_len = y_lens.max()
96
+ y_emb = self.ar_audio_embedding(y)
97
+ y_pos = self.ar_audio_position(y_emb)
98
+
99
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
100
+ ar_xy_padding_mask = xy_padding_mask
101
+
102
+ x_attn_mask = F.pad(
103
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
104
+ (0, y_len),
105
+ value=True, )
106
+ y_attn_mask = F.pad(
107
+ torch.triu(
108
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
109
+ diagonal=1, ),
110
+ (x_len, 0),
111
+ value=False, )
112
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
113
+ bsz, src_len = x.shape[0], x_len + y_len
114
+ _xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
115
+ .expand(-1, self.num_head, -1, -1)
116
+ .reshape(bsz * self.num_head, 1, src_len))
117
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
118
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
119
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
120
+ xy_attn_mask = new_attn_mask
121
+ # x 和完整的 y 一次性输入模型
122
+ xy_pos = torch.concat([x, y_pos], dim=1)
123
+ xy_dec, _ = self.h(
124
+ (xy_pos, None),
125
+ mask=xy_attn_mask, )
126
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
127
+ # loss
128
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
129
+ loss = F.cross_entropy(logits, targets, reduction='sum')
130
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
131
+ return loss, acc
132
+
133
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
134
+ def infer(self,
135
+ x,
136
+ x_lens,
137
+ prompts,
138
+ bert_feature,
139
+ top_k: int=-100,
140
+ early_stop_num: int=-1,
141
+ temperature: float=1.0):
142
+
143
+ x = self.ar_text_embedding(x)
144
+ x = x + self.bert_proj(bert_feature.transpose(1,2))
145
+ x = self.ar_text_position(x)
146
+
147
+ # AR Decoder
148
+ y = prompts
149
+ prefix_len = y.shape[1]
150
+ x_len = x.shape[1]
151
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
152
+ stop = False
153
+ for _ in tqdm(range(1500)):
154
+ y_emb = self.ar_audio_embedding(y)
155
+ y_pos = self.ar_audio_position(y_emb)
156
+ # x 和逐渐增长的 y 一起输入给模型
157
+ xy_pos = torch.concat([x, y_pos], dim=1)
158
+ y_len = y.shape[1]
159
+ x_attn_mask_pad = F.pad(
160
+ x_attn_mask,
161
+ (0, y_len),
162
+ value=True, )
163
+ y_attn_mask = F.pad(
164
+ torch.triu(
165
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
166
+ (x_len, 0),
167
+ value=False, )
168
+ xy_attn_mask = torch.concat(
169
+ [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
170
+
171
+ xy_dec, _ = self.h(
172
+ (xy_pos, None),
173
+ mask=xy_attn_mask, )
174
+ logits = self.ar_predict_layer(xy_dec[:, -1])
175
+ samples = topk_sampling(
176
+ logits, top_k=top_k, top_p=1.0, temperature=temperature)
177
+
178
+ if early_stop_num != -1 and (y.shape[1] - prefix_len
179
+ ) > early_stop_num:
180
+ print("use early stop num:", early_stop_num)
181
+ stop = True
182
+
183
+ if torch.argmax(
184
+ logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
185
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
186
+ stop = True
187
+ if stop:
188
+ if prompts.shape[1] == y.shape[1]:
189
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
190
+ print('bad zero prediction')
191
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
192
+ break
193
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
194
+ # print(samples.shape)#[1,1]#第一个1是bs
195
+ # import os
196
+ # os._exit(2333)
197
+ y = torch.concat([y, samples], dim=1)
198
+ return y
199
+
200
+ def pad_y_eos(self, y, y_mask_int, eos_id):
201
+ targets = F.pad(
202
+ y, (0, 1), value=0) + eos_id * F.pad(
203
+ y_mask_int, (0, 1), value=1)
204
+ # 错位
205
+ return targets[:, :-1], targets[:, 1:]
206
+
207
+ def infer_panel(self,
208
+ x,#####全部文本token
209
+ x_lens,
210
+ prompts,####参考音频token
211
+ bert_feature,
212
+ top_k: int=-100,
213
+ early_stop_num: int=-1,
214
+ temperature: float=1.0):
215
+
216
+ x = self.ar_text_embedding(x)
217
+ x = x + self.bert_proj(bert_feature.transpose(1,2))
218
+ x = self.ar_text_position(x)
219
+
220
+ # AR Decoder
221
+ y = prompts
222
+ prefix_len = y.shape[1]
223
+ x_len = x.shape[1]
224
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
225
+ stop = False
226
+ # print(1111111,self.num_layers)
227
+ cache={
228
+ "all_stage":self.num_layers,
229
+ "k":[None]*self.num_layers,###根据配置自己手写
230
+ "v":[None]*self.num_layers,
231
+ # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
232
+ "y_emb":None,##只需要对最新的samples求emb,再拼历史的就行
233
+ # "logits":None,###原版就已经只对结尾求再拼接了,不用管
234
+ # "xy_dec":None,###不需要,本来只需要最后一个做logits
235
+ "first_infer":1,
236
+ "stage":0
237
+ }
238
+ for idx in tqdm(range(1500)):
239
+ if(cache["first_infer"]==1):
240
+ y_emb = self.ar_audio_embedding(y)
241
+ else:
242
+ y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
243
+ cache["y_emb"]=y_emb
244
+ y_pos = self.ar_audio_position(y_emb)
245
+ # x 和逐渐增长的 y 一起输入给模型
246
+ if(cache["first_infer"]==1):
247
+ xy_pos = torch.concat([x, y_pos], dim=1)
248
+ else:
249
+ xy_pos=y_pos[:,-1:]
250
+ y_len = y_pos.shape[1]
251
+ ###以下3个不做缓存
252
+ if (cache["first_infer"] == 1):
253
+ x_attn_mask_pad = F.pad(
254
+ x_attn_mask,
255
+ (0, y_len),###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
256
+ value=True, )
257
+ y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
258
+ torch.triu(
259
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
260
+ (x_len, 0),
261
+ value=False, )
262
+ xy_attn_mask = torch.concat(
263
+ [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
264
+ else:
265
+ ###最右边一列(是错的)
266
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
267
+ # xy_attn_mask[:,-1]=False
268
+ ###最下面一行(是对的)
269
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
270
+ # pdb.set_trace()
271
+ ###缓存重头戏
272
+ # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
273
+ xy_dec, _ = self.h(
274
+ (xy_pos, None),
275
+ mask=xy_attn_mask,cache=cache )
276
+ logits = self.ar_predict_layer(xy_dec[:, -1])##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
277
+ # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
278
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
279
+ if early_stop_num != -1 and (y.shape[1] - prefix_len
280
+ ) > early_stop_num:
281
+ print("use early stop num:", early_stop_num)
282
+ stop = True
283
+
284
+ if torch.argmax(
285
+ logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
286
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
287
+ stop = True
288
+ if stop:
289
+ if prompts.shape[1] == y.shape[1]:
290
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
291
+ print('bad zero prediction')
292
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
293
+ break
294
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
295
+ # print(samples.shape)#[1,1]#第一个1是bs
296
+ y = torch.concat([y, samples], dim=1)
297
+ cache["first_infer"]=0
298
+ return y,idx
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+
6
+
7
+ def sequence_mask(length, max_length=None):
8
+ if max_length is None:
9
+ max_length = length.max()
10
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
+ return x.unsqueeze(0) < length.unsqueeze(1)
12
+
13
+
14
+ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ lengths:
18
+ A 1-D tensor containing sentence lengths.
19
+ max_len:
20
+ The length of masks.
21
+ Returns:
22
+ Return a 2-D bool tensor, where masked positions
23
+ are filled with `True` and non-masked positions are
24
+ filled with `False`.
25
+
26
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
27
+ #>>> make_pad_mask(lengths)
28
+ tensor([[False, True, True, True, True],
29
+ [False, False, False, True, True],
30
+ [False, False, True, True, True],
31
+ [False, False, False, False, False]])
32
+ """
33
+ assert lengths.ndim == 1, lengths.ndim
34
+ max_len = max(max_len, lengths.max())
35
+ n = lengths.size(0)
36
+ seq_range = torch.arange(0, max_len, device=lengths.device)
37
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
38
+
39
+ return expaned_lengths >= lengths.unsqueeze(-1)
40
+
41
+
42
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
43
+ def top_k_top_p_filtering(logits,
44
+ top_k=0,
45
+ top_p=1.0,
46
+ filter_value=-float("Inf"),
47
+ min_tokens_to_keep=1):
48
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
49
+ Args:
50
+ logits: logits distribution shape (batch size, vocabulary size)
51
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
52
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
53
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
54
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
55
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
56
+ """
57
+ if top_k > 0:
58
+ top_k = min(max(top_k, min_tokens_to_keep),
59
+ logits.size(-1)) # Safety check
60
+ # Remove all tokens with a probability less than the last token of the top-k
61
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
62
+ logits[indices_to_remove] = filter_value
63
+
64
+ if top_p < 1.0:
65
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
66
+ cumulative_probs = torch.cumsum(
67
+ F.softmax(sorted_logits, dim=-1), dim=-1)
68
+
69
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
70
+ sorted_indices_to_remove = cumulative_probs > top_p
71
+ if min_tokens_to_keep > 1:
72
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
73
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
74
+ # Shift the indices to the right to keep also the first token above the threshold
75
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
76
+ ..., :-1].clone()
77
+ sorted_indices_to_remove[..., 0] = 0
78
+
79
+ # scatter sorted tensors to original indexing
80
+ indices_to_remove = sorted_indices_to_remove.scatter(
81
+ 1, sorted_indices, sorted_indices_to_remove)
82
+ logits[indices_to_remove] = filter_value
83
+ return logits
84
+
85
+
86
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
87
+ # temperature: (`optional`) float
88
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
89
+ # top_k: (`optional`) int
90
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
91
+ # top_p: (`optional`) float
92
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
93
+
94
+ # Temperature (higher temperature => more likely to sample low probability tokens)
95
+ if temperature != 1.0:
96
+ logits = logits / temperature
97
+ # Top-p/top-k filtering
98
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
99
+ # Sample
100
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
101
+ return token
102
+
103
+
104
+ from typing import Optional, Tuple
105
+ def multinomial_sample_one_no_sync(
106
+ probs_sort,
107
+ ): # Does multinomial sampling without a cuda synchronization
108
+ q = torch.empty_like(probs_sort).exponential_(1)
109
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
110
+
111
+
112
+ def logits_to_probs(
113
+ logits,
114
+ previous_tokens: Optional[torch.Tensor] = None,
115
+ temperature: float = 1.0,
116
+ top_k: Optional[int] = None,
117
+ top_p: Optional[int] = None,
118
+ repetition_penalty: float = 1.0,
119
+ ):
120
+ previous_tokens=previous_tokens.squeeze()
121
+ # print(logits.shape,previous_tokens.shape)
122
+ # pdb.set_trace()
123
+ if previous_tokens is not None and repetition_penalty != 1.0:
124
+ previous_tokens = previous_tokens.long()
125
+ score = torch.gather(logits, dim=0, index=previous_tokens)
126
+ score = torch.where(
127
+ score < 0, score * repetition_penalty, score / repetition_penalty
128
+ )
129
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
130
+
131
+ if top_p is not None and top_p < 1.0:
132
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
133
+ cum_probs = torch.cumsum(
134
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
135
+ )
136
+ sorted_indices_to_remove = cum_probs > top_p
137
+ sorted_indices_to_remove[0] = False # keep at least one option
138
+ indices_to_remove = sorted_indices_to_remove.scatter(
139
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
140
+ )
141
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
142
+
143
+ logits = logits / max(temperature, 1e-5)
144
+
145
+ if top_k is not None:
146
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
147
+ pivot = v.select(-1, -1).unsqueeze(-1)
148
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
149
+
150
+ probs = torch.nn.functional.softmax(logits, dim=-1)
151
+ return probs
152
+
153
+
154
+ def sample(
155
+ logits,
156
+ previous_tokens: Optional[torch.Tensor] = None,
157
+ **sampling_kwargs,
158
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ probs = logits_to_probs(
160
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
161
+ )
162
+ idx_next = multinomial_sample_one_no_sync(probs)
163
+ return idx_next, probs
164
+
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/__init__.py ADDED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/activation.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear
7
+ from torch.nn import Module
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.init import xavier_uniform_
11
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
12
+ from torch.nn.parameter import Parameter
13
+
14
+ from torch.nn import functional as F
15
+ from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
16
+ F.multi_head_attention_forward=multi_head_attention_forward_patched
17
+
18
+ class MultiheadAttention(Module):
19
+ r"""Allows the model to jointly attend to information
20
+ from different representation subspaces as described in the paper:
21
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
22
+
23
+ Multi-Head Attention is defined as:
24
+
25
+ .. math::
26
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
27
+
28
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
29
+
30
+ ``forward()`` will use a special optimized implementation if all of the following
31
+ conditions are met:
32
+
33
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
34
+ restriction will be loosened in the future.)
35
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
36
+ - training is disabled (using ``.eval()``)
37
+ - dropout is 0
38
+ - ``add_bias_kv`` is ``False``
39
+ - ``add_zero_attn`` is ``False``
40
+ - ``batch_first`` is ``True`` and the input is batched
41
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
42
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
43
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
44
+ nor ``attn_mask`` is passed
45
+
46
+ If the optimized implementation is in use, a
47
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
48
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
49
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
50
+ will be returned, and an additional speedup proportional to the fraction of the input
51
+ that is padding can be expected.
52
+
53
+ Args:
54
+ embed_dim: Total dimension of the model.
55
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
56
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
57
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
58
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
59
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
60
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
61
+ Default: ``False``.
62
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
63
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
64
+ batch_first: If ``True``, then the input and output tensors are provided
65
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
66
+
67
+ Examples::
68
+
69
+ >>> # xdoctest: +SKIP
70
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
71
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
72
+
73
+ """
74
+ __constants__ = ["batch_first"]
75
+ bias_k: Optional[torch.Tensor]
76
+ bias_v: Optional[torch.Tensor]
77
+
78
+ def __init__(
79
+ self,
80
+ embed_dim,
81
+ num_heads,
82
+ dropout=0.0,
83
+ bias=True,
84
+ add_bias_kv=False,
85
+ add_zero_attn=False,
86
+ kdim=None,
87
+ vdim=None,
88
+ batch_first=False,
89
+ linear1_cls=Linear,
90
+ linear2_cls=Linear,
91
+ device=None,
92
+ dtype=None, ) -> None:
93
+ factory_kwargs = {"device": device, "dtype": dtype}
94
+ super(MultiheadAttention, self).__init__()
95
+ self.embed_dim = embed_dim
96
+ self.kdim = kdim if kdim is not None else embed_dim
97
+ self.vdim = vdim if vdim is not None else embed_dim
98
+ self._qkv_same_embed_dim = (self.kdim == embed_dim and
99
+ self.vdim == embed_dim)
100
+
101
+ self.num_heads = num_heads
102
+ self.dropout = dropout
103
+ self.batch_first = batch_first
104
+ self.head_dim = embed_dim // num_heads
105
+ assert (self.head_dim * num_heads == self.embed_dim
106
+ ), "embed_dim must be divisible by num_heads"
107
+
108
+ if add_bias_kv:
109
+ self.bias_k = Parameter(
110
+ torch.empty((1, 1, embed_dim), **factory_kwargs))
111
+ self.bias_v = Parameter(
112
+ torch.empty((1, 1, embed_dim), **factory_kwargs))
113
+ else:
114
+ self.bias_k = self.bias_v = None
115
+
116
+ if linear1_cls == Linear:
117
+ if not self._qkv_same_embed_dim:
118
+ self.q_proj_weight = Parameter(
119
+ torch.empty((embed_dim, embed_dim), **factory_kwargs))
120
+ self.k_proj_weight = Parameter(
121
+ torch.empty((embed_dim, self.kdim), **factory_kwargs))
122
+ self.v_proj_weight = Parameter(
123
+ torch.empty((embed_dim, self.vdim), **factory_kwargs))
124
+ self.register_parameter("in_proj_weight", None)
125
+ else:
126
+ self.in_proj_weight = Parameter(
127
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
128
+ self.register_parameter("q_proj_weight", None)
129
+ self.register_parameter("k_proj_weight", None)
130
+ self.register_parameter("v_proj_weight", None)
131
+
132
+ if bias:
133
+ self.in_proj_bias = Parameter(
134
+ torch.empty(3 * embed_dim, **factory_kwargs))
135
+ else:
136
+ self.register_parameter("in_proj_bias", None)
137
+ self.out_proj = NonDynamicallyQuantizableLinear(
138
+ embed_dim, embed_dim, bias=bias, **factory_kwargs)
139
+
140
+ self._reset_parameters()
141
+ else:
142
+ if not self._qkv_same_embed_dim:
143
+ raise NotImplementedError
144
+ else:
145
+ self.in_proj_linear = linear1_cls(
146
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
147
+ self.in_proj_weight = self.in_proj_linear.weight
148
+
149
+ self.register_parameter("q_proj_weight", None)
150
+ self.register_parameter("k_proj_weight", None)
151
+ self.register_parameter("v_proj_weight", None)
152
+
153
+ if bias:
154
+ self.in_proj_bias = self.in_proj_linear.bias
155
+ else:
156
+ self.register_parameter("in_proj_bias", None)
157
+
158
+ self.out_proj = linear2_cls(
159
+ embed_dim, embed_dim, bias=bias, **factory_kwargs)
160
+
161
+ if self.bias_k is not None:
162
+ xavier_normal_(self.bias_k)
163
+ if self.bias_v is not None:
164
+ xavier_normal_(self.bias_v)
165
+
166
+ self.add_zero_attn = add_zero_attn
167
+
168
+ def _reset_parameters(self):
169
+ if self._qkv_same_embed_dim:
170
+ xavier_uniform_(self.in_proj_weight)
171
+ else:
172
+ xavier_uniform_(self.q_proj_weight)
173
+ xavier_uniform_(self.k_proj_weight)
174
+ xavier_uniform_(self.v_proj_weight)
175
+
176
+ if self.in_proj_bias is not None:
177
+ constant_(self.in_proj_bias, 0.0)
178
+ constant_(self.out_proj.bias, 0.0)
179
+
180
+ if self.bias_k is not None:
181
+ xavier_normal_(self.bias_k)
182
+ if self.bias_v is not None:
183
+ xavier_normal_(self.bias_v)
184
+
185
+ def __setstate__(self, state):
186
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
187
+ if "_qkv_same_embed_dim" not in state:
188
+ state["_qkv_same_embed_dim"] = True
189
+
190
+ super(MultiheadAttention, self).__setstate__(state)
191
+
192
+ def forward(
193
+ self,
194
+ query: Tensor,
195
+ key: Tensor,
196
+ value: Tensor,
197
+ key_padding_mask: Optional[Tensor]=None,
198
+ need_weights: bool=True,
199
+ attn_mask: Optional[Tensor]=None,
200
+ average_attn_weights: bool=True,cache=None
201
+ ) -> Tuple[Tensor, Optional[Tensor]]:
202
+ r"""
203
+ Args:
204
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
205
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
206
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
207
+ Queries are compared against key-value pairs to produce the output.
208
+ See "Attention Is All You Need" for more details.
209
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
210
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
211
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
212
+ See "Attention Is All You Need" for more details.
213
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
214
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
215
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
216
+ See "Attention Is All You Need" for more details.
217
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
218
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
219
+ Binary and byte masks are supported.
220
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
221
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
222
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
223
+ Default: ``True``.
224
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
225
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
226
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
227
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
228
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
229
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
230
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
231
+ the attention weight.
232
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
233
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
234
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
235
+
236
+ Outputs:
237
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
238
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
239
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
240
+ embedding dimension ``embed_dim``.
241
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
242
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
243
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
244
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
245
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
246
+
247
+ .. note::
248
+ `batch_first` argument is ignored for unbatched inputs.
249
+ """
250
+ is_batched = query.dim() == 3
251
+ if key_padding_mask is not None:
252
+ _kpm_dtype = key_padding_mask.dtype
253
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
254
+ key_padding_mask):
255
+ raise AssertionError(
256
+ "only bool and floating types of key_padding_mask are supported"
257
+ )
258
+ why_not_fast_path = ""
259
+ if not is_batched:
260
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
261
+ elif query is not key or key is not value:
262
+ # When lifting this restriction, don't forget to either
263
+ # enforce that the dtypes all match or test cases where
264
+ # they don't!
265
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
266
+ elif (self.in_proj_bias is not None and
267
+ query.dtype != self.in_proj_bias.dtype):
268
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
269
+ elif (self.in_proj_weight is not None and
270
+ query.dtype != self.in_proj_weight.dtype):
271
+ # this case will fail anyway, but at least they'll get a useful error message.
272
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
273
+ elif self.training:
274
+ why_not_fast_path = "training is enabled"
275
+ elif not self.batch_first:
276
+ why_not_fast_path = "batch_first was not True"
277
+ elif self.bias_k is not None:
278
+ why_not_fast_path = "self.bias_k was not None"
279
+ elif self.bias_v is not None:
280
+ why_not_fast_path = "self.bias_v was not None"
281
+ elif self.dropout:
282
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
283
+ elif self.add_zero_attn:
284
+ why_not_fast_path = "add_zero_attn was enabled"
285
+ elif not self._qkv_same_embed_dim:
286
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
287
+ elif attn_mask is not None:
288
+ why_not_fast_path = "attn_mask was not None"
289
+ elif query.is_nested and key_padding_mask is not None:
290
+ why_not_fast_path = (
291
+ "key_padding_mask is not supported with NestedTensor input")
292
+ elif self.num_heads % 2 == 1:
293
+ why_not_fast_path = "num_heads is odd"
294
+ elif torch.is_autocast_enabled():
295
+ why_not_fast_path = "autocast is enabled"
296
+
297
+ if not why_not_fast_path:
298
+ tensor_args = (query, key, value, self.in_proj_weight,
299
+ self.in_proj_bias, self.out_proj.weight,
300
+ self.out_proj.bias, )
301
+ # We have to use list comprehensions below because TorchScript does not support
302
+ # generator expressions.
303
+ if torch.overrides.has_torch_function(tensor_args):
304
+ why_not_fast_path = "some Tensor argument has_torch_function"
305
+ elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
306
+ for x in tensor_args]):
307
+ why_not_fast_path = (
308
+ "some Tensor argument is neither CUDA nor CPU")
309
+ elif torch.is_grad_enabled() and any(
310
+ [x is not None and x.requires_grad for x in tensor_args]):
311
+ why_not_fast_path = (
312
+ "grad is enabled and at least one of query or the "
313
+ "input/output projection weights or biases requires_grad")
314
+ if not why_not_fast_path:
315
+ return torch._native_multi_head_attention(
316
+ query,
317
+ key,
318
+ value,
319
+ self.embed_dim,
320
+ self.num_heads,
321
+ self.in_proj_weight,
322
+ self.in_proj_bias,
323
+ self.out_proj.weight,
324
+ self.out_proj.bias,
325
+ key_padding_mask
326
+ if key_padding_mask is not None else attn_mask,
327
+ need_weights,
328
+ average_attn_weights,
329
+ 1 if key_padding_mask is not None else 0
330
+ if attn_mask is not None else None, )
331
+
332
+ any_nested = query.is_nested or key.is_nested or value.is_nested
333
+ assert not any_nested, (
334
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
335
+ + f"The fast path was not hit because {why_not_fast_path}")
336
+
337
+ if self.batch_first and is_batched:
338
+ # make sure that the transpose op does not affect the "is" property
339
+ if key is value:
340
+ if query is key:
341
+ query = key = value = query.transpose(1, 0)
342
+ else:
343
+ query, key = [x.transpose(1, 0) for x in (query, key)]
344
+ value = key
345
+ else:
346
+ query, key, value = [
347
+ x.transpose(1, 0) for x in (query, key, value)
348
+ ]
349
+
350
+ if not self._qkv_same_embed_dim:
351
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
352
+ query,
353
+ key,
354
+ value,
355
+ self.embed_dim,
356
+ self.num_heads,
357
+ self.in_proj_weight,
358
+ self.in_proj_bias,
359
+ self.bias_k,
360
+ self.bias_v,
361
+ self.add_zero_attn,
362
+ self.dropout,
363
+ self.out_proj.weight,
364
+ self.out_proj.bias,
365
+ training=self.training,
366
+ key_padding_mask=key_padding_mask,
367
+ need_weights=need_weights,
368
+ attn_mask=attn_mask,
369
+ use_separate_proj_weight=True,
370
+ q_proj_weight=self.q_proj_weight,
371
+ k_proj_weight=self.k_proj_weight,
372
+ v_proj_weight=self.v_proj_weight,
373
+ average_attn_weights=average_attn_weights,cache=cache )
374
+ else:
375
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
376
+ query,
377
+ key,
378
+ value,
379
+ self.embed_dim,
380
+ self.num_heads,
381
+ self.in_proj_weight,
382
+ self.in_proj_bias,
383
+ self.bias_k,
384
+ self.bias_v,
385
+ self.add_zero_attn,
386
+ self.dropout,
387
+ self.out_proj.weight,
388
+ self.out_proj.bias,
389
+ training=self.training,
390
+ key_padding_mask=key_padding_mask,
391
+ need_weights=need_weights,
392
+ attn_mask=attn_mask,
393
+ average_attn_weights=average_attn_weights,cache=cache )
394
+ if self.batch_first and is_batched:
395
+ return attn_output.transpose(1, 0), attn_output_weights
396
+ else:
397
+ return attn_output, attn_output_weights
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/embedding.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float=0.0, ):
14
+ super().__init__()
15
+
16
+ self.vocab_size = vocab_size
17
+ self.embedding_dim = embedding_dim
18
+
19
+ self.dropout = torch.nn.Dropout(p=dropout)
20
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
21
+
22
+ @property
23
+ def weight(self) -> torch.Tensor:
24
+ return self.word_embeddings.weight
25
+
26
+ def embedding(self, index: int) -> torch.Tensor:
27
+ return self.word_embeddings.weight[index:index + 1]
28
+
29
+ def forward(self, x: torch.Tensor):
30
+ x = self.word_embeddings(x)
31
+ x = self.dropout(x)
32
+ return x
33
+
34
+
35
+ class SinePositionalEmbedding(nn.Module):
36
+ def __init__(
37
+ self,
38
+ embedding_dim: int,
39
+ dropout: float=0.0,
40
+ scale: bool=False,
41
+ alpha: bool=False, ):
42
+ super().__init__()
43
+ self.embedding_dim = embedding_dim
44
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
45
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
46
+ self.dropout = torch.nn.Dropout(p=dropout)
47
+
48
+ self.reverse = False
49
+ self.pe = None
50
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
51
+
52
+ def extend_pe(self, x):
53
+ """Reset the positional encodings."""
54
+ if self.pe is not None:
55
+ if self.pe.size(1) >= x.size(1):
56
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
57
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
58
+ return
59
+ pe = torch.zeros(x.size(1), self.embedding_dim)
60
+ if self.reverse:
61
+ position = torch.arange(
62
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
63
+ else:
64
+ position = torch.arange(
65
+ 0, x.size(1), dtype=torch.float32).unsqueeze(1)
66
+ div_term = torch.exp(
67
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
68
+ -(math.log(10000.0) / self.embedding_dim))
69
+ pe[:, 0::2] = torch.sin(position * div_term)
70
+ pe[:, 1::2] = torch.cos(position * div_term)
71
+ pe = pe.unsqueeze(0)
72
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ self.extend_pe(x)
76
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
77
+ output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
78
+ return self.dropout(output)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/lr_schedulers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
2
+ import math
3
+
4
+ import torch
5
+ from matplotlib import pyplot as plt
6
+ from torch import nn
7
+ from torch.optim import Adam
8
+
9
+
10
+ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
11
+ """
12
+ Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
13
+ """
14
+
15
+ def __init__(self,
16
+ optimizer,
17
+ init_lr,
18
+ peak_lr,
19
+ end_lr,
20
+ warmup_steps=10000,
21
+ total_steps=400000,
22
+ current_step=0):
23
+ self.init_lr = init_lr
24
+ self.peak_lr = peak_lr
25
+ self.end_lr = end_lr
26
+ self.optimizer = optimizer
27
+ self._warmup_rate = (peak_lr - init_lr) / warmup_steps
28
+ self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
29
+ self._current_step = current_step
30
+ self.lr = init_lr
31
+ self.warmup_steps = warmup_steps
32
+ self.total_steps = total_steps
33
+ self._last_lr = [self.lr]
34
+
35
+ def set_lr(self, lr):
36
+ self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
37
+ for g in self.optimizer.param_groups:
38
+ # g['lr'] = lr
39
+ g['lr'] = self.end_lr###锁定用线性
40
+
41
+ def step(self):
42
+ if self._current_step < self.warmup_steps:
43
+ lr = self.init_lr + self._warmup_rate * self._current_step
44
+
45
+ elif self._current_step > self.total_steps:
46
+ lr = self.end_lr
47
+
48
+ else:
49
+ decay_ratio = (self._current_step - self.warmup_steps) / (
50
+ self.total_steps - self.warmup_steps)
51
+ if decay_ratio < 0.0 or decay_ratio > 1.0:
52
+ raise RuntimeError(
53
+ "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
54
+ )
55
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
56
+ lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
57
+
58
+ self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
59
+ self.set_lr(lr)
60
+ self.lr = lr
61
+ self._current_step += 1
62
+ return self.lr
63
+
64
+
65
+
66
+ if __name__ == '__main__':
67
+ m = nn.Linear(10, 10)
68
+ opt = Adam(m.parameters(), lr=1e-4)
69
+ s = WarmupCosineLRSchedule(
70
+ opt,
71
+ 1e-6,
72
+ 2e-4,
73
+ 1e-6,
74
+ warmup_steps=2000,
75
+ total_steps=20000,
76
+ current_step=0)
77
+ lrs = []
78
+ for i in range(25000):
79
+ s.step()
80
+ lrs.append(s.lr)
81
+ print(s.lr)
82
+
83
+ plt.plot(lrs)
84
+ plt.plot(range(0, 25000), lrs)
85
+ plt.show()
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/optim.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import logging
18
+ from collections import defaultdict
19
+ from typing import List
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import Tensor
24
+ from torch.optim import Optimizer
25
+
26
+
27
+ class BatchedOptimizer(Optimizer):
28
+ """
29
+ This class adds to class Optimizer the capability to optimize parameters in batches:
30
+ it will stack the parameters and their grads for you so the optimizer can work
31
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
32
+ as it reduces the number of kernels launched in the optimizer.
33
+
34
+ Args:
35
+ params:
36
+ """
37
+
38
+ def __init__(self, params, defaults):
39
+ super(BatchedOptimizer, self).__init__(params, defaults)
40
+
41
+ @contextlib.contextmanager
42
+ def batched_params(self, param_group, group_params_names):
43
+ """
44
+ This function returns (technically, yields) a list of
45
+ of tuples (p, state), where
46
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
47
+ that share the same shape, and its gradient is also stacked;
48
+ `state` is the state corresponding to this batch of parameters
49
+ (it will be physically located in the "state" for one of the real
50
+ parameters, the last one that has any particular shape and dtype).
51
+
52
+ This function is decorated as a context manager so that it can
53
+ write parameters back to their "real" locations.
54
+
55
+ The idea is, instead of doing:
56
+ <code>
57
+ for p in group["params"]:
58
+ state = self.state[p]
59
+ ...
60
+ </code>
61
+ you can do:
62
+ <code>
63
+ with self.batched_params(group["params"]) as batches:
64
+ for p, state, p_names in batches:
65
+ ...
66
+ </code>
67
+
68
+ Args:
69
+ group: a parameter group, which is a list of parameters; should be
70
+ one of self.param_groups.
71
+ group_params_names: name for each parameter in group,
72
+ which is List[str].
73
+ """
74
+ batches = defaultdict(
75
+ list
76
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
77
+ batches_names = defaultdict(
78
+ list
79
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
80
+
81
+ assert len(param_group) == len(group_params_names)
82
+ for p, named_p in zip(param_group, group_params_names):
83
+ key = (str(p.dtype), *p.shape)
84
+ batches[key].append(p)
85
+ batches_names[key].append(named_p)
86
+
87
+ batches_names_keys = list(batches_names.keys())
88
+ sorted_idx = sorted(
89
+ range(len(batches_names)), key=lambda i: batches_names_keys[i])
90
+ batches_names = [
91
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
92
+ ]
93
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
94
+
95
+ stacked_params_dict = dict()
96
+
97
+ # turn batches into a list, in deterministic order.
98
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
99
+ # one for each batch in `batches`.
100
+ tuples = []
101
+
102
+ for batch, batch_names in zip(batches, batches_names):
103
+ p = batch[0]
104
+ # we arbitrarily store the state in the
105
+ # state corresponding to the 1st parameter in the
106
+ # group. class Optimizer will take care of saving/loading state.
107
+ state = self.state[p]
108
+ p_stacked = torch.stack(batch)
109
+ grad = torch.stack([
110
+ torch.zeros_like(p) if p.grad is None else p.grad for p in batch
111
+ ])
112
+ p_stacked.grad = grad
113
+ stacked_params_dict[key] = p_stacked
114
+ tuples.append((p_stacked, state, batch_names))
115
+
116
+ yield tuples # <-- calling code will do the actual optimization here!
117
+
118
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
119
+ for i, p in enumerate(batch): # batch is list of Parameter
120
+ p.copy_(stacked_params[i])
121
+
122
+
123
+ class ScaledAdam(BatchedOptimizer):
124
+ """
125
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
126
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
127
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
128
+ param = underlying_param * log_scale.exp())
129
+
130
+
131
+ Args:
132
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
133
+ lr: The learning rate. We will typically use a learning rate schedule that starts
134
+ at 0.03 and decreases over time, i.e. much higher than other common
135
+ optimizers.
136
+ clipping_scale: (e.g. 2.0)
137
+ A scale for gradient-clipping: if specified, the normalized gradients
138
+ over the whole model will be clipped to have 2-norm equal to
139
+ `clipping_scale` times the median 2-norm over the most recent period
140
+ of `clipping_update_period` minibatches. By "normalized gradients",
141
+ we mean after multiplying by the rms parameter value for this tensor
142
+ [for non-scalars]; this is appropriate because our update is scaled
143
+ by this quantity.
144
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
145
+ Must satisfy 0 < beta <= beta2 < 1.
146
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
147
+ scale of each parameter tensor and scalar parameters of the mode..
148
+ If each parameter were decomposed
149
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
150
+ would be a the scaling factor on the learning rate of p_scale.
151
+ eps: A general-purpose epsilon to prevent division by zero
152
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
153
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
154
+ parameter tensor to be >= this value)
155
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
156
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
157
+ parameter tensor to be <= this value)
158
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
159
+ model has any parameters with numel() == 1).
160
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
161
+ of the parameter tensor. This is provided to save a little time
162
+ in the update.
163
+ clipping_update_period: if clipping_scale is specified, this is the period
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ params,
169
+ lr=3e-02,
170
+ clipping_scale=None,
171
+ betas=(0.9, 0.98),
172
+ scalar_lr_scale=0.1,
173
+ eps=1.0e-08,
174
+ param_min_rms=1.0e-05,
175
+ param_max_rms=3.0,
176
+ scalar_max=10.0,
177
+ size_update_period=4,
178
+ clipping_update_period=100,
179
+ parameters_names=None,
180
+ show_dominant_parameters=True, ):
181
+
182
+ assert parameters_names is not None, (
183
+ "Please prepare parameters_names,"
184
+ "which is a List[List[str]]. Each List[str] is for a group"
185
+ "and each str is for a parameter")
186
+ defaults = dict(
187
+ lr=lr,
188
+ clipping_scale=clipping_scale,
189
+ betas=betas,
190
+ scalar_lr_scale=scalar_lr_scale,
191
+ eps=eps,
192
+ param_min_rms=param_min_rms,
193
+ param_max_rms=param_max_rms,
194
+ scalar_max=scalar_max,
195
+ size_update_period=size_update_period,
196
+ clipping_update_period=clipping_update_period, )
197
+
198
+ super(ScaledAdam, self).__init__(params, defaults)
199
+ assert len(self.param_groups) == len(parameters_names)
200
+ self.parameters_names = parameters_names
201
+ self.show_dominant_parameters = show_dominant_parameters
202
+
203
+ def __setstate__(self, state):
204
+ super(ScaledAdam, self).__setstate__(state)
205
+
206
+ @torch.no_grad()
207
+ def step(self, closure=None):
208
+ """Performs a single optimization step.
209
+
210
+ Arguments:
211
+ closure (callable, optional): A closure that reevaluates the model
212
+ and returns the loss.
213
+ """
214
+ loss = None
215
+ if closure is not None:
216
+ with torch.enable_grad():
217
+ loss = closure()
218
+
219
+ batch = True
220
+
221
+ for group, group_params_names in zip(self.param_groups,
222
+ self.parameters_names):
223
+
224
+ with self.batched_params(group["params"],
225
+ group_params_names) as batches:
226
+
227
+ # batches is list of pairs (stacked_param, state). stacked_param is like
228
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
229
+ # a stacking dim, it is not a real dim.
230
+
231
+ if (len(batches[0][1]) ==
232
+ 0): # if len(first state) == 0: not yet initialized
233
+ clipping_scale = 1
234
+ else:
235
+ clipping_scale = self._get_clipping_scale(group, batches)
236
+
237
+ for p, state, _ in batches:
238
+ # Perform optimization step.
239
+ # grad is not going to be None, we handled that when creating the batches.
240
+ grad = p.grad
241
+ if grad.is_sparse:
242
+ raise RuntimeError(
243
+ "ScaledAdam optimizer does not support sparse gradients"
244
+ )
245
+ # State initialization
246
+ if len(state) == 0:
247
+ self._init_state(group, p, state)
248
+
249
+ self._step_one_batch(group, p, state, clipping_scale)
250
+
251
+ return loss
252
+
253
+ def _init_state(self, group: dict, p: Tensor, state: dict):
254
+ """
255
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
256
+ is actually the batch dimension, corresponding to batched-together
257
+ parameters of a given shape.
258
+
259
+
260
+ Args:
261
+ group: Dict to look up configuration values.
262
+ p: The parameter that we are initializing the state for
263
+ state: Dict from string to whatever state we are initializing
264
+ """
265
+ size_update_period = group["size_update_period"]
266
+
267
+ state["step"] = 0
268
+
269
+ kwargs = {"device": p.device, "dtype": p.dtype}
270
+
271
+ # 'delta' implements conventional momentum. There are
272
+ # several different kinds of update going on, so rather than
273
+ # compute "exp_avg" like in Adam, we store and decay a
274
+ # parameter-change "delta", which combines all forms of
275
+ # update. this is equivalent to how it's done in Adam,
276
+ # except for the first few steps.
277
+ state["delta"] = torch.zeros_like(
278
+ p, memory_format=torch.preserve_format)
279
+
280
+ batch_size = p.shape[0]
281
+ numel = p.numel() // batch_size
282
+ numel = p.numel()
283
+
284
+ if numel > 1:
285
+ # "param_rms" just periodically records the scalar root-mean-square value of
286
+ # the parameter tensor.
287
+ # it has a shape like (batch_size, 1, 1, 1, 1)
288
+ param_rms = (
289
+ (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
290
+ state["param_rms"] = param_rms
291
+
292
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
293
+ state["scale_grads"] = torch.zeros(size_update_period,
294
+ *param_rms.shape, **kwargs)
295
+
296
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
297
+ state["exp_avg_sq"] = torch.zeros_like(
298
+ p, memory_format=torch.preserve_format)
299
+
300
+ def _get_clipping_scale(self,
301
+ group: dict,
302
+ tuples: List[Tuple[Tensor, dict, List[str]]]
303
+ ) -> float:
304
+ """
305
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
306
+ by this amount before applying the rest of the update.
307
+
308
+ Args:
309
+ group: the parameter group, an item in self.param_groups
310
+ tuples: a list of tuples of (param, state, param_names)
311
+ where param is a batched set of parameters,
312
+ with a .grad (1st dim is batch dim)
313
+ and state is the state-dict where optimization parameters are kept.
314
+ param_names is a List[str] while each str is name for a parameter
315
+ in batched set of parameters "param".
316
+ """
317
+ assert len(tuples) >= 1
318
+ clipping_scale = group["clipping_scale"]
319
+ (first_p, first_state, _) = tuples[0]
320
+ step = first_state["step"]
321
+ if clipping_scale is None or step == 0:
322
+ # no clipping. return early on step == 0 because the other
323
+ # parameters' state won't have been initialized yet.
324
+ return 1.0
325
+ clipping_update_period = group["clipping_update_period"]
326
+
327
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
328
+ for (p, state, param_names) in tuples:
329
+ grad = p.grad
330
+ if grad.is_sparse:
331
+ raise RuntimeError(
332
+ "ScaledAdam optimizer does not support sparse gradients")
333
+ if p.numel() == p.shape[0]: # a batch of scalars
334
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
335
+ else:
336
+ tot_sumsq += ((grad * state["param_rms"])**2).sum()
337
+
338
+ tot_norm = tot_sumsq.sqrt()
339
+ if "model_norms" not in first_state:
340
+ first_state["model_norms"] = torch.zeros(
341
+ clipping_update_period, device=p.device)
342
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
343
+
344
+ if step % clipping_update_period == 0:
345
+ # Print some stats.
346
+ # We don't reach here if step == 0 because we would have returned
347
+ # above.
348
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
349
+ quartiles = []
350
+ for n in range(0, 5):
351
+ index = min(
352
+ clipping_update_period - 1,
353
+ (clipping_update_period // 4) * n, )
354
+ quartiles.append(sorted_norms[index].item())
355
+
356
+ median = quartiles[2]
357
+ threshold = clipping_scale * median
358
+ first_state["model_norm_threshold"] = threshold
359
+ percent_clipped = (first_state["num_clipped"] * 100.0 /
360
+ clipping_update_period
361
+ if "num_clipped" in first_state else 0.0)
362
+ first_state["num_clipped"] = 0
363
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
364
+ logging.info(
365
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
366
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
367
+ )
368
+
369
+ if step < clipping_update_period:
370
+ return 1.0 # We have not yet estimated a norm to clip to.
371
+ else:
372
+ try:
373
+ model_norm_threshold = first_state["model_norm_threshold"]
374
+ except KeyError:
375
+ logging.info(
376
+ "Warning: model_norm_threshold not in state: possibly "
377
+ "you changed config when restarting, adding clipping_scale option?"
378
+ )
379
+ return 1.0
380
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
381
+ if ans < 1.0:
382
+ first_state["num_clipped"] += 1
383
+ if ans < 0.1:
384
+ logging.warn(
385
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
386
+ )
387
+ if self.show_dominant_parameters:
388
+ assert p.shape[0] == len(param_names)
389
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
390
+ return ans
391
+
392
+ def _show_gradient_dominating_parameter(
393
+ self, tuples: List[Tuple[Tensor, dict, List[str]]],
394
+ tot_sumsq: Tensor):
395
+ """
396
+ Show information of parameter wihch dominanting tot_sumsq.
397
+
398
+ Args:
399
+ tuples: a list of tuples of (param, state, param_names)
400
+ where param is a batched set of parameters,
401
+ with a .grad (1st dim is batch dim)
402
+ and state is the state-dict where optimization parameters are kept.
403
+ param_names is a List[str] while each str is name for a parameter
404
+ in batched set of parameters "param".
405
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
406
+ from tuples, we still pass it to save some time.
407
+ """
408
+ all_sumsq_orig = {}
409
+ for (p, state, batch_param_names) in tuples:
410
+ # p is a stacked batch parameters.
411
+ batch_grad = p.grad
412
+ if p.numel() == p.shape[0]: # a batch of scalars
413
+ batch_sumsq_orig = batch_grad**2
414
+ # Dummpy values used by following `zip` statement.
415
+ batch_rms_orig = torch.ones(p.shape[0])
416
+ else:
417
+ batch_rms_orig = state["param_rms"]
418
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
419
+ dim=list(range(1, batch_grad.ndim)))
420
+
421
+ for name, sumsq_orig, rms, grad in zip(batch_param_names,
422
+ batch_sumsq_orig,
423
+ batch_rms_orig, batch_grad):
424
+
425
+ proportion_orig = sumsq_orig / tot_sumsq
426
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
427
+
428
+ assert torch.isclose(
429
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
430
+ torch.tensor(1.0), )
431
+ sorted_by_proportion = {
432
+ k: v
433
+ for k, v in sorted(
434
+ all_sumsq_orig.items(),
435
+ key=lambda item: item[1][0],
436
+ reverse=True, )
437
+ }
438
+ dominant_param_name = next(iter(sorted_by_proportion))
439
+ (dominant_proportion, dominant_sumsq, dominant_rms,
440
+ dominant_grad, ) = sorted_by_proportion[dominant_param_name]
441
+ logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
442
+ f" with proportion {dominant_proportion:.2f},"
443
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
444
+ f"={dominant_sumsq:.3e},"
445
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
446
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}")
447
+
448
+ def _step_one_batch(self,
449
+ group: dict,
450
+ p: Tensor,
451
+ state: dict,
452
+ clipping_scale: float):
453
+ """
454
+ Do the step for one parameter, which is actually going to be a batch of
455
+ `real` parameters, with dim 0 as the batch dim.
456
+ Args:
457
+ group: dict to look up configuration values
458
+ p: parameter to update (actually multiple parameters stacked together
459
+ as a batch)
460
+ state: state-dict for p, to look up the optimizer state
461
+ """
462
+ lr = group["lr"]
463
+ size_update_period = group["size_update_period"]
464
+ beta1 = group["betas"][0]
465
+
466
+ grad = p.grad
467
+ if clipping_scale != 1.0:
468
+ grad = grad * clipping_scale
469
+ step = state["step"]
470
+ delta = state["delta"]
471
+
472
+ delta.mul_(beta1)
473
+ batch_size = p.shape[0]
474
+ numel = p.numel() // batch_size
475
+ if numel > 1:
476
+ # Update the size/scale of p, and set param_rms
477
+ scale_grads = state["scale_grads"]
478
+ scale_grads[step % size_update_period] = (p * grad).sum(
479
+ dim=list(range(1, p.ndim)), keepdim=True)
480
+ if step % size_update_period == size_update_period - 1:
481
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
482
+ param_rms.copy_((p**2)
483
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
484
+ .sqrt())
485
+ if step > 0:
486
+ # self._size_update() learns the overall scale on the
487
+ # parameter, by shrinking or expanding it.
488
+ self._size_update(group, scale_grads, p, state)
489
+
490
+ if numel == 1:
491
+ # For parameters with 1 element we just use regular Adam.
492
+ # Updates delta.
493
+ self._step_scalar(group, p, state)
494
+ else:
495
+ self._step(group, p, state)
496
+
497
+ state["step"] = step + 1
498
+
499
+ def _size_update(self,
500
+ group: dict,
501
+ scale_grads: Tensor,
502
+ p: Tensor,
503
+ state: dict) -> None:
504
+ """
505
+ Called only where p.numel() > 1, this updates the scale of the parameter.
506
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
507
+ gradient descent on underlying param and on scale, this function does the update
508
+ on `scale`.
509
+
510
+ Args:
511
+ group: dict to look up configuration values
512
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
513
+ grads w.r.t. the scales.
514
+ p: The parameter to update
515
+ state: The state-dict of p
516
+ """
517
+
518
+ param_rms = state["param_rms"]
519
+ beta1, beta2 = group["betas"]
520
+ size_lr = group["lr"] * group["scalar_lr_scale"]
521
+ param_min_rms = group["param_min_rms"]
522
+ param_max_rms = group["param_max_rms"]
523
+ eps = group["eps"]
524
+ step = state["step"]
525
+ batch_size = p.shape[0]
526
+
527
+ size_update_period = scale_grads.shape[0]
528
+ # correct beta2 for the size update period: we will have
529
+ # faster decay at this level.
530
+ beta2_corr = beta2**size_update_period
531
+
532
+ scale_exp_avg_sq = state[
533
+ "scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
534
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
535
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
536
+ alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
537
+
538
+ # The 1st time we reach here is when size_step == 1.
539
+ size_step = (step + 1) // size_update_period
540
+ bias_correction2 = 1 - beta2_corr**size_step
541
+ # we don't bother with bias_correction1; this will help prevent divergence
542
+ # at the start of training.
543
+
544
+ denom = scale_exp_avg_sq.sqrt() + eps
545
+
546
+ scale_step = (-size_lr * (bias_correction2**0.5) *
547
+ scale_grads.sum(dim=0) / denom)
548
+
549
+ is_too_small = param_rms < param_min_rms
550
+ is_too_large = param_rms > param_max_rms
551
+
552
+ # when the param gets too small, just don't shrink it any further.
553
+ scale_step.masked_fill_(is_too_small, 0.0)
554
+ # when it gets too large, stop it from getting any larger.
555
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
556
+ delta = state["delta"]
557
+ # the factor of (1-beta1) relates to momentum.
558
+ delta.add_(p * scale_step, alpha=(1 - beta1))
559
+
560
+ def _step(self, group: dict, p: Tensor, state: dict):
561
+ """
562
+ This function does the core update of self.step(), in the case where the members of
563
+ the batch have more than 1 element.
564
+
565
+ Args:
566
+ group: A dict which will be used to look up configuration values
567
+ p: The parameter to be updated
568
+ grad: The grad of p
569
+ state: The state-dict corresponding to parameter p
570
+
571
+ This function modifies p.
572
+ """
573
+ grad = p.grad
574
+ lr = group["lr"]
575
+ beta1, beta2 = group["betas"]
576
+ eps = group["eps"]
577
+ param_min_rms = group["param_min_rms"]
578
+ step = state["step"]
579
+
580
+ exp_avg_sq = state["exp_avg_sq"]
581
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
582
+
583
+ this_step = state["step"] - (state["zero_step"]
584
+ if "zero_step" in state else 0)
585
+ bias_correction2 = 1 - beta2**(this_step + 1)
586
+ if bias_correction2 < 0.99:
587
+ # note: not in-place.
588
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
589
+
590
+ denom = exp_avg_sq.sqrt()
591
+ denom += eps
592
+ grad = grad / denom
593
+
594
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
595
+
596
+ delta = state["delta"]
597
+ delta.add_(grad * alpha)
598
+ p.add_(delta)
599
+
600
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
601
+ """
602
+ A simplified form of the core update for scalar tensors, where we cannot get a good
603
+ estimate of the parameter rms.
604
+ """
605
+ beta1, beta2 = group["betas"]
606
+ scalar_max = group["scalar_max"]
607
+ eps = group["eps"]
608
+ lr = group["lr"] * group["scalar_lr_scale"]
609
+ grad = p.grad
610
+
611
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
612
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
613
+
614
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
615
+ # slower update at the start will help stability anyway.
616
+ bias_correction2 = 1 - beta2**(state["step"] + 1)
617
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
618
+
619
+ delta = state["delta"]
620
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
621
+ p.clamp_(min=-scalar_max, max=scalar_max)
622
+ p.add_(delta)
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
3
+ # import torch
4
+ # Tensor = torch.Tensor
5
+ # from typing import Callable, List, Optional, Tuple, Union
6
+
7
+ def multi_head_attention_forward_patched(
8
+ query: Tensor,
9
+ key: Tensor,
10
+ value: Tensor,
11
+ embed_dim_to_check: int,
12
+ num_heads: int,
13
+ in_proj_weight: Optional[Tensor],
14
+ in_proj_bias: Optional[Tensor],
15
+ bias_k: Optional[Tensor],
16
+ bias_v: Optional[Tensor],
17
+ add_zero_attn: bool,
18
+ dropout_p: float,
19
+ out_proj_weight: Tensor,
20
+ out_proj_bias: Optional[Tensor],
21
+ training: bool = True,
22
+ key_padding_mask: Optional[Tensor] = None,
23
+ need_weights: bool = True,
24
+ attn_mask: Optional[Tensor] = None,
25
+ use_separate_proj_weight: bool = False,
26
+ q_proj_weight: Optional[Tensor] = None,
27
+ k_proj_weight: Optional[Tensor] = None,
28
+ v_proj_weight: Optional[Tensor] = None,
29
+ static_k: Optional[Tensor] = None,
30
+ static_v: Optional[Tensor] = None,
31
+ average_attn_weights: bool = True,
32
+ is_causal: bool = False,cache=None
33
+ ) -> Tuple[Tensor, Optional[Tensor]]:
34
+ r"""
35
+ Args:
36
+ query, key, value: map a query and a set of key-value pairs to an output.
37
+ See "Attention Is All You Need" for more details.
38
+ embed_dim_to_check: total dimension of the model.
39
+ num_heads: parallel attention heads.
40
+ in_proj_weight, in_proj_bias: input projection weight and bias.
41
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
42
+ add_zero_attn: add a new batch of zeros to the key and
43
+ value sequences at dim=1.
44
+ dropout_p: probability of an element to be zeroed.
45
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
46
+ training: apply dropout if is ``True``.
47
+ key_padding_mask: if provided, specified padding elements in the key will
48
+ be ignored by the attention. This is an binary mask. When the value is True,
49
+ the corresponding value on the attention layer will be filled with -inf.
50
+ need_weights: output attn_output_weights.
51
+ Default: `True`
52
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
53
+ For best performance when attention weights are not nedeeded.
54
+ *Setting needs_weights to `True`
55
+ leads to a significant performance degradation.*
56
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
57
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
58
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
59
+ attn_mask for computing scaled dot product attention.
60
+ Default: ``False``.
61
+ .. warning::
62
+ is_causal is provides a hint that the attn_mask is the
63
+ causal mask.Providing incorrect hints can result in
64
+ incorrect execution, including forward and backward
65
+ compatibility.
66
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
67
+ and value in different forms. If false, in_proj_weight will be used, which is
68
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
69
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
70
+ static_k, static_v: static key and value used for attention operators.
71
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
72
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
73
+ when ``need_weights=True.``. Default: True
74
+
75
+
76
+ Shape:
77
+ Inputs:
78
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
79
+ the embedding dimension.
80
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
81
+ the embedding dimension.
82
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
83
+ the embedding dimension.
84
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
85
+ If a FloatTensor is provided, it will be directly added to the value.
86
+ If a BoolTensor is provided, the positions with the
87
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
88
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
89
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
90
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
91
+ positions. If a BoolTensor is provided, positions with ``True``
92
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
93
+ is provided, it will be added to the attention weight.
94
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
95
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
96
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
97
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
98
+
99
+ Outputs:
100
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
101
+ E is the embedding dimension.
102
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
103
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
104
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
105
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
106
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
107
+ """
108
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
109
+ if has_torch_function(tens_ops):
110
+ return handle_torch_function(
111
+ multi_head_attention_forward,
112
+ tens_ops,
113
+ query,
114
+ key,
115
+ value,
116
+ embed_dim_to_check,
117
+ num_heads,
118
+ in_proj_weight,
119
+ in_proj_bias,
120
+ bias_k,
121
+ bias_v,
122
+ add_zero_attn,
123
+ dropout_p,
124
+ out_proj_weight,
125
+ out_proj_bias,
126
+ training=training,
127
+ key_padding_mask=key_padding_mask,
128
+ need_weights=need_weights,
129
+ attn_mask=attn_mask,
130
+ is_causal=is_causal,
131
+ use_separate_proj_weight=use_separate_proj_weight,
132
+ q_proj_weight=q_proj_weight,
133
+ k_proj_weight=k_proj_weight,
134
+ v_proj_weight=v_proj_weight,
135
+ static_k=static_k,
136
+ static_v=static_v,
137
+ average_attn_weights=average_attn_weights,cache=cache
138
+ )
139
+
140
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
141
+
142
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
143
+ # is batched, run the computation and before returning squeeze the
144
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
145
+ if not is_batched:
146
+ # unsqueeze if the input is unbatched
147
+ query = query.unsqueeze(1)
148
+ key = key.unsqueeze(1)
149
+ value = value.unsqueeze(1)
150
+ if key_padding_mask is not None:
151
+ key_padding_mask = key_padding_mask.unsqueeze(0)
152
+
153
+ # set up shape vars
154
+ tgt_len, bsz, embed_dim = query.shape
155
+ src_len, _, _ = key.shape
156
+
157
+ key_padding_mask = _canonical_mask(
158
+ mask=key_padding_mask,
159
+ mask_name="key_padding_mask",
160
+ other_type=_none_or_dtype(attn_mask),
161
+ other_name="attn_mask",
162
+ target_type=query.dtype
163
+ )
164
+
165
+ if is_causal and attn_mask is None:
166
+ raise RuntimeError(
167
+ "Need attn_mask if specifying the is_causal hint. "
168
+ "You may use the Transformer module method "
169
+ "`generate_square_subsequent_mask` to create this mask."
170
+ )
171
+
172
+ if is_causal and key_padding_mask is None and not need_weights:
173
+ # when we have a kpm or need weights, we need attn_mask
174
+ # Otherwise, we use the is_causal hint go as is_causal
175
+ # indicator to SDPA.
176
+ attn_mask = None
177
+ else:
178
+ attn_mask = _canonical_mask(
179
+ mask=attn_mask,
180
+ mask_name="attn_mask",
181
+ other_type=None,
182
+ other_name="",
183
+ target_type=query.dtype,
184
+ check_other=False,
185
+ )
186
+
187
+
188
+ if key_padding_mask is not None:
189
+ # We have the attn_mask, and use that to merge kpm into it.
190
+ # Turn off use of is_causal hint, as the merged mask is no
191
+ # longer causal.
192
+ is_causal = False
193
+
194
+ assert embed_dim == embed_dim_to_check, \
195
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
196
+ if isinstance(embed_dim, torch.Tensor):
197
+ # embed_dim can be a tensor when JIT tracing
198
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
199
+ else:
200
+ head_dim = embed_dim // num_heads
201
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
202
+ if use_separate_proj_weight:
203
+ # allow MHA to have different embedding dimensions when separate projection weights are used
204
+ assert key.shape[:2] == value.shape[:2], \
205
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
206
+ else:
207
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
208
+
209
+ #
210
+ # compute in-projection
211
+ #
212
+ if not use_separate_proj_weight:
213
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
214
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
215
+ else:
216
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
217
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
218
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
219
+ if in_proj_bias is None:
220
+ b_q = b_k = b_v = None
221
+ else:
222
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
223
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
224
+ if(cache!=None):
225
+ if(cache["first_infer"]==1):
226
+ cache["k"][cache["stage"]]=k
227
+ # print(0,cache["k"].shape)
228
+ cache["v"][cache["stage"]]=v
229
+ else:###12个layer每个都要留自己的cache_kv
230
+ # print(1,cache["k"].shape)
231
+ cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
232
+ cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
233
+ # print(2, cache["k"].shape)
234
+ src_len = cache["k"][cache["stage"]].shape[0]
235
+ k=cache["k"][cache["stage"]]
236
+ v=cache["v"][cache["stage"]]
237
+ # if attn_mask is not None:
238
+ # attn_mask=attn_mask[-1:,]
239
+ # print(attn_mask.shape,attn_mask)
240
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
241
+ # print(2333,cache)
242
+ # prep attention mask
243
+
244
+ attn_mask = _canonical_mask(
245
+ mask=attn_mask,
246
+ mask_name="attn_mask",
247
+ other_type=None,
248
+ other_name="",
249
+ target_type=q.dtype,
250
+ check_other=False,
251
+ )
252
+
253
+ if attn_mask is not None:
254
+ # ensure attn_mask's dim is 3
255
+ if attn_mask.dim() == 2:
256
+ correct_2d_size = (tgt_len, src_len)
257
+ if attn_mask.shape != correct_2d_size:
258
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
259
+ attn_mask = attn_mask.unsqueeze(0)
260
+ elif attn_mask.dim() == 3:
261
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
262
+ if attn_mask.shape != correct_3d_size:
263
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
264
+ else:
265
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
266
+
267
+ # add bias along batch dimension (currently second)
268
+ if bias_k is not None and bias_v is not None:
269
+ assert static_k is None, "bias cannot be added to static key."
270
+ assert static_v is None, "bias cannot be added to static value."
271
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
272
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
273
+ if attn_mask is not None:
274
+ attn_mask = pad(attn_mask, (0, 1))
275
+ if key_padding_mask is not None:
276
+ key_padding_mask = pad(key_padding_mask, (0, 1))
277
+ else:
278
+ assert bias_k is None
279
+ assert bias_v is None
280
+
281
+ #
282
+ # reshape q, k, v for multihead attention and make em batch first
283
+ #
284
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
285
+ if static_k is None:
286
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
287
+ else:
288
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
289
+ assert static_k.size(0) == bsz * num_heads, \
290
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
291
+ assert static_k.size(2) == head_dim, \
292
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
293
+ k = static_k
294
+ if static_v is None:
295
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
296
+ else:
297
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
298
+ assert static_v.size(0) == bsz * num_heads, \
299
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
300
+ assert static_v.size(2) == head_dim, \
301
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
302
+ v = static_v
303
+
304
+ # add zero attention along batch dimension (now first)
305
+ if add_zero_attn:
306
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
307
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
308
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
309
+ if attn_mask is not None:
310
+ attn_mask = pad(attn_mask, (0, 1))
311
+ if key_padding_mask is not None:
312
+ key_padding_mask = pad(key_padding_mask, (0, 1))
313
+
314
+ # update source sequence length after adjustments
315
+ src_len = k.size(1)
316
+
317
+ # merge key padding and attention masks
318
+ if key_padding_mask is not None:
319
+ assert key_padding_mask.shape == (bsz, src_len), \
320
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
321
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
322
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
323
+ if attn_mask is None:
324
+ attn_mask = key_padding_mask
325
+ else:
326
+ attn_mask = attn_mask + key_padding_mask
327
+
328
+ # adjust dropout probability
329
+ if not training:
330
+ dropout_p = 0.0
331
+
332
+ #
333
+ # (deep breath) calculate attention and out projection
334
+ #
335
+
336
+ if need_weights:
337
+ B, Nt, E = q.shape
338
+ q_scaled = q / math.sqrt(E)
339
+
340
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
341
+
342
+ if attn_mask is not None:
343
+ attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
344
+ else:
345
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
346
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
347
+ if dropout_p > 0.0:
348
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
349
+
350
+ attn_output = torch.bmm(attn_output_weights, v)
351
+
352
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
353
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
354
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
355
+
356
+ # optionally average attention weights over heads
357
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
358
+ if average_attn_weights:
359
+ attn_output_weights = attn_output_weights.mean(dim=1)
360
+
361
+ if not is_batched:
362
+ # squeeze the output if input was unbatched
363
+ attn_output = attn_output.squeeze(1)
364
+ attn_output_weights = attn_output_weights.squeeze(0)
365
+ return attn_output, attn_output_weights
366
+ else:
367
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
368
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
369
+ # in order to match the input for SDPA of (N, num_heads, L, S)
370
+ if attn_mask is not None:
371
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
372
+ attn_mask = attn_mask.unsqueeze(0)
373
+ else:
374
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
375
+
376
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
377
+ k = k.view(bsz, num_heads, src_len, head_dim)
378
+ v = v.view(bsz, num_heads, src_len, head_dim)
379
+
380
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
381
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
382
+
383
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
384
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
385
+ if not is_batched:
386
+ # squeeze the output if input was unbatched
387
+ attn_output = attn_output.squeeze(1)
388
+ return attn_output, None