Spaces:
Runtime error
Runtime error
adde revision
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -34
- .gitignore +163 -0
- Languages/ben/G_100000.pth +3 -0
- Languages/ben/config.json +3 -0
- Languages/ben/vocab.txt +3 -0
- Languages/ell/G_100000.pth +3 -0
- Languages/ell/config.json +3 -0
- Languages/ell/vocab.txt +3 -0
- Languages/fra/G_100000.pth +3 -0
- Languages/fra/config.json +3 -0
- Languages/fra/vocab.txt +3 -0
- Languages/guj/G_100000.pth +3 -0
- Languages/guj/config.json +3 -0
- Languages/guj/vocab.txt +3 -0
- Languages/hin/G_100000.pth +3 -0
- Languages/hin/config.json +3 -0
- Languages/hin/vocab.txt +3 -0
- Languages/nld/G_100000.pth +3 -0
- Languages/nld/config.json +3 -0
- Languages/nld/vocab.txt +3 -0
- Languages/pol/G_100000.pth +3 -0
- Languages/pol/config.json +3 -0
- Languages/pol/vocab.txt +3 -0
- app.py +317 -0
- aux_files/uroman.pl +3 -0
- configurations/__init__.py +0 -0
- configurations/get_constants.py +176 -0
- configurations/get_hyperparameters.py +19 -0
- df/__init__.py +3 -0
- df/checkpoint.py +213 -0
- df/config.py +266 -0
- df/deepfilternet2.py +453 -0
- df/enhance.py +333 -0
- df/logger.py +212 -0
- df/model.py +24 -0
- df/modules.py +956 -0
- df/multiframe.py +329 -0
- df/utils.py +230 -0
- libdf/__init__.py +3 -0
- libdf/__init__.pyi +57 -0
- libdf/py.typed +0 -0
- model_weights/voice_enhance/checkpoints/model_96.ckpt.best +3 -0
- model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt +3 -0
- model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt +3 -0
- model_weights/voiceover/freevc-24.json +3 -0
- model_weights/voiceover/freevc-24.pth +3 -0
- model_weights/wavlm_models/WavLM-Large.pt +3 -0
- model_weights/wavlm_models/WavLM-Large.pt.txt +3 -0
- nnet/__init__.py +0 -0
- nnet/attentions.py +300 -0
.gitattributes
CHANGED
@@ -1,35 +1,6 @@
|
|
1 |
-
*.
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.
|
25 |
-
*.
|
26 |
-
|
27 |
-
*.
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.txt filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.pt* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt* filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.pl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
Temp_Audios/
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
.pybuilder/
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
# For a library or package, you might want to ignore these files since the code is
|
89 |
+
# intended to run in multiple environments; otherwise, check them in:
|
90 |
+
# .python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# poetry
|
100 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
101 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102 |
+
# commonly ignored for libraries.
|
103 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
104 |
+
#poetry.lock
|
105 |
+
|
106 |
+
# pdm
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
108 |
+
#pdm.lock
|
109 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
110 |
+
# in version control.
|
111 |
+
# https://pdm.fming.dev/#use-with-ide
|
112 |
+
.pdm.toml
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
*.ini
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
Languages/ben/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8c098eab2e5e378fc52bec57683839bbc641b2241033dab17174f6e37db29a4
|
3 |
+
size 145512166
|
Languages/ben/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/ben/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7085f1a1f6040b4da0ac55bb3ff91b77229d1ed14f7d86df2b23676a1a2cb81b
|
3 |
+
size 268
|
Languages/ell/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75bfa237f0fe859b34c4340bc7dccd944678cf9984bce5b5a82e2c90ca268db8
|
3 |
+
size 145504497
|
Languages/ell/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/ell/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c53d89f446eba9d061b510e31900d235ef0e021e44a790978dae5a4350a4013
|
3 |
+
size 164
|
Languages/fra/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63725b5a9201548b2247af02bd69a059335bddf52c1b858dbe38a43a40478bd7
|
3 |
+
size 145489135
|
Languages/fra/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/fra/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b57f0f246b488fe914508d82a8607e1aea357beb0f801069b39bfeb3a4c0d47
|
3 |
+
size 104
|
Languages/guj/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:427ac3c74f61be494b389cae7d771311d0bcf576f4e2f1b22f257539e26e323a
|
3 |
+
size 145501427
|
Languages/guj/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/guj/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:611d4c5d7ba4bce727c1154277aea43df7a534e22e523877d1885a36727d63c3
|
3 |
+
size 232
|
Languages/hin/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f1d5e47edd7368ff40ff5673ddfc606ea713e785420d26c2da396b555458d3b
|
3 |
+
size 145510619
|
Languages/hin/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/hin/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eea03474615c78a1c42d1299c345b6865421c07485544ac3361bff472e5005ac
|
3 |
+
size 266
|
Languages/nld/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b09e9917b07f06dd911045c8fc8738594b4c4d65c55223c46335093a4904816
|
3 |
+
size 145486855
|
Languages/nld/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/nld/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d7e65d89b8be768ac4a1e53643aebfe42830b82910bfed87f13904e2c5292a4
|
3 |
+
size 94
|
Languages/pol/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d6f4a9de92a6eb15bca8cb01826d8a9938ab6fb2c04a1c13a06d1d170c88ba6
|
3 |
+
size 145490647
|
Languages/pol/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49
|
3 |
+
size 1887
|
Languages/pol/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5514b17eb0cd17950849e3f2f2f22a6c7c2d18f2729f5b2fbfc2f2e5f035dc4a
|
3 |
+
size 103
|
app.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# load the libraries for the application
|
3 |
+
# -------------------------------------------
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import nltk
|
7 |
+
import torch
|
8 |
+
import librosa
|
9 |
+
import tempfile
|
10 |
+
import subprocess
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from scipy.io import wavfile
|
15 |
+
from nnet import utils, commons
|
16 |
+
from transformers import pipeline
|
17 |
+
from scipy.io.wavfile import write
|
18 |
+
from faster_whisper import WhisperModel
|
19 |
+
from nnet.models import SynthesizerTrn as vitsTRN
|
20 |
+
from nnet.models_vc import SynthesizerTrn as freeTRN
|
21 |
+
from nnet.mel_processing import mel_spectrogram_torch
|
22 |
+
from configurations.get_constants import constantConfig
|
23 |
+
|
24 |
+
from speaker_encoder.voice_encoder import SpeakerEncoder
|
25 |
+
|
26 |
+
from df.enhance import enhance, init_df, load_audio, save_audio
|
27 |
+
from configurations.get_hyperparameters import hyperparameterConfig
|
28 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
29 |
+
|
30 |
+
nltk.download('punkt')
|
31 |
+
from nltk.tokenize import sent_tokenize
|
32 |
+
|
33 |
+
# making the FreeVC function
|
34 |
+
# ---------------------------------
|
35 |
+
class FreeVCModel:
|
36 |
+
def __init__(self, config, ptfile, speaker_model, wavLM_model, device='cpu'):
|
37 |
+
self.hps = utils.get_hparams_from_file(config)
|
38 |
+
|
39 |
+
self.net_g = freeTRN(
|
40 |
+
self.hps.data.filter_length // 2 + 1,
|
41 |
+
self.hps.train.segment_size // self.hps.data.hop_length,
|
42 |
+
**self.hps.model
|
43 |
+
).to(hyperparameters.device)
|
44 |
+
_ = self.net_g.eval()
|
45 |
+
_ = utils.load_checkpoint(ptfile, self.net_g, None, True)
|
46 |
+
|
47 |
+
self.cmodel = utils.get_cmodel(device, wavLM_model)
|
48 |
+
|
49 |
+
if self.hps.model.use_spk:
|
50 |
+
self.smodel = SpeakerEncoder(speaker_model)
|
51 |
+
|
52 |
+
def convert(self, src, tgt):
|
53 |
+
fs_src, src_audio = src
|
54 |
+
fs_tgt, tgt_audio = tgt
|
55 |
+
|
56 |
+
src = f"{constants.temp_audio_folder}/src.wav"
|
57 |
+
tgt = f"{constants.temp_audio_folder}/tgt.wav"
|
58 |
+
out = f"{constants.temp_audio_folder}/cnvr.wav"
|
59 |
+
with torch.no_grad():
|
60 |
+
wavfile.write(tgt, fs_tgt, tgt_audio)
|
61 |
+
wav_tgt, _ = librosa.load(tgt, sr=self.hps.data.sampling_rate)
|
62 |
+
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
|
63 |
+
if self.hps.model.use_spk:
|
64 |
+
g_tgt = self.smodel.embed_utterance(wav_tgt)
|
65 |
+
g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(hyperparameters.device.type)
|
66 |
+
else:
|
67 |
+
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(hyperparameters.device.type)
|
68 |
+
mel_tgt = mel_spectrogram_torch(
|
69 |
+
wav_tgt,
|
70 |
+
self.hps.data.filter_length,
|
71 |
+
self.hps.data.n_mel_channels,
|
72 |
+
self.hps.data.sampling_rate,
|
73 |
+
self.hps.data.hop_length,
|
74 |
+
self.hps.data.win_length,
|
75 |
+
self.hps.data.mel_fmin,
|
76 |
+
self.hps.data.mel_fmax,
|
77 |
+
)
|
78 |
+
wavfile.write(src, fs_src, src_audio)
|
79 |
+
wav_src, _ = librosa.load(src, sr=self.hps.data.sampling_rate)
|
80 |
+
wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(hyperparameters.device.type)
|
81 |
+
c = utils.get_content(self.cmodel, wav_src)
|
82 |
+
|
83 |
+
if self.hps.model.use_spk:
|
84 |
+
audio = self.net_g.infer(c, g=g_tgt)
|
85 |
+
else:
|
86 |
+
audio = self.net_g.infer(c, mel=mel_tgt)
|
87 |
+
audio = audio[0][0].data.cpu().float().numpy()
|
88 |
+
write(out, 24000, audio)
|
89 |
+
|
90 |
+
return out
|
91 |
+
|
92 |
+
# load the system configurations
|
93 |
+
constants = constantConfig()
|
94 |
+
hyperparameters = hyperparameterConfig()
|
95 |
+
|
96 |
+
# load the models
|
97 |
+
model, df_state, _ = init_df(hyperparameters.voice_enhacing_model, config_allow_defaults=True) # voice enhancing model
|
98 |
+
stt_model = WhisperModel(hyperparameters.stt_model, device=hyperparameters.device.type, compute_type="float32") #speech to text model
|
99 |
+
|
100 |
+
trans_model = AutoModelForSeq2SeqLM.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model], torch_dtype=torch.bfloat16).to(hyperparameters.device)
|
101 |
+
trans_tokenizer = AutoTokenizer.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model])
|
102 |
+
|
103 |
+
modelConvertSpeech = FreeVCModel(config=hyperparameters.text2speech_config, ptfile=hyperparameters.text2speech_model,
|
104 |
+
speaker_model=hyperparameters.text2speech_encoder, wavLM_model=hyperparameters.wavlm_model,
|
105 |
+
device=hyperparameters.device.type)
|
106 |
+
|
107 |
+
# download the language model if doesn't existing
|
108 |
+
# ----------------------------------------------------
|
109 |
+
def download(lang, lang_directory):
|
110 |
+
|
111 |
+
if not os.path.exists(f"{lang_directory}/{lang}"):
|
112 |
+
cmd = ";".join([
|
113 |
+
f"wget {constants.language_download_web}/{lang}.tar.gz -O {lang_directory}/{lang}.tar.gz",
|
114 |
+
f"tar zxvf {lang_directory}/{lang}.tar.gz -C {lang_directory}"
|
115 |
+
])
|
116 |
+
subprocess.check_output(cmd, shell=True)
|
117 |
+
try:
|
118 |
+
os.remove(f"{lang_directory}/{lang}.tar.gz")
|
119 |
+
except:
|
120 |
+
pass
|
121 |
+
return f"{lang_directory}/{lang}"
|
122 |
+
|
123 |
+
def preprocess_char(text, lang=None):
|
124 |
+
"""
|
125 |
+
Special treatement of characters in certain languages
|
126 |
+
"""
|
127 |
+
if lang == 'ron':
|
128 |
+
text = text.replace("ț", "ţ")
|
129 |
+
return text
|
130 |
+
|
131 |
+
def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None):
|
132 |
+
txt = preprocess_char(txt, lang=lang)
|
133 |
+
is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
|
134 |
+
if is_uroman:
|
135 |
+
txt = text_mapper.uromanize(txt, f'{uroman_dir}/bin/uroman.pl')
|
136 |
+
|
137 |
+
txt = txt.lower()
|
138 |
+
txt = text_mapper.filter_oov(txt)
|
139 |
+
return txt
|
140 |
+
|
141 |
+
def detect_language(text,LID):
|
142 |
+
predictions = LID.predict(text)
|
143 |
+
detected_lang_code = predictions[0][0].replace("__label__", "")
|
144 |
+
return detected_lang_code
|
145 |
+
|
146 |
+
# text to speech
|
147 |
+
class TextMapper(object):
|
148 |
+
def __init__(self, vocab_file):
|
149 |
+
self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()]
|
150 |
+
self.SPACE_ID = self.symbols.index(" ")
|
151 |
+
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
152 |
+
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
|
153 |
+
|
154 |
+
def text_to_sequence(self, text, cleaner_names):
|
155 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
156 |
+
Args:
|
157 |
+
text: string to convert to a sequence
|
158 |
+
cleaner_names: names of the cleaner functions to run the text through
|
159 |
+
Returns:
|
160 |
+
List of integers corresponding to the symbols in the text
|
161 |
+
'''
|
162 |
+
sequence = []
|
163 |
+
clean_text = text.strip()
|
164 |
+
for symbol in clean_text:
|
165 |
+
symbol_id = self._symbol_to_id[symbol]
|
166 |
+
sequence += [symbol_id]
|
167 |
+
return sequence
|
168 |
+
|
169 |
+
def uromanize(self, text, uroman_pl):
|
170 |
+
with tempfile.NamedTemporaryFile() as tf, \
|
171 |
+
tempfile.NamedTemporaryFile() as tf2:
|
172 |
+
with open(tf.name, "w") as f:
|
173 |
+
f.write("\n".join([text]))
|
174 |
+
cmd = f"perl " + uroman_pl
|
175 |
+
cmd += f" -l xxx "
|
176 |
+
cmd += f" < {tf.name} > {tf2.name}"
|
177 |
+
os.system(cmd)
|
178 |
+
outtexts = []
|
179 |
+
with open(tf2.name) as f:
|
180 |
+
for line in f:
|
181 |
+
line = re.sub(r"\s+", " ", line).strip()
|
182 |
+
outtexts.append(line)
|
183 |
+
outtext = outtexts[0]
|
184 |
+
return outtext
|
185 |
+
|
186 |
+
def get_text(self, text, hps):
|
187 |
+
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
|
188 |
+
if hps.data.add_blank:
|
189 |
+
text_norm = commons.intersperse(text_norm, 0)
|
190 |
+
text_norm = torch.LongTensor(text_norm)
|
191 |
+
return text_norm
|
192 |
+
|
193 |
+
def filter_oov(self, text):
|
194 |
+
val_chars = self._symbol_to_id
|
195 |
+
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
|
196 |
+
return txt_filt
|
197 |
+
|
198 |
+
def speech_to_text(audio_file):
|
199 |
+
try:
|
200 |
+
fs, audio = audio_file
|
201 |
+
wavfile.write(constants.input_speech_file, fs, audio)
|
202 |
+
audio0, _ = load_audio(constants.input_speech_file, sr=df_state.sr())
|
203 |
+
|
204 |
+
# Enhance the SNR of the audio
|
205 |
+
enhanced = enhance(model, df_state, audio0)
|
206 |
+
save_audio(constants.enhanced_speech_file, enhanced, df_state.sr())
|
207 |
+
|
208 |
+
segments, info = stt_model.transcribe(constants.enhanced_speech_file)
|
209 |
+
|
210 |
+
speech_text = ''
|
211 |
+
for segment in segments:
|
212 |
+
speech_text = f'{speech_text}{segment.text}'
|
213 |
+
try:
|
214 |
+
source_lang_nllb = [k for k, v in constants.flores_codes_to_tts_codes.items() if v[:2] == info.language][0]
|
215 |
+
except:
|
216 |
+
source_lang_nllb = 'language cant be determined, select manually'
|
217 |
+
|
218 |
+
# text translation
|
219 |
+
return speech_text, gr.Dropdown.update(value=source_lang_nllb)
|
220 |
+
except:
|
221 |
+
return '', gr.Dropdown.update(value='English')
|
222 |
+
|
223 |
+
# Text tp speech
|
224 |
+
def text_to_speech(text, target_lang):
|
225 |
+
txt = text
|
226 |
+
|
227 |
+
# LANG = get_target_tts_lang(target_lang)
|
228 |
+
LANG = constants.flores_codes_to_tts_codes[target_lang]
|
229 |
+
ckpt_dir = download(LANG, lang_directory=constants.language_directory)
|
230 |
+
|
231 |
+
vocab_file = f"{ckpt_dir}/{constants.language_vocab_text}"
|
232 |
+
config_file = f"{ckpt_dir}/{constants.language_vocab_configuration}"
|
233 |
+
hps = utils.get_hparams_from_file(config_file)
|
234 |
+
text_mapper = TextMapper(vocab_file)
|
235 |
+
net_g = vitsTRN(
|
236 |
+
len(text_mapper.symbols),
|
237 |
+
hps.data.filter_length // 2 + 1,
|
238 |
+
hps.train.segment_size // hps.data.hop_length,
|
239 |
+
**hps.model)
|
240 |
+
net_g.to(hyperparameters.device)
|
241 |
+
_ = net_g.eval()
|
242 |
+
|
243 |
+
g_pth = f"{ckpt_dir}/{constants.language_vocab_model}"
|
244 |
+
|
245 |
+
_ = utils.load_checkpoint(g_pth, net_g, None)
|
246 |
+
|
247 |
+
txt = preprocess_text(txt, text_mapper, hps, lang=LANG, uroman_dir=constants.uroman_directory)
|
248 |
+
stn_tst = text_mapper.get_text(txt, hps)
|
249 |
+
with torch.no_grad():
|
250 |
+
x_tst = stn_tst.unsqueeze(0).to(hyperparameters.device)
|
251 |
+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(hyperparameters.device)
|
252 |
+
hyp = net_g.infer(
|
253 |
+
x_tst, x_tst_lengths, noise_scale=.667,
|
254 |
+
noise_scale_w=0.8, length_scale=1.0
|
255 |
+
)[0][0,0].cpu().float().numpy()
|
256 |
+
|
257 |
+
return hps.data.sampling_rate, hyp
|
258 |
+
|
259 |
+
def translation(audio, text, source_lang_nllb, target_code_nllb, output_type, sentence_mode):
|
260 |
+
target_code = constants.flores_codes[target_code_nllb]
|
261 |
+
translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source_lang_nllb, tgt_lang=target_code, device=hyperparameters.device)
|
262 |
+
|
263 |
+
# output = translator(text, max_length=400)[0]['translation_text']
|
264 |
+
if sentence_mode == "Sentence-wise":
|
265 |
+
sentences = sent_tokenize(text)
|
266 |
+
translated_sentences = []
|
267 |
+
for sentence in sentences:
|
268 |
+
translated_sentence = translator(sentence, max_length=400)[0]['translation_text']
|
269 |
+
translated_sentences.append(translated_sentence)
|
270 |
+
output = ' '.join(translated_sentences)
|
271 |
+
else:
|
272 |
+
output = translator(text, max_length=1024)[0]['translation_text']
|
273 |
+
|
274 |
+
# get the text to speech
|
275 |
+
fs_out, audio_out = text_to_speech(output, target_code_nllb)
|
276 |
+
|
277 |
+
if output_type == 'own voice':
|
278 |
+
out_file = modelConvertSpeech.convert((fs_out, audio_out), audio)
|
279 |
+
return output, out_file
|
280 |
+
|
281 |
+
wavfile.write(constants.text2speech_wavfile, fs_out, audio_out)
|
282 |
+
return output, constants.text2speech_wavfile
|
283 |
+
|
284 |
+
with gr.Blocks(title = "Octopus Translation App") as octopus_translator:
|
285 |
+
with gr.Row():
|
286 |
+
audio_file = gr.Audio(source="microphone")
|
287 |
+
|
288 |
+
with gr.Row():
|
289 |
+
input_text = gr.Textbox(label="Input text")
|
290 |
+
source_language = gr.Dropdown(list(constants.flores_codes.keys()), value='English', label='Source (Autoselected)', interactive=True)
|
291 |
+
|
292 |
+
with gr.Row():
|
293 |
+
output_text = gr.Textbox(label='Translated text')
|
294 |
+
target_language = gr.Dropdown(list(constants.flores_codes.keys()), value='German', label='Target', interactive=True)
|
295 |
+
|
296 |
+
|
297 |
+
with gr.Row():
|
298 |
+
output_speech = gr.Audio(label='Translated speech')
|
299 |
+
translate_button = gr.Button('Translate')
|
300 |
+
|
301 |
+
|
302 |
+
with gr.Row():
|
303 |
+
enhance_audio = gr.Radio(['yes', 'no'], value='yes', label='Enhance input voice', interactive=True)
|
304 |
+
input_type = gr.Radio(['Whole text', 'Sentence-wise'],value='Sentence-wise', label="Translation Mode", interactive=True)
|
305 |
+
output_audio_type = gr.Radio(['standard speaker', 'voice transfer'], value='voice transfer', label='Enhance output voice', interactive=True)
|
306 |
+
|
307 |
+
audio_file.change(speech_to_text,
|
308 |
+
inputs=[audio_file],
|
309 |
+
outputs=[input_text, source_language])
|
310 |
+
|
311 |
+
translate_button.click(translation,
|
312 |
+
inputs=[audio_file, input_text,
|
313 |
+
source_language, target_language,
|
314 |
+
output_audio_type, input_type],
|
315 |
+
outputs=[output_text, output_speech])
|
316 |
+
|
317 |
+
octopus_translator.launch(share=False)
|
aux_files/uroman.pl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ceece2c05343e8bc3b1a7cdc8cecd530af94a7928013c0e4224fd5c729fb29a
|
3 |
+
size 5347
|
configurations/__init__.py
ADDED
File without changes
|
configurations/get_constants.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
class constantConfig:
|
4 |
+
def __init__(self):
|
5 |
+
self.flores_codes={'Acehnese (Arabic script)': 'ace_Arab',
|
6 |
+
'Acehnese (Latin script)': 'ace_Latn',
|
7 |
+
'Mesopotamian Arabic': 'acm_Arab',
|
8 |
+
'Ta’izzi-Adeni Arabic': 'acq_Arab',
|
9 |
+
'Tunisian Arabic': 'aeb_Arab',
|
10 |
+
'Afrikaans': 'afr_Latn',
|
11 |
+
'South Levantine Arabic': 'ajp_Arab',
|
12 |
+
'Akan': 'aka_Latn',
|
13 |
+
'Amharic': 'amh_Ethi',
|
14 |
+
'North Levantine Arabic': 'apc_Arab',
|
15 |
+
'Modern Standard Arabic': 'arb_Arab',
|
16 |
+
'Modern Standard Arabic (Romanized)': 'arb_Latn',
|
17 |
+
'Najdi Arabic': 'ars_Arab',
|
18 |
+
'Moroccan Arabic': 'ary_Arab',
|
19 |
+
'Egyptian Arabic': 'arz_Arab',
|
20 |
+
'Assamese': 'asm_Beng',
|
21 |
+
'Asturian': 'ast_Latn',
|
22 |
+
'Awadhi': 'awa_Deva',
|
23 |
+
'Central Aymara': 'ayr_Latn',
|
24 |
+
'South Azerbaijani': 'azb_Arab',
|
25 |
+
'North Azerbaijani': 'azj_Latn',
|
26 |
+
'Bashkir': 'bak_Cyrl',
|
27 |
+
'Bambara': 'bam_Latn',
|
28 |
+
'Balinese': 'ban_Latn',
|
29 |
+
'Belarusian': 'bel_Cyrl',
|
30 |
+
'Bemba': 'bem_Latn',
|
31 |
+
'Bengali': 'ben_Beng',
|
32 |
+
'Bhojpuri': 'bho_Deva',
|
33 |
+
'Banjar (Arabic script)': 'bjn_Arab',
|
34 |
+
'Banjar (Latin script)': 'bjn_Latn',
|
35 |
+
'Standard Tibetan': 'bod_Tibt',
|
36 |
+
'Bosnian': 'bos_Latn',
|
37 |
+
'Buginese': 'bug_Latn',
|
38 |
+
'Bulgarian': 'bul_Cyrl',
|
39 |
+
'Catalan': 'cat_Latn',
|
40 |
+
'Cebuano': 'ceb_Latn',
|
41 |
+
'Czech': 'ces_Latn',
|
42 |
+
'Chokwe': 'cjk_Latn',
|
43 |
+
'Central Kurdish': 'ckb_Arab',
|
44 |
+
'Crimean Tatar': 'crh_Latn',
|
45 |
+
'Welsh': 'cym_Latn',
|
46 |
+
'Danish': 'dan_Latn',
|
47 |
+
'German': 'deu_Latn',
|
48 |
+
'Southwestern Dinka': 'dik_Latn',
|
49 |
+
'Dyula': 'dyu_Latn',
|
50 |
+
'Dzongkha': 'dzo_Tibt',
|
51 |
+
'Greek': 'ell_Grek',
|
52 |
+
'English': 'eng_Latn',
|
53 |
+
'Esperanto': 'epo_Latn',
|
54 |
+
'Estonian': 'est_Latn',
|
55 |
+
'Basque': 'eus_Latn',
|
56 |
+
'Ewe': 'ewe_Latn',
|
57 |
+
'Faroese': 'fao_Latn',
|
58 |
+
'Fijian': 'fij_Latn',
|
59 |
+
'Finnish': 'fin_Latn',
|
60 |
+
'Fon': 'fon_Latn',
|
61 |
+
'French': 'fra_Latn',
|
62 |
+
'Friulian': 'fur_Latn',
|
63 |
+
'Nigerian Fulfulde': 'fuv_Latn',
|
64 |
+
'Scottish Gaelic': 'gla_Latn',
|
65 |
+
'Irish': 'gle_Latn',
|
66 |
+
'Galician': 'glg_Latn',
|
67 |
+
'Guarani': 'grn_Latn',
|
68 |
+
'Gujarati': 'guj_Gujr',
|
69 |
+
'Haitian Creole': 'hat_Latn',
|
70 |
+
'Hausa': 'hau_Latn',
|
71 |
+
'Hebrew': 'heb_Hebr',
|
72 |
+
'Hindi': 'hin_Deva',
|
73 |
+
'Chhattisgarhi': 'hne_Deva',
|
74 |
+
'Croatian': 'hrv_Latn',
|
75 |
+
'Hungarian': 'hun_Latn',
|
76 |
+
'Armenian': 'hye_Armn',
|
77 |
+
'Igbo': 'ibo_Latn',
|
78 |
+
'Ilocano': 'ilo_Latn',
|
79 |
+
'Indonesian': 'ind_Latn',
|
80 |
+
'Icelandic': 'isl_Latn',
|
81 |
+
'Italian': 'ita_Latn',
|
82 |
+
'Javanese': 'jav_Latn',
|
83 |
+
'Japanese': 'jpn_Jpan',
|
84 |
+
'Kabyle': 'kab_Latn',
|
85 |
+
'Jingpho': 'kac_Latn',
|
86 |
+
'Kamba': 'kam_Latn',
|
87 |
+
'Kannada': 'kan_Knda',
|
88 |
+
'Kashmiri (Arabic script)': 'kas_Arab',
|
89 |
+
'Kashmiri (Devanagari script)': 'kas_Deva',
|
90 |
+
'Georgian': 'kat_Geor',
|
91 |
+
'Central Kanuri (Arabic script)': 'knc_Arab',
|
92 |
+
'Central Kanuri (Latin script)': 'knc_Latn',
|
93 |
+
'Kazakh': 'kaz_Cyrl',
|
94 |
+
'Kabiyè': 'kbp_Latn',
|
95 |
+
'Kabuverdianu': 'kea_Latn',
|
96 |
+
'Khmer': 'khm_Khmr',
|
97 |
+
'Kikuyu': 'kik_Latn',
|
98 |
+
'Kinyarwanda': 'kin_Latn', 'Kyrgyz': 'kir_Cyrl', 'Kimbundu': 'kmb_Latn',
|
99 |
+
'Northern Kurdish': 'kmr_Latn', 'Kikongo': 'kon_Latn',
|
100 |
+
'Korean': 'kor_Hang', 'Lao': 'lao_Laoo', 'Ligurian': 'lij_Latn',
|
101 |
+
'Limburgish': 'lim_Latn', 'Lingala': 'lin_Latn', 'Lithuanian': 'lit_Latn', 'Lombard': 'lmo_Latn',
|
102 |
+
'Latgalian': 'ltg_Latn', 'Luxembourgish': 'ltz_Latn', 'Luba-Kasai': 'lua_Latn', 'Ganda': 'lug_Latn',
|
103 |
+
'Luo': 'luo_Latn', 'Mizo': 'lus_Latn', 'Standard Latvian': 'lvs_Latn', 'Magahi': 'mag_Deva',
|
104 |
+
'Maithili': 'mai_Deva', 'Malayalam': 'mal_Mlym', 'Marathi': 'mar_Deva',
|
105 |
+
'Minangkabau (Arabic script)': 'min_Arab', 'Minangkabau (Latin script)': 'min_Latn',
|
106 |
+
'Macedonian': 'mkd_Cyrl', 'Plateau Malagasy': 'plt_Latn', 'Maltese': 'mlt_Latn',
|
107 |
+
'Meitei (Bengali script)': 'mni_Beng', 'Halh Mongolian': 'khk_Cyrl', 'Mossi': 'mos_Latn',
|
108 |
+
'Maori': 'mri_Latn', 'Burmese': 'mya_Mymr', 'Dutch': 'nld_Latn', 'Norwegian Nynorsk': 'nno_Latn',
|
109 |
+
'Norwegian Bokmål': 'nob_Latn', 'Nepali': 'npi_Deva', 'Northern Sotho': 'nso_Latn',
|
110 |
+
'Nuer': 'nus_Latn',
|
111 |
+
'Nyanja': 'nya_Latn', 'Occitan': 'oci_Latn', 'West Central Oromo': 'gaz_Latn', 'Odia': 'ory_Orya',
|
112 |
+
'Pangasinan': 'pag_Latn', 'Eastern Panjabi': 'pan_Guru', 'Papiamento': 'pap_Latn',
|
113 |
+
'Western Persian': 'pes_Arab',
|
114 |
+
'Polish': 'pol_Latn', 'Portuguese': 'por_Latn', 'Dari': 'prs_Arab', 'Southern Pashto': 'pbt_Arab',
|
115 |
+
'Ayacucho Quechua': 'quy_Latn', 'Romanian': 'ron_Latn', 'Rundi': 'run_Latn', 'Russian': 'rus_Cyrl',
|
116 |
+
'Sango': 'sag_Latn', 'Sanskrit': 'san_Deva', 'Santali': 'sat_Olck', 'Sicilian': 'scn_Latn',
|
117 |
+
'Shan': 'shn_Mymr',
|
118 |
+
'Sinhala': 'sin_Sinh', 'Slovak': 'slk_Latn', 'Slovenian': 'slv_Latn', 'Samoan': 'smo_Latn',
|
119 |
+
'Shona': 'sna_Latn',
|
120 |
+
'Sindhi': 'snd_Arab', 'Somali': 'som_Latn', 'Southern Sotho': 'sot_Latn', 'Spanish': 'spa_Latn',
|
121 |
+
'Tosk Albanian': 'als_Latn', 'Sardinian': 'srd_Latn', 'Serbian': 'srp_Cyrl', 'Swati': 'ssw_Latn',
|
122 |
+
'Sundanese': 'sun_Latn', 'Swedish': 'swe_Latn', 'Swahili': 'swh_Latn', 'Silesian': 'szl_Latn',
|
123 |
+
'Tamil': 'tam_Taml', 'Tatar': 'tat_Cyrl', 'Telugu': 'tel_Telu', 'Tajik': 'tgk_Cyrl',
|
124 |
+
'Tagalog': 'tgl_Latn',
|
125 |
+
'Thai': 'tha_Thai', 'Tigrinya': 'tir_Ethi', 'Tamasheq (Latin script)': 'taq_Latn',
|
126 |
+
'Tamasheq (Tifinagh script)': 'taq_Tfng',
|
127 |
+
'Tok Pisin': 'tpi_Latn', 'Tswana': 'tsn_Latn', 'Tsonga': 'tso_Latn', 'Turkmen': 'tuk_Latn', 'Tumbuka': 'tum_Latn',
|
128 |
+
'Turkish': 'tur_Latn', 'Twi': 'twi_Latn', 'Central Atlas Tamazight': 'tzm_Tfng',
|
129 |
+
'Uyghur': 'uig_Arab',
|
130 |
+
'Ukrainian': 'ukr_Cyrl', 'Umbundu': 'umb_Latn', 'Urdu': 'urd_Arab', 'Northern Uzbek': 'uzn_Latn',
|
131 |
+
'Venetian': 'vec_Latn',
|
132 |
+
'Vietnamese': 'vie_Latn', 'Waray': 'war_Latn', 'Wolof': 'wol_Latn', 'Xhosa': 'xho_Latn',
|
133 |
+
'Eastern Yiddish': 'ydd_Hebr',
|
134 |
+
'Yoruba': 'yor_Latn', 'Yue Chinese': 'yue_Hant', 'Chinese (Simplified)': 'zho_Hans',
|
135 |
+
'Chinese (Traditional)': 'zho_Hant',
|
136 |
+
'Standard Malay': 'zsm_Latn', 'Zulu': 'zul_Latn'}
|
137 |
+
|
138 |
+
self.model_name_dict = {'0.6B': 'facebook/nllb-200-distilled-600M',
|
139 |
+
'1.3B': 'facebook/nllb-200-distilled-1.3B',
|
140 |
+
'3.3B': 'facebook/nllb-200-3.3B',
|
141 |
+
}
|
142 |
+
|
143 |
+
self.whisper_codes_to_flores_codes = {"de" : self.flores_codes['German'],
|
144 |
+
"en" : self.flores_codes['English'],
|
145 |
+
"pl" : self.flores_codes['Polish'],
|
146 |
+
"hi" : self.flores_codes['Hindi']
|
147 |
+
}
|
148 |
+
|
149 |
+
self.flores_codes_to_tts_codes = {'Acehnese': 'ace', 'Mesopotamian Arabic': 'acm', 'Ta’izzi-Adeni Arabic': 'acq', 'Tunisian Arabic': 'aeb', 'Afrikaans': 'afr', 'South Levantine Arabic': 'ajp', 'Akan': 'aka', 'Amharic': 'amh', 'North Levantine Arabic': 'apc', 'Modern Standard Arabic': 'arb', 'Najdi Arabic': 'ars', 'Moroccan Arabic': 'ary', 'Egyptian Arabic': 'arz', 'Assamese': 'asm', 'Asturian': 'ast', 'Awadhi': 'awa', 'Central Aymara': 'ayr', 'South Azerbaijani': 'azb', 'North Azerbaijani': 'azj', 'Bashkir': 'bak', 'Bambara': 'bam', 'Balinese': 'ban', 'Belarusian': 'bel', 'Bemba': 'bem', 'Bengali': 'ben', 'Bhojpuri': 'bho', 'Banjar': 'bjn', 'Standard Tibetan': 'bod', 'Bosnian': 'bos', 'Buginese': 'bug', 'Bulgarian': 'bul', 'Catalan': 'cat', 'Cebuano': 'ceb', 'Czech': 'ces', 'Chokwe': 'cjk', 'Central Kurdish': 'ckb', 'Crimean Tatar': 'crh', 'Welsh': 'cym', 'Danish': 'dan', 'German': 'deu', 'Southwestern Dinka': 'dik', 'Dyula': 'dyu', 'Dzongkha': 'dzo', 'Greek': 'ell', 'English': 'eng', 'Esperanto': 'epo', 'Estonian': 'est', 'Basque': 'eus', 'Ewe': 'ewe', 'Faroese': 'fao', 'Fijian': 'fij', 'Finnish': 'fin', 'Fon': 'fon', 'French': 'fra', 'Friulian': 'fur', 'Nigerian Fulfulde': 'fuv', 'Scottish Gaelic': 'gla', 'Irish': 'gle', 'Galician': 'glg', 'Guarani': 'grn', 'Gujarati': 'guj', 'Haitian Creole': 'hat', 'Hausa': 'hau', 'Hebrew': 'heb', 'Hindi': 'hin', 'Chhattisgarhi': 'hne', 'Croatian': 'hrv', 'Hungarian': 'hun', 'Armenian': 'hye', 'Igbo': 'ibo', 'Ilocano': 'ilo', 'Indonesian': 'ind', 'Icelandic': 'isl', 'Italian': 'ita', 'Javanese': 'jav', 'Japanese': 'jpn', 'Kabyle': 'kab', 'Jingpho': 'kac', 'Kamba': 'kam', 'Kannada': 'kan', 'Kashmiri': 'kas', 'Georgian': 'kat', 'Central Kanuri': 'knc', 'Kazakh': 'kaz', 'Kabiyè': 'kbp', 'Kabuverdianu': 'kea', 'Khmer': 'khm', 'Kikuyu': 'kik', 'Kinyarwanda': 'kin', 'Kyrgyz': 'kir', 'Kimbundu': 'kmb', 'Northern Kurdish': 'kmr', 'Kikongo': 'kon', 'Korean': 'kor', 'Lao': 'lao', 'Ligurian': 'lij', 'Limburgish': 'lim', 'Lingala': 'lin', 'Lithuanian': 'lit', 'Lombard': 'lmo', 'Latgalian': 'ltg', 'Luxembourgish': 'ltz', 'Luba-Kasai': 'lua', 'Ganda': 'lug', 'Luo': 'luo', 'Mizo': 'lus', 'Standard Latvian': 'lvs', 'Magahi': 'mag', 'Maithili': 'mai', 'Malayalam': 'mal', 'Marathi': 'mar', 'Minangkabau': 'min', 'Macedonian': 'mkd', 'Plateau Malagasy': 'plt', 'Maltese': 'mlt', 'Meitei': 'mni', 'Halh Mongolian': 'khk', 'Mossi': 'mos', 'Maori': 'mri', 'Burmese': 'mya', 'Dutch': 'nld', 'Norwegian Nynorsk': 'nno', 'Norwegian Bokmål': 'nob', 'Nepali': 'npi', 'Northern Sotho': 'nso', 'Nuer': 'nus', 'Nyanja': 'nya', 'Occitan': 'oci', 'West Central Oromo': 'gaz', 'Odia': 'ory', 'Pangasinan': 'pag', 'Eastern Panjabi': 'pan', 'Papiamento': 'pap', 'Western Persian': 'pes', 'Polish': 'pol', 'Portuguese': 'por', 'Dari': 'prs', 'Southern Pashto': 'pbt', 'Ayacucho Quechua': 'quy', 'Romanian': 'ron', 'Rundi': 'run', 'Russian': 'rus', 'Sango': 'sag', 'Sanskrit': 'san', 'Santali': 'sat', 'Sicilian': 'scn', 'Shan': 'shn', 'Sinhala': 'sin', 'Slovak': 'slk', 'Slovenian': 'slv', 'Samoan': 'smo', 'Shona': 'sna', 'Sindhi': 'snd', 'Somali': 'som', 'Southern Sotho': 'sot', 'Spanish': 'spa', 'Tosk Albanian': 'als', 'Sardinian': 'srd', 'Serbian': 'srp', 'Swati': 'ssw', 'Sundanese': 'sun', 'Swedish': 'swe', 'Swahili': 'swh', 'Silesian': 'szl', 'Tamil': 'tam', 'Tatar': 'tat', 'Telugu': 'tel', 'Tajik': 'tgk', 'Tagalog': 'tgl', 'Thai': 'tha', 'Tigrinya': 'tir', 'Tamasheq': 'taq', 'Tok Pisin': 'tpi', 'Tswana': 'tsn', 'Tsonga': 'tso', 'Turkmen': 'tuk', 'Tumbuka': 'tum', 'Turkish': 'tur', 'Twi': 'twi', 'Central Atlas Tamazight': 'tzm', 'Uyghur': 'uig', 'Ukrainian': 'ukr', 'Umbundu': 'umb', 'Urdu': 'urd', 'Northern Uzbek': 'uzn', 'Venetian': 'vec', 'Vietnamese': 'vie', 'Waray': 'war', 'Wolof': 'wol', 'Xhosa': 'xho', 'Eastern Yiddish': 'ydd', 'Yoruba': 'yor', 'Yue Chinese': 'yue', 'Chinese': 'zho', 'Standard Malay': 'zsm', 'Zulu': 'zul'}
|
150 |
+
|
151 |
+
self.language_directory = 'Languages'
|
152 |
+
self.uroman_directory = 'aux_files'
|
153 |
+
|
154 |
+
self.language_download_web = 'https://dl.fbaipublicfiles.com/mms/tts'
|
155 |
+
self.language_vocab_text = "vocab.txt"
|
156 |
+
self.language_vocab_configuration = "config.json"
|
157 |
+
self.language_vocab_model = "G_100000.pth"
|
158 |
+
|
159 |
+
# creating the audio files temporary
|
160 |
+
# ---------------------------------------
|
161 |
+
self.temp_audio_folder = 'Temp_Audios'
|
162 |
+
|
163 |
+
self.text2speech_wavfile = f'{self.temp_audio_folder}/text2speech.wav'
|
164 |
+
self.enhanced_speech_file = f"{self.temp_audio_folder}/enhanced.mp3"
|
165 |
+
self.input_speech_file = f'{self.temp_audio_folder}/output.wav'
|
166 |
+
|
167 |
+
|
168 |
+
try:
|
169 |
+
os.makedirs(self.language_directory)
|
170 |
+
except:
|
171 |
+
pass
|
172 |
+
|
173 |
+
try:
|
174 |
+
os.makedirs(self.temp_audio_folder)
|
175 |
+
except:
|
176 |
+
pass
|
configurations/get_hyperparameters.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class hyperparameterConfig:
|
4 |
+
def __init__(self):
|
5 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
+
|
7 |
+
self.stt_model = "large-v2"
|
8 |
+
self.nllb_model = '1.3B'
|
9 |
+
|
10 |
+
# text to speech model
|
11 |
+
self.text2speech_model = 'model_weights/voiceover/freevc-24.pth'
|
12 |
+
self.text2speech_config = 'model_weights/voiceover/freevc-24.json'
|
13 |
+
self.text2speech_encoder = 'model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt'
|
14 |
+
|
15 |
+
# voice enhancing model
|
16 |
+
self.voice_enhacing_model = 'model_weights/voice_enhance'
|
17 |
+
|
18 |
+
# loading the wavlm model
|
19 |
+
self.wavlm_model = 'model_weights/wavlm_models/WavLM-Large.pt'
|
df/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .config import config
|
2 |
+
|
3 |
+
__all__ = ["config"]
|
df/checkpoint.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from loguru import logger
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from df.config import Csv, config
|
12 |
+
from df.model import init_model
|
13 |
+
from df.utils import check_finite_module
|
14 |
+
from libdf import DF
|
15 |
+
|
16 |
+
|
17 |
+
def get_epoch(cp) -> int:
|
18 |
+
return int(os.path.basename(cp).split(".")[0].split("_")[-1])
|
19 |
+
|
20 |
+
|
21 |
+
def load_model(
|
22 |
+
cp_dir: Optional[str],
|
23 |
+
df_state: DF,
|
24 |
+
jit: bool = False,
|
25 |
+
mask_only: bool = False,
|
26 |
+
train_df_only: bool = False,
|
27 |
+
extension: str = "ckpt",
|
28 |
+
epoch: Union[str, int, None] = "latest",
|
29 |
+
) -> Tuple[nn.Module, int]:
|
30 |
+
if mask_only and train_df_only:
|
31 |
+
raise ValueError("Only one of `mask_only` `train_df_only` can be enabled")
|
32 |
+
model = init_model(df_state, run_df=mask_only is False, train_mask=train_df_only is False)
|
33 |
+
if jit:
|
34 |
+
model = torch.jit.script(model)
|
35 |
+
blacklist: List[str] = config("CP_BLACKLIST", [], Csv(), save=False, section="train") # type: ignore
|
36 |
+
if cp_dir is not None:
|
37 |
+
epoch = read_cp(
|
38 |
+
model, "model", cp_dir, blacklist=blacklist, extension=extension, epoch=epoch
|
39 |
+
)
|
40 |
+
epoch = 0 if epoch is None else epoch
|
41 |
+
else:
|
42 |
+
epoch = 0
|
43 |
+
return model, epoch
|
44 |
+
|
45 |
+
|
46 |
+
def read_cp(
|
47 |
+
obj: Union[torch.optim.Optimizer, nn.Module],
|
48 |
+
name: str,
|
49 |
+
dirname: str,
|
50 |
+
epoch: Union[str, int, None] = "latest",
|
51 |
+
extension="ckpt",
|
52 |
+
blacklist=[],
|
53 |
+
log: bool = True,
|
54 |
+
):
|
55 |
+
checkpoints = []
|
56 |
+
if isinstance(epoch, str):
|
57 |
+
assert epoch in ("best", "latest")
|
58 |
+
if epoch == "best":
|
59 |
+
checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}.best"))
|
60 |
+
if len(checkpoints) == 0:
|
61 |
+
logger.warning("Could not find `best` checkpoint. Checking for default...")
|
62 |
+
if len(checkpoints) == 0:
|
63 |
+
checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}"))
|
64 |
+
checkpoints += glob.glob(os.path.join(dirname, f"{name}*.{extension}.best"))
|
65 |
+
if len(checkpoints) == 0:
|
66 |
+
return None
|
67 |
+
if isinstance(epoch, int):
|
68 |
+
latest = next((x for x in checkpoints if get_epoch(x) == epoch), None)
|
69 |
+
if latest is None:
|
70 |
+
logger.error(f"Could not find checkpoint of epoch {epoch}")
|
71 |
+
exit(1)
|
72 |
+
else:
|
73 |
+
latest = max(checkpoints, key=get_epoch)
|
74 |
+
epoch = get_epoch(latest)
|
75 |
+
if log:
|
76 |
+
logger.info("Found checkpoint {} with epoch {}".format(latest, epoch))
|
77 |
+
latest = torch.load(latest, map_location="cpu")
|
78 |
+
latest = {k.replace("clc", "df"): v for k, v in latest.items()}
|
79 |
+
if blacklist:
|
80 |
+
reg = re.compile("".join(f"({b})|" for b in blacklist)[:-1])
|
81 |
+
len_before = len(latest)
|
82 |
+
latest = {k: v for k, v in latest.items() if reg.search(k) is None}
|
83 |
+
if len(latest) < len_before:
|
84 |
+
logger.info("Filtered checkpoint modules: {}".format(blacklist))
|
85 |
+
if isinstance(obj, nn.Module):
|
86 |
+
while True:
|
87 |
+
try:
|
88 |
+
missing, unexpected = obj.load_state_dict(latest, strict=False)
|
89 |
+
except RuntimeError as e:
|
90 |
+
e_str = str(e)
|
91 |
+
logger.warning(e_str)
|
92 |
+
if "size mismatch" in e_str:
|
93 |
+
latest = {k: v for k, v in latest.items() if k not in e_str}
|
94 |
+
continue
|
95 |
+
raise e
|
96 |
+
break
|
97 |
+
for key in missing:
|
98 |
+
logger.warning(f"Missing key: '{key}'")
|
99 |
+
for key in unexpected:
|
100 |
+
if key.endswith(".h0"):
|
101 |
+
continue
|
102 |
+
logger.warning(f"Unexpected key: {key}")
|
103 |
+
return epoch
|
104 |
+
obj.load_state_dict(latest)
|
105 |
+
|
106 |
+
|
107 |
+
def write_cp(
|
108 |
+
obj: Union[torch.optim.Optimizer, nn.Module],
|
109 |
+
name: str,
|
110 |
+
dirname: str,
|
111 |
+
epoch: int,
|
112 |
+
extension="ckpt",
|
113 |
+
metric: Optional[float] = None,
|
114 |
+
cmp="min",
|
115 |
+
):
|
116 |
+
check_finite_module(obj)
|
117 |
+
n_keep = config("n_checkpoint_history", default=3, cast=int, section="train")
|
118 |
+
n_keep_best = config("n_best_checkpoint_history", default=5, cast=int, section="train")
|
119 |
+
if metric is not None:
|
120 |
+
assert cmp in ("min", "max")
|
121 |
+
metric = float(metric) # Make sure it is not an integer
|
122 |
+
# Each line contains a previous best with entries: (epoch, metric)
|
123 |
+
with open(os.path.join(dirname, ".best"), "a+") as prev_best_f:
|
124 |
+
prev_best_f.seek(0) # "a+" creates a file in read/write mode without truncating
|
125 |
+
lines = prev_best_f.readlines()
|
126 |
+
if len(lines) == 0:
|
127 |
+
prev_best = float("inf" if cmp == "min" else "-inf")
|
128 |
+
else:
|
129 |
+
prev_best = float(lines[-1].strip().split(" ")[1])
|
130 |
+
cmp = "__lt__" if cmp == "min" else "__gt__"
|
131 |
+
if getattr(metric, cmp)(prev_best):
|
132 |
+
logger.info(f"Saving new best checkpoint at epoch {epoch} with metric: {metric}")
|
133 |
+
prev_best_f.seek(0, os.SEEK_END)
|
134 |
+
np.savetxt(prev_best_f, np.array([[float(epoch), metric]]))
|
135 |
+
cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}.best")
|
136 |
+
torch.save(obj.state_dict(), cp_name)
|
137 |
+
cleanup(name, dirname, extension + ".best", nkeep=n_keep_best)
|
138 |
+
cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}")
|
139 |
+
logger.info(f"Writing checkpoint {cp_name} with epoch {epoch}")
|
140 |
+
torch.save(obj.state_dict(), cp_name)
|
141 |
+
cleanup(name, dirname, extension, nkeep=n_keep)
|
142 |
+
|
143 |
+
|
144 |
+
def cleanup(name: str, dirname: str, extension: str, nkeep=5):
|
145 |
+
if nkeep < 0:
|
146 |
+
return
|
147 |
+
checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}"))
|
148 |
+
if len(checkpoints) == 0:
|
149 |
+
return
|
150 |
+
checkpoints = sorted(checkpoints, key=get_epoch, reverse=True)
|
151 |
+
for cp in checkpoints[nkeep:]:
|
152 |
+
logger.debug("Removing old checkpoint: {}".format(cp))
|
153 |
+
os.remove(cp)
|
154 |
+
|
155 |
+
|
156 |
+
def check_patience(
|
157 |
+
dirname: str, max_patience: int, new_metric: float, cmp: str = "min", raise_: bool = True
|
158 |
+
):
|
159 |
+
cmp = "__lt__" if cmp == "min" else "__gt__"
|
160 |
+
new_metric = float(new_metric) # Make sure it is not an integer
|
161 |
+
prev_patience, prev_metric = read_patience(dirname)
|
162 |
+
if prev_patience is None or getattr(new_metric, cmp)(prev_metric):
|
163 |
+
# We have a better new_metric, reset patience
|
164 |
+
write_patience(dirname, 0, new_metric)
|
165 |
+
else:
|
166 |
+
# We don't have a better metric, decrement patience
|
167 |
+
new_patience = prev_patience + 1
|
168 |
+
write_patience(dirname, new_patience, prev_metric)
|
169 |
+
if new_patience >= max_patience:
|
170 |
+
if raise_:
|
171 |
+
raise ValueError(
|
172 |
+
f"No improvements on validation metric ({new_metric}) for {max_patience} epochs. "
|
173 |
+
"Stopping."
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
return False
|
177 |
+
return True
|
178 |
+
|
179 |
+
|
180 |
+
def read_patience(dirname: str) -> Tuple[Optional[int], float]:
|
181 |
+
fn = os.path.join(dirname, ".patience")
|
182 |
+
if not os.path.isfile(fn):
|
183 |
+
return None, 0.0
|
184 |
+
patience, metric = np.loadtxt(fn)
|
185 |
+
return int(patience), float(metric)
|
186 |
+
|
187 |
+
|
188 |
+
def write_patience(dirname: str, new_patience: int, metric: float):
|
189 |
+
return np.savetxt(os.path.join(dirname, ".patience"), [new_patience, metric])
|
190 |
+
|
191 |
+
|
192 |
+
def test_check_patience():
|
193 |
+
import tempfile
|
194 |
+
|
195 |
+
with tempfile.TemporaryDirectory() as d:
|
196 |
+
check_patience(d, 3, 1.0)
|
197 |
+
check_patience(d, 3, 1.0)
|
198 |
+
check_patience(d, 3, 1.0)
|
199 |
+
assert check_patience(d, 3, 1.0, raise_=False) is False
|
200 |
+
|
201 |
+
with tempfile.TemporaryDirectory() as d:
|
202 |
+
check_patience(d, 3, 1.0)
|
203 |
+
check_patience(d, 3, 0.9)
|
204 |
+
check_patience(d, 3, 1.0)
|
205 |
+
check_patience(d, 3, 1.0)
|
206 |
+
assert check_patience(d, 3, 1.0, raise_=False) is False
|
207 |
+
|
208 |
+
with tempfile.TemporaryDirectory() as d:
|
209 |
+
check_patience(d, 3, 1.0, cmp="max")
|
210 |
+
check_patience(d, 3, 1.9, cmp="max")
|
211 |
+
check_patience(d, 3, 1.0, cmp="max")
|
212 |
+
check_patience(d, 3, 1.0, cmp="max")
|
213 |
+
assert check_patience(d, 3, 1.0, cmp="max", raise_=False) is False
|
df/config.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import string
|
3 |
+
from configparser import ConfigParser
|
4 |
+
from shlex import shlex
|
5 |
+
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
T = TypeVar("T")
|
10 |
+
|
11 |
+
|
12 |
+
class DfParams:
|
13 |
+
def __init__(self):
|
14 |
+
# Sampling rate used for training
|
15 |
+
self.sr: int = config("SR", cast=int, default=48_000, section="DF")
|
16 |
+
# FFT size in samples
|
17 |
+
self.fft_size: int = config("FFT_SIZE", cast=int, default=960, section="DF")
|
18 |
+
# STFT Hop size in samples
|
19 |
+
self.hop_size: int = config("HOP_SIZE", cast=int, default=480, section="DF")
|
20 |
+
# Number of ERB bands
|
21 |
+
self.nb_erb: int = config("NB_ERB", cast=int, default=32, section="DF")
|
22 |
+
# Number of deep filtering bins; DF is applied from 0th to nb_df-th frequency bins
|
23 |
+
self.nb_df: int = config("NB_DF", cast=int, default=96, section="DF")
|
24 |
+
# Normalization decay factor; used for complex and erb features
|
25 |
+
self.norm_tau: float = config("NORM_TAU", 1, float, section="DF")
|
26 |
+
# Local SNR minimum value, ground truth will be truncated
|
27 |
+
self.lsnr_max: int = config("LSNR_MAX", 35, int, section="DF")
|
28 |
+
# Local SNR maximum value, ground truth will be truncated
|
29 |
+
self.lsnr_min: int = config("LSNR_MIN", -15, int, section="DF")
|
30 |
+
# Minimum number of frequency bins per ERB band
|
31 |
+
self.min_nb_freqs = config("MIN_NB_ERB_FREQS", 2, int, section="DF")
|
32 |
+
# Deep Filtering order
|
33 |
+
self.df_order: int = config("DF_ORDER", cast=int, default=5, section="DF")
|
34 |
+
# Deep Filtering look-ahead
|
35 |
+
self.df_lookahead: int = config("DF_LOOKAHEAD", cast=int, default=0, section="DF")
|
36 |
+
# Pad mode. By default, padding will be handled on the input side:
|
37 |
+
# - `input`, which pads the input features passed to the model
|
38 |
+
# - `output`, which pads the output spectrogram corresponding to `df_lookahead`
|
39 |
+
self.pad_mode: str = config("PAD_MODE", default="input_specf", section="DF")
|
40 |
+
|
41 |
+
|
42 |
+
class Config:
|
43 |
+
"""Adopted from python-decouple"""
|
44 |
+
|
45 |
+
DEFAULT_SECTION = "settings"
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
self.parser: ConfigParser = None # type: ignore
|
49 |
+
self.path = ""
|
50 |
+
self.modified = False
|
51 |
+
self.allow_defaults = True
|
52 |
+
|
53 |
+
def load(
|
54 |
+
self, path: Optional[str], config_must_exist=False, allow_defaults=True, allow_reload=False
|
55 |
+
):
|
56 |
+
self.allow_defaults = allow_defaults
|
57 |
+
if self.parser is not None and not allow_reload:
|
58 |
+
raise ValueError("Config already loaded")
|
59 |
+
self.parser = ConfigParser()
|
60 |
+
self.path = path
|
61 |
+
if path is not None and os.path.isfile(path):
|
62 |
+
with open(path) as f:
|
63 |
+
self.parser.read_file(f)
|
64 |
+
else:
|
65 |
+
if config_must_exist:
|
66 |
+
raise ValueError(f"No config file found at '{path}'.")
|
67 |
+
if not self.parser.has_section(self.DEFAULT_SECTION):
|
68 |
+
self.parser.add_section(self.DEFAULT_SECTION)
|
69 |
+
self._fix_clc()
|
70 |
+
self._fix_df()
|
71 |
+
|
72 |
+
def use_defaults(self):
|
73 |
+
self.load(path=None, config_must_exist=False)
|
74 |
+
|
75 |
+
def save(self, path: str):
|
76 |
+
if not self.modified:
|
77 |
+
logger.debug("Config not modified. No need to overwrite on disk.")
|
78 |
+
return
|
79 |
+
if self.parser is None:
|
80 |
+
self.parser = ConfigParser()
|
81 |
+
for section in self.parser.sections():
|
82 |
+
if len(self.parser[section]) == 0:
|
83 |
+
self.parser.remove_section(section)
|
84 |
+
with open(path, mode="w") as f:
|
85 |
+
self.parser.write(f)
|
86 |
+
|
87 |
+
def tostr(self, value, cast):
|
88 |
+
if isinstance(cast, Csv) and isinstance(value, (tuple, list)):
|
89 |
+
return "".join(str(v) + cast.delimiter for v in value)[:-1]
|
90 |
+
return str(value)
|
91 |
+
|
92 |
+
def set(self, option: str, value: T, cast: Type[T], section: Optional[str] = None) -> T:
|
93 |
+
section = self.DEFAULT_SECTION if section is None else section
|
94 |
+
section = section.lower()
|
95 |
+
if not self.parser.has_section(section):
|
96 |
+
self.parser.add_section(section)
|
97 |
+
if self.parser.has_option(section, option):
|
98 |
+
if value == self.cast(self.parser.get(section, option), cast):
|
99 |
+
return value
|
100 |
+
self.modified = True
|
101 |
+
self.parser.set(section, option, self.tostr(value, cast))
|
102 |
+
return value
|
103 |
+
|
104 |
+
def __call__(
|
105 |
+
self,
|
106 |
+
option: str,
|
107 |
+
default: Any = None,
|
108 |
+
cast: Type[T] = str,
|
109 |
+
save: bool = True,
|
110 |
+
section: Optional[str] = None,
|
111 |
+
) -> T:
|
112 |
+
# Get value either from an ENV or from the .ini file
|
113 |
+
section = self.DEFAULT_SECTION if section is None else section
|
114 |
+
value = None
|
115 |
+
if self.parser is None:
|
116 |
+
raise ValueError("No configuration loaded")
|
117 |
+
if not self.parser.has_section(section.lower()):
|
118 |
+
self.parser.add_section(section.lower())
|
119 |
+
if option in os.environ:
|
120 |
+
value = os.environ[option]
|
121 |
+
if save:
|
122 |
+
self.parser.set(section, option, self.tostr(value, cast))
|
123 |
+
elif self.parser.has_option(section, option):
|
124 |
+
value = self.read_from_section(section, option, default, cast=cast, save=save)
|
125 |
+
elif self.parser.has_option(section.lower(), option):
|
126 |
+
value = self.read_from_section(section.lower(), option, default, cast=cast, save=save)
|
127 |
+
elif self.parser.has_option(self.DEFAULT_SECTION, option):
|
128 |
+
logger.warning(
|
129 |
+
f"Couldn't find option {option} in section {section}. "
|
130 |
+
"Falling back to default settings section."
|
131 |
+
)
|
132 |
+
value = self.read_from_section(self.DEFAULT_SECTION, option, cast=cast, save=save)
|
133 |
+
elif default is None:
|
134 |
+
raise ValueError("Value {} not found.".format(option))
|
135 |
+
elif not self.allow_defaults and save:
|
136 |
+
raise ValueError(f"Value '{option}' not found in config (defaults not allowed).")
|
137 |
+
else:
|
138 |
+
value = default
|
139 |
+
if save:
|
140 |
+
self.set(option, value, cast, section)
|
141 |
+
return self.cast(value, cast)
|
142 |
+
|
143 |
+
def cast(self, value, cast):
|
144 |
+
# Do the casting to get the correct type
|
145 |
+
if cast is bool:
|
146 |
+
value = str(value).lower()
|
147 |
+
if value in {"true", "yes", "y", "on", "1"}:
|
148 |
+
return True # type: ignore
|
149 |
+
elif value in {"false", "no", "n", "off", "0"}:
|
150 |
+
return False # type: ignore
|
151 |
+
raise ValueError("Parse error")
|
152 |
+
return cast(value)
|
153 |
+
|
154 |
+
def get(self, option: str, cast: Type[T] = str, section: Optional[str] = None) -> T:
|
155 |
+
section = self.DEFAULT_SECTION if section is None else section
|
156 |
+
if not self.parser.has_section(section):
|
157 |
+
raise KeyError(section)
|
158 |
+
if not self.parser.has_option(section, option):
|
159 |
+
raise KeyError(option)
|
160 |
+
return self.cast(self.parser.get(section, option), cast)
|
161 |
+
|
162 |
+
def read_from_section(
|
163 |
+
self, section: str, option: str, default: Any = None, cast: Type = str, save: bool = True
|
164 |
+
) -> str:
|
165 |
+
value = self.parser.get(section, option)
|
166 |
+
if not save:
|
167 |
+
# Set to default or remove to not read it at trainig start again
|
168 |
+
if default is None:
|
169 |
+
self.parser.remove_option(section, option)
|
170 |
+
elif not self.allow_defaults:
|
171 |
+
raise ValueError(f"Value '{option}' not found in config (defaults not allowed).")
|
172 |
+
else:
|
173 |
+
self.parser.set(section, option, self.tostr(default, cast))
|
174 |
+
elif section.lower() != section:
|
175 |
+
self.parser.set(section.lower(), option, self.tostr(value, cast))
|
176 |
+
self.parser.remove_option(section, option)
|
177 |
+
self.modified = True
|
178 |
+
return value
|
179 |
+
|
180 |
+
def overwrite(self, section: str, option: str, value: Any):
|
181 |
+
if not self.parser.has_section(section):
|
182 |
+
return ValueError(f"Section not found: '{section}'")
|
183 |
+
if not self.parser.has_option(section, option):
|
184 |
+
return ValueError(f"Option not found '{option}' in section '{section}'")
|
185 |
+
self.modified = True
|
186 |
+
cast = type(value)
|
187 |
+
return self.parser.set(section, option, self.tostr(value, cast))
|
188 |
+
|
189 |
+
def _fix_df(self):
|
190 |
+
"""Renaming of some groups/options for compatibility with old models."""
|
191 |
+
if self.parser.has_section("deepfilternet") and self.parser.has_section("df"):
|
192 |
+
sec_deepfilternet = self.parser["deepfilternet"]
|
193 |
+
sec_df = self.parser["df"]
|
194 |
+
if "df_order" in sec_deepfilternet:
|
195 |
+
sec_df["df_order"] = sec_deepfilternet["df_order"]
|
196 |
+
del sec_deepfilternet["df_order"]
|
197 |
+
if "df_lookahead" in sec_deepfilternet:
|
198 |
+
sec_df["df_lookahead"] = sec_deepfilternet["df_lookahead"]
|
199 |
+
del sec_deepfilternet["df_lookahead"]
|
200 |
+
|
201 |
+
def _fix_clc(self):
|
202 |
+
"""Renaming of some groups/options for compatibility with old models."""
|
203 |
+
if (
|
204 |
+
not self.parser.has_section("deepfilternet")
|
205 |
+
and self.parser.has_section("train")
|
206 |
+
and self.parser.get("train", "model") == "convgru5"
|
207 |
+
):
|
208 |
+
self.overwrite("train", "model", "deepfilternet")
|
209 |
+
self.parser.add_section("deepfilternet")
|
210 |
+
self.parser["deepfilternet"] = self.parser["convgru"]
|
211 |
+
del self.parser["convgru"]
|
212 |
+
if not self.parser.has_section("df") and self.parser.has_section("clc"):
|
213 |
+
self.parser["df"] = self.parser["clc"]
|
214 |
+
del self.parser["clc"]
|
215 |
+
for section in self.parser.sections():
|
216 |
+
for k, v in self.parser[section].items():
|
217 |
+
if "clc" in k.lower():
|
218 |
+
self.parser.set(section, k.lower().replace("clc", "df"), v)
|
219 |
+
del self.parser[section][k]
|
220 |
+
|
221 |
+
def __repr__(self):
|
222 |
+
msg = ""
|
223 |
+
for section in self.parser.sections():
|
224 |
+
msg += f"{section}:\n"
|
225 |
+
for k, v in self.parser[section].items():
|
226 |
+
msg += f" {k}: {v}\n"
|
227 |
+
return msg
|
228 |
+
|
229 |
+
|
230 |
+
config = Config()
|
231 |
+
|
232 |
+
|
233 |
+
class Csv(object):
|
234 |
+
"""
|
235 |
+
Produces a csv parser that return a list of transformed elements. From python-decouple.
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(
|
239 |
+
self, cast: Type[T] = str, delimiter=",", strip=string.whitespace, post_process=list
|
240 |
+
):
|
241 |
+
"""
|
242 |
+
Parameters:
|
243 |
+
cast -- callable that transforms the item just before it's added to the list.
|
244 |
+
delimiter -- string of delimiters chars passed to shlex.
|
245 |
+
strip -- string of non-relevant characters to be passed to str.strip after the split.
|
246 |
+
post_process -- callable to post process all casted values. Default is `list`.
|
247 |
+
"""
|
248 |
+
self.cast: Type[T] = cast
|
249 |
+
self.delimiter = delimiter
|
250 |
+
self.strip = strip
|
251 |
+
self.post_process = post_process
|
252 |
+
|
253 |
+
def __call__(self, value: Union[str, Tuple[T], List[T]]) -> List[T]:
|
254 |
+
"""The actual transformation"""
|
255 |
+
if isinstance(value, (tuple, list)):
|
256 |
+
# if default value is a list
|
257 |
+
value = "".join(str(v) + self.delimiter for v in value)[:-1]
|
258 |
+
|
259 |
+
def transform(s):
|
260 |
+
return self.cast(s.strip(self.strip))
|
261 |
+
|
262 |
+
splitter = shlex(value, posix=True)
|
263 |
+
splitter.whitespace = self.delimiter
|
264 |
+
splitter.whitespace_split = True
|
265 |
+
|
266 |
+
return self.post_process(transform(s) for s in splitter)
|
df/deepfilternet2.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Final, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from loguru import logger
|
6 |
+
from torch import Tensor, nn
|
7 |
+
|
8 |
+
from df.config import Csv, DfParams, config
|
9 |
+
from df.modules import (
|
10 |
+
Conv2dNormAct,
|
11 |
+
ConvTranspose2dNormAct,
|
12 |
+
DfOp,
|
13 |
+
GroupedGRU,
|
14 |
+
GroupedLinear,
|
15 |
+
GroupedLinearEinsum,
|
16 |
+
Mask,
|
17 |
+
SqueezedGRU,
|
18 |
+
erb_fb,
|
19 |
+
get_device,
|
20 |
+
)
|
21 |
+
from df.multiframe import MF_METHODS, MultiFrameModule
|
22 |
+
from libdf import DF
|
23 |
+
|
24 |
+
|
25 |
+
class ModelParams(DfParams):
|
26 |
+
section = "deepfilternet"
|
27 |
+
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
self.conv_lookahead: int = config(
|
31 |
+
"CONV_LOOKAHEAD", cast=int, default=0, section=self.section
|
32 |
+
)
|
33 |
+
self.conv_ch: int = config("CONV_CH", cast=int, default=16, section=self.section)
|
34 |
+
self.conv_depthwise: bool = config(
|
35 |
+
"CONV_DEPTHWISE", cast=bool, default=True, section=self.section
|
36 |
+
)
|
37 |
+
self.convt_depthwise: bool = config(
|
38 |
+
"CONVT_DEPTHWISE", cast=bool, default=True, section=self.section
|
39 |
+
)
|
40 |
+
self.conv_kernel: List[int] = config(
|
41 |
+
"CONV_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore
|
42 |
+
)
|
43 |
+
self.conv_kernel_inp: List[int] = config(
|
44 |
+
"CONV_KERNEL_INP", cast=Csv(int), default=(3, 3), section=self.section # type: ignore
|
45 |
+
)
|
46 |
+
self.emb_hidden_dim: int = config(
|
47 |
+
"EMB_HIDDEN_DIM", cast=int, default=256, section=self.section
|
48 |
+
)
|
49 |
+
self.emb_num_layers: int = config(
|
50 |
+
"EMB_NUM_LAYERS", cast=int, default=2, section=self.section
|
51 |
+
)
|
52 |
+
self.df_hidden_dim: int = config(
|
53 |
+
"DF_HIDDEN_DIM", cast=int, default=256, section=self.section
|
54 |
+
)
|
55 |
+
self.df_gru_skip: str = config("DF_GRU_SKIP", default="none", section=self.section)
|
56 |
+
self.df_output_layer: str = config(
|
57 |
+
"DF_OUTPUT_LAYER", default="linear", section=self.section
|
58 |
+
)
|
59 |
+
self.df_pathway_kernel_size_t: int = config(
|
60 |
+
"DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section
|
61 |
+
)
|
62 |
+
self.enc_concat: bool = config("ENC_CONCAT", cast=bool, default=False, section=self.section)
|
63 |
+
self.df_num_layers: int = config("DF_NUM_LAYERS", cast=int, default=3, section=self.section)
|
64 |
+
self.df_n_iter: int = config("DF_N_ITER", cast=int, default=2, section=self.section)
|
65 |
+
self.gru_type: str = config("GRU_TYPE", default="grouped", section=self.section)
|
66 |
+
self.gru_groups: int = config("GRU_GROUPS", cast=int, default=1, section=self.section)
|
67 |
+
self.lin_groups: int = config("LINEAR_GROUPS", cast=int, default=1, section=self.section)
|
68 |
+
self.group_shuffle: bool = config(
|
69 |
+
"GROUP_SHUFFLE", cast=bool, default=True, section=self.section
|
70 |
+
)
|
71 |
+
self.dfop_method: str = config("DFOP_METHOD", cast=str, default="df", section=self.section)
|
72 |
+
self.mask_pf: bool = config("MASK_PF", cast=bool, default=False, section=self.section)
|
73 |
+
|
74 |
+
|
75 |
+
def init_model(df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True):
|
76 |
+
p = ModelParams()
|
77 |
+
if df_state is None:
|
78 |
+
df_state = DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb)
|
79 |
+
erb = erb_fb(df_state.erb_widths(), p.sr, inverse=False)
|
80 |
+
erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True)
|
81 |
+
model = DfNet(erb, erb_inverse, run_df, train_mask)
|
82 |
+
return model.to(device=get_device())
|
83 |
+
|
84 |
+
|
85 |
+
class Add(nn.Module):
|
86 |
+
def forward(self, a, b):
|
87 |
+
return a + b
|
88 |
+
|
89 |
+
|
90 |
+
class Concat(nn.Module):
|
91 |
+
def forward(self, a, b):
|
92 |
+
return torch.cat((a, b), dim=-1)
|
93 |
+
|
94 |
+
|
95 |
+
class Encoder(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
p = ModelParams()
|
99 |
+
assert p.nb_erb % 4 == 0, "erb_bins should be divisible by 4"
|
100 |
+
|
101 |
+
self.erb_conv0 = Conv2dNormAct(
|
102 |
+
1, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
|
103 |
+
)
|
104 |
+
conv_layer = partial(
|
105 |
+
Conv2dNormAct,
|
106 |
+
in_ch=p.conv_ch,
|
107 |
+
out_ch=p.conv_ch,
|
108 |
+
kernel_size=p.conv_kernel,
|
109 |
+
bias=False,
|
110 |
+
separable=True,
|
111 |
+
)
|
112 |
+
self.erb_conv1 = conv_layer(fstride=2)
|
113 |
+
self.erb_conv2 = conv_layer(fstride=2)
|
114 |
+
self.erb_conv3 = conv_layer(fstride=1)
|
115 |
+
self.df_conv0 = Conv2dNormAct(
|
116 |
+
2, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
|
117 |
+
)
|
118 |
+
self.df_conv1 = conv_layer(fstride=2)
|
119 |
+
self.erb_bins = p.nb_erb
|
120 |
+
self.emb_in_dim = p.conv_ch * p.nb_erb // 4
|
121 |
+
self.emb_out_dim = p.emb_hidden_dim
|
122 |
+
if p.gru_type == "grouped":
|
123 |
+
self.df_fc_emb = GroupedLinear(
|
124 |
+
p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
df_fc_emb = GroupedLinearEinsum(
|
128 |
+
p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups
|
129 |
+
)
|
130 |
+
self.df_fc_emb = nn.Sequential(df_fc_emb, nn.ReLU(inplace=True))
|
131 |
+
if p.enc_concat:
|
132 |
+
self.emb_in_dim *= 2
|
133 |
+
self.combine = Concat()
|
134 |
+
else:
|
135 |
+
self.combine = Add()
|
136 |
+
self.emb_out_dim = p.emb_hidden_dim
|
137 |
+
self.emb_n_layers = p.emb_num_layers
|
138 |
+
assert p.gru_type in ("grouped", "squeeze"), f"But got {p.gru_type}"
|
139 |
+
if p.gru_type == "grouped":
|
140 |
+
self.emb_gru = GroupedGRU(
|
141 |
+
self.emb_in_dim,
|
142 |
+
self.emb_out_dim,
|
143 |
+
num_layers=1,
|
144 |
+
batch_first=True,
|
145 |
+
groups=p.gru_groups,
|
146 |
+
shuffle=p.group_shuffle,
|
147 |
+
add_outputs=True,
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
self.emb_gru = SqueezedGRU(
|
151 |
+
self.emb_in_dim,
|
152 |
+
self.emb_out_dim,
|
153 |
+
num_layers=1,
|
154 |
+
batch_first=True,
|
155 |
+
linear_groups=p.lin_groups,
|
156 |
+
linear_act_layer=partial(nn.ReLU, inplace=True),
|
157 |
+
)
|
158 |
+
self.lsnr_fc = nn.Sequential(nn.Linear(self.emb_out_dim, 1), nn.Sigmoid())
|
159 |
+
self.lsnr_scale = p.lsnr_max - p.lsnr_min
|
160 |
+
self.lsnr_offset = p.lsnr_min
|
161 |
+
|
162 |
+
def forward(
|
163 |
+
self, feat_erb: Tensor, feat_spec: Tensor
|
164 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
165 |
+
# Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands.
|
166 |
+
# erb: [B, 1, T, Fe]
|
167 |
+
# spec: [B, 2, T, Fc]
|
168 |
+
# b, _, t, _ = feat_erb.shape
|
169 |
+
e0 = self.erb_conv0(feat_erb) # [B, C, T, F]
|
170 |
+
e1 = self.erb_conv1(e0) # [B, C*2, T, F/2]
|
171 |
+
e2 = self.erb_conv2(e1) # [B, C*4, T, F/4]
|
172 |
+
e3 = self.erb_conv3(e2) # [B, C*4, T, F/4]
|
173 |
+
c0 = self.df_conv0(feat_spec) # [B, C, T, Fc]
|
174 |
+
c1 = self.df_conv1(c0) # [B, C*2, T, Fc]
|
175 |
+
cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1]
|
176 |
+
cemb = self.df_fc_emb(cemb) # [T, B, C * F/4]
|
177 |
+
emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F/4]
|
178 |
+
emb = self.combine(emb, cemb)
|
179 |
+
emb, _ = self.emb_gru(emb) # [B, T, -1]
|
180 |
+
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
181 |
+
return e0, e1, e2, e3, emb, c0, lsnr
|
182 |
+
|
183 |
+
|
184 |
+
class ErbDecoder(nn.Module):
|
185 |
+
def __init__(self):
|
186 |
+
super().__init__()
|
187 |
+
p = ModelParams()
|
188 |
+
assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8"
|
189 |
+
|
190 |
+
self.emb_out_dim = p.emb_hidden_dim
|
191 |
+
|
192 |
+
if p.gru_type == "grouped":
|
193 |
+
self.emb_gru = GroupedGRU(
|
194 |
+
p.conv_ch * p.nb_erb // 4, # For compat
|
195 |
+
self.emb_out_dim,
|
196 |
+
num_layers=p.emb_num_layers - 1,
|
197 |
+
batch_first=True,
|
198 |
+
groups=p.gru_groups,
|
199 |
+
shuffle=p.group_shuffle,
|
200 |
+
add_outputs=True,
|
201 |
+
)
|
202 |
+
# SqueezedGRU uses GroupedLinearEinsum, so let's use it here as well
|
203 |
+
fc_emb = GroupedLinear(
|
204 |
+
p.emb_hidden_dim,
|
205 |
+
p.conv_ch * p.nb_erb // 4,
|
206 |
+
groups=p.lin_groups,
|
207 |
+
shuffle=p.group_shuffle,
|
208 |
+
)
|
209 |
+
self.fc_emb = nn.Sequential(fc_emb, nn.ReLU(inplace=True))
|
210 |
+
else:
|
211 |
+
self.emb_gru = SqueezedGRU(
|
212 |
+
self.emb_out_dim,
|
213 |
+
self.emb_out_dim,
|
214 |
+
output_size=p.conv_ch * p.nb_erb // 4,
|
215 |
+
num_layers=p.emb_num_layers - 1,
|
216 |
+
batch_first=True,
|
217 |
+
gru_skip_op=nn.Identity,
|
218 |
+
linear_groups=p.lin_groups,
|
219 |
+
linear_act_layer=partial(nn.ReLU, inplace=True),
|
220 |
+
)
|
221 |
+
self.fc_emb = nn.Identity()
|
222 |
+
tconv_layer = partial(
|
223 |
+
ConvTranspose2dNormAct,
|
224 |
+
kernel_size=p.conv_kernel,
|
225 |
+
bias=False,
|
226 |
+
separable=True,
|
227 |
+
)
|
228 |
+
conv_layer = partial(
|
229 |
+
Conv2dNormAct,
|
230 |
+
bias=False,
|
231 |
+
separable=True,
|
232 |
+
)
|
233 |
+
# convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions
|
234 |
+
self.conv3p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
|
235 |
+
self.convt3 = conv_layer(p.conv_ch, p.conv_ch, kernel_size=p.conv_kernel)
|
236 |
+
self.conv2p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
|
237 |
+
self.convt2 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2)
|
238 |
+
self.conv1p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
|
239 |
+
self.convt1 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2)
|
240 |
+
self.conv0p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1)
|
241 |
+
self.conv0_out = conv_layer(
|
242 |
+
p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid
|
243 |
+
)
|
244 |
+
|
245 |
+
def forward(self, emb, e3, e2, e1, e0) -> Tensor:
|
246 |
+
# Estimates erb mask
|
247 |
+
b, _, t, f8 = e3.shape
|
248 |
+
emb, _ = self.emb_gru(emb)
|
249 |
+
emb = self.fc_emb(emb)
|
250 |
+
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8]
|
251 |
+
e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4]
|
252 |
+
e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2]
|
253 |
+
e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F]
|
254 |
+
m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F]
|
255 |
+
return m
|
256 |
+
|
257 |
+
|
258 |
+
class DfOutputReshapeMF(nn.Module):
|
259 |
+
"""Coefficients output reshape for multiframe/MultiFrameModule
|
260 |
+
|
261 |
+
Requires input of shape B, C, T, F, 2.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, df_order: int, df_bins: int):
|
265 |
+
super().__init__()
|
266 |
+
self.df_order = df_order
|
267 |
+
self.df_bins = df_bins
|
268 |
+
|
269 |
+
def forward(self, coefs: Tensor) -> Tensor:
|
270 |
+
# [B, T, F, O*2] -> [B, O, T, F, 2]
|
271 |
+
coefs = coefs.view(*coefs.shape[:-1], -1, 2)
|
272 |
+
coefs = coefs.permute(0, 3, 1, 2, 4)
|
273 |
+
return coefs
|
274 |
+
|
275 |
+
|
276 |
+
class DfDecoder(nn.Module):
|
277 |
+
def __init__(self, out_channels: int = -1):
|
278 |
+
super().__init__()
|
279 |
+
p = ModelParams()
|
280 |
+
layer_width = p.conv_ch
|
281 |
+
self.emb_dim = p.emb_hidden_dim
|
282 |
+
|
283 |
+
self.df_n_hidden = p.df_hidden_dim
|
284 |
+
self.df_n_layers = p.df_num_layers
|
285 |
+
self.df_order = p.df_order
|
286 |
+
self.df_bins = p.nb_df
|
287 |
+
self.gru_groups = p.gru_groups
|
288 |
+
self.df_out_ch = out_channels if out_channels > 0 else p.df_order * 2
|
289 |
+
|
290 |
+
conv_layer = partial(Conv2dNormAct, separable=True, bias=False)
|
291 |
+
kt = p.df_pathway_kernel_size_t
|
292 |
+
self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1))
|
293 |
+
if p.gru_type == "grouped":
|
294 |
+
self.df_gru = GroupedGRU(
|
295 |
+
p.emb_hidden_dim,
|
296 |
+
p.df_hidden_dim,
|
297 |
+
num_layers=self.df_n_layers,
|
298 |
+
batch_first=True,
|
299 |
+
groups=p.gru_groups,
|
300 |
+
shuffle=p.group_shuffle,
|
301 |
+
add_outputs=True,
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
self.df_gru = SqueezedGRU(
|
305 |
+
p.emb_hidden_dim,
|
306 |
+
p.df_hidden_dim,
|
307 |
+
num_layers=self.df_n_layers,
|
308 |
+
batch_first=True,
|
309 |
+
gru_skip_op=nn.Identity,
|
310 |
+
linear_act_layer=partial(nn.ReLU, inplace=True),
|
311 |
+
)
|
312 |
+
p.df_gru_skip = p.df_gru_skip.lower()
|
313 |
+
assert p.df_gru_skip in ("none", "identity", "groupedlinear")
|
314 |
+
self.df_skip: Optional[nn.Module]
|
315 |
+
if p.df_gru_skip == "none":
|
316 |
+
self.df_skip = None
|
317 |
+
elif p.df_gru_skip == "identity":
|
318 |
+
assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match"
|
319 |
+
self.df_skip = nn.Identity()
|
320 |
+
elif p.df_gru_skip == "groupedlinear":
|
321 |
+
self.df_skip = GroupedLinearEinsum(
|
322 |
+
p.emb_hidden_dim, p.df_hidden_dim, groups=p.lin_groups
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
raise NotImplementedError()
|
326 |
+
assert p.df_output_layer in ("linear", "groupedlinear")
|
327 |
+
self.df_out: nn.Module
|
328 |
+
out_dim = self.df_bins * self.df_out_ch
|
329 |
+
if p.df_output_layer == "linear":
|
330 |
+
df_out = nn.Linear(self.df_n_hidden, out_dim)
|
331 |
+
elif p.df_output_layer == "groupedlinear":
|
332 |
+
df_out = GroupedLinearEinsum(self.df_n_hidden, out_dim, groups=p.lin_groups)
|
333 |
+
else:
|
334 |
+
raise NotImplementedError
|
335 |
+
self.df_out = nn.Sequential(df_out, nn.Tanh())
|
336 |
+
self.df_fc_a = nn.Sequential(nn.Linear(self.df_n_hidden, 1), nn.Sigmoid())
|
337 |
+
self.out_transform = DfOutputReshapeMF(self.df_order, self.df_bins)
|
338 |
+
|
339 |
+
def forward(self, emb: Tensor, c0: Tensor) -> Tuple[Tensor, Tensor]:
|
340 |
+
b, t, _ = emb.shape
|
341 |
+
c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden
|
342 |
+
if self.df_skip is not None:
|
343 |
+
c += self.df_skip(emb)
|
344 |
+
c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last
|
345 |
+
alpha = self.df_fc_a(c) # [B, T, 1]
|
346 |
+
c = self.df_out(c) # [B, T, F*O*2], O: df_order
|
347 |
+
c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2]
|
348 |
+
c = self.out_transform(c)
|
349 |
+
return c, alpha
|
350 |
+
|
351 |
+
|
352 |
+
class DfNet(nn.Module):
|
353 |
+
run_df: Final[bool]
|
354 |
+
pad_specf: Final[bool]
|
355 |
+
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
erb_fb: Tensor,
|
359 |
+
erb_inv_fb: Tensor,
|
360 |
+
run_df: bool = True,
|
361 |
+
train_mask: bool = True,
|
362 |
+
):
|
363 |
+
super().__init__()
|
364 |
+
p = ModelParams()
|
365 |
+
layer_width = p.conv_ch
|
366 |
+
assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8"
|
367 |
+
self.df_lookahead = p.df_lookahead if p.pad_mode == "model" else 0
|
368 |
+
self.nb_df = p.nb_df
|
369 |
+
self.freq_bins: int = p.fft_size // 2 + 1
|
370 |
+
self.emb_dim: int = layer_width * p.nb_erb
|
371 |
+
self.erb_bins: int = p.nb_erb
|
372 |
+
if p.conv_lookahead > 0 and p.pad_mode.startswith("input"):
|
373 |
+
self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0)
|
374 |
+
else:
|
375 |
+
self.pad_feat = nn.Identity()
|
376 |
+
self.pad_specf = p.pad_mode.endswith("specf")
|
377 |
+
if p.df_lookahead > 0 and self.pad_specf:
|
378 |
+
self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -p.df_lookahead, p.df_lookahead), 0.0)
|
379 |
+
else:
|
380 |
+
self.pad_spec = nn.Identity()
|
381 |
+
if (p.conv_lookahead > 0 or p.df_lookahead > 0) and p.pad_mode.startswith("output"):
|
382 |
+
assert p.conv_lookahead == p.df_lookahead
|
383 |
+
pad = (0, 0, 0, 0, -p.conv_lookahead, p.conv_lookahead)
|
384 |
+
self.pad_out = nn.ConstantPad3d(pad, 0.0)
|
385 |
+
else:
|
386 |
+
self.pad_out = nn.Identity()
|
387 |
+
self.register_buffer("erb_fb", erb_fb)
|
388 |
+
self.enc = Encoder()
|
389 |
+
self.erb_dec = ErbDecoder()
|
390 |
+
self.mask = Mask(erb_inv_fb, post_filter=p.mask_pf)
|
391 |
+
|
392 |
+
self.df_order = p.df_order
|
393 |
+
self.df_bins = p.nb_df
|
394 |
+
self.df_op: Union[DfOp, MultiFrameModule]
|
395 |
+
if p.dfop_method == "real_unfold":
|
396 |
+
raise ValueError("RealUnfold DF OP is now unsupported.")
|
397 |
+
assert p.df_output_layer != "linear", "Must be used with `groupedlinear`"
|
398 |
+
self.df_op = MF_METHODS[p.dfop_method](
|
399 |
+
num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead
|
400 |
+
)
|
401 |
+
n_ch_out = self.df_op.num_channels()
|
402 |
+
self.df_dec = DfDecoder(out_channels=n_ch_out)
|
403 |
+
|
404 |
+
self.run_df = run_df
|
405 |
+
if not run_df:
|
406 |
+
logger.warning("Runing without DF")
|
407 |
+
self.train_mask = train_mask
|
408 |
+
assert p.df_n_iter == 1
|
409 |
+
|
410 |
+
def forward(
|
411 |
+
self,
|
412 |
+
spec: Tensor,
|
413 |
+
feat_erb: Tensor,
|
414 |
+
feat_spec: Tensor, # Not used, take spec modified by mask instead
|
415 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
416 |
+
"""Forward method of DeepFilterNet2.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
spec (Tensor): Spectrum of shape [B, 1, T, F, 2]
|
420 |
+
feat_erb (Tensor): ERB features of shape [B, 1, T, E]
|
421 |
+
feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F']
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2]
|
425 |
+
m (Tensor): ERB mask estimate of shape [B, 1, T, E]
|
426 |
+
lsnr (Tensor): Local SNR estimate of shape [B, T, 1]
|
427 |
+
"""
|
428 |
+
feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
|
429 |
+
|
430 |
+
feat_erb = self.pad_feat(feat_erb)
|
431 |
+
feat_spec = self.pad_feat(feat_spec)
|
432 |
+
e0, e1, e2, e3, emb, c0, lsnr = self.enc(feat_erb, feat_spec)
|
433 |
+
m = self.erb_dec(emb, e3, e2, e1, e0)
|
434 |
+
|
435 |
+
m = self.pad_out(m.unsqueeze(-1)).squeeze(-1)
|
436 |
+
spec = self.mask(spec, m)
|
437 |
+
|
438 |
+
if self.run_df:
|
439 |
+
df_coefs, df_alpha = self.df_dec(emb, c0)
|
440 |
+
df_coefs = self.pad_out(df_coefs)
|
441 |
+
|
442 |
+
if self.pad_specf:
|
443 |
+
# Only pad the lower part of the spectrum.
|
444 |
+
spec_f = self.pad_spec(spec)
|
445 |
+
spec_f = self.df_op(spec_f, df_coefs)
|
446 |
+
spec[..., : self.nb_df, :] = spec_f[..., : self.nb_df, :]
|
447 |
+
else:
|
448 |
+
spec = self.pad_spec(spec)
|
449 |
+
spec = self.df_op(spec, df_coefs)
|
450 |
+
else:
|
451 |
+
df_alpha = torch.zeros(())
|
452 |
+
|
453 |
+
return spec, m, lsnr, df_alpha
|
df/enhance.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchaudio as ta
|
9 |
+
from loguru import logger
|
10 |
+
from numpy import ndarray
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from torchaudio.backend.common import AudioMetaData
|
14 |
+
|
15 |
+
import df
|
16 |
+
from df import config
|
17 |
+
from df.checkpoint import load_model as load_model_cp
|
18 |
+
from df.logger import init_logger, warn_once
|
19 |
+
from df.model import ModelParams
|
20 |
+
from df.modules import get_device
|
21 |
+
from df.utils import as_complex, as_real, get_norm_alpha, resample
|
22 |
+
from libdf import DF, erb, erb_norm, unit_norm
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
model, df_state, suffix = init_df(
|
27 |
+
args.model_base_dir,
|
28 |
+
post_filter=args.pf,
|
29 |
+
log_level=args.log_level,
|
30 |
+
config_allow_defaults=True,
|
31 |
+
epoch=args.epoch,
|
32 |
+
)
|
33 |
+
if args.output_dir is None:
|
34 |
+
args.output_dir = "."
|
35 |
+
elif not os.path.isdir(args.output_dir):
|
36 |
+
os.mkdir(args.output_dir)
|
37 |
+
df_sr = ModelParams().sr
|
38 |
+
n_samples = len(args.noisy_audio_files)
|
39 |
+
for i, file in enumerate(args.noisy_audio_files):
|
40 |
+
progress = (i + 1) / n_samples * 100
|
41 |
+
audio, meta = load_audio(file, df_sr)
|
42 |
+
t0 = time.time()
|
43 |
+
audio = enhance(
|
44 |
+
model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim
|
45 |
+
)
|
46 |
+
t1 = time.time()
|
47 |
+
t_audio = audio.shape[-1] / df_sr
|
48 |
+
t = t1 - t0
|
49 |
+
rtf = t / t_audio
|
50 |
+
fn = os.path.basename(file)
|
51 |
+
p_str = f"{progress:2.0f}% | " if n_samples > 1 else ""
|
52 |
+
logger.info(f"{p_str}Enhanced noisy audio file '{fn}' in {t:.1f}s (RT factor: {rtf:.3f})")
|
53 |
+
audio = resample(audio, df_sr, meta.sample_rate)
|
54 |
+
save_audio(
|
55 |
+
file, audio, sr=meta.sample_rate, output_dir=args.output_dir, suffix=suffix, log=False
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def init_df(
|
60 |
+
model_base_dir: Optional[str] = None,
|
61 |
+
post_filter: bool = False,
|
62 |
+
log_level: str = "INFO",
|
63 |
+
log_file: Optional[str] = "enhance.log",
|
64 |
+
config_allow_defaults: bool = False,
|
65 |
+
epoch: Union[str, int, None] = "best",
|
66 |
+
default_model: str = "DeepFilterNet2",
|
67 |
+
) -> Tuple[nn.Module, DF, str]:
|
68 |
+
"""Initializes and loads config, model and deep filtering state.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
model_base_dir (str): Path to the model directory containing checkpoint and config. If None,
|
72 |
+
load the pretrained DeepFilterNet2 model.
|
73 |
+
post_filter (bool): Enable post filter for some minor, extra noise reduction.
|
74 |
+
log_level (str): Control amount of logging. Defaults to `INFO`.
|
75 |
+
log_file (str): Optional log file name. None disables it. Defaults to `enhance.log`.
|
76 |
+
config_allow_defaults (bool): Whether to allow initializing new config values with defaults.
|
77 |
+
epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`.
|
78 |
+
`none` disables checkpoint loading. Defaults to `best`.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
model (nn.Modules): Intialized model, moved to GPU if available.
|
82 |
+
df_state (DF): Deep filtering state for stft/istft/erb
|
83 |
+
suffix (str): Suffix based on the model name. This can be used for saving the enhanced
|
84 |
+
audio.
|
85 |
+
"""
|
86 |
+
try:
|
87 |
+
from icecream import ic, install
|
88 |
+
|
89 |
+
ic.configureOutput(includeContext=True)
|
90 |
+
install()
|
91 |
+
except ImportError:
|
92 |
+
pass
|
93 |
+
use_default_model = False
|
94 |
+
if model_base_dir == "DeepFilterNet":
|
95 |
+
default_model = "DeepFilterNet"
|
96 |
+
use_default_model = True
|
97 |
+
elif model_base_dir == "DeepFilterNet2":
|
98 |
+
use_default_model = True
|
99 |
+
if model_base_dir is None or use_default_model:
|
100 |
+
use_default_model = True
|
101 |
+
model_base_dir = os.path.relpath(
|
102 |
+
os.path.join(
|
103 |
+
os.path.dirname(df.__file__), os.pardir, "pretrained_models", default_model
|
104 |
+
)
|
105 |
+
)
|
106 |
+
if not os.path.isdir(model_base_dir):
|
107 |
+
raise NotADirectoryError("Base directory not found at {}".format(model_base_dir))
|
108 |
+
log_file = os.path.join(model_base_dir, log_file) if log_file is not None else None
|
109 |
+
init_logger(file=log_file, level=log_level, model=model_base_dir)
|
110 |
+
if use_default_model:
|
111 |
+
logger.info(f"Using {default_model} model at {model_base_dir}")
|
112 |
+
config.load(
|
113 |
+
os.path.join(model_base_dir, "config.ini"),
|
114 |
+
config_must_exist=True,
|
115 |
+
allow_defaults=config_allow_defaults,
|
116 |
+
allow_reload=True,
|
117 |
+
)
|
118 |
+
if post_filter:
|
119 |
+
config.set("mask_pf", True, bool, ModelParams().section)
|
120 |
+
logger.info("Running with post-filter")
|
121 |
+
p = ModelParams()
|
122 |
+
df_state = DF(
|
123 |
+
sr=p.sr,
|
124 |
+
fft_size=p.fft_size,
|
125 |
+
hop_size=p.hop_size,
|
126 |
+
nb_bands=p.nb_erb,
|
127 |
+
min_nb_erb_freqs=p.min_nb_freqs,
|
128 |
+
)
|
129 |
+
checkpoint_dir = os.path.join(model_base_dir, "checkpoints")
|
130 |
+
load_cp = epoch is not None and not (isinstance(epoch, str) and epoch.lower() == "none")
|
131 |
+
if not load_cp:
|
132 |
+
checkpoint_dir = None
|
133 |
+
try:
|
134 |
+
mask_only = config.get("mask_only", cast=bool, section="train")
|
135 |
+
except KeyError:
|
136 |
+
mask_only = False
|
137 |
+
model, epoch = load_model_cp(checkpoint_dir, df_state, epoch=epoch, mask_only=mask_only)
|
138 |
+
if (epoch is None or epoch == 0) and load_cp:
|
139 |
+
logger.error("Could not find a checkpoint")
|
140 |
+
exit(1)
|
141 |
+
logger.debug(f"Loaded checkpoint from epoch {epoch}")
|
142 |
+
model = model.to(get_device())
|
143 |
+
# Set suffix to model name
|
144 |
+
suffix = os.path.basename(os.path.abspath(model_base_dir))
|
145 |
+
if post_filter:
|
146 |
+
suffix += "_pf"
|
147 |
+
logger.info("Model loaded")
|
148 |
+
return model, df_state, suffix
|
149 |
+
|
150 |
+
|
151 |
+
def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]:
|
152 |
+
spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F]
|
153 |
+
a = get_norm_alpha(False)
|
154 |
+
erb_fb = df.erb_widths()
|
155 |
+
with warnings.catch_warnings():
|
156 |
+
warnings.simplefilter("ignore", UserWarning)
|
157 |
+
erb_feat = torch.as_tensor(erb_norm(erb(spec, erb_fb), a)).unsqueeze(1)
|
158 |
+
spec_feat = as_real(torch.as_tensor(unit_norm(spec[..., :nb_df], a)).unsqueeze(1))
|
159 |
+
spec = as_real(torch.as_tensor(spec).unsqueeze(1))
|
160 |
+
if device is not None:
|
161 |
+
spec = spec.to(device)
|
162 |
+
erb_feat = erb_feat.to(device)
|
163 |
+
spec_feat = spec_feat.to(device)
|
164 |
+
return spec, erb_feat, spec_feat
|
165 |
+
|
166 |
+
|
167 |
+
def load_audio(
|
168 |
+
file: str, sr: Optional[int], verbose=True, **kwargs
|
169 |
+
) -> Tuple[Tensor, AudioMetaData]:
|
170 |
+
"""Loads an audio file using torchaudio.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
file (str): Path to an audio file.
|
174 |
+
sr (int): Optionally resample audio to specified target sampling rate.
|
175 |
+
**kwargs: Passed to torchaudio.load(). Depends on the backend. The resample method
|
176 |
+
may be set via `method` which is passed to `resample()`.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
audio (Tensor): Audio tensor of shape [C, T], if channels_first=True (default).
|
180 |
+
info (AudioMetaData): Meta data of the original audio file. Contains the original sr.
|
181 |
+
"""
|
182 |
+
ikwargs = {}
|
183 |
+
if "format" in kwargs:
|
184 |
+
ikwargs["format"] = kwargs["format"]
|
185 |
+
rkwargs = {}
|
186 |
+
if "method" in kwargs:
|
187 |
+
rkwargs["method"] = kwargs.pop("method")
|
188 |
+
info: AudioMetaData = ta.info(file, **ikwargs)
|
189 |
+
audio, orig_sr = ta.load(file, **kwargs)
|
190 |
+
if sr is not None and orig_sr != sr:
|
191 |
+
if verbose:
|
192 |
+
warn_once(
|
193 |
+
f"Audio sampling rate does not match model sampling rate ({orig_sr}, {sr}). "
|
194 |
+
"Resampling..."
|
195 |
+
)
|
196 |
+
audio = resample(audio, orig_sr, sr, **rkwargs)
|
197 |
+
return audio, info
|
198 |
+
|
199 |
+
|
200 |
+
def save_audio(
|
201 |
+
file: str,
|
202 |
+
audio: Union[Tensor, ndarray],
|
203 |
+
sr: int,
|
204 |
+
output_dir: Optional[str] = None,
|
205 |
+
suffix: Optional[str] = None,
|
206 |
+
log: bool = False,
|
207 |
+
dtype=torch.int16,
|
208 |
+
):
|
209 |
+
outpath = file
|
210 |
+
if suffix is not None:
|
211 |
+
file, ext = os.path.splitext(file)
|
212 |
+
outpath = file + f"_{suffix}" + ext
|
213 |
+
if output_dir is not None:
|
214 |
+
outpath = os.path.join(output_dir, os.path.basename(outpath))
|
215 |
+
if log:
|
216 |
+
logger.info(f"Saving audio file '{outpath}'")
|
217 |
+
audio = torch.as_tensor(audio)
|
218 |
+
if audio.ndim == 1:
|
219 |
+
audio.unsqueeze_(0)
|
220 |
+
if dtype == torch.int16 and audio.dtype != torch.int16:
|
221 |
+
audio = (audio * (1 << 15)).to(torch.int16)
|
222 |
+
if dtype == torch.float32 and audio.dtype != torch.float32:
|
223 |
+
audio = audio.to(torch.float32) / (1 << 15)
|
224 |
+
ta.save(outpath, audio, sr)
|
225 |
+
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def enhance(
|
229 |
+
model: nn.Module, df_state: DF, audio: Tensor, pad=False, atten_lim_db: Optional[float] = None
|
230 |
+
):
|
231 |
+
model.eval()
|
232 |
+
bs = audio.shape[0]
|
233 |
+
if hasattr(model, "reset_h0"):
|
234 |
+
model.reset_h0(batch_size=bs, device=get_device())
|
235 |
+
orig_len = audio.shape[-1]
|
236 |
+
n_fft, hop = 0, 0
|
237 |
+
if pad:
|
238 |
+
n_fft, hop = df_state.fft_size(), df_state.hop_size()
|
239 |
+
# Pad audio to compensate for the delay due to the real-time STFT implementation
|
240 |
+
audio = F.pad(audio, (0, n_fft))
|
241 |
+
nb_df = getattr(model, "nb_df", getattr(model, "df_bins", ModelParams().nb_df))
|
242 |
+
spec, erb_feat, spec_feat = df_features(audio, df_state, nb_df, device=get_device())
|
243 |
+
enhanced = model(spec, erb_feat, spec_feat)[0].cpu()
|
244 |
+
enhanced = as_complex(enhanced.squeeze(1))
|
245 |
+
if atten_lim_db is not None and abs(atten_lim_db) > 0:
|
246 |
+
lim = 10 ** (-abs(atten_lim_db) / 20)
|
247 |
+
enhanced = as_complex(spec.squeeze(1)) * lim + enhanced * (1 - lim)
|
248 |
+
audio = torch.as_tensor(df_state.synthesis(enhanced.numpy()))
|
249 |
+
if pad:
|
250 |
+
# The frame size is equal to p.hop_size. Given a new frame, the STFT loop requires e.g.
|
251 |
+
# ceil((n_fft-hop)/hop). I.e. for 50% overlap, then hop=n_fft//2
|
252 |
+
# requires 1 additional frame lookahead; 75% requires 3 additional frames lookahead.
|
253 |
+
# Thus, the STFT/ISTFT loop introduces an algorithmic delay of n_fft - hop.
|
254 |
+
assert n_fft % hop == 0 # This is only tested for 50% and 75% overlap
|
255 |
+
d = n_fft - hop
|
256 |
+
audio = audio[:, d : orig_len + d]
|
257 |
+
return audio
|
258 |
+
|
259 |
+
|
260 |
+
def parse_epoch_type(value: str) -> Union[int, str]:
|
261 |
+
try:
|
262 |
+
return int(value)
|
263 |
+
except ValueError:
|
264 |
+
assert value in ("best", "latest")
|
265 |
+
return value
|
266 |
+
|
267 |
+
|
268 |
+
def setup_df_argument_parser(default_log_level: str = "INFO") -> argparse.ArgumentParser:
|
269 |
+
parser = argparse.ArgumentParser()
|
270 |
+
parser.add_argument(
|
271 |
+
"--model-base-dir",
|
272 |
+
"-m",
|
273 |
+
type=str,
|
274 |
+
default=None,
|
275 |
+
help="Model directory containing checkpoints and config. "
|
276 |
+
"To load a pretrained model, you may just provide the model name, e.g. `DeepFilterNet`. "
|
277 |
+
"By default, the pretrained DeepFilterNet2 model is loaded.",
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--pf",
|
281 |
+
help="Post-filter that slightly over-attenuates very noisy sections.",
|
282 |
+
action="store_true",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--output-dir",
|
286 |
+
"-o",
|
287 |
+
type=str,
|
288 |
+
default=None,
|
289 |
+
help="Directory in which the enhanced audio files will be stored.",
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--log-level",
|
293 |
+
type=str,
|
294 |
+
default=default_log_level,
|
295 |
+
help="Logger verbosity. Can be one of (debug, info, error, none)",
|
296 |
+
)
|
297 |
+
parser.add_argument("--debug", "-d", action="store_const", const="DEBUG", dest="log_level")
|
298 |
+
parser.add_argument(
|
299 |
+
"--epoch",
|
300 |
+
"-e",
|
301 |
+
default="best",
|
302 |
+
type=parse_epoch_type,
|
303 |
+
help="Epoch for checkpoint loading. Can be one of ['best', 'latest', <int>].",
|
304 |
+
)
|
305 |
+
return parser
|
306 |
+
|
307 |
+
|
308 |
+
def run():
|
309 |
+
parser = setup_df_argument_parser()
|
310 |
+
parser.add_argument(
|
311 |
+
"--compensate-delay",
|
312 |
+
"-D",
|
313 |
+
action="store_true",
|
314 |
+
help="Add some paddig to compensate the delay introduced by the real-time STFT/ISTFT implementation.",
|
315 |
+
)
|
316 |
+
parser.add_argument(
|
317 |
+
"--atten-lim",
|
318 |
+
"-a",
|
319 |
+
type=int,
|
320 |
+
default=None,
|
321 |
+
help="Attenuation limit in dB by mixing the enhanced signal with the noisy signal.",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"noisy_audio_files",
|
325 |
+
type=str,
|
326 |
+
nargs="+",
|
327 |
+
help="List of noise files to mix with the clean speech file.",
|
328 |
+
)
|
329 |
+
main(parser.parse_args())
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
run()
|
df/logger.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
from collections import defaultdict
|
5 |
+
from copy import deepcopy
|
6 |
+
from typing import Dict, Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from loguru import logger
|
11 |
+
from torch.types import Number
|
12 |
+
|
13 |
+
from df.modules import GroupedLinearEinsum
|
14 |
+
from df.utils import get_branch_name, get_commit_hash, get_device, get_host
|
15 |
+
|
16 |
+
_logger_initialized = False
|
17 |
+
WARN_ONCE_NO = logger.level("WARNING").no + 1
|
18 |
+
DEPRECATED_NO = logger.level("WARNING").no + 2
|
19 |
+
|
20 |
+
|
21 |
+
def init_logger(file: Optional[str] = None, level: str = "INFO", model: Optional[str] = None):
|
22 |
+
global _logger_initialized, _duplicate_filter
|
23 |
+
if _logger_initialized:
|
24 |
+
logger.debug("Logger already initialized.")
|
25 |
+
else:
|
26 |
+
logger.remove()
|
27 |
+
level = level.upper()
|
28 |
+
if level.lower() != "none":
|
29 |
+
log_format = Formatter(debug=logger.level(level).no <= logger.level("DEBUG").no).format
|
30 |
+
logger.add(
|
31 |
+
sys.stdout,
|
32 |
+
level=level,
|
33 |
+
format=log_format,
|
34 |
+
filter=lambda r: r["level"].no not in {WARN_ONCE_NO, DEPRECATED_NO},
|
35 |
+
)
|
36 |
+
if file is not None:
|
37 |
+
logger.add(
|
38 |
+
file,
|
39 |
+
level=level,
|
40 |
+
format=log_format,
|
41 |
+
filter=lambda r: r["level"].no != WARN_ONCE_NO,
|
42 |
+
)
|
43 |
+
|
44 |
+
logger.info(f"Running on torch {torch.__version__}")
|
45 |
+
logger.info(f"Running on host {get_host()}")
|
46 |
+
commit = get_commit_hash()
|
47 |
+
if commit is not None:
|
48 |
+
logger.info(f"Git commit: {commit}, branch: {get_branch_name()}")
|
49 |
+
if (jobid := os.getenv("SLURM_JOB_ID")) is not None:
|
50 |
+
logger.info(f"Slurm jobid: {jobid}")
|
51 |
+
logger.level("WARNONCE", no=WARN_ONCE_NO, color="<yellow><bold>")
|
52 |
+
logger.add(
|
53 |
+
sys.stderr,
|
54 |
+
level=max(logger.level(level).no, WARN_ONCE_NO),
|
55 |
+
format=log_format,
|
56 |
+
filter=lambda r: r["level"].no == WARN_ONCE_NO and _duplicate_filter(r),
|
57 |
+
)
|
58 |
+
logger.level("DEPRECATED", no=DEPRECATED_NO, color="<yellow><bold>")
|
59 |
+
logger.add(
|
60 |
+
sys.stderr,
|
61 |
+
level=max(logger.level(level).no, DEPRECATED_NO),
|
62 |
+
format=log_format,
|
63 |
+
filter=lambda r: r["level"].no == DEPRECATED_NO and _duplicate_filter(r),
|
64 |
+
)
|
65 |
+
if model is not None:
|
66 |
+
logger.info("Loading model settings of {}", os.path.basename(model.rstrip("/")))
|
67 |
+
_logger_initialized = True
|
68 |
+
|
69 |
+
|
70 |
+
def warn_once(message, *args, **kwargs):
|
71 |
+
logger.log("WARNONCE", message, *args, **kwargs)
|
72 |
+
|
73 |
+
|
74 |
+
def log_deprecated(message, *args, **kwargs):
|
75 |
+
logger.log("DEPRECATED", message, *args, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
class Formatter:
|
79 |
+
def __init__(self, debug=False):
|
80 |
+
if debug:
|
81 |
+
self.fmt = (
|
82 |
+
"<green>{time:YYYY-MM-DD HH:mm:ss}</green>"
|
83 |
+
" | <level>{level: <8}</level>"
|
84 |
+
" | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan>"
|
85 |
+
" | <level>{message}</level>"
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
self.fmt = (
|
89 |
+
"<green>{time:YYYY-MM-DD HH:mm:ss}</green>"
|
90 |
+
" | <level>{level: <8}</level>"
|
91 |
+
" | <cyan>DF</cyan>"
|
92 |
+
" | <level>{message}</level>"
|
93 |
+
)
|
94 |
+
self.fmt += "\n{exception}"
|
95 |
+
|
96 |
+
def format(self, record):
|
97 |
+
if record["level"].no == WARN_ONCE_NO:
|
98 |
+
return self.fmt.replace("{level: <8}", "WARNING ")
|
99 |
+
return self.fmt
|
100 |
+
|
101 |
+
|
102 |
+
def _metrics_key(k_: Tuple[str, float]):
|
103 |
+
k0 = k_[0]
|
104 |
+
ks = k0.split("_")
|
105 |
+
if len(ks) > 2:
|
106 |
+
try:
|
107 |
+
return int(ks[-1])
|
108 |
+
except ValueError:
|
109 |
+
return 1000
|
110 |
+
elif k0 == "loss":
|
111 |
+
return -999
|
112 |
+
elif "loss" in k0.lower():
|
113 |
+
return -998
|
114 |
+
elif k0 == "lr":
|
115 |
+
return 998
|
116 |
+
elif k0 == "wd":
|
117 |
+
return 999
|
118 |
+
else:
|
119 |
+
return -101
|
120 |
+
|
121 |
+
|
122 |
+
def log_metrics(prefix: str, metrics: Dict[str, Number], level="INFO"):
|
123 |
+
msg = ""
|
124 |
+
stages = defaultdict(str)
|
125 |
+
loss_msg = ""
|
126 |
+
for n, v in sorted(metrics.items(), key=_metrics_key):
|
127 |
+
if abs(v) > 1e-3:
|
128 |
+
m = f" | {n}: {v:.5f}"
|
129 |
+
else:
|
130 |
+
m = f" | {n}: {v:.3E}"
|
131 |
+
if "stage" in n:
|
132 |
+
s = n.split("stage_")[1].split("_snr")[0]
|
133 |
+
stages[s] += m.replace(f"stage_{s}_", "")
|
134 |
+
elif ("valid" in prefix or "test" in prefix) and "loss" in n.lower():
|
135 |
+
loss_msg += m
|
136 |
+
else:
|
137 |
+
msg += m
|
138 |
+
for s, msg_s in stages.items():
|
139 |
+
logger.log(level, f"{prefix} | stage {s}" + msg_s)
|
140 |
+
if len(stages) == 0:
|
141 |
+
logger.log(level, prefix + msg)
|
142 |
+
if len(loss_msg) > 0:
|
143 |
+
logger.log(level, prefix + loss_msg)
|
144 |
+
|
145 |
+
|
146 |
+
class DuplicateFilter:
|
147 |
+
"""
|
148 |
+
Filters away duplicate log messages.
|
149 |
+
Modified version of: https://stackoverflow.com/a/60462619
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self):
|
153 |
+
self.msgs = set()
|
154 |
+
|
155 |
+
def __call__(self, record) -> bool:
|
156 |
+
k = f"{record['level']}{record['message']}"
|
157 |
+
if k in self.msgs:
|
158 |
+
return False
|
159 |
+
else:
|
160 |
+
self.msgs.add(k)
|
161 |
+
return True
|
162 |
+
|
163 |
+
|
164 |
+
_duplicate_filter = DuplicateFilter()
|
165 |
+
|
166 |
+
|
167 |
+
def log_model_summary(model: torch.nn.Module, verbose=False):
|
168 |
+
try:
|
169 |
+
import ptflops
|
170 |
+
except ImportError:
|
171 |
+
logger.debug("Failed to import ptflops. Cannot print model summary.")
|
172 |
+
return
|
173 |
+
|
174 |
+
from df.model import ModelParams
|
175 |
+
|
176 |
+
# Generate input of 1 second audio
|
177 |
+
# Necessary inputs are:
|
178 |
+
# spec: [B, 1, T, F, 2], F: freq bin
|
179 |
+
# feat_erb: [B, 1, T, E], E: ERB bands
|
180 |
+
# feat_spec: [B, 2, T, C*2], C: Complex features
|
181 |
+
p = ModelParams()
|
182 |
+
b = 1
|
183 |
+
t = p.sr // p.hop_size
|
184 |
+
device = get_device()
|
185 |
+
spec = torch.randn([b, 1, t, p.fft_size // 2 + 1, 2]).to(device)
|
186 |
+
feat_erb = torch.randn([b, 1, t, p.nb_erb]).to(device)
|
187 |
+
feat_spec = torch.randn([b, 1, t, p.nb_df, 2]).to(device)
|
188 |
+
|
189 |
+
warnings.filterwarnings("ignore", "RNN module weights", category=UserWarning, module="torch")
|
190 |
+
macs, params = ptflops.get_model_complexity_info(
|
191 |
+
deepcopy(model),
|
192 |
+
(t,),
|
193 |
+
input_constructor=lambda _: {"spec": spec, "feat_erb": feat_erb, "feat_spec": feat_spec},
|
194 |
+
as_strings=False,
|
195 |
+
print_per_layer_stat=verbose,
|
196 |
+
verbose=verbose,
|
197 |
+
custom_modules_hooks={
|
198 |
+
GroupedLinearEinsum: grouped_linear_flops_counter_hook,
|
199 |
+
},
|
200 |
+
)
|
201 |
+
logger.info(f"Model complexity: {params/1e6:.3f}M #Params, {macs/1e6:.1f}M MACS")
|
202 |
+
|
203 |
+
|
204 |
+
def grouped_linear_flops_counter_hook(module: GroupedLinearEinsum, input, output):
|
205 |
+
# input: ([B, T, I],)
|
206 |
+
# output: [B, T, H]
|
207 |
+
input = input[0] # [B, T, I]
|
208 |
+
output_last_dim = module.weight.shape[-1]
|
209 |
+
input = input.unflatten(-1, (module.groups, module.ws)) # [B, T, G, I/G]
|
210 |
+
# GroupedLinear calculates "...gi,...gih->...gh"
|
211 |
+
weight_flops = np.prod(input.shape) * output_last_dim
|
212 |
+
module.__flops__ += int(weight_flops) # type: ignore
|
df/model.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from df.config import DfParams, config
|
7 |
+
|
8 |
+
|
9 |
+
class ModelParams(DfParams):
|
10 |
+
def __init__(self):
|
11 |
+
self.__model = config("MODEL", default="deepfilternet", section="train")
|
12 |
+
self.__params = getattr(import_module("df." + self.__model), "ModelParams")()
|
13 |
+
|
14 |
+
def __getattr__(self, attr: str):
|
15 |
+
return getattr(self.__params, attr)
|
16 |
+
|
17 |
+
|
18 |
+
def init_model(*args, **kwargs):
|
19 |
+
"""Initialize the model specified in the config."""
|
20 |
+
model = config("MODEL", default="deepfilternet", section="train")
|
21 |
+
logger.info(f"Initializing model `{model}`")
|
22 |
+
model = getattr(import_module("df." + model), "init_model")(*args, **kwargs)
|
23 |
+
model.to(memory_format=torch.channels_last)
|
24 |
+
return model
|
df/modules.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.nn import init
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
from typing_extensions import Final
|
12 |
+
|
13 |
+
from df.model import ModelParams
|
14 |
+
from df.utils import as_complex, as_real, get_device, get_norm_alpha
|
15 |
+
from libdf import unit_norm_init
|
16 |
+
|
17 |
+
|
18 |
+
class Conv2dNormAct(nn.Sequential):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_ch: int,
|
22 |
+
out_ch: int,
|
23 |
+
kernel_size: Union[int, Iterable[int]],
|
24 |
+
fstride: int = 1,
|
25 |
+
dilation: int = 1,
|
26 |
+
fpad: bool = True,
|
27 |
+
bias: bool = True,
|
28 |
+
separable: bool = False,
|
29 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
30 |
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
31 |
+
):
|
32 |
+
"""Causal Conv2d by delaying the signal for any lookahead.
|
33 |
+
|
34 |
+
Expected input format: [B, C, T, F]
|
35 |
+
"""
|
36 |
+
lookahead = 0 # This needs to be handled on the input feature side
|
37 |
+
# Padding on time axis
|
38 |
+
kernel_size = (
|
39 |
+
(kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
40 |
+
)
|
41 |
+
if fpad:
|
42 |
+
fpad_ = kernel_size[1] // 2 + dilation - 1
|
43 |
+
else:
|
44 |
+
fpad_ = 0
|
45 |
+
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
|
46 |
+
layers = []
|
47 |
+
if any(x > 0 for x in pad):
|
48 |
+
layers.append(nn.ConstantPad2d(pad, 0.0))
|
49 |
+
groups = math.gcd(in_ch, out_ch) if separable else 1
|
50 |
+
if groups == 1:
|
51 |
+
separable = False
|
52 |
+
if max(kernel_size) == 1:
|
53 |
+
separable = False
|
54 |
+
layers.append(
|
55 |
+
nn.Conv2d(
|
56 |
+
in_ch,
|
57 |
+
out_ch,
|
58 |
+
kernel_size=kernel_size,
|
59 |
+
padding=(0, fpad_),
|
60 |
+
stride=(1, fstride), # Stride over time is always 1
|
61 |
+
dilation=(1, dilation), # Same for dilation
|
62 |
+
groups=groups,
|
63 |
+
bias=bias,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
if separable:
|
67 |
+
layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False))
|
68 |
+
if norm_layer is not None:
|
69 |
+
layers.append(norm_layer(out_ch))
|
70 |
+
if activation_layer is not None:
|
71 |
+
layers.append(activation_layer())
|
72 |
+
super().__init__(*layers)
|
73 |
+
|
74 |
+
|
75 |
+
class ConvTranspose2dNormAct(nn.Sequential):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
in_ch: int,
|
79 |
+
out_ch: int,
|
80 |
+
kernel_size: Union[int, Tuple[int, int]],
|
81 |
+
fstride: int = 1,
|
82 |
+
dilation: int = 1,
|
83 |
+
fpad: bool = True,
|
84 |
+
bias: bool = True,
|
85 |
+
separable: bool = False,
|
86 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
87 |
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
88 |
+
):
|
89 |
+
"""Causal ConvTranspose2d.
|
90 |
+
|
91 |
+
Expected input format: [B, C, T, F]
|
92 |
+
"""
|
93 |
+
# Padding on time axis, with lookahead = 0
|
94 |
+
lookahead = 0 # This needs to be handled on the input feature side
|
95 |
+
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
96 |
+
if fpad:
|
97 |
+
fpad_ = kernel_size[1] // 2
|
98 |
+
else:
|
99 |
+
fpad_ = 0
|
100 |
+
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
|
101 |
+
layers = []
|
102 |
+
if any(x > 0 for x in pad):
|
103 |
+
layers.append(nn.ConstantPad2d(pad, 0.0))
|
104 |
+
groups = math.gcd(in_ch, out_ch) if separable else 1
|
105 |
+
if groups == 1:
|
106 |
+
separable = False
|
107 |
+
layers.append(
|
108 |
+
nn.ConvTranspose2d(
|
109 |
+
in_ch,
|
110 |
+
out_ch,
|
111 |
+
kernel_size=kernel_size,
|
112 |
+
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
|
113 |
+
output_padding=(0, fpad_),
|
114 |
+
stride=(1, fstride), # Stride over time is always 1
|
115 |
+
dilation=(1, dilation),
|
116 |
+
groups=groups,
|
117 |
+
bias=bias,
|
118 |
+
)
|
119 |
+
)
|
120 |
+
if separable:
|
121 |
+
layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False))
|
122 |
+
if norm_layer is not None:
|
123 |
+
layers.append(norm_layer(out_ch))
|
124 |
+
if activation_layer is not None:
|
125 |
+
layers.append(activation_layer())
|
126 |
+
super().__init__(*layers)
|
127 |
+
|
128 |
+
|
129 |
+
def convkxf(
|
130 |
+
in_ch: int,
|
131 |
+
out_ch: Optional[int] = None,
|
132 |
+
k: int = 1,
|
133 |
+
f: int = 3,
|
134 |
+
fstride: int = 2,
|
135 |
+
lookahead: int = 0,
|
136 |
+
batch_norm: bool = False,
|
137 |
+
act: nn.Module = nn.ReLU(inplace=True),
|
138 |
+
mode="normal", # must be "normal", "transposed" or "upsample"
|
139 |
+
depthwise: bool = True,
|
140 |
+
complex_in: bool = False,
|
141 |
+
):
|
142 |
+
bias = batch_norm is False
|
143 |
+
assert f % 2 == 1
|
144 |
+
stride = 1 if f == 1 else (1, fstride)
|
145 |
+
if out_ch is None:
|
146 |
+
out_ch = in_ch * 2 if mode == "normal" else in_ch // 2
|
147 |
+
fpad = (f - 1) // 2
|
148 |
+
convpad = (0, fpad)
|
149 |
+
modules = []
|
150 |
+
# Manually pad for time axis kernel to not introduce delay
|
151 |
+
pad = (0, 0, k - 1 - lookahead, lookahead)
|
152 |
+
if any(p > 0 for p in pad):
|
153 |
+
modules.append(("pad", nn.ConstantPad2d(pad, 0.0)))
|
154 |
+
if depthwise:
|
155 |
+
groups = min(in_ch, out_ch)
|
156 |
+
else:
|
157 |
+
groups = 1
|
158 |
+
if in_ch % groups != 0 or out_ch % groups != 0:
|
159 |
+
groups = 1
|
160 |
+
if complex_in and groups % 2 == 0:
|
161 |
+
groups //= 2
|
162 |
+
convkwargs = {
|
163 |
+
"in_channels": in_ch,
|
164 |
+
"out_channels": out_ch,
|
165 |
+
"kernel_size": (k, f),
|
166 |
+
"stride": stride,
|
167 |
+
"groups": groups,
|
168 |
+
"bias": bias,
|
169 |
+
}
|
170 |
+
if mode == "normal":
|
171 |
+
modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs)))
|
172 |
+
elif mode == "transposed":
|
173 |
+
# Since pytorch's transposed conv padding does not correspond to the actual padding but
|
174 |
+
# rather the padding that was used in the encoder conv, we need to set time axis padding
|
175 |
+
# according to k. E.g., this disables the padding for k=2:
|
176 |
+
# dilation - (k - 1) - padding
|
177 |
+
# = 1 - (2 - 1) - 1 = 0; => padding = fpad (=1 for k=2)
|
178 |
+
padding = (k - 1, fpad)
|
179 |
+
modules.append(
|
180 |
+
("sconvt", nn.ConvTranspose2d(padding=padding, output_padding=convpad, **convkwargs))
|
181 |
+
)
|
182 |
+
elif mode == "upsample":
|
183 |
+
modules.append(("upsample", FreqUpsample(fstride)))
|
184 |
+
convkwargs["stride"] = 1
|
185 |
+
modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs)))
|
186 |
+
else:
|
187 |
+
raise NotImplementedError()
|
188 |
+
if groups > 1:
|
189 |
+
modules.append(("1x1conv", nn.Conv2d(out_ch, out_ch, 1, bias=False)))
|
190 |
+
if batch_norm:
|
191 |
+
modules.append(("norm", nn.BatchNorm2d(out_ch)))
|
192 |
+
modules.append(("act", act))
|
193 |
+
return nn.Sequential(OrderedDict(modules))
|
194 |
+
|
195 |
+
|
196 |
+
class FreqUpsample(nn.Module):
|
197 |
+
def __init__(self, factor: int, mode="nearest"):
|
198 |
+
super().__init__()
|
199 |
+
self.f = float(factor)
|
200 |
+
self.mode = mode
|
201 |
+
|
202 |
+
def forward(self, x: Tensor) -> Tensor:
|
203 |
+
return F.interpolate(x, scale_factor=[1.0, self.f], mode=self.mode)
|
204 |
+
|
205 |
+
|
206 |
+
def erb_fb(widths: np.ndarray, sr: int, normalized: bool = True, inverse: bool = False) -> Tensor:
|
207 |
+
n_freqs = int(np.sum(widths))
|
208 |
+
all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1]
|
209 |
+
|
210 |
+
b_pts = np.cumsum([0] + widths.tolist()).astype(int)[:-1]
|
211 |
+
|
212 |
+
fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0]))
|
213 |
+
for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())):
|
214 |
+
fb[b : b + w, i] = 1
|
215 |
+
# Normalize to constant energy per resulting band
|
216 |
+
if inverse:
|
217 |
+
fb = fb.t()
|
218 |
+
if not normalized:
|
219 |
+
fb /= fb.sum(dim=1, keepdim=True)
|
220 |
+
else:
|
221 |
+
if normalized:
|
222 |
+
fb /= fb.sum(dim=0)
|
223 |
+
return fb.to(device=get_device())
|
224 |
+
|
225 |
+
|
226 |
+
class Mask(nn.Module):
|
227 |
+
def __init__(self, erb_inv_fb: Tensor, post_filter: bool = False, eps: float = 1e-12):
|
228 |
+
super().__init__()
|
229 |
+
self.erb_inv_fb: Tensor
|
230 |
+
self.register_buffer("erb_inv_fb", erb_inv_fb)
|
231 |
+
self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0"
|
232 |
+
self.post_filter = post_filter
|
233 |
+
self.eps = eps
|
234 |
+
|
235 |
+
def pf(self, mask: Tensor, beta: float = 0.02) -> Tensor:
|
236 |
+
"""Post-Filter proposed by Valin et al. [1].
|
237 |
+
|
238 |
+
Args:
|
239 |
+
mask (Tensor): Real valued mask, typically of shape [B, C, T, F].
|
240 |
+
beta: Global gain factor.
|
241 |
+
Refs:
|
242 |
+
[1]: Valin et al.: A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
|
243 |
+
"""
|
244 |
+
mask_sin = mask * torch.sin(np.pi * mask / 2)
|
245 |
+
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
|
246 |
+
return mask_pf
|
247 |
+
|
248 |
+
def forward(self, spec: Tensor, mask: Tensor, atten_lim: Optional[Tensor] = None) -> Tensor:
|
249 |
+
# spec (real) [B, 1, T, F, 2], F: freq_bins
|
250 |
+
# mask (real): [B, 1, T, Fe], Fe: erb_bins
|
251 |
+
# atten_lim: [B]
|
252 |
+
if not self.training and self.post_filter:
|
253 |
+
mask = self.pf(mask)
|
254 |
+
if atten_lim is not None:
|
255 |
+
# dB to amplitude
|
256 |
+
atten_lim = 10 ** (-atten_lim / 20)
|
257 |
+
# Greater equal (__ge__) not implemented for TorchVersion.
|
258 |
+
if self.clamp_tensor:
|
259 |
+
# Supported by torch >= 1.9
|
260 |
+
mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1))
|
261 |
+
else:
|
262 |
+
m_out = []
|
263 |
+
for i in range(atten_lim.shape[0]):
|
264 |
+
m_out.append(mask[i].clamp_min(atten_lim[i].item()))
|
265 |
+
mask = torch.stack(m_out, dim=0)
|
266 |
+
mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F]
|
267 |
+
return spec * mask.unsqueeze(4)
|
268 |
+
|
269 |
+
|
270 |
+
class ExponentialUnitNorm(nn.Module):
|
271 |
+
"""Unit norm for a complex spectrogram.
|
272 |
+
|
273 |
+
This should match the rust code:
|
274 |
+
```rust
|
275 |
+
for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
|
276 |
+
*s = x.norm() * (1. - alpha) + *s * alpha;
|
277 |
+
*x /= s.sqrt();
|
278 |
+
}
|
279 |
+
```
|
280 |
+
"""
|
281 |
+
|
282 |
+
alpha: Final[float]
|
283 |
+
eps: Final[float]
|
284 |
+
|
285 |
+
def __init__(self, alpha: float, num_freq_bins: int, eps: float = 1e-14):
|
286 |
+
super().__init__()
|
287 |
+
self.alpha = alpha
|
288 |
+
self.eps = eps
|
289 |
+
self.init_state: Tensor
|
290 |
+
s = torch.from_numpy(unit_norm_init(num_freq_bins)).view(1, 1, num_freq_bins, 1)
|
291 |
+
self.register_buffer("init_state", s)
|
292 |
+
|
293 |
+
def forward(self, x: Tensor) -> Tensor:
|
294 |
+
# x: [B, C, T, F, 2]
|
295 |
+
b, c, t, f, _ = x.shape
|
296 |
+
x_abs = x.square().sum(dim=-1, keepdim=True).clamp_min(self.eps).sqrt()
|
297 |
+
state = self.init_state.clone().expand(b, c, f, 1)
|
298 |
+
out_states: List[Tensor] = []
|
299 |
+
for t in range(t):
|
300 |
+
state = x_abs[:, :, t] * (1 - self.alpha) + state * self.alpha
|
301 |
+
out_states.append(state)
|
302 |
+
return x / torch.stack(out_states, 2).sqrt()
|
303 |
+
|
304 |
+
|
305 |
+
class DfOp(nn.Module):
|
306 |
+
df_order: Final[int]
|
307 |
+
df_bins: Final[int]
|
308 |
+
df_lookahead: Final[int]
|
309 |
+
freq_bins: Final[int]
|
310 |
+
|
311 |
+
def __init__(
|
312 |
+
self,
|
313 |
+
df_bins: int,
|
314 |
+
df_order: int = 5,
|
315 |
+
df_lookahead: int = 0,
|
316 |
+
method: str = "complex_strided",
|
317 |
+
freq_bins: int = 0,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
self.df_order = df_order
|
321 |
+
self.df_bins = df_bins
|
322 |
+
self.df_lookahead = df_lookahead
|
323 |
+
self.freq_bins = freq_bins
|
324 |
+
self.set_forward(method)
|
325 |
+
|
326 |
+
def set_forward(self, method: str):
|
327 |
+
# All forward methods should be mathematically similar.
|
328 |
+
# DeepFilterNet results are obtained with 'real_unfold'.
|
329 |
+
forward_methods = {
|
330 |
+
"real_loop": self.forward_real_loop,
|
331 |
+
"real_strided": self.forward_real_strided,
|
332 |
+
"real_unfold": self.forward_real_unfold,
|
333 |
+
"complex_strided": self.forward_complex_strided,
|
334 |
+
"real_one_step": self.forward_real_no_pad_one_step,
|
335 |
+
"real_hidden_state_loop": self.forward_real_hidden_state_loop,
|
336 |
+
}
|
337 |
+
if method not in forward_methods.keys():
|
338 |
+
raise NotImplementedError(f"`method` must be one of {forward_methods.keys()}")
|
339 |
+
if method == "real_hidden_state_loop":
|
340 |
+
assert self.freq_bins >= self.df_bins
|
341 |
+
self.spec_buf: Tensor
|
342 |
+
# Currently only designed for batch size of 1
|
343 |
+
self.register_buffer(
|
344 |
+
"spec_buf", torch.zeros(1, 1, self.df_order, self.freq_bins, 2), persistent=False
|
345 |
+
)
|
346 |
+
self.forward = forward_methods[method]
|
347 |
+
|
348 |
+
def forward_real_loop(
|
349 |
+
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
|
350 |
+
) -> Tensor:
|
351 |
+
# Version 0: Manual loop over df_order, maybe best for onnx export?
|
352 |
+
b, _, t, _, _ = spec.shape
|
353 |
+
f = self.df_bins
|
354 |
+
padded = spec_pad(
|
355 |
+
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
|
356 |
+
)
|
357 |
+
|
358 |
+
spec_f = torch.zeros((b, t, f, 2), device=spec.device)
|
359 |
+
for i in range(self.df_order):
|
360 |
+
spec_f[..., 0] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 0]
|
361 |
+
spec_f[..., 0] -= padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 1]
|
362 |
+
spec_f[..., 1] += padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 0]
|
363 |
+
spec_f[..., 1] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 1]
|
364 |
+
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
|
365 |
+
|
366 |
+
def forward_real_strided(
|
367 |
+
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
|
368 |
+
) -> Tensor:
|
369 |
+
# Version1: Use as_strided instead of unfold
|
370 |
+
# spec (real) [B, 1, T, F, 2], O: df_order
|
371 |
+
# coefs (real) [B, T, O, F, 2]
|
372 |
+
# alpha (real) [B, T, 1]
|
373 |
+
padded = as_strided(
|
374 |
+
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
|
375 |
+
)
|
376 |
+
# Complex numbers are not supported by onnx
|
377 |
+
re = padded[..., 0] * coefs[..., 0]
|
378 |
+
re -= padded[..., 1] * coefs[..., 1]
|
379 |
+
im = padded[..., 1] * coefs[..., 0]
|
380 |
+
im += padded[..., 0] * coefs[..., 1]
|
381 |
+
spec_f = torch.stack((re, im), -1).sum(2)
|
382 |
+
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
|
383 |
+
|
384 |
+
def forward_real_unfold(
|
385 |
+
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
|
386 |
+
) -> Tensor:
|
387 |
+
# Version2: Unfold
|
388 |
+
# spec (real) [B, 1, T, F, 2], O: df_order
|
389 |
+
# coefs (real) [B, T, O, F, 2]
|
390 |
+
# alpha (real) [B, T, 1]
|
391 |
+
padded = spec_pad(
|
392 |
+
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
|
393 |
+
)
|
394 |
+
padded = padded.unfold(dimension=1, size=self.df_order, step=1) # [B, T, F, 2, O]
|
395 |
+
padded = padded.permute(0, 1, 4, 2, 3)
|
396 |
+
spec_f = torch.empty_like(padded)
|
397 |
+
spec_f[..., 0] = padded[..., 0] * coefs[..., 0] # re1
|
398 |
+
spec_f[..., 0] -= padded[..., 1] * coefs[..., 1] # re2
|
399 |
+
spec_f[..., 1] = padded[..., 1] * coefs[..., 0] # im1
|
400 |
+
spec_f[..., 1] += padded[..., 0] * coefs[..., 1] # im2
|
401 |
+
spec_f = spec_f.sum(dim=2)
|
402 |
+
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
|
403 |
+
|
404 |
+
def forward_complex_strided(
|
405 |
+
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
|
406 |
+
) -> Tensor:
|
407 |
+
# Version3: Complex strided; definatly nicest, no permute, no indexing, but complex gradient
|
408 |
+
# spec (real) [B, 1, T, F, 2], O: df_order
|
409 |
+
# coefs (real) [B, T, O, F, 2]
|
410 |
+
# alpha (real) [B, T, 1]
|
411 |
+
padded = as_strided(
|
412 |
+
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
|
413 |
+
)
|
414 |
+
spec_f = torch.sum(torch.view_as_complex(padded) * torch.view_as_complex(coefs), dim=2)
|
415 |
+
spec_f = torch.view_as_real(spec_f)
|
416 |
+
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
|
417 |
+
|
418 |
+
def forward_real_no_pad_one_step(
|
419 |
+
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
|
420 |
+
) -> Tensor:
|
421 |
+
# Version4: Only viable for onnx handling. `spec` needs external (ring-)buffer handling.
|
422 |
+
# Thus, time steps `t` must be equal to `df_order`.
|
423 |
+
|
424 |
+
# spec (real) [B, 1, O, F', 2]
|
425 |
+
# coefs (real) [B, 1, O, F, 2]
|
426 |
+
assert (
|
427 |
+
spec.shape[2] == self.df_order
|
428 |
+
), "This forward method needs spectrogram buffer with `df_order` time steps as input"
|
429 |
+
assert coefs.shape[1] == 1, "This forward method is only valid for 1 time step"
|
430 |
+
sre, sim = spec[..., : self.df_bins, :].split(1, -1)
|
431 |
+
cre, cim = coefs.split(1, -1)
|
432 |
+
outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1)
|
433 |
+
outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1)
|
434 |
+
spec_f = torch.stack((outr, outi), dim=-1)
|
435 |
+
return assign_df(
|
436 |
+
spec[:, :, self.df_order - self.df_lookahead - 1],
|
437 |
+
spec_f.unsqueeze(1),
|
438 |
+
self.df_bins,
|
439 |
+
alpha,
|
440 |
+
)
|
441 |
+
|
442 |
+
def forward_real_hidden_state_loop(self, spec: Tensor, coefs: Tensor, alpha: Tensor) -> Tensor:
|
443 |
+
# Version5: Designed for onnx export. `spec` buffer handling is done via a torch buffer.
|
444 |
+
|
445 |
+
# spec (real) [B, 1, T, F', 2]
|
446 |
+
# coefs (real) [B, T, O, F, 2]
|
447 |
+
b, _, t, _, _ = spec.shape
|
448 |
+
spec_out = torch.empty((b, 1, t, self.freq_bins, 2), device=spec.device)
|
449 |
+
for t in range(spec.shape[2]):
|
450 |
+
self.spec_buf = self.spec_buf.roll(-1, dims=2)
|
451 |
+
self.spec_buf[:, :, -1] = spec[:, :, t]
|
452 |
+
sre, sim = self.spec_buf[..., : self.df_bins, :].split(1, -1)
|
453 |
+
cre, cim = coefs[:, t : t + 1].split(1, -1)
|
454 |
+
outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1)
|
455 |
+
outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1)
|
456 |
+
spec_f = torch.stack((outr, outi), dim=-1)
|
457 |
+
spec_out[:, :, t] = assign_df(
|
458 |
+
self.spec_buf[:, :, self.df_order - self.df_lookahead - 1].unsqueeze(2),
|
459 |
+
spec_f.unsqueeze(1),
|
460 |
+
self.df_bins,
|
461 |
+
alpha[:, t],
|
462 |
+
).squeeze(2)
|
463 |
+
return spec_out
|
464 |
+
|
465 |
+
|
466 |
+
def assign_df(spec: Tensor, spec_f: Tensor, df_bins: int, alpha: Optional[Tensor]):
|
467 |
+
spec_out = spec.clone()
|
468 |
+
if alpha is not None:
|
469 |
+
b = spec.shape[0]
|
470 |
+
alpha = alpha.view(b, 1, -1, 1, 1)
|
471 |
+
spec_out[..., :df_bins, :] = spec_f * alpha + spec[..., :df_bins, :] * (1 - alpha)
|
472 |
+
else:
|
473 |
+
spec_out[..., :df_bins, :] = spec_f
|
474 |
+
return spec_out
|
475 |
+
|
476 |
+
|
477 |
+
def spec_pad(x: Tensor, window_size: int, lookahead: int, dim: int = 0) -> Tensor:
|
478 |
+
pad = [0] * x.dim() * 2
|
479 |
+
if dim >= 0:
|
480 |
+
pad[(x.dim() - dim - 1) * 2] = window_size - lookahead - 1
|
481 |
+
pad[(x.dim() - dim - 1) * 2 + 1] = lookahead
|
482 |
+
else:
|
483 |
+
pad[(-dim - 1) * 2] = window_size - lookahead - 1
|
484 |
+
pad[(-dim - 1) * 2 + 1] = lookahead
|
485 |
+
return F.pad(x, pad)
|
486 |
+
|
487 |
+
|
488 |
+
def as_strided(x: Tensor, window_size: int, lookahead: int, step: int = 1, dim: int = 0) -> Tensor:
|
489 |
+
shape = list(x.shape)
|
490 |
+
shape.insert(dim + 1, window_size)
|
491 |
+
x = spec_pad(x, window_size, lookahead, dim=dim)
|
492 |
+
# torch.fx workaround
|
493 |
+
step = 1
|
494 |
+
stride = [x.stride(0), x.stride(1), x.stride(2), x.stride(3)]
|
495 |
+
stride.insert(dim, stride[dim] * step)
|
496 |
+
return torch.as_strided(x, shape, stride)
|
497 |
+
|
498 |
+
|
499 |
+
class GroupedGRULayer(nn.Module):
|
500 |
+
input_size: Final[int]
|
501 |
+
hidden_size: Final[int]
|
502 |
+
out_size: Final[int]
|
503 |
+
bidirectional: Final[bool]
|
504 |
+
num_directions: Final[int]
|
505 |
+
groups: Final[int]
|
506 |
+
batch_first: Final[bool]
|
507 |
+
|
508 |
+
def __init__(
|
509 |
+
self,
|
510 |
+
input_size: int,
|
511 |
+
hidden_size: int,
|
512 |
+
groups: int,
|
513 |
+
batch_first: bool = True,
|
514 |
+
bias: bool = True,
|
515 |
+
dropout: float = 0,
|
516 |
+
bidirectional: bool = False,
|
517 |
+
):
|
518 |
+
super().__init__()
|
519 |
+
assert input_size % groups == 0
|
520 |
+
assert hidden_size % groups == 0
|
521 |
+
kwargs = {
|
522 |
+
"bias": bias,
|
523 |
+
"batch_first": batch_first,
|
524 |
+
"dropout": dropout,
|
525 |
+
"bidirectional": bidirectional,
|
526 |
+
}
|
527 |
+
self.input_size = input_size // groups
|
528 |
+
self.hidden_size = hidden_size // groups
|
529 |
+
self.out_size = hidden_size
|
530 |
+
self.bidirectional = bidirectional
|
531 |
+
self.num_directions = 2 if bidirectional else 1
|
532 |
+
self.groups = groups
|
533 |
+
self.batch_first = batch_first
|
534 |
+
assert (self.hidden_size % groups) == 0, "Hidden size must be divisible by groups"
|
535 |
+
self.layers = nn.ModuleList(
|
536 |
+
(nn.GRU(self.input_size, self.hidden_size, **kwargs) for _ in range(groups))
|
537 |
+
)
|
538 |
+
|
539 |
+
def flatten_parameters(self):
|
540 |
+
for layer in self.layers:
|
541 |
+
layer.flatten_parameters()
|
542 |
+
|
543 |
+
def get_h0(self, batch_size: int = 1, device: torch.device = torch.device("cpu")):
|
544 |
+
return torch.zeros(
|
545 |
+
self.groups * self.num_directions,
|
546 |
+
batch_size,
|
547 |
+
self.hidden_size,
|
548 |
+
device=device,
|
549 |
+
)
|
550 |
+
|
551 |
+
def forward(self, input: Tensor, h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
552 |
+
# input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size
|
553 |
+
# state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size
|
554 |
+
|
555 |
+
if h0 is None:
|
556 |
+
dim0, dim1 = input.shape[:2]
|
557 |
+
bs = dim0 if self.batch_first else dim1
|
558 |
+
h0 = self.get_h0(bs, device=input.device)
|
559 |
+
outputs: List[Tensor] = []
|
560 |
+
outstates: List[Tensor] = []
|
561 |
+
for i, layer in enumerate(self.layers):
|
562 |
+
o, s = layer(
|
563 |
+
input[..., i * self.input_size : (i + 1) * self.input_size],
|
564 |
+
h0[i * self.num_directions : (i + 1) * self.num_directions].detach(),
|
565 |
+
)
|
566 |
+
outputs.append(o)
|
567 |
+
outstates.append(s)
|
568 |
+
output = torch.cat(outputs, dim=-1)
|
569 |
+
h = torch.cat(outstates, dim=0)
|
570 |
+
return output, h
|
571 |
+
|
572 |
+
|
573 |
+
class GroupedGRU(nn.Module):
|
574 |
+
groups: Final[int]
|
575 |
+
num_layers: Final[int]
|
576 |
+
batch_first: Final[bool]
|
577 |
+
hidden_size: Final[int]
|
578 |
+
bidirectional: Final[bool]
|
579 |
+
num_directions: Final[int]
|
580 |
+
shuffle: Final[bool]
|
581 |
+
add_outputs: Final[bool]
|
582 |
+
|
583 |
+
def __init__(
|
584 |
+
self,
|
585 |
+
input_size: int,
|
586 |
+
hidden_size: int,
|
587 |
+
num_layers: int = 1,
|
588 |
+
groups: int = 4,
|
589 |
+
bias: bool = True,
|
590 |
+
batch_first: bool = True,
|
591 |
+
dropout: float = 0,
|
592 |
+
bidirectional: bool = False,
|
593 |
+
shuffle: bool = True,
|
594 |
+
add_outputs: bool = False,
|
595 |
+
):
|
596 |
+
super().__init__()
|
597 |
+
kwargs = {
|
598 |
+
"groups": groups,
|
599 |
+
"bias": bias,
|
600 |
+
"batch_first": batch_first,
|
601 |
+
"dropout": dropout,
|
602 |
+
"bidirectional": bidirectional,
|
603 |
+
}
|
604 |
+
assert input_size % groups == 0
|
605 |
+
assert hidden_size % groups == 0
|
606 |
+
assert num_layers > 0
|
607 |
+
self.input_size = input_size
|
608 |
+
self.groups = groups
|
609 |
+
self.num_layers = num_layers
|
610 |
+
self.batch_first = batch_first
|
611 |
+
self.hidden_size = hidden_size // groups
|
612 |
+
self.bidirectional = bidirectional
|
613 |
+
self.num_directions = 2 if bidirectional else 1
|
614 |
+
if groups == 1:
|
615 |
+
shuffle = False # Fully connected, no need to shuffle
|
616 |
+
self.shuffle = shuffle
|
617 |
+
self.add_outputs = add_outputs
|
618 |
+
self.grus: List[GroupedGRULayer] = nn.ModuleList() # type: ignore
|
619 |
+
self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs))
|
620 |
+
for _ in range(1, num_layers):
|
621 |
+
self.grus.append(GroupedGRULayer(hidden_size, hidden_size, **kwargs))
|
622 |
+
self.flatten_parameters()
|
623 |
+
|
624 |
+
def flatten_parameters(self):
|
625 |
+
for gru in self.grus:
|
626 |
+
gru.flatten_parameters()
|
627 |
+
|
628 |
+
def get_h0(self, batch_size: int, device: torch.device = torch.device("cpu")) -> Tensor:
|
629 |
+
return torch.zeros(
|
630 |
+
(self.num_layers * self.groups * self.num_directions, batch_size, self.hidden_size),
|
631 |
+
device=device,
|
632 |
+
)
|
633 |
+
|
634 |
+
def forward(self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
635 |
+
dim0, dim1, _ = input.shape
|
636 |
+
b = dim0 if self.batch_first else dim1
|
637 |
+
if state is None:
|
638 |
+
state = self.get_h0(b, input.device)
|
639 |
+
output = torch.zeros(
|
640 |
+
dim0, dim1, self.hidden_size * self.num_directions * self.groups, device=input.device
|
641 |
+
)
|
642 |
+
outstates = []
|
643 |
+
h = self.groups * self.num_directions
|
644 |
+
for i, gru in enumerate(self.grus):
|
645 |
+
input, s = gru(input, state[i * h : (i + 1) * h])
|
646 |
+
outstates.append(s)
|
647 |
+
if self.shuffle and i < self.num_layers - 1:
|
648 |
+
input = (
|
649 |
+
input.view(dim0, dim1, -1, self.groups).transpose(2, 3).reshape(dim0, dim1, -1)
|
650 |
+
)
|
651 |
+
if self.add_outputs:
|
652 |
+
output += input
|
653 |
+
else:
|
654 |
+
output = input
|
655 |
+
outstate = torch.cat(outstates, dim=0)
|
656 |
+
return output, outstate
|
657 |
+
|
658 |
+
|
659 |
+
class SqueezedGRU(nn.Module):
|
660 |
+
input_size: Final[int]
|
661 |
+
hidden_size: Final[int]
|
662 |
+
|
663 |
+
def __init__(
|
664 |
+
self,
|
665 |
+
input_size: int,
|
666 |
+
hidden_size: int,
|
667 |
+
output_size: Optional[int] = None,
|
668 |
+
num_layers: int = 1,
|
669 |
+
linear_groups: int = 8,
|
670 |
+
batch_first: bool = True,
|
671 |
+
gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None,
|
672 |
+
linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity,
|
673 |
+
):
|
674 |
+
super().__init__()
|
675 |
+
self.input_size = input_size
|
676 |
+
self.hidden_size = hidden_size
|
677 |
+
self.linear_in = nn.Sequential(
|
678 |
+
GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer()
|
679 |
+
)
|
680 |
+
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first)
|
681 |
+
self.gru_skip = gru_skip_op() if gru_skip_op is not None else None
|
682 |
+
if output_size is not None:
|
683 |
+
self.linear_out = nn.Sequential(
|
684 |
+
GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer()
|
685 |
+
)
|
686 |
+
else:
|
687 |
+
self.linear_out = nn.Identity()
|
688 |
+
|
689 |
+
def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]:
|
690 |
+
input = self.linear_in(input)
|
691 |
+
x, h = self.gru(input, h)
|
692 |
+
if self.gru_skip is not None:
|
693 |
+
x = x + self.gru_skip(input)
|
694 |
+
x = self.linear_out(x)
|
695 |
+
return x, h
|
696 |
+
|
697 |
+
|
698 |
+
class GroupedLinearEinsum(nn.Module):
|
699 |
+
input_size: Final[int]
|
700 |
+
hidden_size: Final[int]
|
701 |
+
groups: Final[int]
|
702 |
+
|
703 |
+
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
|
704 |
+
super().__init__()
|
705 |
+
# self.weight: Tensor
|
706 |
+
self.input_size = input_size
|
707 |
+
self.hidden_size = hidden_size
|
708 |
+
self.groups = groups
|
709 |
+
assert input_size % groups == 0
|
710 |
+
self.ws = input_size // groups
|
711 |
+
self.register_parameter(
|
712 |
+
"weight",
|
713 |
+
Parameter(
|
714 |
+
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
|
715 |
+
),
|
716 |
+
)
|
717 |
+
self.reset_parameters()
|
718 |
+
|
719 |
+
def reset_parameters(self):
|
720 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
|
721 |
+
|
722 |
+
def forward(self, x: Tensor) -> Tensor:
|
723 |
+
# x: [..., I]
|
724 |
+
x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
|
725 |
+
x = torch.einsum("...gi,...gih->...gh", x, self.weight) # [..., G, H/G]
|
726 |
+
x = x.flatten(2, 3) # [B, T, H]
|
727 |
+
return x
|
728 |
+
|
729 |
+
|
730 |
+
class GroupedLinear(nn.Module):
|
731 |
+
input_size: Final[int]
|
732 |
+
hidden_size: Final[int]
|
733 |
+
groups: Final[int]
|
734 |
+
shuffle: Final[bool]
|
735 |
+
|
736 |
+
def __init__(self, input_size: int, hidden_size: int, groups: int = 1, shuffle: bool = True):
|
737 |
+
super().__init__()
|
738 |
+
assert input_size % groups == 0
|
739 |
+
assert hidden_size % groups == 0
|
740 |
+
self.groups = groups
|
741 |
+
self.input_size = input_size // groups
|
742 |
+
self.hidden_size = hidden_size // groups
|
743 |
+
if groups == 1:
|
744 |
+
shuffle = False
|
745 |
+
self.shuffle = shuffle
|
746 |
+
self.layers = nn.ModuleList(
|
747 |
+
nn.Linear(self.input_size, self.hidden_size) for _ in range(groups)
|
748 |
+
)
|
749 |
+
|
750 |
+
def forward(self, x: Tensor) -> Tensor:
|
751 |
+
outputs: List[Tensor] = []
|
752 |
+
for i, layer in enumerate(self.layers):
|
753 |
+
outputs.append(layer(x[..., i * self.input_size : (i + 1) * self.input_size]))
|
754 |
+
output = torch.cat(outputs, dim=-1)
|
755 |
+
if self.shuffle:
|
756 |
+
orig_shape = output.shape
|
757 |
+
output = (
|
758 |
+
output.view(-1, self.hidden_size, self.groups).transpose(-1, -2).reshape(orig_shape)
|
759 |
+
)
|
760 |
+
return output
|
761 |
+
|
762 |
+
|
763 |
+
class LocalSnrTarget(nn.Module):
|
764 |
+
def __init__(
|
765 |
+
self, ws: int = 20, db: bool = True, ws_ns: Optional[int] = None, target_snr_range=None
|
766 |
+
):
|
767 |
+
super().__init__()
|
768 |
+
self.ws = self.calc_ws(ws)
|
769 |
+
self.ws_ns = self.ws * 2 if ws_ns is None else self.calc_ws(ws_ns)
|
770 |
+
self.db = db
|
771 |
+
self.range = target_snr_range
|
772 |
+
|
773 |
+
def calc_ws(self, ws_ms: int) -> int:
|
774 |
+
# Calculates windows size in stft domain given a window size in ms
|
775 |
+
p = ModelParams()
|
776 |
+
ws = ws_ms - p.fft_size / p.sr * 1000 # length ms of an fft_window
|
777 |
+
ws = 1 + ws / (p.hop_size / p.sr * 1000) # consider hop_size
|
778 |
+
return max(int(round(ws)), 1)
|
779 |
+
|
780 |
+
def forward(self, clean: Tensor, noise: Tensor, max_bin: Optional[int] = None) -> Tensor:
|
781 |
+
# clean: [B, 1, T, F]
|
782 |
+
# out: [B, T']
|
783 |
+
if max_bin is not None:
|
784 |
+
clean = as_complex(clean[..., :max_bin])
|
785 |
+
noise = as_complex(noise[..., :max_bin])
|
786 |
+
return (
|
787 |
+
local_snr(clean, noise, window_size=self.ws, db=self.db, window_size_ns=self.ws_ns)[0]
|
788 |
+
.clamp(self.range[0], self.range[1])
|
789 |
+
.squeeze(1)
|
790 |
+
)
|
791 |
+
|
792 |
+
|
793 |
+
def _local_energy(x: Tensor, ws: int, device: torch.device) -> Tensor:
|
794 |
+
if (ws % 2) == 0:
|
795 |
+
ws += 1
|
796 |
+
ws_half = ws // 2
|
797 |
+
x = F.pad(x.pow(2).sum(-1).sum(-1), (ws_half, ws_half, 0, 0))
|
798 |
+
w = torch.hann_window(ws, device=device, dtype=x.dtype)
|
799 |
+
x = x.unfold(-1, size=ws, step=1) * w
|
800 |
+
return torch.sum(x, dim=-1).div(ws)
|
801 |
+
|
802 |
+
|
803 |
+
def local_snr(
|
804 |
+
clean: Tensor,
|
805 |
+
noise: Tensor,
|
806 |
+
window_size: int,
|
807 |
+
db: bool = False,
|
808 |
+
window_size_ns: Optional[int] = None,
|
809 |
+
eps: float = 1e-12,
|
810 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
811 |
+
# clean shape: [B, C, T, F]
|
812 |
+
clean = as_real(clean)
|
813 |
+
noise = as_real(noise)
|
814 |
+
assert clean.dim() == 5
|
815 |
+
|
816 |
+
E_speech = _local_energy(clean, window_size, clean.device)
|
817 |
+
window_size_ns = window_size if window_size_ns is None else window_size_ns
|
818 |
+
E_noise = _local_energy(noise, window_size_ns, clean.device)
|
819 |
+
|
820 |
+
snr = E_speech / E_noise.clamp_min(eps)
|
821 |
+
if db:
|
822 |
+
snr = snr.clamp_min(eps).log10().mul(10)
|
823 |
+
return snr, E_speech, E_noise
|
824 |
+
|
825 |
+
|
826 |
+
def test_grouped_gru():
|
827 |
+
from icecream import ic
|
828 |
+
|
829 |
+
g = 2 # groups
|
830 |
+
h = 4 # hidden_size
|
831 |
+
i = 2 # input_size
|
832 |
+
b = 1 # batch_size
|
833 |
+
t = 5 # time_steps
|
834 |
+
m = GroupedGRULayer(i, h, g, batch_first=True)
|
835 |
+
ic(m)
|
836 |
+
input = torch.randn((b, t, i))
|
837 |
+
h0 = m.get_h0(b)
|
838 |
+
assert list(h0.shape) == [g, b, h // g]
|
839 |
+
out, hout = m(input, h0)
|
840 |
+
|
841 |
+
# Should be exportable as raw nn.Module
|
842 |
+
torch.onnx.export(
|
843 |
+
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
|
844 |
+
)
|
845 |
+
# Should be exportable as traced
|
846 |
+
m = torch.jit.trace(m, (input, h0))
|
847 |
+
torch.onnx.export(
|
848 |
+
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
|
849 |
+
)
|
850 |
+
# and as scripted module
|
851 |
+
m = torch.jit.script(m)
|
852 |
+
torch.onnx.export(
|
853 |
+
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
|
854 |
+
)
|
855 |
+
|
856 |
+
# now grouped gru
|
857 |
+
num = 2
|
858 |
+
m = GroupedGRU(i, h, num, g, batch_first=True, shuffle=True)
|
859 |
+
ic(m)
|
860 |
+
h0 = m.get_h0(b)
|
861 |
+
assert list(h0.shape) == [num * g, b, h // g]
|
862 |
+
out, hout = m(input, h0)
|
863 |
+
|
864 |
+
# Should be exportable as traced
|
865 |
+
m = torch.jit.trace(m, (input, h0))
|
866 |
+
torch.onnx.export(
|
867 |
+
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
|
868 |
+
)
|
869 |
+
# and scripted module
|
870 |
+
m = torch.jit.script(m)
|
871 |
+
torch.onnx.export(
|
872 |
+
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
|
873 |
+
)
|
874 |
+
|
875 |
+
|
876 |
+
def test_erb():
|
877 |
+
import libdf
|
878 |
+
from df.config import config
|
879 |
+
|
880 |
+
config.use_defaults()
|
881 |
+
p = ModelParams()
|
882 |
+
n_freq = p.fft_size // 2 + 1
|
883 |
+
df_state = libdf.DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb)
|
884 |
+
erb = erb_fb(df_state.erb_widths(), p.sr)
|
885 |
+
erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True)
|
886 |
+
input = torch.randn((1, 1, 1, n_freq), dtype=torch.complex64)
|
887 |
+
input_abs = input.abs().square()
|
888 |
+
erb_widths = df_state.erb_widths()
|
889 |
+
df_erb = torch.from_numpy(libdf.erb(input.numpy(), erb_widths, False))
|
890 |
+
py_erb = torch.matmul(input_abs, erb)
|
891 |
+
assert torch.allclose(df_erb, py_erb)
|
892 |
+
df_out = torch.from_numpy(libdf.erb_inv(df_erb.numpy(), erb_widths))
|
893 |
+
py_out = torch.matmul(py_erb, erb_inverse)
|
894 |
+
assert torch.allclose(df_out, py_out)
|
895 |
+
|
896 |
+
|
897 |
+
def test_unit_norm():
|
898 |
+
from df.config import config
|
899 |
+
from libdf import unit_norm
|
900 |
+
|
901 |
+
config.use_defaults()
|
902 |
+
p = ModelParams()
|
903 |
+
b = 2
|
904 |
+
F = p.nb_df
|
905 |
+
t = 100
|
906 |
+
spec = torch.randn(b, 1, t, F, 2)
|
907 |
+
alpha = get_norm_alpha(log=False)
|
908 |
+
# Expects complex input of shape [C, T, F]
|
909 |
+
norm_lib = torch.as_tensor(unit_norm(torch.view_as_complex(spec).squeeze(1).numpy(), alpha))
|
910 |
+
m = ExponentialUnitNorm(alpha, F)
|
911 |
+
norm_torch = torch.view_as_complex(m(spec).squeeze(1))
|
912 |
+
assert torch.allclose(norm_lib.real, norm_torch.real)
|
913 |
+
assert torch.allclose(norm_lib.imag, norm_torch.imag)
|
914 |
+
assert torch.allclose(norm_lib.abs(), norm_torch.abs())
|
915 |
+
|
916 |
+
|
917 |
+
def test_dfop():
|
918 |
+
from df.config import config
|
919 |
+
|
920 |
+
config.use_defaults()
|
921 |
+
p = ModelParams()
|
922 |
+
f = p.nb_df
|
923 |
+
F = f * 2
|
924 |
+
o = p.df_order
|
925 |
+
d = p.df_lookahead
|
926 |
+
t = 100
|
927 |
+
spec = torch.randn(1, 1, t, F, 2)
|
928 |
+
coefs = torch.randn(1, t, o, f, 2)
|
929 |
+
alpha = torch.randn(1, t, 1)
|
930 |
+
dfop = DfOp(df_bins=p.nb_df)
|
931 |
+
dfop.set_forward("real_loop")
|
932 |
+
out1 = dfop(spec, coefs, alpha)
|
933 |
+
dfop.set_forward("real_strided")
|
934 |
+
out2 = dfop(spec, coefs, alpha)
|
935 |
+
dfop.set_forward("real_unfold")
|
936 |
+
out3 = dfop(spec, coefs, alpha)
|
937 |
+
dfop.set_forward("complex_strided")
|
938 |
+
out4 = dfop(spec, coefs, alpha)
|
939 |
+
torch.testing.assert_allclose(out1, out2)
|
940 |
+
torch.testing.assert_allclose(out1, out3)
|
941 |
+
torch.testing.assert_allclose(out1, out4)
|
942 |
+
# This forward method requires external padding/lookahead as well as spectrogram buffer
|
943 |
+
# handling, i.e. via a ring buffer. Could be used in real time usage.
|
944 |
+
dfop.set_forward("real_one_step")
|
945 |
+
spec_padded = spec_pad(spec, o, d, dim=-3)
|
946 |
+
out5 = torch.zeros_like(out1)
|
947 |
+
for i in range(t):
|
948 |
+
out5[:, :, i] = dfop(
|
949 |
+
spec_padded[:, :, i : i + o], coefs[:, i].unsqueeze(1), alpha[:, i].unsqueeze(1)
|
950 |
+
)
|
951 |
+
torch.testing.assert_allclose(out1, out5)
|
952 |
+
# Forward method that does the padding/lookahead handling using an internal hidden state.
|
953 |
+
dfop.freq_bins = F
|
954 |
+
dfop.set_forward("real_hidden_state_loop")
|
955 |
+
out6 = dfop(spec, coefs, alpha)
|
956 |
+
torch.testing.assert_allclose(out1, out6)
|
df/multiframe.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Dict, Final
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor, nn
|
7 |
+
|
8 |
+
|
9 |
+
class MultiFrameModule(nn.Module, ABC):
|
10 |
+
"""Multi-frame speech enhancement modules.
|
11 |
+
|
12 |
+
Signal model and notation:
|
13 |
+
Noisy: `x = s + n`
|
14 |
+
Enhanced: `y = f(x)`
|
15 |
+
Objective: `min ||s - y||`
|
16 |
+
|
17 |
+
PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD.
|
18 |
+
IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx`
|
19 |
+
"""
|
20 |
+
|
21 |
+
num_freqs: Final[int]
|
22 |
+
frame_size: Final[int]
|
23 |
+
need_unfold: Final[bool]
|
24 |
+
|
25 |
+
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
|
26 |
+
"""Multi-Frame filtering module.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
num_freqs (int): Number of frequency bins used for filtering.
|
30 |
+
frame_size (int): Frame size in FD domain.
|
31 |
+
lookahead (int): Lookahead, may be used to select the output time step. Note: This
|
32 |
+
module does not add additional padding according to lookahead!
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
self.num_freqs = num_freqs
|
36 |
+
self.frame_size = frame_size
|
37 |
+
self.pad = nn.ConstantPad2d((0, 0, frame_size - 1, 0), 0.0)
|
38 |
+
self.need_unfold = frame_size > 1
|
39 |
+
self.lookahead = lookahead
|
40 |
+
|
41 |
+
def spec_unfold(self, spec: Tensor):
|
42 |
+
"""Pads and unfolds the spectrogram according to frame_size.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
spec (complex Tensor): Spectrogram of shape [B, C, T, F]
|
46 |
+
Returns:
|
47 |
+
spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
|
48 |
+
"""
|
49 |
+
if self.need_unfold:
|
50 |
+
return self.pad(spec).unfold(2, self.frame_size, 1)
|
51 |
+
return spec.unsqueeze(-1)
|
52 |
+
|
53 |
+
def forward(self, spec: Tensor, coefs: Tensor):
|
54 |
+
"""Pads and unfolds the spectrogram and forwards to impl.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
spec (Tensor): Spectrogram of shape [B, C, T, F, 2]
|
58 |
+
coefs (Tensor): Spectrogram of shape [B, C, T, F, 2]
|
59 |
+
"""
|
60 |
+
spec_u = self.spec_unfold(torch.view_as_complex(spec))
|
61 |
+
coefs = torch.view_as_complex(coefs)
|
62 |
+
spec_f = spec_u.narrow(-2, 0, self.num_freqs)
|
63 |
+
spec_f = self.forward_impl(spec_f, coefs)
|
64 |
+
if self.training:
|
65 |
+
spec = spec.clone()
|
66 |
+
spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f)
|
67 |
+
return spec
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def forward_impl(self, spec: Tensor, coefs: Tensor) -> Tensor:
|
71 |
+
"""Forward impl taking complex spectrogram and coefficients.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
spec (complex Tensor): Spectrogram of shape [B, C1, T, F, N]
|
75 |
+
coefs (complex Tensor): Coefficients [B, C2, T, F]
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
spec (complex Tensor): Enhanced spectrogram of shape [B, C1, T, F]
|
79 |
+
"""
|
80 |
+
...
|
81 |
+
|
82 |
+
@abstractmethod
|
83 |
+
def num_channels(self) -> int:
|
84 |
+
"""Return the number of required channels.
|
85 |
+
|
86 |
+
If multiple inputs are required, then all these should be combined in one Tensor containing
|
87 |
+
the summed channels.
|
88 |
+
"""
|
89 |
+
...
|
90 |
+
|
91 |
+
|
92 |
+
def psd(x: Tensor, n: int) -> Tensor:
|
93 |
+
"""Compute the PSD correlation matrix Rxx for a spectrogram.
|
94 |
+
|
95 |
+
That is, `X*conj(X)`, where `*` is the outer product.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
x (complex Tensor): Spectrogram of shape [B, C, T, F]. Will be unfolded with `n` steps over
|
99 |
+
the time axis.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Rxx (complex Tensor): Correlation matrix of shape [B, C, T, F, N, N]
|
103 |
+
"""
|
104 |
+
x = F.pad(x, (0, 0, n - 1, 0)).unfold(-2, n, 1)
|
105 |
+
return torch.einsum("...n,...m->...mn", x, x.conj())
|
106 |
+
|
107 |
+
|
108 |
+
def df(spec: Tensor, coefs: Tensor) -> Tensor:
|
109 |
+
"""Deep filter implemenation using `torch.einsum`. Requires unfolded spectrogram.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
spec (complex Tensor): Spectrogram of shape [B, C, T, F, N]
|
113 |
+
coefs (complex Tensor): Spectrogram of shape [B, C, N, T, F]
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
spec (complex Tensor): Spectrogram of shape [B, C, T, F]
|
117 |
+
"""
|
118 |
+
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
|
119 |
+
|
120 |
+
|
121 |
+
class CRM(MultiFrameModule):
|
122 |
+
"""Complex ratio mask."""
|
123 |
+
|
124 |
+
def __init__(self, num_freqs: int, frame_size: int = 1, lookahead: int = 0):
|
125 |
+
assert frame_size == 1 and lookahead == 0, (frame_size, lookahead)
|
126 |
+
super().__init__(num_freqs, 1)
|
127 |
+
|
128 |
+
def forward_impl(self, spec: Tensor, coefs: Tensor):
|
129 |
+
return spec.squeeze(-1).mul(coefs)
|
130 |
+
|
131 |
+
def num_channels(self):
|
132 |
+
return 2
|
133 |
+
|
134 |
+
|
135 |
+
class DF(MultiFrameModule):
|
136 |
+
conj: Final[bool]
|
137 |
+
"""Deep Filtering."""
|
138 |
+
|
139 |
+
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False):
|
140 |
+
super().__init__(num_freqs, frame_size, lookahead)
|
141 |
+
self.conj = conj
|
142 |
+
|
143 |
+
def forward_impl(self, spec: Tensor, coefs: Tensor):
|
144 |
+
coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:])
|
145 |
+
if self.conj:
|
146 |
+
coefs = coefs.conj()
|
147 |
+
return df(spec, coefs)
|
148 |
+
|
149 |
+
def num_channels(self):
|
150 |
+
return self.frame_size * 2
|
151 |
+
|
152 |
+
|
153 |
+
class MfWf(MultiFrameModule):
|
154 |
+
"""Multi-frame Wiener filter base module."""
|
155 |
+
|
156 |
+
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
|
157 |
+
"""Multi-frame Wiener Filter.
|
158 |
+
|
159 |
+
Several implementation methods are available resulting in different number of required input
|
160 |
+
coefficient channels.
|
161 |
+
|
162 |
+
Methods:
|
163 |
+
psd_ifc: Predict PSD `Rxx` and IFC `rss`.
|
164 |
+
df: Use deep filtering to predict speech and noisy spectrograms. These will be used for
|
165 |
+
PSD calculation for Wiener filtering. Alias: `df_sx`
|
166 |
+
c: Directly predict Wiener filter coefficients. Computation same as deep filtering.
|
167 |
+
|
168 |
+
"""
|
169 |
+
super().__init__(num_freqs, frame_size, lookahead=0)
|
170 |
+
self.idx = -lookahead
|
171 |
+
|
172 |
+
def num_channels(self):
|
173 |
+
return self.num_channels
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> Tensor:
|
177 |
+
return torch.einsum(
|
178 |
+
"...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss
|
179 |
+
) # [T, F, N]
|
180 |
+
|
181 |
+
@abstractmethod
|
182 |
+
def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor:
|
183 |
+
"""Multi-frame Wiener filter impl taking complex spectrogram and coefficients.
|
184 |
+
|
185 |
+
Coefficients may be split into multiple parts w.g. for multiple DF coefs or PSDs.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
spec (complex Tensor): Spectrogram of shape [B, C1, T, F, N]
|
189 |
+
coefs (complex Tensor): Coefficients [B, C2, T, F]
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
c (complex Tensor): MfWf coefs of shape [B, C1, T, F, N]
|
193 |
+
"""
|
194 |
+
...
|
195 |
+
|
196 |
+
def forward_impl(self, spec: Tensor, coefs: Tensor) -> Tensor:
|
197 |
+
coefs = self.mfwf(spec, coefs)
|
198 |
+
return self.apply_coefs(spec, coefs)
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def apply_coefs(spec: Tensor, coefs: Tensor) -> Tensor:
|
202 |
+
# spec: [B, C, T, F, N]
|
203 |
+
# coefs: [B, C, T, F, N]
|
204 |
+
return torch.einsum("...n,...n->...", spec, coefs)
|
205 |
+
|
206 |
+
|
207 |
+
class MfWfDf(MfWf):
|
208 |
+
eps_diag: Final[float]
|
209 |
+
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
num_freqs: int,
|
213 |
+
frame_size: int,
|
214 |
+
lookahead: int = 0,
|
215 |
+
eps_diag: float = 1e-7,
|
216 |
+
eps: float = 1e-7,
|
217 |
+
):
|
218 |
+
super().__init__(num_freqs, frame_size, lookahead)
|
219 |
+
self.eps_diag = eps_diag
|
220 |
+
self.eps = eps
|
221 |
+
|
222 |
+
def num_channels(self):
|
223 |
+
# frame_size/df_order * 2 (x/s) * 2 (re/im)
|
224 |
+
return self.frame_size * 4
|
225 |
+
|
226 |
+
def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor:
|
227 |
+
coefs.chunk
|
228 |
+
df_s, df_x = torch.chunk(coefs, 2, 1) # [B, C, T, F, N]
|
229 |
+
df_s = df_s.unflatten(1, (-1, self.frame_size))
|
230 |
+
df_x = df_x.unflatten(1, (-1, self.frame_size))
|
231 |
+
spec_s = df(spec, df_s) # [B, C, T, F]
|
232 |
+
spec_x = df(spec, df_x)
|
233 |
+
Rss = psd(spec_s, self.frame_size) # [B, C, T, F, N. N]
|
234 |
+
Rxx = psd(spec_x, self.frame_size)
|
235 |
+
rss = Rss[..., -1] # TODO: use -1 or self.idx?
|
236 |
+
c = self.solve(Rxx, rss, self.eps_diag, self.eps) # [B, C, T, F, N]
|
237 |
+
return c
|
238 |
+
|
239 |
+
|
240 |
+
class MfWfPsd(MfWf):
|
241 |
+
"""Multi-frame Wiener filter by predicting noisy PSD `Rxx` and speech IFC `rss`."""
|
242 |
+
|
243 |
+
def num_channels(self):
|
244 |
+
# (Rxx + rss) * 2 (re/im)
|
245 |
+
return (self.frame_size**2 + self.frame_size) * 2
|
246 |
+
|
247 |
+
def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: # type: ignore
|
248 |
+
Rxx, rss = torch.split(coefs.movedim(1, -1), [self.frame_size**2, self.frame_size], -1)
|
249 |
+
c = self.solve(Rxx.unflatten(-1, (self.frame_size, self.frame_size)), rss)
|
250 |
+
return c
|
251 |
+
|
252 |
+
|
253 |
+
class MfWfC(MfWf):
|
254 |
+
"""Multi-frame Wiener filter by directly predicting the MfWf coefficients."""
|
255 |
+
|
256 |
+
def num_channels(self):
|
257 |
+
# mfwf coefs * 2 (re/im)
|
258 |
+
return self.frame_size * 2
|
259 |
+
|
260 |
+
def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: # type: ignore
|
261 |
+
coefs = coefs.unflatten(1, (-1, self.frame_size)).permute(
|
262 |
+
0, 1, 3, 4, 2
|
263 |
+
) # [B, C*N, T, F] -> [B, C, T, F, N]
|
264 |
+
return coefs
|
265 |
+
|
266 |
+
|
267 |
+
class MvdrSouden(MultiFrameModule):
|
268 |
+
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
|
269 |
+
super().__init__(num_freqs, frame_size, lookahead)
|
270 |
+
|
271 |
+
|
272 |
+
class MvdrEvd(MultiFrameModule):
|
273 |
+
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
|
274 |
+
super().__init__(num_freqs, frame_size, lookahead)
|
275 |
+
|
276 |
+
|
277 |
+
class MvdrRtfPower(MultiFrameModule):
|
278 |
+
def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0):
|
279 |
+
super().__init__(num_freqs, frame_size, lookahead)
|
280 |
+
|
281 |
+
|
282 |
+
MF_METHODS: Dict[str, MultiFrameModule] = {
|
283 |
+
"crm": CRM,
|
284 |
+
"df": DF,
|
285 |
+
"mfwf_df": MfWfDf,
|
286 |
+
"mfwf_df_sx": MfWfDf,
|
287 |
+
"mfwf_psd": MfWfPsd,
|
288 |
+
"mfwf_psd_ifc": MfWfPsd,
|
289 |
+
"mfwf_c": MfWfC,
|
290 |
+
}
|
291 |
+
|
292 |
+
|
293 |
+
# From torchaudio
|
294 |
+
def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
|
295 |
+
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
|
296 |
+
Args:
|
297 |
+
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
|
298 |
+
dim1 (int, optional): the first dimension of the diagonal matrix
|
299 |
+
(Default: -1)
|
300 |
+
dim2 (int, optional): the second dimension of the diagonal matrix
|
301 |
+
(Default: -2)
|
302 |
+
Returns:
|
303 |
+
Tensor: trace of the input Tensor
|
304 |
+
"""
|
305 |
+
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
|
306 |
+
assert (
|
307 |
+
input.shape[dim1] == input.shape[dim2]
|
308 |
+
), "The size of ``dim1`` and ``dim2`` must be the same."
|
309 |
+
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
|
310 |
+
return input.sum(dim=-1)
|
311 |
+
|
312 |
+
|
313 |
+
def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
|
314 |
+
"""Perform Tikhonov regularization (only modifying real part).
|
315 |
+
Args:
|
316 |
+
mat (torch.Tensor): input matrix (..., channel, channel)
|
317 |
+
reg (float, optional): regularization factor (Default: 1e-8)
|
318 |
+
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
|
319 |
+
Returns:
|
320 |
+
Tensor: regularized matrix (..., channel, channel)
|
321 |
+
"""
|
322 |
+
# Add eps
|
323 |
+
C = mat.size(-1)
|
324 |
+
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
|
325 |
+
epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
|
326 |
+
# in case that correlation_matrix is all-zero
|
327 |
+
epsilon = epsilon + eps
|
328 |
+
mat = mat + epsilon * eye[..., :, :]
|
329 |
+
return mat
|
df/utils.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import subprocess
|
6 |
+
from socket import gethostname
|
7 |
+
from typing import Any, Dict, Set, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from loguru import logger
|
12 |
+
from torch import Tensor
|
13 |
+
#from torch._six import string_classes
|
14 |
+
from torch.autograd import Function
|
15 |
+
from torch.types import Number
|
16 |
+
|
17 |
+
from df.config import config
|
18 |
+
from df.model import ModelParams
|
19 |
+
|
20 |
+
try:
|
21 |
+
from torchaudio.functional import resample as ta_resample
|
22 |
+
except ImportError:
|
23 |
+
from torchaudio.compliance.kaldi import resample_waveform as ta_resample # type: ignore
|
24 |
+
|
25 |
+
|
26 |
+
def get_resample_params(method: str) -> Dict[str, Any]:
|
27 |
+
params = {
|
28 |
+
"sinc_fast": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 16},
|
29 |
+
"sinc_best": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 64},
|
30 |
+
"kaiser_fast": {
|
31 |
+
"resampling_method": "kaiser_window",
|
32 |
+
"lowpass_filter_width": 16,
|
33 |
+
"rolloff": 0.85,
|
34 |
+
"beta": 8.555504641634386,
|
35 |
+
},
|
36 |
+
"kaiser_best": {
|
37 |
+
"resampling_method": "kaiser_window",
|
38 |
+
"lowpass_filter_width": 16,
|
39 |
+
"rolloff": 0.9475937167399596,
|
40 |
+
"beta": 14.769656459379492,
|
41 |
+
},
|
42 |
+
}
|
43 |
+
assert method in params.keys(), f"method must be one of {list(params.keys())}"
|
44 |
+
return params[method]
|
45 |
+
|
46 |
+
|
47 |
+
def resample(audio: Tensor, orig_sr: int, new_sr: int, method="sinc_fast"):
|
48 |
+
params = get_resample_params(method)
|
49 |
+
return ta_resample(audio, orig_sr, new_sr, **params)
|
50 |
+
|
51 |
+
|
52 |
+
def get_device():
|
53 |
+
s = config("DEVICE", default="", section="train")
|
54 |
+
if s == "":
|
55 |
+
if torch.cuda.is_available():
|
56 |
+
DEVICE = torch.device("cuda:0")
|
57 |
+
else:
|
58 |
+
DEVICE = torch.device("cpu")
|
59 |
+
else:
|
60 |
+
DEVICE = torch.device(s)
|
61 |
+
return DEVICE
|
62 |
+
|
63 |
+
|
64 |
+
def as_complex(x: Tensor):
|
65 |
+
if torch.is_complex(x):
|
66 |
+
return x
|
67 |
+
if x.shape[-1] != 2:
|
68 |
+
raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}")
|
69 |
+
if x.stride(-1) != 1:
|
70 |
+
x = x.contiguous()
|
71 |
+
return torch.view_as_complex(x)
|
72 |
+
|
73 |
+
|
74 |
+
def as_real(x: Tensor):
|
75 |
+
if torch.is_complex(x):
|
76 |
+
return torch.view_as_real(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class angle_re_im(Function):
|
81 |
+
"""Similar to torch.angle but robustify the gradient for zero magnitude."""
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def forward(ctx, re: Tensor, im: Tensor):
|
85 |
+
ctx.save_for_backward(re, im)
|
86 |
+
return torch.atan2(im, re)
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def backward(ctx, grad: Tensor) -> Tuple[Tensor, Tensor]:
|
90 |
+
re, im = ctx.saved_tensors
|
91 |
+
grad_inv = grad / (re.square() + im.square()).clamp_min_(1e-10)
|
92 |
+
return -im * grad_inv, re * grad_inv
|
93 |
+
|
94 |
+
|
95 |
+
class angle(Function):
|
96 |
+
"""Similar to torch.angle but robustify the gradient for zero magnitude."""
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def forward(ctx, x: Tensor):
|
100 |
+
ctx.save_for_backward(x)
|
101 |
+
return torch.atan2(x.imag, x.real)
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def backward(ctx, grad: Tensor):
|
105 |
+
(x,) = ctx.saved_tensors
|
106 |
+
grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10)
|
107 |
+
return torch.view_as_complex(torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1))
|
108 |
+
|
109 |
+
|
110 |
+
def check_finite_module(obj, name="Module", _raise=True) -> Set[str]:
|
111 |
+
out: Set[str] = set()
|
112 |
+
if isinstance(obj, torch.nn.Module):
|
113 |
+
for name, child in obj.named_children():
|
114 |
+
out = out | check_finite_module(child, name)
|
115 |
+
for name, param in obj.named_parameters():
|
116 |
+
out = out | check_finite_module(param, name)
|
117 |
+
for name, buf in obj.named_buffers():
|
118 |
+
out = out | check_finite_module(buf, name)
|
119 |
+
if _raise and len(out) > 0:
|
120 |
+
raise ValueError(f"{name} not finite during checkpoint writing including: {out}")
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
def make_np(x: Union[Tensor, np.ndarray, Number]) -> np.ndarray:
|
125 |
+
"""Transforms Tensor to numpy.
|
126 |
+
Args:
|
127 |
+
x: An instance of torch tensor or caffe blob name
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
numpy.array: Numpy array
|
131 |
+
"""
|
132 |
+
if isinstance(x, np.ndarray):
|
133 |
+
return x
|
134 |
+
if np.isscalar(x):
|
135 |
+
return np.array([x])
|
136 |
+
if isinstance(x, Tensor):
|
137 |
+
return x.detach().cpu().numpy()
|
138 |
+
raise NotImplementedError(
|
139 |
+
"Got {}, but numpy array, scalar, or torch tensor are expected.".format(type(x))
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def get_norm_alpha(log: bool = True) -> float:
|
144 |
+
p = ModelParams()
|
145 |
+
a_ = _calculate_norm_alpha(sr=p.sr, hop_size=p.hop_size, tau=p.norm_tau)
|
146 |
+
precision = 3
|
147 |
+
a = 1.0
|
148 |
+
while a >= 1.0:
|
149 |
+
a = round(a_, precision)
|
150 |
+
precision += 1
|
151 |
+
if log:
|
152 |
+
logger.info(f"Running with normalization window alpha = '{a}'")
|
153 |
+
return a
|
154 |
+
|
155 |
+
|
156 |
+
def _calculate_norm_alpha(sr: int, hop_size: int, tau: float):
|
157 |
+
"""Exponential decay factor alpha for a given tau (decay window size [s])."""
|
158 |
+
dt = hop_size / sr
|
159 |
+
return math.exp(-dt / tau)
|
160 |
+
|
161 |
+
|
162 |
+
def check_manual_seed(seed: int = None):
|
163 |
+
"""If manual seed is not specified, choose a random one and communicate it to the user."""
|
164 |
+
seed = seed or random.randint(1, 10000)
|
165 |
+
np.random.seed(seed)
|
166 |
+
random.seed(seed)
|
167 |
+
torch.manual_seed(seed)
|
168 |
+
return seed
|
169 |
+
|
170 |
+
|
171 |
+
def get_git_root():
|
172 |
+
git_local_dir = os.path.dirname(os.path.abspath(__file__))
|
173 |
+
args = ["git", "-C", git_local_dir, "rev-parse", "--show-toplevel"]
|
174 |
+
return subprocess.check_output(args).strip().decode()
|
175 |
+
|
176 |
+
|
177 |
+
def get_commit_hash():
|
178 |
+
"""Returns the current git commit."""
|
179 |
+
try:
|
180 |
+
git_dir = get_git_root()
|
181 |
+
args = ["git", "-C", git_dir, "rev-parse", "--short", "--verify", "HEAD"]
|
182 |
+
commit = subprocess.check_output(args).strip().decode()
|
183 |
+
except subprocess.CalledProcessError:
|
184 |
+
# probably not in git repo
|
185 |
+
commit = None
|
186 |
+
return commit
|
187 |
+
|
188 |
+
|
189 |
+
def get_host() -> str:
|
190 |
+
return gethostname()
|
191 |
+
|
192 |
+
|
193 |
+
def get_branch_name():
|
194 |
+
try:
|
195 |
+
git_dir = os.path.dirname(os.path.abspath(__file__))
|
196 |
+
args = ["git", "-C", git_dir, "rev-parse", "--abbrev-ref", "HEAD"]
|
197 |
+
branch = subprocess.check_output(args).strip().decode()
|
198 |
+
except subprocess.CalledProcessError:
|
199 |
+
# probably not in git repo
|
200 |
+
branch = None
|
201 |
+
return branch
|
202 |
+
|
203 |
+
|
204 |
+
# from pytorch/ignite:
|
205 |
+
def apply_to_tensor(input_, func):
|
206 |
+
"""Apply a function on a tensor or mapping, or sequence of tensors."""
|
207 |
+
if isinstance(input_, torch.nn.Module):
|
208 |
+
return [apply_to_tensor(c, func) for c in input_.children()]
|
209 |
+
elif isinstance(input_, torch.nn.Parameter):
|
210 |
+
return func(input_.data)
|
211 |
+
elif isinstance(input_, Tensor):
|
212 |
+
return func(input_)
|
213 |
+
elif isinstance(input_, str):
|
214 |
+
return input_
|
215 |
+
elif isinstance(input_, collections.Mapping):
|
216 |
+
return {k: apply_to_tensor(sample, func) for k, sample in input_.items()}
|
217 |
+
elif isinstance(input_, collections.Iterable):
|
218 |
+
return [apply_to_tensor(sample, func) for sample in input_]
|
219 |
+
elif input_ is None:
|
220 |
+
return input_
|
221 |
+
else:
|
222 |
+
return input_
|
223 |
+
|
224 |
+
|
225 |
+
def detach_hidden(hidden: Any) -> Any:
|
226 |
+
"""Cut backpropagation graph.
|
227 |
+
Auxillary function to cut the backpropagation graph by detaching the hidden
|
228 |
+
vector.
|
229 |
+
"""
|
230 |
+
return apply_to_tensor(hidden, Tensor.detach)
|
libdf/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .libdf import *
|
2 |
+
|
3 |
+
__doc__ = libdf.__doc__
|
libdf/__init__.pyi
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
from numpy import ndarray
|
4 |
+
|
5 |
+
class DF:
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
sr: int,
|
9 |
+
fft_size: int,
|
10 |
+
hop_size: int,
|
11 |
+
nb_bands: int,
|
12 |
+
min_nb_erb_freqs: Optional[int] = 1,
|
13 |
+
):
|
14 |
+
"""DeepFilter state used for analysis and synthesis.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
sr (int): Sampling rate.
|
18 |
+
fft_size (int): Window length used for the Fast Fourier transform.
|
19 |
+
hop_size (int): Hop size between two analysis windows. Also called frame size.
|
20 |
+
nb_bands (int): Number of ERB bands.
|
21 |
+
min_nb_erb_freqs (int): Minimum number of frequency bands per ERB band. Defaults to 1.
|
22 |
+
"""
|
23 |
+
...
|
24 |
+
def analysis(self, input: ndarray) -> ndarray:
|
25 |
+
"""Analysis of a time-domain signal.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
input (ndarray): 2D real-valued array of shape [C, T].
|
29 |
+
Output:
|
30 |
+
output (ndarray): 3D complex-valued array of shape [C, T', F], where F is the `fft_size`,
|
31 |
+
and T' the original time T divided by `hop_size`.
|
32 |
+
"""
|
33 |
+
...
|
34 |
+
def synthesis(self, input: ndarray) -> ndarray:
|
35 |
+
"""Synthesis of a frequency-domain signal.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
input (ndarray): 3D complex-valued array of shape [C, T, F].
|
39 |
+
Output:
|
40 |
+
output (ndarray): 2D real-valued array of shape [C, T].
|
41 |
+
"""
|
42 |
+
...
|
43 |
+
def erb_widths(self) -> ndarray: ...
|
44 |
+
def fft_window(self) -> ndarray: ...
|
45 |
+
def sr(self) -> int: ...
|
46 |
+
def fft_size(self) -> int: ...
|
47 |
+
def hop_size(self) -> int: ...
|
48 |
+
def nb_erb(self) -> int: ...
|
49 |
+
def reset(self) -> None: ...
|
50 |
+
|
51 |
+
def erb(
|
52 |
+
input: ndarray, erb_fb: Union[ndarray, List[int]], db: Optional[bool] = None
|
53 |
+
) -> ndarray: ...
|
54 |
+
def erb_inv(input: ndarray, erb_fb: Union[ndarray, List[int]]) -> ndarray: ...
|
55 |
+
def erb_norm(erb: ndarray, alpha: float, state: Optional[ndarray] = None) -> ndarray: ...
|
56 |
+
def unit_norm(spec: ndarray, alpha: float, state: Optional[ndarray] = None) -> ndarray: ...
|
57 |
+
def unit_norm_init(num_freq_bins: int) -> ndarray: ...
|
libdf/py.typed
ADDED
File without changes
|
model_weights/voice_enhance/checkpoints/model_96.ckpt.best
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb5eccb429e675bb4ec5ec9e280f048bfff9787b40bd3eb835fd11509eb14a3e
|
3 |
+
size 9397209
|
model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
|
3 |
+
size 17090379
|
model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7dfd48d0da24db35ee4a653d0d36a4104cb26873050a5c3584675eee21937621
|
3 |
+
size 69
|
model_weights/voiceover/freevc-24.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:872360b61e6bbe09bec29810e7ad0d16318e379f6195a7ff3b06e50efb08ad31
|
3 |
+
size 1264
|
model_weights/voiceover/freevc-24.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
|
3 |
+
size 472644351
|
model_weights/wavlm_models/WavLM-Large.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
|
3 |
+
size 1261965425
|
model_weights/wavlm_models/WavLM-Large.pt.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9836bca8ab0e9d0b4797aa78f41b367800d26cfd25ade7b1edcb35bc3c171e4
|
3 |
+
size 52
|
nnet/__init__.py
ADDED
File without changes
|
nnet/attentions.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from nnet import commons
|
7 |
+
from nnet.modules import LayerNorm
|
8 |
+
|
9 |
+
|
10 |
+
class Encoder(nn.Module):
|
11 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
|
12 |
+
super().__init__()
|
13 |
+
self.hidden_channels = hidden_channels
|
14 |
+
self.filter_channels = filter_channels
|
15 |
+
self.n_heads = n_heads
|
16 |
+
self.n_layers = n_layers
|
17 |
+
self.kernel_size = kernel_size
|
18 |
+
self.p_dropout = p_dropout
|
19 |
+
self.window_size = window_size
|
20 |
+
|
21 |
+
self.drop = nn.Dropout(p_dropout)
|
22 |
+
self.attn_layers = nn.ModuleList()
|
23 |
+
self.norm_layers_1 = nn.ModuleList()
|
24 |
+
self.ffn_layers = nn.ModuleList()
|
25 |
+
self.norm_layers_2 = nn.ModuleList()
|
26 |
+
for i in range(self.n_layers):
|
27 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
28 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
29 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
30 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
31 |
+
|
32 |
+
def forward(self, x, x_mask):
|
33 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
34 |
+
x = x * x_mask
|
35 |
+
for i in range(self.n_layers):
|
36 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
37 |
+
y = self.drop(y)
|
38 |
+
x = self.norm_layers_1[i](x + y)
|
39 |
+
|
40 |
+
y = self.ffn_layers[i](x, x_mask)
|
41 |
+
y = self.drop(y)
|
42 |
+
x = self.norm_layers_2[i](x + y)
|
43 |
+
x = x * x_mask
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class Decoder(nn.Module):
|
48 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
|
49 |
+
super().__init__()
|
50 |
+
self.hidden_channels = hidden_channels
|
51 |
+
self.filter_channels = filter_channels
|
52 |
+
self.n_heads = n_heads
|
53 |
+
self.n_layers = n_layers
|
54 |
+
self.kernel_size = kernel_size
|
55 |
+
self.p_dropout = p_dropout
|
56 |
+
self.proximal_bias = proximal_bias
|
57 |
+
self.proximal_init = proximal_init
|
58 |
+
|
59 |
+
self.drop = nn.Dropout(p_dropout)
|
60 |
+
self.self_attn_layers = nn.ModuleList()
|
61 |
+
self.norm_layers_0 = nn.ModuleList()
|
62 |
+
self.encdec_attn_layers = nn.ModuleList()
|
63 |
+
self.norm_layers_1 = nn.ModuleList()
|
64 |
+
self.ffn_layers = nn.ModuleList()
|
65 |
+
self.norm_layers_2 = nn.ModuleList()
|
66 |
+
for i in range(self.n_layers):
|
67 |
+
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
|
68 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
69 |
+
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
70 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
71 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
|
72 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
73 |
+
|
74 |
+
def forward(self, x, x_mask, h, h_mask):
|
75 |
+
"""
|
76 |
+
x: decoder input
|
77 |
+
h: encoder output
|
78 |
+
"""
|
79 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
80 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
81 |
+
x = x * x_mask
|
82 |
+
for i in range(self.n_layers):
|
83 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
84 |
+
y = self.drop(y)
|
85 |
+
x = self.norm_layers_0[i](x + y)
|
86 |
+
|
87 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
88 |
+
y = self.drop(y)
|
89 |
+
x = self.norm_layers_1[i](x + y)
|
90 |
+
|
91 |
+
y = self.ffn_layers[i](x, x_mask)
|
92 |
+
y = self.drop(y)
|
93 |
+
x = self.norm_layers_2[i](x + y)
|
94 |
+
x = x * x_mask
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class MultiHeadAttention(nn.Module):
|
99 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
100 |
+
super().__init__()
|
101 |
+
assert channels % n_heads == 0
|
102 |
+
|
103 |
+
self.channels = channels
|
104 |
+
self.out_channels = out_channels
|
105 |
+
self.n_heads = n_heads
|
106 |
+
self.p_dropout = p_dropout
|
107 |
+
self.window_size = window_size
|
108 |
+
self.heads_share = heads_share
|
109 |
+
self.block_length = block_length
|
110 |
+
self.proximal_bias = proximal_bias
|
111 |
+
self.proximal_init = proximal_init
|
112 |
+
self.attn = None
|
113 |
+
|
114 |
+
self.k_channels = channels // n_heads
|
115 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
116 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
117 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
118 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
119 |
+
self.drop = nn.Dropout(p_dropout)
|
120 |
+
|
121 |
+
if window_size is not None:
|
122 |
+
n_heads_rel = 1 if heads_share else n_heads
|
123 |
+
rel_stddev = self.k_channels**-0.5
|
124 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
125 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
126 |
+
|
127 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
128 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
129 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
130 |
+
if proximal_init:
|
131 |
+
with torch.no_grad():
|
132 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
133 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
134 |
+
|
135 |
+
def forward(self, x, c, attn_mask=None):
|
136 |
+
q = self.conv_q(x)
|
137 |
+
k = self.conv_k(c)
|
138 |
+
v = self.conv_v(c)
|
139 |
+
|
140 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
141 |
+
|
142 |
+
x = self.conv_o(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
def attention(self, query, key, value, mask=None):
|
146 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
147 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
148 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
149 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
150 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
151 |
+
|
152 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
153 |
+
if self.window_size is not None:
|
154 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
155 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
156 |
+
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
|
157 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
158 |
+
scores = scores + scores_local
|
159 |
+
if self.proximal_bias:
|
160 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
161 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
162 |
+
if mask is not None:
|
163 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
164 |
+
if self.block_length is not None:
|
165 |
+
assert t_s == t_t, "Local attention is only available for self-attention."
|
166 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
167 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
168 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
169 |
+
p_attn = self.drop(p_attn)
|
170 |
+
output = torch.matmul(p_attn, value)
|
171 |
+
if self.window_size is not None:
|
172 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
173 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
174 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
175 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
176 |
+
return output, p_attn
|
177 |
+
|
178 |
+
def _matmul_with_relative_values(self, x, y):
|
179 |
+
"""
|
180 |
+
x: [b, h, l, m]
|
181 |
+
y: [h or 1, m, d]
|
182 |
+
ret: [b, h, l, d]
|
183 |
+
"""
|
184 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
185 |
+
return ret
|
186 |
+
|
187 |
+
def _matmul_with_relative_keys(self, x, y):
|
188 |
+
"""
|
189 |
+
x: [b, h, l, d]
|
190 |
+
y: [h or 1, m, d]
|
191 |
+
ret: [b, h, l, m]
|
192 |
+
"""
|
193 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
194 |
+
return ret
|
195 |
+
|
196 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
197 |
+
max_relative_position = 2 * self.window_size + 1
|
198 |
+
# Pad first before slice to avoid using cond ops.
|
199 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
200 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
201 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
202 |
+
if pad_length > 0:
|
203 |
+
padded_relative_embeddings = F.pad(
|
204 |
+
relative_embeddings,
|
205 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
206 |
+
else:
|
207 |
+
padded_relative_embeddings = relative_embeddings
|
208 |
+
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
|
209 |
+
return used_relative_embeddings
|
210 |
+
|
211 |
+
def _relative_position_to_absolute_position(self, x):
|
212 |
+
"""
|
213 |
+
x: [b, h, l, 2*l-1]
|
214 |
+
ret: [b, h, l, l]
|
215 |
+
"""
|
216 |
+
batch, heads, length, _ = x.size()
|
217 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
218 |
+
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
219 |
+
|
220 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
221 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
222 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
223 |
+
|
224 |
+
# Reshape and slice out the padded elements.
|
225 |
+
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
226 |
+
return x_final
|
227 |
+
|
228 |
+
def _absolute_position_to_relative_position(self, x):
|
229 |
+
"""
|
230 |
+
x: [b, h, l, l]
|
231 |
+
ret: [b, h, l, 2*l-1]
|
232 |
+
"""
|
233 |
+
batch, heads, length, _ = x.size()
|
234 |
+
# padd along column
|
235 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
236 |
+
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
|
237 |
+
# add 0's in the beginning that will skew the elements after reshape
|
238 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
239 |
+
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
240 |
+
return x_final
|
241 |
+
|
242 |
+
def _attention_bias_proximal(self, length):
|
243 |
+
"""Bias for self-attention to encourage attention to close positions.
|
244 |
+
Args:
|
245 |
+
length: an integer scalar.
|
246 |
+
Returns:
|
247 |
+
a Tensor with shape [1, 1, length, length]
|
248 |
+
"""
|
249 |
+
r = torch.arange(length, dtype=torch.float32)
|
250 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
251 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
252 |
+
|
253 |
+
|
254 |
+
class FFN(nn.Module):
|
255 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
|
256 |
+
super().__init__()
|
257 |
+
self.in_channels = in_channels
|
258 |
+
self.out_channels = out_channels
|
259 |
+
self.filter_channels = filter_channels
|
260 |
+
self.kernel_size = kernel_size
|
261 |
+
self.p_dropout = p_dropout
|
262 |
+
self.activation = activation
|
263 |
+
self.causal = causal
|
264 |
+
|
265 |
+
if causal:
|
266 |
+
self.padding = self._causal_padding
|
267 |
+
else:
|
268 |
+
self.padding = self._same_padding
|
269 |
+
|
270 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
271 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
272 |
+
self.drop = nn.Dropout(p_dropout)
|
273 |
+
|
274 |
+
def forward(self, x, x_mask):
|
275 |
+
x = self.conv_1(self.padding(x * x_mask))
|
276 |
+
if self.activation == "gelu":
|
277 |
+
x = x * torch.sigmoid(1.702 * x)
|
278 |
+
else:
|
279 |
+
x = torch.relu(x)
|
280 |
+
x = self.drop(x)
|
281 |
+
x = self.conv_2(self.padding(x * x_mask))
|
282 |
+
return x * x_mask
|
283 |
+
|
284 |
+
def _causal_padding(self, x):
|
285 |
+
if self.kernel_size == 1:
|
286 |
+
return x
|
287 |
+
pad_l = self.kernel_size - 1
|
288 |
+
pad_r = 0
|
289 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
290 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
291 |
+
return x
|
292 |
+
|
293 |
+
def _same_padding(self, x):
|
294 |
+
if self.kernel_size == 1:
|
295 |
+
return x
|
296 |
+
pad_l = (self.kernel_size - 1) // 2
|
297 |
+
pad_r = self.kernel_size // 2
|
298 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
299 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
300 |
+
return x
|