Spaces:
Sleeping
Sleeping
Hugo Flores
commited on
Commit
•
50f034f
0
Parent(s):
first commit
Browse files- .dockerignore +2 -0
- .gitattributes +4 -0
- .gitignore +173 -0
- .pre-commit-config.yaml +15 -0
- Dockerfile +36 -0
- README.md +287 -0
- changelog.md +11 -0
- conf/vampnet-c2f.yml +9 -0
- conf/vampnet.yml +56 -0
- docker-compose.yml +90 -0
- env/alias.sh +3 -0
- env/data.sh +36 -0
- env/entry_script.sh +41 -0
- env/setup.py +123 -0
- requirements.txt +29 -0
- scripts/generative/eval.py +124 -0
- scripts/generative/train.py +662 -0
- scripts/utils/README.md +28 -0
- scripts/utils/stage.py +30 -0
- setup.py +40 -0
- vampnet/__init__.py +6 -0
- vampnet/enchilada.py +179 -0
- vampnet/modules/__init__.py +4 -0
- vampnet/modules/base.py +461 -0
- vampnet/modules/modules.py +168 -0
- vampnet/modules/transformer.py +606 -0
- vampnet/scheduler.py +47 -0
.dockerignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.wav
|
2 |
+
runs/
|
.gitattributes
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
2 |
+
tests/assets/cpu_test_file.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
tests/assets/quick.pth filter=lfs diff=lfs merge=lfs -text
|
4 |
+
tests/assets/slow.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/env.sh
|
108 |
+
venv/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
# Files created by experiments
|
131 |
+
output/
|
132 |
+
snapshot/
|
133 |
+
*.m4a
|
134 |
+
*.wav
|
135 |
+
notebooks/scratch.ipynb
|
136 |
+
notebooks/inspect.ipynb
|
137 |
+
notebooks/effects.ipynb
|
138 |
+
notebooks/*.ipynb
|
139 |
+
notebooks/*.gif
|
140 |
+
notebooks/*.wav
|
141 |
+
notebooks/*.mp4
|
142 |
+
*runs/
|
143 |
+
boards/
|
144 |
+
samples/
|
145 |
+
*.ipynb
|
146 |
+
|
147 |
+
results.json
|
148 |
+
metrics.csv
|
149 |
+
mprofile_*
|
150 |
+
mem.png
|
151 |
+
|
152 |
+
results/
|
153 |
+
mprofile*
|
154 |
+
*.png
|
155 |
+
# do not ignore the test wav file
|
156 |
+
!tests/audio/short_test_audio.wav
|
157 |
+
!tests/audio/output.wav
|
158 |
+
*/.DS_Store
|
159 |
+
.DS_Store
|
160 |
+
env.sh
|
161 |
+
_codebraid/
|
162 |
+
**/*.html
|
163 |
+
**/*.exec.md
|
164 |
+
flagged/
|
165 |
+
log.txt
|
166 |
+
ckpt/
|
167 |
+
.syncthing*
|
168 |
+
tests/assets/
|
169 |
+
archived/
|
170 |
+
|
171 |
+
scratch/
|
172 |
+
|
173 |
+
runs-archive
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/asottile/reorder_python_imports
|
3 |
+
rev: v2.5.0
|
4 |
+
hooks:
|
5 |
+
- id: reorder-python-imports
|
6 |
+
- repo: https://github.com/psf/black
|
7 |
+
rev: 23.1.0
|
8 |
+
hooks:
|
9 |
+
- id: black
|
10 |
+
language_version: python3
|
11 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
12 |
+
rev: v4.0.1
|
13 |
+
hooks:
|
14 |
+
- id: end-of-file-fixer
|
15 |
+
- id: trailing-whitespace
|
Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM us.gcr.io/lyrebird-research/research-image/audio:beta
|
2 |
+
|
3 |
+
COPY requirements.txt requirements.txt
|
4 |
+
ARG GITHUB_TOKEN
|
5 |
+
RUN echo machine github.com login ${GITHUB_TOKEN} > ~/.netrc
|
6 |
+
|
7 |
+
COPY env/alias.sh /alias.sh
|
8 |
+
COPY env/entry_script.sh /entry_script.sh
|
9 |
+
RUN cat /alias.sh >> ~/.zshrc
|
10 |
+
|
11 |
+
# USER researcher
|
12 |
+
RUN pip install --upgrade -r requirements.txt
|
13 |
+
RUN pip install --upgrade tensorflow
|
14 |
+
RUN pip install --upgrade librosa
|
15 |
+
RUN pip install --upgrade numba
|
16 |
+
ENV PYTHONPATH "$PYTHONPATH:/u/home/src"
|
17 |
+
ENV NUMBA_CACHE_DIR=/tmp/
|
18 |
+
|
19 |
+
USER root
|
20 |
+
RUN wget https://github.com/jgm/pandoc/releases/download/2.18/pandoc-2.18-1-amd64.deb
|
21 |
+
RUN dpkg -i pandoc-2.18-1-amd64.deb
|
22 |
+
RUN apt-get update && apt-get install task-spooler
|
23 |
+
|
24 |
+
RUN head -n -1 /entry_script.sh > /entry_script_jupyter.sh
|
25 |
+
RUN head -n -1 /entry_script.sh > /entry_script_tensorboard.sh
|
26 |
+
RUN head -n -1 /entry_script.sh > /entry_script_gradio.sh
|
27 |
+
|
28 |
+
RUN echo \
|
29 |
+
'su -p ${USER} -c "source ~/.zshrc && jupyter lab --ip=0.0.0.0"' >> \
|
30 |
+
/entry_script_jupyter.sh
|
31 |
+
RUN echo \
|
32 |
+
'su -p ${USER} -c "source ~/.zshrc && tensorboard --logdir=$TENSORBOARD_PATH --samples_per_plugin audio=500 --bind_all"' >> \
|
33 |
+
/entry_script_tensorboard.sh
|
34 |
+
RUN echo \
|
35 |
+
'su -p ${USER} -c "source ~/.zshrc && python app.py --args.load=conf/app.yml"' >> \
|
36 |
+
/entry_script_gradio.sh
|
README.md
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Lyrebird Wav2Wav
|
2 |
+
|
3 |
+
This repository contains recipes for training Wav2Wav models.
|
4 |
+
|
5 |
+
## Install hooks
|
6 |
+
|
7 |
+
First install the pre-commit util:
|
8 |
+
|
9 |
+
https://pre-commit.com/#install
|
10 |
+
|
11 |
+
pip install pre-commit # with pip
|
12 |
+
brew install pre-commit # on Mac
|
13 |
+
|
14 |
+
Then install the git hooks
|
15 |
+
|
16 |
+
pre-commit install
|
17 |
+
# check .pre-commit-config.yaml for details of hooks
|
18 |
+
|
19 |
+
Upon `git commit`, the pre-commit hooks will be run automatically on the stage files (i.e. added by `git add`)
|
20 |
+
|
21 |
+
**N.B. By default, pre-commit checks only run on staged files**
|
22 |
+
|
23 |
+
If you need to run it on all files:
|
24 |
+
|
25 |
+
pre-commit run --all-files
|
26 |
+
|
27 |
+
## Usage & model zoo
|
28 |
+
|
29 |
+
To download the model, one must be authenticated to the `lyrebird-research` project on Google Cloud.
|
30 |
+
To see all available models, run
|
31 |
+
|
32 |
+
```bash
|
33 |
+
python -m wav2wav.list_models
|
34 |
+
```
|
35 |
+
|
36 |
+
which outputs something like this:
|
37 |
+
|
38 |
+
```
|
39 |
+
gs://research-models/wav2wav
|
40 |
+
└── prod
|
41 |
+
└── v3
|
42 |
+
└── ckpt
|
43 |
+
├── best
|
44 |
+
│ └── generator
|
45 |
+
│ ├── ❌ model.onnx
|
46 |
+
│ ├── ❌ nvidia_geforce_rtx_2080_ti_11_7.trt
|
47 |
+
│ ├── ✅ package.pth
|
48 |
+
│ ├── ❌ tesla_t4_11_7.trt
|
49 |
+
│ └── ✅ weights.pth
|
50 |
+
└── latest
|
51 |
+
└── generator
|
52 |
+
├── ❌ package.pth
|
53 |
+
└── ❌ weights.pth
|
54 |
+
└── v2
|
55 |
+
...
|
56 |
+
└── dev
|
57 |
+
...
|
58 |
+
```
|
59 |
+
|
60 |
+
This will show all the models that are available on GCP. Models that are available locally are marked with a ✅, while those not available locally
|
61 |
+
are marked with ❌. `.onnx` indicates a model that must be run with
|
62 |
+
the `ONNX` runtime, while `.trt` indicate models that have been optimized
|
63 |
+
with TensorRT. Note that TensorRT models are specific to GPU and CUDA
|
64 |
+
runtime, and their file names indicate what to use to run them.
|
65 |
+
|
66 |
+
`package.pth` is a version of the model that is saved using `torch.package`,
|
67 |
+
and contains a copy of the model code within it, which allow it to work
|
68 |
+
even if the model code in `wav2wav/modules/generator.py` changes. `weights.pth`
|
69 |
+
contains the model weights, and the code must match the code used
|
70 |
+
to create the model.
|
71 |
+
|
72 |
+
To use a model from this list, simply write its path and give it to the `enhance` script,
|
73 |
+
like so:
|
74 |
+
|
75 |
+
```
|
76 |
+
python -m wav2wav.interface \
|
77 |
+
[input_path]
|
78 |
+
--model_path=prod/v3/ckpt/best/generator/weights.pth
|
79 |
+
--output_path [output_path]
|
80 |
+
```
|
81 |
+
|
82 |
+
Models are downloaded to the location set by the environment variable `MODEL_LOCAL_PATH`, and defaults to `~/.wav2wav/models`. Similarly,
|
83 |
+
The model bucket is determined by `MODEL_GCS_PATH` and defaults to
|
84 |
+
`gs://research-models/wav2wav/`.
|
85 |
+
|
86 |
+
## Development
|
87 |
+
### Setting everything up
|
88 |
+
|
89 |
+
Run the setup script to set up your environment via:
|
90 |
+
|
91 |
+
```bash
|
92 |
+
python env/setup.py
|
93 |
+
```
|
94 |
+
|
95 |
+
The setup script does not require any dependencies beyond just Python.
|
96 |
+
Once run, follow the instructions it prints out to create your
|
97 |
+
environment file, which will be at `env/env.sh`.
|
98 |
+
|
99 |
+
Note that if this is a new machine, and
|
100 |
+
the data is not downloaded somewhere on it already, it will ask you
|
101 |
+
for a directory to download the data to.
|
102 |
+
|
103 |
+
For Github setup, if you don't have a .netrc token, create one by going to your Github profile -> Developer settings -> Personal access tokens -> Generate new token. Copy the token and [keep it secret, keep it safe](https://www.youtube.com/watch?v=iThtELZvfPs).
|
104 |
+
|
105 |
+
When complete, run:
|
106 |
+
|
107 |
+
```bash
|
108 |
+
source env/env.sh
|
109 |
+
```
|
110 |
+
|
111 |
+
Now build and launch the Docker containers:
|
112 |
+
|
113 |
+
```bash
|
114 |
+
docker compose up -d
|
115 |
+
```
|
116 |
+
|
117 |
+
This builds and runs a Jupyter notebook and Tensorboard
|
118 |
+
in the background, which points to your `TENSORBOARD_PATH`
|
119 |
+
env. variable.
|
120 |
+
|
121 |
+
Now, launch your development environment via:
|
122 |
+
|
123 |
+
```bash
|
124 |
+
docker compose run dev
|
125 |
+
```
|
126 |
+
|
127 |
+
To tear down your development environment, just do
|
128 |
+
|
129 |
+
```bash
|
130 |
+
docker compose down
|
131 |
+
```
|
132 |
+
|
133 |
+
### Downloading data and pre-processing
|
134 |
+
Next, from within the Docker environment (or an appropriately configured Conda environment with environment variables set as above), do the following:
|
135 |
+
|
136 |
+
```
|
137 |
+
python -m wav2wav.preprocess.download
|
138 |
+
```
|
139 |
+
|
140 |
+
This will download all the necessary data, which are referenced by
|
141 |
+
the CSV files in `conf/audio/*`. These CSVs were generated via
|
142 |
+
`python -m wav2wav.preprocess.organize`.
|
143 |
+
|
144 |
+
### Launching an experiment
|
145 |
+
|
146 |
+
Experiments are first _staged_ by running the `stage` command (which corresponds to the script `scripts/exp/stage.py`).
|
147 |
+
|
148 |
+
`stage` creates a directory with a copy of all of the Git-tracked files in the root repository.`stage` launches a shell into said directory, so all commands are run on the
|
149 |
+
copy of the original repository code. This is useful for rewinding to an old experiment
|
150 |
+
and resuming it, for example. Even if the repository code changes, the snapshot in the experiment directory is unchanged from the original run, so it can be re-used.
|
151 |
+
|
152 |
+
Then, the experiment can be run via:
|
153 |
+
|
154 |
+
```bash
|
155 |
+
torchrun --nproc_per_node gpu \
|
156 |
+
scripts/exp/train.py \
|
157 |
+
--args.load=conf/args.yml \
|
158 |
+
```
|
159 |
+
|
160 |
+
The full settings are in [conf/daps/train.yml](conf/daps/train.yml).
|
161 |
+
|
162 |
+
### Evaluating an experiment
|
163 |
+
|
164 |
+
There are two ways to evaluate an experiment: quantitative and qualitative.
|
165 |
+
For the first, we can use the `scripts/exp/evaluate.py` script. This script evaluates the model over the `val_data` and `test_data`, defined in your
|
166 |
+
`train` script, and takes as input an experiment directory. The metrics
|
167 |
+
computed by this script are saved to the same folder.
|
168 |
+
|
169 |
+
The other way is via a preference test. Let's say we want to compare
|
170 |
+
the v3 prod model against the v2 prod model. to do this, we use the
|
171 |
+
`scripts/exp/qa.py` script. This script creates a zip file containing all
|
172 |
+
the samples and an HTML page for easy viewing. It also creates a Papaya
|
173 |
+
preference test. Use it like this:
|
174 |
+
|
175 |
+
```bash
|
176 |
+
WAV2WAV_MODELS=a,b python scripts/exp/qa.py \
|
177 |
+
--a/model_path prod/v3/ckpt/best/generator/package.pth \
|
178 |
+
--b/model_path prod/v2/ckpt/best/generator/package.pth \
|
179 |
+
--a/name "v3" --b/name "v2" \
|
180 |
+
--device cuda:0 \
|
181 |
+
--n_samples 20 \
|
182 |
+
--zip_path "samples/out.zip"
|
183 |
+
```
|
184 |
+
|
185 |
+
### Useful commands
|
186 |
+
|
187 |
+
#### Monitoring the machine
|
188 |
+
|
189 |
+
There's a useful `tmux` workspace that you can launch via:
|
190 |
+
|
191 |
+
```bash
|
192 |
+
tmuxp load ./workspace.yml
|
193 |
+
```
|
194 |
+
|
195 |
+
which will have a split pane with a shell to launch commands on the left,
|
196 |
+
and GPU monitoring, `htop`, and a script that watches for changes in your
|
197 |
+
directory on the right, in three split panes.
|
198 |
+
|
199 |
+
#### Cleaning up after a run
|
200 |
+
|
201 |
+
Sometimes DDP runs fail to clear themselves out of the machine. To fix this, run
|
202 |
+
|
203 |
+
```bash
|
204 |
+
cleanup
|
205 |
+
```
|
206 |
+
|
207 |
+
### Deploying a new model to production
|
208 |
+
|
209 |
+
Okay, so you ran a model and it seems promising and you want to upload it
|
210 |
+
to GCS so it can be QA'd fully, and then shipped. First, upload
|
211 |
+
your experiment to the `dev` bucket on GCS via:
|
212 |
+
|
213 |
+
```bash
|
214 |
+
gsutil cp -r /path/to/{exp_name} gs://research-models/wav2wav/dev/{exp_name}
|
215 |
+
```
|
216 |
+
|
217 |
+
Once uploaded, QA can access the models by specifying
|
218 |
+
`model_path=dev/{exp_name}/ckpt/{best,latest}/generator/package.pth` when using the
|
219 |
+
`wav2wav.interface.enhance` function. If it passes QA, and is scheduled to
|
220 |
+
ship to production, then next we have to generate the TensorRT model file,
|
221 |
+
which requires us to have a machine that matches that of a production machine.
|
222 |
+
|
223 |
+
There is a script that automates this procedure, that does not require any
|
224 |
+
fiddling from our end. Navigate to the repository root and run:
|
225 |
+
|
226 |
+
```
|
227 |
+
python scripts/utils/convert_on_gcp.py dev/{exp_name}/ckpt/{best,latest}//generator/weights.pth
|
228 |
+
```
|
229 |
+
|
230 |
+
This will provision the machine, download the relevant model from GCS, optimize it on
|
231 |
+
the production GPU with the correct CUDA runtime, and then upload the generated `.trt`
|
232 |
+
and `.onnx` models back to the bucket.
|
233 |
+
|
234 |
+
Finally, copy the model to the `prod` bucket, incrementing the version number by one:
|
235 |
+
|
236 |
+
```bash
|
237 |
+
gsutil cp -r gs://research-models/wav2wav/dev/{exp_name} gs://research-models/wav2wav/prod/v{N}
|
238 |
+
```
|
239 |
+
|
240 |
+
where `N` is the next version (e.g. if v3 is the latest, the new one is v4). Then, update
|
241 |
+
the model table in [Notion](https://www.notion.so/descript/fc04de4b46e6417eba1d06bdc8de6c75?v=e56db4e6b37c4d9b9eca8d9be15c826a) with the new model.
|
242 |
+
|
243 |
+
Once the above is all done, we update the code in two places:
|
244 |
+
|
245 |
+
1. In `interface.py`, we update `PROD_MODEL_PATH` to point to the `weights.pth`
|
246 |
+
for whichever tag ended up shipping (either `best` or `latest`).
|
247 |
+
2. In `interface.py`, we update `PROD_TRT_PATH` to point the generated
|
248 |
+
TensorRT checkpoint generated by the script above.
|
249 |
+
|
250 |
+
After merging to master, a new Docker image will be created, and one can update the relevant lines
|
251 |
+
in descript-workflows like in this [PR](https://github.com/descriptinc/descript-workflows/pull/477/files).
|
252 |
+
|
253 |
+
We have Github action workflows in [.github/workflows/deploy.yml](.github/workflows/deploy.yml) to build and deploy new docker images. Two images are built - one for staging and another for production.
|
254 |
+
To deploy a new release version, follow the instructions in [this coda doc](https://coda.io/d/Research-Engineering_dOABAWL46p-/Deploying-Services_su1am#_lu7E8).
|
255 |
+
|
256 |
+
Coda doc with informations about deploying speech-enhance worker is [here](https://coda.io/d/Research-Engineering_dOABAWL46p-/Deploying-Services_su1am#_lu7E8).
|
257 |
+
|
258 |
+
And that's it! Once the new staging is built, you're done.
|
259 |
+
|
260 |
+
## Testing
|
261 |
+
|
262 |
+
### Profiling and Regression testing
|
263 |
+
|
264 |
+
- The [profiling script](tests/profile_inference.py) profiles the `wav2wav.interface.enhance` function.
|
265 |
+
- NOTE: ALWAYS run the profiler on a T4 GPU. ALWAYS run the profiling in isolation i.e kill all other processes on the GPU. Recommended vm size on GCP is `n1-standard-32` as the stress test of six hours of audio requires ~35GB of system memory.
|
266 |
+
- To run profiling use the [profiling script](tests/profile_inference.py) via command `python3 -m tests.profile_inference`. Results will be printed after `1` run.
|
267 |
+
- Use the [test_regression.py](tests/test_regression.py) script to run tests that
|
268 |
+
- compare performance stats of current model with known best model
|
269 |
+
- test for output deviation from the last model
|
270 |
+
- Run `git lfs checkout` to checkout input file and model weights required for testing the model.
|
271 |
+
- To launch these tests, run `python3 -m pytest tests/test_regression.py -v`.
|
272 |
+
- As a side effect, this will update the `tests/stat.csv` file if the current model performs better than last best known model as per `tests/stat.csv`.
|
273 |
+
- NOTE: In case of architecture change, purge the weights files : `tests/assets/{quick|slow}.pth` and reference stat file : `tests/assets/baseline.json` file. Running the [test_regression.py](tests/test_regression.py) script in absence of reference stat file, will generate new baseline referece stats as well as append new performance stats to stats file. In the absence of saved weights, new weights are generated and saved on disk. Make sure to commit these files (stat.csv, baseline.json, *.pth) when the model architecture changes.
|
274 |
+
|
275 |
+
### Unit tests
|
276 |
+
Regular unit tests that test functionality such as training resume etc. These are run on CPU. Update them when new features are added.
|
277 |
+
|
278 |
+
### Profiling tests
|
279 |
+
These tests profile the model's resource consumption. They are run on T4 GPU with 32 cores and >35GB memory. Their usage is reported in the above sections.
|
280 |
+
|
281 |
+
### Functional tests
|
282 |
+
These tests detect deviation from known baseline model. A category of these tests ensure that a new pytorch model doesn't deviate from the previous one. Another category ensures that the TensorRT version of the current pytorch model doens't deviate from it. These tests are marked with the marker `output_qa` and can be run via the command line `python3 -m pytest -v -m output_qa`. Some of these tests require a GPU.
|
283 |
+
|
284 |
+
### CI tests
|
285 |
+
- The tests are divided into two categories depending on the platform requirement - CPU tests and GPU tests.
|
286 |
+
- The CPU tests contains unit tests.
|
287 |
+
- The GPU tests contain a subset of functional tests. These tests can be run by the command `python3 -m pytest -v -m gpu_ci_test`.
|
changelog.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v0.3.4
|
2 |
+
- Upgraded nussl, removing unneeded code from interface, and improving interaction.
|
3 |
+
|
4 |
+
# v0.3.1
|
5 |
+
- Rich interaction with the model.
|
6 |
+
|
7 |
+
# v0.2.0
|
8 |
+
- Release with script to process zipfiles for QA + gsutil sync for model weights.
|
9 |
+
|
10 |
+
# v0.1.0
|
11 |
+
- Initial release.
|
conf/vampnet-c2f.yml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
VampNet.n_codebooks: 9
|
5 |
+
VampNet.n_conditioning_codebooks: 3
|
6 |
+
|
7 |
+
train/AudioDataset.duration: 3
|
8 |
+
val/AudioDataset.duration: 3
|
9 |
+
test/AudioDataset.duration: 3
|
conf/vampnet.yml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
wav2wav_ckpt: /u/home/src/runs/codec-ckpt/codec.pth
|
3 |
+
save_path: ckpt
|
4 |
+
max_epochs: 1000000
|
5 |
+
epoch_length: 1000
|
6 |
+
save_audio_epochs: 2
|
7 |
+
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
+
|
9 |
+
prefix_amt: 0.0
|
10 |
+
suffix_amt: 0.0
|
11 |
+
prefix_dropout: 0.1
|
12 |
+
suffix_dropout: 0.1
|
13 |
+
|
14 |
+
batch_size: 120
|
15 |
+
num_workers: 80
|
16 |
+
|
17 |
+
# Optimization
|
18 |
+
detect_anomaly: false
|
19 |
+
amp: false
|
20 |
+
|
21 |
+
CrossEntropyLoss.label_smoothing: 0.1
|
22 |
+
|
23 |
+
AdamW.lr: 0.001
|
24 |
+
|
25 |
+
NoamScheduler.factor: 2.0
|
26 |
+
NoamScheduler.warmup: 10000
|
27 |
+
|
28 |
+
VampNet.vocab_size: 1024
|
29 |
+
VampNet.n_codebooks: 3
|
30 |
+
VampNet.n_conditioning_codebooks: 0
|
31 |
+
VampNet.r_cond_dim: 64
|
32 |
+
VampNet.embedding_dim: 1280
|
33 |
+
VampNet.n_layers: 16
|
34 |
+
VampNet.n_heads: 20
|
35 |
+
VampNet.flash_attn: false
|
36 |
+
VampNet.dropout: 0.05
|
37 |
+
|
38 |
+
AudioLoader.relative_path: /data/
|
39 |
+
AudioDataset.loudness_cutoff: -30.0
|
40 |
+
AudioDataset.without_replacement: true
|
41 |
+
AudioLoader.shuffle: true
|
42 |
+
|
43 |
+
train/AudioDataset.duration: 5.0
|
44 |
+
train/AudioDataset.n_examples: 10000000
|
45 |
+
train/AudioLoader.sources:
|
46 |
+
- /data/spotdl/audio/train
|
47 |
+
|
48 |
+
val/AudioDataset.duration: 5.0
|
49 |
+
val/AudioDataset.n_examples: 2000
|
50 |
+
val/AudioLoader.sources:
|
51 |
+
- /data/spotdl/audio/val
|
52 |
+
|
53 |
+
test/AudioDataset.duration: 5.0
|
54 |
+
test/AudioDataset.n_examples: 1000
|
55 |
+
test/AudioLoader.sources:
|
56 |
+
- /data/spotdl/audio/test
|
docker-compose.yml
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
version: "3.5"
|
3 |
+
services:
|
4 |
+
tensorrt:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: ./deployment_build/dockerfile
|
8 |
+
args:
|
9 |
+
GITHUB_TOKEN: ${GITHUB_TOKEN}
|
10 |
+
profiles:
|
11 |
+
- tensorrt
|
12 |
+
volumes:
|
13 |
+
- ./:/u/home/src
|
14 |
+
- ~/.config/gcloud:/root/.config/gcloud
|
15 |
+
deploy:
|
16 |
+
resources:
|
17 |
+
limits:
|
18 |
+
# match production limits
|
19 |
+
cpus: '7'
|
20 |
+
memory: 25000M
|
21 |
+
reservations:
|
22 |
+
devices:
|
23 |
+
- driver: nvidia
|
24 |
+
count: 1
|
25 |
+
capabilities: [gpu]
|
26 |
+
working_dir: /u/home/src
|
27 |
+
entrypoint:
|
28 |
+
- python
|
29 |
+
- -m
|
30 |
+
- wav2wav.converter
|
31 |
+
base:
|
32 |
+
build:
|
33 |
+
context: .
|
34 |
+
dockerfile: ./Dockerfile
|
35 |
+
args:
|
36 |
+
GITHUB_TOKEN: ${GITHUB_TOKEN}
|
37 |
+
volumes:
|
38 |
+
- .:/u/home/src
|
39 |
+
- ~/.wav2wav:/u/home/.wav2wav
|
40 |
+
- ${PATH_TO_DATA}:/data
|
41 |
+
- ${PATH_TO_RUNS}:/runs
|
42 |
+
- ~/.config/gcloud:/u/home/.config/gcloud
|
43 |
+
- ~/.zsh_history:/u/home/.zsh_history
|
44 |
+
environment:
|
45 |
+
- GITHUB_TOKEN
|
46 |
+
- DISCOURSE_API_USERNAME
|
47 |
+
- DISCOURSE_SERVER
|
48 |
+
- DISCOURSE_API_KEY
|
49 |
+
- HOST_USER_ID
|
50 |
+
- HOST_USER_GID
|
51 |
+
- JUPYTER_TOKEN
|
52 |
+
- PATH_TO_DATA=/data
|
53 |
+
- PATH_TO_RUNS=/runs
|
54 |
+
- TENSORBOARD_PATH
|
55 |
+
- MPLCONFIGDIR=/u/home/.mplconfig
|
56 |
+
shm_size: 32G
|
57 |
+
working_dir: /u/home/src
|
58 |
+
deploy:
|
59 |
+
resources:
|
60 |
+
reservations:
|
61 |
+
devices:
|
62 |
+
- driver: nvidia
|
63 |
+
capabilities: [gpu]
|
64 |
+
dev:
|
65 |
+
extends: base
|
66 |
+
profiles:
|
67 |
+
- interactive
|
68 |
+
stdin_open: true
|
69 |
+
tty: true
|
70 |
+
jupyter:
|
71 |
+
extends: base
|
72 |
+
ports:
|
73 |
+
- ${JUPYTER_PORT}:8888
|
74 |
+
entrypoint:
|
75 |
+
- /bin/bash
|
76 |
+
- /entry_script_jupyter.sh
|
77 |
+
tensorboard:
|
78 |
+
extends: base
|
79 |
+
ports:
|
80 |
+
- ${TENSORBOARD_PORT}:6006
|
81 |
+
entrypoint:
|
82 |
+
- /bin/bash
|
83 |
+
- /entry_script_tensorboard.sh
|
84 |
+
gradio:
|
85 |
+
extends: base
|
86 |
+
ports:
|
87 |
+
- 7860:7860
|
88 |
+
entrypoint:
|
89 |
+
- /bin/bash
|
90 |
+
- /entry_script_gradio.sh
|
env/alias.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
alias cleanup="pkill python && echo -en '\e[?25h'"
|
2 |
+
alias stage="python ./scripts/utils/stage.py"
|
3 |
+
alias fix_cursor="echo -en '\e[?25h'"
|
env/data.sh
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PATH_TO_DATA=~/data
|
2 |
+
|
3 |
+
if [[ $(hostname) == "oon17" ]]; then
|
4 |
+
export PATH_TO_DATA=/home/prem/shared/data/
|
5 |
+
fi
|
6 |
+
|
7 |
+
if [[ $(hostname) == "oon19" ]]; then
|
8 |
+
export PATH_TO_DATA=/home/prem/shared/data/
|
9 |
+
fi
|
10 |
+
|
11 |
+
if [[ $(hostname) == "lucas-ssound-trt-vm" ]]; then
|
12 |
+
export PATH_TO_DATA=~/data
|
13 |
+
fi
|
14 |
+
|
15 |
+
if [[ $(hostname) == "a100-ssound" ]]; then
|
16 |
+
export PATH_TO_DATA=~/data
|
17 |
+
fi
|
18 |
+
|
19 |
+
if [[ $(hostname) == "oon25" ]]; then
|
20 |
+
export PATH_TO_DATA=/data
|
21 |
+
fi
|
22 |
+
|
23 |
+
if [[ $(hostname) == "macbook-pro-2.lan" ]]; then
|
24 |
+
export PATH_TO_DATA=~/data
|
25 |
+
fi
|
26 |
+
|
27 |
+
if [[ $(hostname) == "oon11" ]]; then
|
28 |
+
export PATH_TO_DATA=/data2/syncthing_lucas/data
|
29 |
+
fi
|
30 |
+
|
31 |
+
if [[ $(hostname) == "oon12" ]]; then
|
32 |
+
export PATH_TO_DATA=/data
|
33 |
+
fi
|
34 |
+
if [[ $(hostname) == "oon26" ]]; then
|
35 |
+
export PATH_TO_DATA=/data
|
36 |
+
fi
|
env/entry_script.sh
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
|
4 |
+
if [ -z "${USER}" ]; then
|
5 |
+
echo "We need USER to be set!"; exit 100
|
6 |
+
fi
|
7 |
+
|
8 |
+
# check if host uid and gid are set
|
9 |
+
if [ -z "${HOST_USER_ID}" ]; then
|
10 |
+
echo "Please set HOST_USER_ID env. variables to continue." ; exit 0
|
11 |
+
fi
|
12 |
+
|
13 |
+
if [ -z "${HOST_USER_GID}" ]; then
|
14 |
+
echo "Please set HOST_USER_GID env. variables to continue." ; exit 0
|
15 |
+
fi
|
16 |
+
|
17 |
+
USER_ID=$HOST_USER_ID
|
18 |
+
USER_GID=$HOST_USER_GID
|
19 |
+
USER_HOME=/u/home
|
20 |
+
|
21 |
+
# modify uid and gid to match host
|
22 |
+
sed -i -e "s/^${USER}:\([^:]*\):[0-9]*:[0-9]*/${USER}:\1:${USER_ID}:${USER_GID}/" /etc/passwd
|
23 |
+
|
24 |
+
# create a group for host gid
|
25 |
+
groupadd -f --gid "${USER_GID}" "host_group"
|
26 |
+
|
27 |
+
chown $USER_ID $USER_HOME
|
28 |
+
chown $USER_ID /u/home/.zshrc
|
29 |
+
chown $USER_ID /u/home/.oh-my-zsh
|
30 |
+
|
31 |
+
mkdir -p /u/home/.cache
|
32 |
+
chown -R $USER_ID:$USER_GID /u/home/.cache/
|
33 |
+
|
34 |
+
_term() {
|
35 |
+
echo "Caught SIGTERM signal!"
|
36 |
+
kill -TERM "$child" 2>/dev/null
|
37 |
+
}
|
38 |
+
|
39 |
+
trap _term SIGTERM
|
40 |
+
|
41 |
+
su -p "${USER}"
|
env/setup.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script guides the user through setting up their env.sh
|
2 |
+
# if env.sh does not exist. Should have no dependencies other
|
3 |
+
# than Python standard library.
|
4 |
+
import shlex
|
5 |
+
import socket
|
6 |
+
import subprocess
|
7 |
+
import textwrap
|
8 |
+
|
9 |
+
|
10 |
+
def run(cmd):
|
11 |
+
return subprocess.check_output(shlex.split(cmd)).decode("utf-8")
|
12 |
+
|
13 |
+
|
14 |
+
print("1. Setting up Google Cloud access")
|
15 |
+
print("---------------------------------")
|
16 |
+
gcloud_authorized = "gs://research-data-raw" in run("gsutil ls")
|
17 |
+
if not gcloud_authorized:
|
18 |
+
run("gcloud auth login")
|
19 |
+
|
20 |
+
run("gcloud config set project lyrebird-research")
|
21 |
+
run("gcloud auth configure-docker")
|
22 |
+
|
23 |
+
print()
|
24 |
+
print("2. Setting up Github access")
|
25 |
+
print("---------------------------")
|
26 |
+
|
27 |
+
lines = textwrap.wrap(
|
28 |
+
"First, let's get your Github token, so all "
|
29 |
+
"packages can be installed. Create one by going to your "
|
30 |
+
"Github profile -> Developer settings -> Personal access tokens -> "
|
31 |
+
"Generate new token. Copy the token below."
|
32 |
+
)
|
33 |
+
[print(l) for l in lines]
|
34 |
+
|
35 |
+
GITHUB_TOKEN = input("\nGithub token: ") or "undefined"
|
36 |
+
|
37 |
+
print()
|
38 |
+
print("3. Setting up Jupyter and Tensorboard")
|
39 |
+
print("-------------------------------------")
|
40 |
+
|
41 |
+
JUPYTER_TOKEN = input("Password for Jupyter server (default:password): ") or "password"
|
42 |
+
JUPYTER_PORT = input("Jupyter port to run on (default:8888): ") or "8888"
|
43 |
+
TENSORBOARD_PORT = input("Tensorboard port to run on (default:6006): ") or "6006"
|
44 |
+
|
45 |
+
print()
|
46 |
+
print("4. Setting up paths.")
|
47 |
+
print("--------------------")
|
48 |
+
|
49 |
+
PATH_TO_RUNS = input("Where runs should go (default:./runs/): ") or "./runs/"
|
50 |
+
TENSORBOARD_PATH = (
|
51 |
+
input("Bucket/dir for tensorboard logs (default=PATH_TO_RUNS): ") or PATH_TO_RUNS
|
52 |
+
)
|
53 |
+
|
54 |
+
with open("env/data.sh") as f:
|
55 |
+
data_script = f.read()
|
56 |
+
|
57 |
+
write_to_data_sh = False
|
58 |
+
if socket.gethostname() not in data_script:
|
59 |
+
print("Looks like the data path for this machine is not setup.")
|
60 |
+
PATH_TO_DATA = input(f"Path to data on {socket.gethostname()}: ") or "~/data"
|
61 |
+
|
62 |
+
data_command = f"""
|
63 |
+
if [[ $(hostname) == "{socket.gethostname()}" ]]; then
|
64 |
+
export PATH_TO_DATA={PATH_TO_DATA}
|
65 |
+
fi
|
66 |
+
"""
|
67 |
+
write_to_data_sh = True
|
68 |
+
|
69 |
+
|
70 |
+
print()
|
71 |
+
print("5. Setting up Papaya")
|
72 |
+
print("-----------------------------------------")
|
73 |
+
|
74 |
+
PAPAYA_USER_TOKEN = input("Papaya user token: ") or "undefined"
|
75 |
+
|
76 |
+
env_script = f"""
|
77 |
+
source env/alias.sh
|
78 |
+
source env/data.sh
|
79 |
+
export GITHUB_TOKEN={GITHUB_TOKEN}
|
80 |
+
|
81 |
+
export PAPAYA_USER_TOKEN={PAPAYA_USER_TOKEN}
|
82 |
+
|
83 |
+
export HOST_USER_ID=$(id -u)
|
84 |
+
export HOST_USER_GID=$(id -g)
|
85 |
+
|
86 |
+
export JUPYTER_TOKEN={JUPYTER_TOKEN}
|
87 |
+
export JUPYTER_PORT={JUPYTER_PORT}
|
88 |
+
export TENSORBOARD_PORT={TENSORBOARD_PORT}
|
89 |
+
|
90 |
+
export PATH_TO_RUNS={PATH_TO_RUNS}
|
91 |
+
export TENSORBOARD_PATH={TENSORBOARD_PATH}
|
92 |
+
"""
|
93 |
+
|
94 |
+
print()
|
95 |
+
print("6. Potential file contents.")
|
96 |
+
print("---------------------------")
|
97 |
+
|
98 |
+
print("env/env.sh: \n")
|
99 |
+
print("##################")
|
100 |
+
print(env_script)
|
101 |
+
print("##################")
|
102 |
+
|
103 |
+
if write_to_data_sh:
|
104 |
+
data_script += data_command
|
105 |
+
|
106 |
+
print("env/data.sh:")
|
107 |
+
print("##################")
|
108 |
+
print(data_script)
|
109 |
+
print("##################")
|
110 |
+
|
111 |
+
print()
|
112 |
+
write_to_files = input("Write to file [yn]? ") or "n"
|
113 |
+
if write_to_files == "y":
|
114 |
+
with open("env/env.sh", "w") as f:
|
115 |
+
f.write(env_script.strip())
|
116 |
+
with open("env/data.sh", "w") as f:
|
117 |
+
f.write(data_script.strip())
|
118 |
+
|
119 |
+
print()
|
120 |
+
print("8. Finalize setup.")
|
121 |
+
print("------------------")
|
122 |
+
print("Run the following command to complete setup.")
|
123 |
+
print("source env/env.sh")
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
argbind>=0.3.1
|
2 |
+
pytorch-ignite
|
3 |
+
rich
|
4 |
+
audiotools @ git+https://github.com/descriptinc/[email protected]
|
5 |
+
tqdm
|
6 |
+
tensorboard
|
7 |
+
google-cloud-logging==2.2.0
|
8 |
+
pytest
|
9 |
+
pytest-cov
|
10 |
+
papaya_client @ git+https://github.com/descriptinc/lyrebird-papaya.git@master
|
11 |
+
pynvml
|
12 |
+
psutil
|
13 |
+
pandas
|
14 |
+
onnx
|
15 |
+
onnx-simplifier
|
16 |
+
seaborn
|
17 |
+
jupyterlab
|
18 |
+
jupyterlab-link-share
|
19 |
+
pandas
|
20 |
+
watchdog
|
21 |
+
pesq
|
22 |
+
tabulate
|
23 |
+
torchmetrics
|
24 |
+
codebraid==0.5.0
|
25 |
+
jupyter-client==6.1.12
|
26 |
+
tensorboardX
|
27 |
+
gradio
|
28 |
+
einops
|
29 |
+
flash-attn
|
scripts/generative/eval.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import imp
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import argbind
|
7 |
+
import audiotools
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
from flatten_dict import flatten
|
12 |
+
from rich.progress import track
|
13 |
+
from torch.utils.tensorboard import SummaryWriter
|
14 |
+
|
15 |
+
import wav2wav
|
16 |
+
|
17 |
+
train = imp.load_source("train", str(Path(__file__).absolute().parent / "train.py"))
|
18 |
+
|
19 |
+
|
20 |
+
@argbind.bind(without_prefix=True)
|
21 |
+
def evaluate(
|
22 |
+
args,
|
23 |
+
model_tag: str = "ckpt/best",
|
24 |
+
device: str = "cuda",
|
25 |
+
exp: str = None,
|
26 |
+
overwrite: bool = False,
|
27 |
+
):
|
28 |
+
assert exp is not None
|
29 |
+
|
30 |
+
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
31 |
+
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
32 |
+
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
33 |
+
|
34 |
+
with audiotools.util.chdir(exp):
|
35 |
+
vampnet = wav2wav.modules.vampnet.transformer.VampNet.load(
|
36 |
+
f"{model_tag}/vampnet/package.pth"
|
37 |
+
)
|
38 |
+
vampnet = vampnet.to(device)
|
39 |
+
if vampnet.cond_dim > 0:
|
40 |
+
condnet = wav2wav.modules.condnet.transformer.CondNet.load(
|
41 |
+
f"{model_tag}/condnet/package.pth"
|
42 |
+
)
|
43 |
+
condnet = condnet.to(device)
|
44 |
+
else:
|
45 |
+
condnet = None
|
46 |
+
|
47 |
+
vqvae = wav2wav.modules.generator.Generator.load(
|
48 |
+
f"{model_tag}/vqvae/package.pth"
|
49 |
+
)
|
50 |
+
|
51 |
+
_, _, test_data = train.build_datasets(args, vqvae.sample_rate)
|
52 |
+
|
53 |
+
with audiotools.util.chdir(exp):
|
54 |
+
datasets = {
|
55 |
+
"test": test_data,
|
56 |
+
}
|
57 |
+
|
58 |
+
metrics_path = Path(f"{model_tag}/metrics")
|
59 |
+
metrics_path.mkdir(parents=True, exist_ok=True)
|
60 |
+
|
61 |
+
for key, dataset in datasets.items():
|
62 |
+
csv_path = metrics_path / f"{key}.csv"
|
63 |
+
if csv_path.exists() and not overwrite:
|
64 |
+
break
|
65 |
+
metrics = []
|
66 |
+
for i in track(range(len(dataset))):
|
67 |
+
# TODO: for coarse2fine
|
68 |
+
# grab the signal
|
69 |
+
# mask all the codebooks except the conditioning ones
|
70 |
+
# and infer
|
71 |
+
# then compute metrics
|
72 |
+
# for a baseline, just use the coarsest codebook
|
73 |
+
|
74 |
+
try:
|
75 |
+
visqol = audiotools.metrics.quality.visqol(
|
76 |
+
enhanced, clean, "audio"
|
77 |
+
).item()
|
78 |
+
except:
|
79 |
+
visqol = None
|
80 |
+
|
81 |
+
sisdr = sisdr_loss(enhanced, clean)
|
82 |
+
stft = stft_loss(enhanced, clean)
|
83 |
+
mel = mel_loss(enhanced, clean)
|
84 |
+
|
85 |
+
metrics.append(
|
86 |
+
{
|
87 |
+
"visqol": visqol,
|
88 |
+
"sisdr": sisdr.item(),
|
89 |
+
"stft": stft.item(),
|
90 |
+
"mel": mel.item(),
|
91 |
+
"dataset": key,
|
92 |
+
"condition": exp,
|
93 |
+
}
|
94 |
+
)
|
95 |
+
print(metrics[-1])
|
96 |
+
|
97 |
+
transform_args = flatten(item["transform_args"], "dot")
|
98 |
+
for k, v in transform_args.items():
|
99 |
+
if torch.is_tensor(v):
|
100 |
+
if len(v.shape) == 0:
|
101 |
+
metrics[-1][k] = v.item()
|
102 |
+
|
103 |
+
metrics = pd.DataFrame.from_dict(metrics)
|
104 |
+
with open(csv_path, "w") as f:
|
105 |
+
metrics.to_csv(f)
|
106 |
+
|
107 |
+
data = summary(model_tag).to_dict()
|
108 |
+
metrics = {}
|
109 |
+
for k1, v1 in data.items():
|
110 |
+
for k2, v2 in v1.items():
|
111 |
+
metrics[f"metrics/{k2}/{k1}"] = v2
|
112 |
+
|
113 |
+
# Number of steps to record
|
114 |
+
writer = SummaryWriter(log_dir=metrics_path)
|
115 |
+
num_steps = 10
|
116 |
+
for k, v in metrics.items():
|
117 |
+
for i in range(num_steps):
|
118 |
+
writer.add_scalar(k, v, i)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
args = argbind.parse_args()
|
123 |
+
with argbind.scope(args):
|
124 |
+
evaluate(args)
|
scripts/generative/train.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import argbind
|
9 |
+
import audiotools as at
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from audiotools import AudioSignal
|
13 |
+
from audiotools.data import transforms
|
14 |
+
from einops import rearrange
|
15 |
+
from rich import pretty
|
16 |
+
from rich.traceback import install
|
17 |
+
from tensorboardX import SummaryWriter
|
18 |
+
|
19 |
+
import vampnet
|
20 |
+
from vampnet.modules.transformer import VampNet
|
21 |
+
from lac.model.lac import LAC
|
22 |
+
|
23 |
+
|
24 |
+
# Enable cudnn autotuner to speed up training
|
25 |
+
# (can be altered by the funcs.seed function)
|
26 |
+
torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
|
27 |
+
# Uncomment to trade memory for speed.
|
28 |
+
|
29 |
+
# Install to make things look nice
|
30 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
31 |
+
pretty.install()
|
32 |
+
install()
|
33 |
+
|
34 |
+
# optim
|
35 |
+
Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
|
36 |
+
CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
|
37 |
+
AdamW = argbind.bind(torch.optim.AdamW)
|
38 |
+
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
|
39 |
+
|
40 |
+
# transforms
|
41 |
+
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
42 |
+
"BaseTransform",
|
43 |
+
"Compose",
|
44 |
+
"Choose",
|
45 |
+
]
|
46 |
+
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
47 |
+
|
48 |
+
# model
|
49 |
+
VampNet = argbind.bind(VampNet)
|
50 |
+
|
51 |
+
|
52 |
+
# data
|
53 |
+
AudioLoader = argbind.bind(at.datasets.AudioLoader)
|
54 |
+
AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
|
55 |
+
|
56 |
+
IGNORE_INDEX = -100
|
57 |
+
|
58 |
+
|
59 |
+
@argbind.bind("train", "val", without_prefix=True)
|
60 |
+
def build_transform():
|
61 |
+
transform = transforms.Compose(
|
62 |
+
tfm.VolumeNorm(("uniform", -32, -14)),
|
63 |
+
tfm.VolumeChange(("uniform", -6, 3)),
|
64 |
+
tfm.RescaleAudio(),
|
65 |
+
)
|
66 |
+
return transform
|
67 |
+
|
68 |
+
|
69 |
+
@torch.no_grad()
|
70 |
+
def apply_transform(transform_fn, batch):
|
71 |
+
sig: AudioSignal = batch["signal"]
|
72 |
+
kwargs = batch["transform_args"]
|
73 |
+
|
74 |
+
sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
|
75 |
+
return sig
|
76 |
+
|
77 |
+
|
78 |
+
def build_datasets(args, sample_rate: int):
|
79 |
+
with argbind.scope(args, "train"):
|
80 |
+
train_data = AudioDataset(
|
81 |
+
AudioLoader(), sample_rate, transform=build_transform()
|
82 |
+
)
|
83 |
+
with argbind.scope(args, "val"):
|
84 |
+
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
85 |
+
with argbind.scope(args, "test"):
|
86 |
+
test_data = AudioDataset(
|
87 |
+
AudioLoader(), sample_rate, transform=build_transform()
|
88 |
+
)
|
89 |
+
return train_data, val_data, test_data
|
90 |
+
|
91 |
+
|
92 |
+
def rand_float(shape, low, high, rng):
|
93 |
+
return rng.draw(shape)[:, 0] * (high - low) + low
|
94 |
+
|
95 |
+
|
96 |
+
def flip_coin(shape, p, rng):
|
97 |
+
return rng.draw(shape)[:, 0] < p
|
98 |
+
|
99 |
+
|
100 |
+
@argbind.bind(without_prefix=True)
|
101 |
+
def load(
|
102 |
+
args,
|
103 |
+
accel: at.ml.Accelerator,
|
104 |
+
save_path: str,
|
105 |
+
resume: bool = False,
|
106 |
+
tag: str = "latest",
|
107 |
+
load_weights: bool = False,
|
108 |
+
):
|
109 |
+
model, v_extra = None, {}
|
110 |
+
|
111 |
+
if resume:
|
112 |
+
kwargs = {
|
113 |
+
"folder": f"{save_path}/{tag}",
|
114 |
+
"map_location": "cpu",
|
115 |
+
"package": not load_weights,
|
116 |
+
}
|
117 |
+
if (Path(kwargs["folder"]) / "model").exists():
|
118 |
+
model, v_extra = model.load_from_folder(**kwargs)
|
119 |
+
|
120 |
+
codec = LAC.load(args["codec_ckpt"], map_location="cpu")
|
121 |
+
codec.eval()
|
122 |
+
model = VampNet() if model is None else model
|
123 |
+
|
124 |
+
model = accel.prepare_model(model)
|
125 |
+
|
126 |
+
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
127 |
+
assert (
|
128 |
+
accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
|
129 |
+
)
|
130 |
+
|
131 |
+
optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
|
132 |
+
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
133 |
+
scheduler.step()
|
134 |
+
|
135 |
+
trainer_state = {"state_dict": None, "start_idx": 0}
|
136 |
+
|
137 |
+
if "optimizer.pth" in v_extra:
|
138 |
+
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
139 |
+
if "scheduler.pth" in v_extra:
|
140 |
+
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
141 |
+
if "trainer.pth" in v_extra:
|
142 |
+
trainer_state = v_extra["trainer.pth"]
|
143 |
+
|
144 |
+
return {
|
145 |
+
"model": model,
|
146 |
+
"codec": codec,
|
147 |
+
"optimizer": optimizer,
|
148 |
+
"scheduler": scheduler,
|
149 |
+
"trainer_state": trainer_state,
|
150 |
+
}
|
151 |
+
|
152 |
+
|
153 |
+
def get_gpu_memory_map():
|
154 |
+
"""Get the current gpu usage.
|
155 |
+
|
156 |
+
Returns
|
157 |
+
-------
|
158 |
+
usage: dict
|
159 |
+
Keys are device ids as integers.
|
160 |
+
Values are memory usage as integers in MB.
|
161 |
+
"""
|
162 |
+
result = subprocess.check_output(
|
163 |
+
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
|
164 |
+
encoding="utf-8",
|
165 |
+
)
|
166 |
+
# Convert lines into a dictionary
|
167 |
+
gpu_memory = [int(x) for x in result.strip().split("\n")]
|
168 |
+
gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
|
169 |
+
gpu_memory_map = {f"gpu/{k}": v / 1024 for k, v in gpu_memory_map.items()}
|
170 |
+
return gpu_memory_map
|
171 |
+
|
172 |
+
|
173 |
+
def num_params_hook(o, p):
|
174 |
+
return o + f" {p/1e6:<.3f}M params."
|
175 |
+
|
176 |
+
|
177 |
+
def add_num_params_repr_hook(model):
|
178 |
+
import numpy as np
|
179 |
+
from functools import partial
|
180 |
+
|
181 |
+
for n, m in model.named_modules():
|
182 |
+
o = m.extra_repr()
|
183 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
184 |
+
|
185 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
186 |
+
|
187 |
+
|
188 |
+
def accuracy(
|
189 |
+
preds: torch.Tensor,
|
190 |
+
target: torch.Tensor,
|
191 |
+
top_k: int = 1,
|
192 |
+
ignore_index: Optional[int] = None,
|
193 |
+
**kwargs,
|
194 |
+
) -> torch.Tensor:
|
195 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
196 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
197 |
+
target = rearrange(target, "b s -> (b s)")
|
198 |
+
|
199 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
200 |
+
if ignore_index is not None:
|
201 |
+
# Create a mask for the ignored index
|
202 |
+
mask = target != ignore_index
|
203 |
+
# Apply the mask to the target and predictions
|
204 |
+
preds = preds[mask]
|
205 |
+
target = target[mask]
|
206 |
+
|
207 |
+
# Get the top-k predicted classes and their indices
|
208 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
209 |
+
|
210 |
+
# Determine if the true target is in the top-k predicted classes
|
211 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
212 |
+
|
213 |
+
# Calculate the accuracy
|
214 |
+
accuracy = torch.mean(correct.float())
|
215 |
+
|
216 |
+
return accuracy
|
217 |
+
|
218 |
+
|
219 |
+
@argbind.bind(without_prefix=True)
|
220 |
+
def train(
|
221 |
+
args,
|
222 |
+
accel: at.ml.Accelerator,
|
223 |
+
codec_ckpt: str = None,
|
224 |
+
seed: int = 0,
|
225 |
+
save_path: str = "ckpt",
|
226 |
+
max_epochs: int = int(100e3),
|
227 |
+
epoch_length: int = 1000,
|
228 |
+
save_audio_epochs: int = 10,
|
229 |
+
batch_size: int = 48,
|
230 |
+
grad_acc_steps: int = 1,
|
231 |
+
val_idx: list = [0, 1, 2, 3, 4],
|
232 |
+
num_workers: int = 20,
|
233 |
+
detect_anomaly: bool = False,
|
234 |
+
grad_clip_val: float = 5.0,
|
235 |
+
prefix_amt: float = 0.0,
|
236 |
+
suffix_amt: float = 0.0,
|
237 |
+
prefix_dropout: float = 0.1,
|
238 |
+
suffix_dropout: float = 0.1,
|
239 |
+
quiet: bool = False,
|
240 |
+
):
|
241 |
+
assert codec_ckpt is not None, "codec_ckpt is required"
|
242 |
+
|
243 |
+
at.util.seed(seed)
|
244 |
+
writer = None
|
245 |
+
|
246 |
+
if accel.local_rank == 0:
|
247 |
+
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
248 |
+
argbind.dump_args(args, f"{save_path}/args.yml")
|
249 |
+
|
250 |
+
# load the codec model
|
251 |
+
loaded = load(args, accel, save_path)
|
252 |
+
model = loaded["model"]
|
253 |
+
codec = loaded["codec"]
|
254 |
+
optimizer = loaded["optimizer"]
|
255 |
+
scheduler = loaded["scheduler"]
|
256 |
+
trainer_state = loaded["trainer_state"]
|
257 |
+
|
258 |
+
sample_rate = codec.sample_rate
|
259 |
+
|
260 |
+
# a better rng for sampling from our schedule
|
261 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
262 |
+
|
263 |
+
# log a model summary w/ num params
|
264 |
+
if accel.local_rank == 0:
|
265 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
266 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
267 |
+
f.write(repr(accel.unwrap(model)))
|
268 |
+
|
269 |
+
# load the datasets
|
270 |
+
train_data, val_data, _ = build_datasets(args, sample_rate)
|
271 |
+
train_dataloader = accel.prepare_dataloader(
|
272 |
+
train_data,
|
273 |
+
start_idx=trainer_state["start_idx"],
|
274 |
+
num_workers=num_workers,
|
275 |
+
batch_size=batch_size,
|
276 |
+
collate_fn=train_data.collate,
|
277 |
+
)
|
278 |
+
val_dataloader = accel.prepare_dataloader(
|
279 |
+
val_data,
|
280 |
+
start_idx=0,
|
281 |
+
num_workers=num_workers,
|
282 |
+
batch_size=batch_size,
|
283 |
+
collate_fn=val_data.collate,
|
284 |
+
)
|
285 |
+
|
286 |
+
criterion = CrossEntropyLoss()
|
287 |
+
|
288 |
+
class Trainer(at.ml.BaseTrainer):
|
289 |
+
_last_grad_norm = 0.0
|
290 |
+
|
291 |
+
def metrics(self, vn, z_hat, r, target, flat_mask, output):
|
292 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
293 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
294 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
295 |
+
|
296 |
+
assert target.shape[0] == r.shape[0]
|
297 |
+
# grab the indices of the r values that are in the range
|
298 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
299 |
+
|
300 |
+
# grab the target and z_hat values that are in the range
|
301 |
+
r_unmasked_target = unmasked_target[r_idx]
|
302 |
+
r_masked_target = masked_target[r_idx]
|
303 |
+
r_z_hat = z_hat[r_idx]
|
304 |
+
|
305 |
+
for topk in (1, 25):
|
306 |
+
s, e = r_range
|
307 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
308 |
+
|
309 |
+
output[f"{tag}/unmasked"] = accuracy(
|
310 |
+
preds=r_z_hat,
|
311 |
+
target=r_unmasked_target,
|
312 |
+
ignore_index=IGNORE_INDEX,
|
313 |
+
top_k=topk,
|
314 |
+
task="multiclass",
|
315 |
+
num_classes=vn.vocab_size,
|
316 |
+
)
|
317 |
+
output[f"{tag}/masked"] = accuracy(
|
318 |
+
preds=r_z_hat,
|
319 |
+
target=r_masked_target,
|
320 |
+
ignore_index=IGNORE_INDEX,
|
321 |
+
top_k=topk,
|
322 |
+
task="multiclass",
|
323 |
+
num_classes=vn.vocab_size,
|
324 |
+
)
|
325 |
+
|
326 |
+
def train_loop(self, engine, batch):
|
327 |
+
|
328 |
+
model.train()
|
329 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
330 |
+
signal = apply_transform(train_data.transform, batch)
|
331 |
+
|
332 |
+
output = {}
|
333 |
+
vn = accel.unwrap(model)
|
334 |
+
with accel.autocast():
|
335 |
+
with torch.inference_mode():
|
336 |
+
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
337 |
+
z = z[:, : vn.n_codebooks, :]
|
338 |
+
|
339 |
+
n_batch = z.shape[0]
|
340 |
+
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
341 |
+
|
342 |
+
if prefix_amt > 0.0:
|
343 |
+
prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
|
344 |
+
n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
|
345 |
+
else:
|
346 |
+
n_prefix = None
|
347 |
+
if suffix_amt > 0.0:
|
348 |
+
suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
|
349 |
+
n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
|
350 |
+
else:
|
351 |
+
n_suffix = None
|
352 |
+
|
353 |
+
z_mask, mask = vn.add_noise(
|
354 |
+
z, r, n_prefix=n_prefix, n_suffix=n_suffix
|
355 |
+
)
|
356 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
357 |
+
|
358 |
+
dtype = torch.bfloat16 if accel.amp else None
|
359 |
+
with accel.autocast(dtype=dtype):
|
360 |
+
z_hat = model(z_mask_latent, r)
|
361 |
+
# for mask mode
|
362 |
+
z_hat = vn.add_truth_to_logits(z, z_hat, mask)
|
363 |
+
|
364 |
+
target = vn.embedding.flatten(
|
365 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
366 |
+
n_codebooks=vn.n_predict_codebooks,
|
367 |
+
)
|
368 |
+
|
369 |
+
flat_mask = vn.embedding.flatten(
|
370 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
371 |
+
n_codebooks=vn.n_predict_codebooks,
|
372 |
+
)
|
373 |
+
|
374 |
+
if vn.noise_mode == "mask":
|
375 |
+
# replace target with ignore index for masked tokens
|
376 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
377 |
+
output["loss"] = criterion(z_hat, t_masked)
|
378 |
+
else:
|
379 |
+
output["loss"] = criterion(z_hat, target)
|
380 |
+
|
381 |
+
self.metrics(
|
382 |
+
vn=vn,
|
383 |
+
r=r,
|
384 |
+
z_hat=z_hat,
|
385 |
+
target=target,
|
386 |
+
flat_mask=flat_mask,
|
387 |
+
output=output,
|
388 |
+
)
|
389 |
+
|
390 |
+
|
391 |
+
accel.backward(output["loss"] / grad_acc_steps)
|
392 |
+
|
393 |
+
output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
|
394 |
+
output["other/batch_size"] = z.shape[0]
|
395 |
+
|
396 |
+
output.update(get_gpu_memory_map())
|
397 |
+
|
398 |
+
if (
|
399 |
+
(engine.state.iteration % grad_acc_steps == 0)
|
400 |
+
or (engine.state.iteration % epoch_length == 0)
|
401 |
+
or (engine.state.iteration % epoch_length == 1)
|
402 |
+
): # (or we reached the end of the epoch)
|
403 |
+
accel.scaler.unscale_(optimizer)
|
404 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
405 |
+
model.parameters(), grad_clip_val
|
406 |
+
)
|
407 |
+
self._last_grad_norm = output["other/grad_norm"]
|
408 |
+
|
409 |
+
accel.step(optimizer)
|
410 |
+
optimizer.zero_grad()
|
411 |
+
|
412 |
+
scheduler.step()
|
413 |
+
accel.update()
|
414 |
+
else:
|
415 |
+
output["other/grad_norm"] = self._last_grad_norm
|
416 |
+
|
417 |
+
return {k: v for k, v in sorted(output.items())}
|
418 |
+
|
419 |
+
@torch.no_grad()
|
420 |
+
def val_loop(self, engine, batch):
|
421 |
+
model.eval()
|
422 |
+
codec.eval()
|
423 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
424 |
+
signal = apply_transform(val_data.transform, batch)
|
425 |
+
|
426 |
+
vn = accel.unwrap(model)
|
427 |
+
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
428 |
+
z = z[:, : vn.n_codebooks, :]
|
429 |
+
|
430 |
+
n_batch = z.shape[0]
|
431 |
+
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
432 |
+
|
433 |
+
if prefix_amt > 0.0:
|
434 |
+
prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
|
435 |
+
n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
|
436 |
+
else:
|
437 |
+
n_prefix = None
|
438 |
+
if suffix_amt > 0.0:
|
439 |
+
suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
|
440 |
+
n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
|
441 |
+
else:
|
442 |
+
n_suffix = None
|
443 |
+
|
444 |
+
z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
|
445 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
446 |
+
|
447 |
+
z_hat = model(z_mask_latent, r)
|
448 |
+
# for mask mode
|
449 |
+
z_hat = vn.add_truth_to_logits(z, z_hat, mask)
|
450 |
+
|
451 |
+
target = vn.embedding.flatten(
|
452 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
453 |
+
n_codebooks=vn.n_predict_codebooks,
|
454 |
+
)
|
455 |
+
|
456 |
+
flat_mask = vn.embedding.flatten(
|
457 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
458 |
+
n_codebooks=vn.n_predict_codebooks,
|
459 |
+
)
|
460 |
+
|
461 |
+
output = {}
|
462 |
+
if vn.noise_mode == "mask":
|
463 |
+
# replace target with ignore index for masked tokens
|
464 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
465 |
+
output["loss"] = criterion(z_hat, t_masked)
|
466 |
+
else:
|
467 |
+
output["loss"] = criterion(z_hat, target)
|
468 |
+
|
469 |
+
self.metrics(
|
470 |
+
vn=vn,
|
471 |
+
r=r,
|
472 |
+
z_hat=z_hat,
|
473 |
+
target=target,
|
474 |
+
flat_mask=flat_mask,
|
475 |
+
output=output,
|
476 |
+
)
|
477 |
+
|
478 |
+
return output
|
479 |
+
|
480 |
+
def checkpoint(self, engine):
|
481 |
+
if accel.local_rank != 0:
|
482 |
+
print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
483 |
+
return
|
484 |
+
|
485 |
+
metadata = {"logs": dict(engine.state.logs["epoch"])}
|
486 |
+
|
487 |
+
if self.state.epoch % save_audio_epochs == 0:
|
488 |
+
self.save_samples()
|
489 |
+
|
490 |
+
tags = ["latest"]
|
491 |
+
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
492 |
+
self.print(f"Saving to {str(Path('.').absolute())}")
|
493 |
+
|
494 |
+
if self.is_best(engine, loss_key):
|
495 |
+
self.print(f"Best model so far")
|
496 |
+
tags.append("best")
|
497 |
+
|
498 |
+
for tag in tags:
|
499 |
+
model_extra = {
|
500 |
+
"optimizer.pth": optimizer.state_dict(),
|
501 |
+
"scheduler.pth": scheduler.state_dict(),
|
502 |
+
"trainer.pth": {
|
503 |
+
"start_idx": self.state.iteration * batch_size,
|
504 |
+
"state_dict": self.state_dict(),
|
505 |
+
},
|
506 |
+
"metadata.pth": metadata,
|
507 |
+
}
|
508 |
+
|
509 |
+
accel.unwrap(model).metadata = metadata
|
510 |
+
accel.unwrap(model).save_to_folder(
|
511 |
+
f"{save_path}/{tag}", model_extra
|
512 |
+
)
|
513 |
+
|
514 |
+
def save_sampled(self, z):
|
515 |
+
num_samples = z.shape[0]
|
516 |
+
|
517 |
+
for i in range(num_samples):
|
518 |
+
sampled = accel.unwrap(model).sample(
|
519 |
+
codec,
|
520 |
+
time_steps=z.shape[-1],
|
521 |
+
start_tokens=z[i : i + 1],
|
522 |
+
)
|
523 |
+
sampled.cpu().write_audio_to_tb(
|
524 |
+
f"sampled/{i}",
|
525 |
+
self.writer,
|
526 |
+
step=self.state.epoch,
|
527 |
+
plot_fn=None,
|
528 |
+
)
|
529 |
+
|
530 |
+
def save_imputation(self, z: torch.Tensor):
|
531 |
+
# imputations
|
532 |
+
mask_begin = z.shape[-1] // 4
|
533 |
+
mask_end = (z.shape[-1] * 3) // 4
|
534 |
+
|
535 |
+
imp_mask = torch.zeros(z.shape[0], z.shape[-1]).to(accel.device).int()
|
536 |
+
imp_mask[:, mask_begin:mask_end] = 1
|
537 |
+
|
538 |
+
imp_noisy = (
|
539 |
+
z * (1 - imp_mask[:, None, :])
|
540 |
+
+ torch.randint_like(z, 0, accel.unwrap(model).vocab_size)
|
541 |
+
* imp_mask[:, None, :]
|
542 |
+
)
|
543 |
+
imputed_noisy = accel.unwrap(model).to_signal(imp_noisy, codec)
|
544 |
+
imputed_true = accel.unwrap(model).to_signal(z, codec)
|
545 |
+
|
546 |
+
imputed = []
|
547 |
+
for i in range(len(z)):
|
548 |
+
imputed.append(
|
549 |
+
accel.unwrap(model).sample(
|
550 |
+
codec,
|
551 |
+
time_steps=z.shape[-1],
|
552 |
+
start_tokens=z[i][None, ...],
|
553 |
+
mask=imp_mask[i][None, ...],
|
554 |
+
)
|
555 |
+
)
|
556 |
+
imputed = AudioSignal.batch(imputed)
|
557 |
+
|
558 |
+
for i in range(len(val_idx)):
|
559 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
560 |
+
f"imputed_noisy/{i}",
|
561 |
+
self.writer,
|
562 |
+
step=self.state.epoch,
|
563 |
+
plot_fn=None,
|
564 |
+
)
|
565 |
+
imputed[i].cpu().write_audio_to_tb(
|
566 |
+
f"imputed/{i}",
|
567 |
+
self.writer,
|
568 |
+
step=self.state.epoch,
|
569 |
+
plot_fn=None,
|
570 |
+
)
|
571 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
572 |
+
f"imputed_true/{i}",
|
573 |
+
self.writer,
|
574 |
+
step=self.state.epoch,
|
575 |
+
plot_fn=None,
|
576 |
+
)
|
577 |
+
|
578 |
+
@torch.no_grad()
|
579 |
+
def save_samples(self):
|
580 |
+
model.eval()
|
581 |
+
codec.eval()
|
582 |
+
vn = accel.unwrap(model)
|
583 |
+
|
584 |
+
batch = [val_data[i] for i in val_idx]
|
585 |
+
batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
|
586 |
+
|
587 |
+
signal = apply_transform(val_data.transform, batch)
|
588 |
+
|
589 |
+
z = codec.encode(signal.samples, signal.sample_rate)["codes"]
|
590 |
+
z = z[:, : vn.n_codebooks, :]
|
591 |
+
|
592 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
593 |
+
|
594 |
+
n_batch = z.shape[0]
|
595 |
+
|
596 |
+
if prefix_amt > 0.0:
|
597 |
+
prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
|
598 |
+
n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
|
599 |
+
else:
|
600 |
+
n_prefix = None
|
601 |
+
if suffix_amt > 0.0:
|
602 |
+
suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
|
603 |
+
n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
|
604 |
+
else:
|
605 |
+
n_suffix = None
|
606 |
+
|
607 |
+
z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
|
608 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
609 |
+
|
610 |
+
z_hat = model(z_mask_latent, r)
|
611 |
+
# for mask mode
|
612 |
+
z_hat = vn.add_truth_to_logits(z, z_hat, mask)
|
613 |
+
|
614 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
615 |
+
z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
|
616 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
617 |
+
|
618 |
+
print("z_mask", z_mask.shape)
|
619 |
+
generated = vn.to_signal(z_pred, codec)
|
620 |
+
reconstructed = vn.to_signal(z, codec)
|
621 |
+
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
622 |
+
|
623 |
+
for i in range(generated.batch_size):
|
624 |
+
audio_dict = {
|
625 |
+
"original": signal[i],
|
626 |
+
"masked": masked[i],
|
627 |
+
"generated": generated[i],
|
628 |
+
"reconstructed": reconstructed[i],
|
629 |
+
}
|
630 |
+
for k, v in audio_dict.items():
|
631 |
+
v.cpu().write_audio_to_tb(
|
632 |
+
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
633 |
+
self.writer,
|
634 |
+
step=self.state.epoch,
|
635 |
+
plot_fn=None,
|
636 |
+
)
|
637 |
+
|
638 |
+
self.save_sampled(z)
|
639 |
+
self.save_imputation(z)
|
640 |
+
|
641 |
+
trainer = Trainer(writer=writer, quiet=quiet)
|
642 |
+
|
643 |
+
if trainer_state["state_dict"] is not None:
|
644 |
+
trainer.load_state_dict(trainer_state["state_dict"])
|
645 |
+
if hasattr(train_dataloader.sampler, "set_epoch"):
|
646 |
+
train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
|
647 |
+
|
648 |
+
trainer.run(
|
649 |
+
train_dataloader,
|
650 |
+
val_dataloader,
|
651 |
+
num_epochs=max_epochs,
|
652 |
+
epoch_length=epoch_length,
|
653 |
+
detect_anomaly=detect_anomaly,
|
654 |
+
)
|
655 |
+
|
656 |
+
|
657 |
+
if __name__ == "__main__":
|
658 |
+
args = argbind.parse_args()
|
659 |
+
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
660 |
+
with argbind.scope(args):
|
661 |
+
with Accelerator() as accel:
|
662 |
+
train(args, accel)
|
scripts/utils/README.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scripts
|
2 |
+
|
3 |
+
## process_zip.py
|
4 |
+
|
5 |
+
Some requirements that may not be installed in the docker image:
|
6 |
+
* argbind
|
7 |
+
* wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
|
8 |
+
|
9 |
+
### zip folder structure
|
10 |
+
|
11 |
+
The zip folder should have the following internal structure:
|
12 |
+
|
13 |
+
```
|
14 |
+
base_folder/
|
15 |
+
test_case_1/
|
16 |
+
before.wav
|
17 |
+
test_case_2/
|
18 |
+
before.wav
|
19 |
+
...
|
20 |
+
test_case_n/
|
21 |
+
before.wav
|
22 |
+
```
|
23 |
+
|
24 |
+
Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
|
25 |
+
https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
|
26 |
+
|
27 |
+
### Execution
|
28 |
+
`python process_zip.py <path/to/zip> -tag <string>`
|
scripts/utils/stage.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import argbind
|
6 |
+
import rich
|
7 |
+
from audiotools.ml import Experiment
|
8 |
+
|
9 |
+
|
10 |
+
@argbind.bind(without_prefix=True)
|
11 |
+
def run(
|
12 |
+
run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
|
13 |
+
name: str = None,
|
14 |
+
recent: bool = False,
|
15 |
+
):
|
16 |
+
if recent:
|
17 |
+
paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
|
18 |
+
paths = [p.name for p in paths if p.is_dir()]
|
19 |
+
if paths:
|
20 |
+
name = paths[-1]
|
21 |
+
|
22 |
+
with Experiment(run_dir, name) as exp:
|
23 |
+
exp.snapshot()
|
24 |
+
rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
args = argbind.parse_args()
|
29 |
+
with argbind.scope(args):
|
30 |
+
run()
|
setup.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages
|
2 |
+
from setuptools import setup
|
3 |
+
|
4 |
+
with open("README.md") as f:
|
5 |
+
long_description = f.read()
|
6 |
+
|
7 |
+
setup(
|
8 |
+
name="vampnet",
|
9 |
+
version="0.0.1",
|
10 |
+
classifiers=[
|
11 |
+
"Intended Audience :: Developers",
|
12 |
+
"Natural Language :: English",
|
13 |
+
"Programming Language :: Python :: 3.7",
|
14 |
+
"Topic :: Artistic Software",
|
15 |
+
"Topic :: Multimedia",
|
16 |
+
"Topic :: Multimedia :: Sound/Audio",
|
17 |
+
"Topic :: Multimedia :: Sound/Audio :: Editors",
|
18 |
+
"Topic :: Software Development :: Libraries",
|
19 |
+
],
|
20 |
+
description="Generative Music Modeling.",
|
21 |
+
long_description=long_description,
|
22 |
+
long_description_content_type="text/markdown",
|
23 |
+
author="Hugo Flores García",
|
24 |
+
author_email="[email protected]",
|
25 |
+
url="https://github.com/descriptinc/lyrebird-vampnet",
|
26 |
+
license="MIT",
|
27 |
+
packages=find_packages(),
|
28 |
+
install_requires=[
|
29 |
+
"torch<=1.11.0",
|
30 |
+
"argbind>=0.3.2",
|
31 |
+
"pytorch-ignite",
|
32 |
+
"rich",
|
33 |
+
"audiotools @ git+https://github.com/descriptinc/[email protected]",
|
34 |
+
"tqdm",
|
35 |
+
"tensorboard",
|
36 |
+
"google-cloud-logging==2.2.0",
|
37 |
+
"torchmetrics>=0.7.3",
|
38 |
+
"einops",
|
39 |
+
],
|
40 |
+
)
|
vampnet/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from . import modules
|
3 |
+
from . import scheduler
|
4 |
+
from . import enchilada
|
5 |
+
|
6 |
+
__version__ = "0.0.1"
|
vampnet/enchilada.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from audiotools import AudioSignal
|
6 |
+
|
7 |
+
from .modules.transformer import VampNet
|
8 |
+
from lac.model.lac import LAC
|
9 |
+
|
10 |
+
|
11 |
+
class TheWholeEnchilada:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
coarse_ckpt: str,
|
15 |
+
coarse2fine_ckpt: str,
|
16 |
+
codec_ckpt: str,
|
17 |
+
device: str = "cpu",
|
18 |
+
):
|
19 |
+
self.codec = LAC.load(Path(codec_ckpt))
|
20 |
+
self.codec.eval()
|
21 |
+
self.codec.to(device)
|
22 |
+
|
23 |
+
self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
|
24 |
+
self.coarse.to(device)
|
25 |
+
self.coarse.eval()
|
26 |
+
|
27 |
+
self.coarse2fine = VampNet.load(
|
28 |
+
location=Path(coarse2fine_ckpt), map_location="cpu"
|
29 |
+
)
|
30 |
+
# FIXME
|
31 |
+
print(
|
32 |
+
f"WARNING: PATCHING coarse2fine seq_len to 288, for backwards compatibility with a specific jazzpop model. it used to be {self.coarse2fine.seq_len}"
|
33 |
+
)
|
34 |
+
self.coarse2fine.seq_len = 288
|
35 |
+
|
36 |
+
self.coarse2fine.to(device)
|
37 |
+
self.coarse2fine.eval()
|
38 |
+
|
39 |
+
self.device = device
|
40 |
+
|
41 |
+
def seconds_to_tokens(self, seconds: float):
|
42 |
+
return int(seconds * self.codec.sample_rate / self.codec.hop_length)
|
43 |
+
|
44 |
+
def to(self, device):
|
45 |
+
self.device = device
|
46 |
+
self.coarse.to(device)
|
47 |
+
self.coarse2fine.to(device)
|
48 |
+
self.codec.to(device)
|
49 |
+
return self
|
50 |
+
|
51 |
+
def encode(self, signal: AudioSignal):
|
52 |
+
with torch.inference_mode():
|
53 |
+
# coarse z
|
54 |
+
cz = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
55 |
+
|
56 |
+
return cz
|
57 |
+
|
58 |
+
def vamp(
|
59 |
+
self,
|
60 |
+
signal,
|
61 |
+
prefix_dur_s: float = 1.25,
|
62 |
+
suffix_dur_s: float = 1.25,
|
63 |
+
downsample_hint: bool = True,
|
64 |
+
downsample_factor: int = 4,
|
65 |
+
num_loops: int = 3,
|
66 |
+
**kwargs,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Loop imputation of a signal.
|
70 |
+
"""
|
71 |
+
signal.to(self.device).resample(self.codec.sample_rate).to_mono()
|
72 |
+
|
73 |
+
z = self.encode(signal)
|
74 |
+
|
75 |
+
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
76 |
+
original_cz = cz.clone()
|
77 |
+
seq_len = original_cz.shape[-1]
|
78 |
+
assert (
|
79 |
+
seq_len == self.coarse.seq_len
|
80 |
+
), f"expected seq_len {self.coarse.seq_len}, got {seq_len} for token sequence length. Is your signal the same duration as the model was trained with? "
|
81 |
+
|
82 |
+
vamp_hop_s = prefix_dur_s
|
83 |
+
vamp_hop = self.seconds_to_tokens(vamp_hop_s)
|
84 |
+
|
85 |
+
cmask = torch.ones_like(cz)
|
86 |
+
|
87 |
+
if downsample_hint:
|
88 |
+
# downsample by factor of 4
|
89 |
+
for i in range(cmask.shape[-1]):
|
90 |
+
if i % downsample_factor == 0:
|
91 |
+
cmask[:, :, i] = 0
|
92 |
+
|
93 |
+
if prefix_dur_s > 0:
|
94 |
+
prefix_len = self.seconds_to_tokens(prefix_dur_s)
|
95 |
+
cmask[:, :, :prefix_len] = 0
|
96 |
+
print(f"prefix_len: {prefix_len}")
|
97 |
+
else:
|
98 |
+
prefix_len = 0
|
99 |
+
|
100 |
+
if suffix_dur_s > 0:
|
101 |
+
suffix_len = self.seconds_to_tokens(suffix_dur_s)
|
102 |
+
cmask[:, :, -suffix_len:] = 0
|
103 |
+
print(f"suffix_len: {suffix_len}")
|
104 |
+
else:
|
105 |
+
suffix_len = 0
|
106 |
+
|
107 |
+
prefix_z = cz[:, :, :prefix_len]
|
108 |
+
|
109 |
+
coarse_vamp = [prefix_z.clone()]
|
110 |
+
for i in range(num_loops):
|
111 |
+
sampled_cz = self.coarse.sample(
|
112 |
+
codec=self.codec,
|
113 |
+
time_steps=seq_len,
|
114 |
+
mask=cmask,
|
115 |
+
start_tokens=cz,
|
116 |
+
return_signal=False,
|
117 |
+
**kwargs,
|
118 |
+
)
|
119 |
+
|
120 |
+
new_prefix = sampled_cz[:, :, prefix_len : prefix_len + vamp_hop]
|
121 |
+
coarse_vamp.append(new_prefix.clone())
|
122 |
+
|
123 |
+
# replace the prefix in cz with the new prefix
|
124 |
+
# don't worry about a copy of the prefix still being
|
125 |
+
# in the mask area, since that will be masked out
|
126 |
+
cz[:, :, :vamp_hop] = new_prefix.clone()
|
127 |
+
print("to append and to prefix")
|
128 |
+
|
129 |
+
# we're done, so add the suffix
|
130 |
+
coarse_vamp.append(sampled_cz[:, :, prefix_len + vamp_hop :])
|
131 |
+
|
132 |
+
# concatenate the vamps
|
133 |
+
coarse_vamp = torch.cat(coarse_vamp, dim=-1)
|
134 |
+
|
135 |
+
# add a layer of
|
136 |
+
fine_prefix = z[:, self.coarse.n_codebooks :, :prefix_len]
|
137 |
+
fine_suffix = z[:, self.coarse.n_codebooks :, -suffix_len:]
|
138 |
+
fine_vamp = torch.randint(
|
139 |
+
0,
|
140 |
+
self.coarse2fine.vocab_size,
|
141 |
+
(
|
142 |
+
coarse_vamp.shape[0],
|
143 |
+
self.coarse2fine.n_predict_codebooks,
|
144 |
+
coarse_vamp.shape[-1],
|
145 |
+
),
|
146 |
+
).to(self.device)
|
147 |
+
fine_vamp[:, :, :prefix_len] = fine_prefix
|
148 |
+
fine_vamp[:, :, -suffix_len:] = fine_suffix
|
149 |
+
|
150 |
+
vamp_z = torch.cat([coarse_vamp, fine_vamp], dim=1)
|
151 |
+
|
152 |
+
# now we sample from the coarse2fine model
|
153 |
+
# to get the fine details
|
154 |
+
start_pos = 0
|
155 |
+
|
156 |
+
c2f_vamp = []
|
157 |
+
while start_pos < vamp_z.shape[-1]:
|
158 |
+
end_pos = min(start_pos + self.coarse2fine.seq_len, vamp_z.shape[-1])
|
159 |
+
|
160 |
+
c2fz = vamp_z[:, :, start_pos:end_pos]
|
161 |
+
self.coarse2fine: VampNet
|
162 |
+
sampled_c2fz = self.coarse2fine.sample(
|
163 |
+
codec=self.codec,
|
164 |
+
start_tokens=c2fz,
|
165 |
+
return_signal=False,
|
166 |
+
mask=None,
|
167 |
+
)
|
168 |
+
c2f_vamp.append(sampled_c2fz)
|
169 |
+
start_pos += self.coarse2fine.seq_len
|
170 |
+
|
171 |
+
c2f_vamp = torch.cat(c2f_vamp, dim=-1)
|
172 |
+
|
173 |
+
# make it a signal
|
174 |
+
vamp_signal = self.coarse2fine.to_signal(c2f_vamp, self.codec)
|
175 |
+
|
176 |
+
return {
|
177 |
+
"full": vamp_signal,
|
178 |
+
"coarse": self.coarse.to_signal(coarse_vamp, self.codec),
|
179 |
+
}
|
vampnet/modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import audiotools
|
2 |
+
|
3 |
+
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
+
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention"]
|
vampnet/modules/base.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
from typing import Tuple
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import audiotools as at
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def log(t, eps=1e-20):
|
15 |
+
return torch.log(t + eps)
|
16 |
+
|
17 |
+
|
18 |
+
def gumbel_noise(t):
|
19 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
20 |
+
return -log(-log(noise))
|
21 |
+
|
22 |
+
|
23 |
+
def gumbel_sample(t, temperature=1.0, dim=-1):
|
24 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
25 |
+
|
26 |
+
|
27 |
+
class VampBase(at.ml.BaseModel):
|
28 |
+
def forward(self, x: torch.Tensor, r: torch.Tensor):
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
def add_noise(
|
32 |
+
self,
|
33 |
+
x: torch.Tensor,
|
34 |
+
r: torch.Tensor,
|
35 |
+
random_x: Optional[torch.Tensor] = None,
|
36 |
+
mask: Optional[torch.Tensor] = None,
|
37 |
+
n_prefix: Optional[torch.Tensor] = None,
|
38 |
+
n_suffix: Optional[torch.Tensor] = None,
|
39 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
41 |
+
|
42 |
+
if mask is None:
|
43 |
+
r = self.gamma(r)[:, None, None]
|
44 |
+
probs = torch.ones_like(x) * r
|
45 |
+
|
46 |
+
# if we have a prefix or suffix, set their mask prob to 0
|
47 |
+
if n_prefix is not None:
|
48 |
+
for i, n in enumerate(n_prefix):
|
49 |
+
probs[i, :, :n] = 0.0
|
50 |
+
if n_suffix is not None:
|
51 |
+
for i, n in enumerate(n_suffix):
|
52 |
+
probs[i, :, -n:] = 0.0
|
53 |
+
|
54 |
+
mask = torch.bernoulli(probs)
|
55 |
+
mask = mask.round().long()
|
56 |
+
|
57 |
+
# if we have any conditioning codebooks, set their mask to 0
|
58 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0
|
59 |
+
else:
|
60 |
+
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
61 |
+
assert mask.shape == x.shape, "mask must be same shape as x"
|
62 |
+
|
63 |
+
if random_x is None:
|
64 |
+
random_x = torch.randint_like(x, 0, self.vocab_size)
|
65 |
+
|
66 |
+
if self.noise_mode == "mask":
|
67 |
+
random_x = torch.full_like(x, self.mask_token)
|
68 |
+
elif self.noise_mode == "random":
|
69 |
+
if random_x is None:
|
70 |
+
random_x = torch.randint_like(x, 0, self.vocab_size)
|
71 |
+
else:
|
72 |
+
raise ValueError(f"invalid noise mode {self.noise_mode}")
|
73 |
+
|
74 |
+
x = x * (1 - mask) + random_x * mask
|
75 |
+
return x, mask
|
76 |
+
|
77 |
+
def add_truth_to_logits(
|
78 |
+
self,
|
79 |
+
z_true,
|
80 |
+
z_hat,
|
81 |
+
mask,
|
82 |
+
):
|
83 |
+
if self.noise_mode == "mask":
|
84 |
+
z_true = z_true[:, self.n_conditioning_codebooks :, :]
|
85 |
+
mask = mask[:, self.n_conditioning_codebooks :, :]
|
86 |
+
|
87 |
+
truth = F.one_hot(z_true, self.vocab_size)
|
88 |
+
print(truth.shape)
|
89 |
+
# truth = rearrange(truth, "b c t p -> b p (t c)")
|
90 |
+
mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
|
91 |
+
z_hat = rearrange(
|
92 |
+
z_hat,
|
93 |
+
"b p (t c) -> b c t p",
|
94 |
+
c=self.n_codebooks - self.n_conditioning_codebooks,
|
95 |
+
)
|
96 |
+
|
97 |
+
z_hat = z_hat * mask + truth * (1 - mask)
|
98 |
+
|
99 |
+
z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
|
100 |
+
|
101 |
+
return z_hat
|
102 |
+
|
103 |
+
def gamma(self, r):
|
104 |
+
return (r * torch.pi / 2).cos()
|
105 |
+
|
106 |
+
def r_embed(self, r, max_positions=10000):
|
107 |
+
""" """
|
108 |
+
assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
|
109 |
+
|
110 |
+
if self.r_cond_dim > 0:
|
111 |
+
dtype = r.dtype
|
112 |
+
|
113 |
+
r = self.gamma(r) * max_positions
|
114 |
+
half_dim = self.r_cond_dim // 2
|
115 |
+
|
116 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
117 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
118 |
+
|
119 |
+
emb = r[:, None] * emb[None, :]
|
120 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
121 |
+
|
122 |
+
if self.r_cond_dim % 2 == 1: # zero pad
|
123 |
+
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
124 |
+
|
125 |
+
return emb.to(dtype)
|
126 |
+
else:
|
127 |
+
return r
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def to_signal(self, z, vqvae):
|
131 |
+
if z.ndim == 2:
|
132 |
+
z = self.embedding.unflatten(z)
|
133 |
+
assert z.ndim == 3
|
134 |
+
|
135 |
+
signal = at.AudioSignal(
|
136 |
+
vqvae.decode(
|
137 |
+
vqvae.quantizer.from_latents(self.embedding.from_codes(z, vqvae))[0]
|
138 |
+
)["audio"],
|
139 |
+
vqvae.sample_rate,
|
140 |
+
)
|
141 |
+
|
142 |
+
return signal
|
143 |
+
|
144 |
+
@torch.no_grad()
|
145 |
+
def sample(self, **kwargs):
|
146 |
+
if self.noise_mode == "mask":
|
147 |
+
return self.maskgit_sample(**kwargs)
|
148 |
+
else:
|
149 |
+
return self.paella_sample(**kwargs)
|
150 |
+
|
151 |
+
def paella_sample(
|
152 |
+
self,
|
153 |
+
vqvae,
|
154 |
+
time_steps: int = 400,
|
155 |
+
sampling_steps: int = 12,
|
156 |
+
start_tokens: Optional[torch.Tensor] = None,
|
157 |
+
mask: Optional[torch.Tensor] = None,
|
158 |
+
device: str = "cpu",
|
159 |
+
temperature: Union[float, Tuple[float, float]] = 1.0,
|
160 |
+
top_k: int = None,
|
161 |
+
sample: str = "gumbel",
|
162 |
+
renoise_mode: str = "start",
|
163 |
+
renoise_steps=None,
|
164 |
+
typical_filtering=True,
|
165 |
+
typical_mass=0.2,
|
166 |
+
typical_min_tokens=1,
|
167 |
+
return_signal=True,
|
168 |
+
):
|
169 |
+
r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(device)
|
170 |
+
if renoise_steps == None:
|
171 |
+
renoise_steps = sampling_steps - 1
|
172 |
+
|
173 |
+
if isinstance(temperature, float):
|
174 |
+
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
175 |
+
elif isinstance(temperature, tuple):
|
176 |
+
assert len(temperature) == 2
|
177 |
+
l, h = temperature
|
178 |
+
temperature = torch.linspace(l, h, sampling_steps)
|
179 |
+
else:
|
180 |
+
raise TypeError(f"invalid type for temperature")
|
181 |
+
|
182 |
+
if self.n_conditioning_codebooks > 0:
|
183 |
+
assert (
|
184 |
+
start_tokens is not None
|
185 |
+
), "must provide start_tokens if n_conditioning_codebooks > 0"
|
186 |
+
|
187 |
+
if start_tokens is None:
|
188 |
+
if self.noise_mode == "noise":
|
189 |
+
z = torch.randint(
|
190 |
+
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
191 |
+
).to(device)
|
192 |
+
elif self.noise_mode == "mask":
|
193 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
194 |
+
else:
|
195 |
+
z = start_tokens
|
196 |
+
assert (
|
197 |
+
z.ndim == 3
|
198 |
+
), f"start_tokens must be shape (batch, n_codebooks, seq_len), got {z.shape}"
|
199 |
+
assert z.shape[0] == 1, f"batch size must be 1"
|
200 |
+
|
201 |
+
if mask is None:
|
202 |
+
mask = torch.ones(z.shape[0], z.shape[-1]).to(device).int()
|
203 |
+
|
204 |
+
# apply mask
|
205 |
+
assert mask.shape == (
|
206 |
+
z.shape[0],
|
207 |
+
z.shape[-1],
|
208 |
+
), f"mask must be shape (batch, seq_len), got {mask.shape}"
|
209 |
+
mask = mask[:, None, :]
|
210 |
+
mask = mask.repeat(1, z.shape[1], 1)
|
211 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
212 |
+
|
213 |
+
if self.noise_mode == "mask":
|
214 |
+
z_true = z.clone()
|
215 |
+
|
216 |
+
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
217 |
+
z_init = z.clone()
|
218 |
+
for i, tmpt in enumerate(temperature):
|
219 |
+
if renoise_mode == "prev":
|
220 |
+
z_prev = z.clone()
|
221 |
+
|
222 |
+
latents = self.embedding.from_codes(z, vqvae)
|
223 |
+
logits = self.forward(latents, r[i])
|
224 |
+
|
225 |
+
# for mask mode
|
226 |
+
logits = self.add_truth_to_logits(z_true, logits, mask)
|
227 |
+
|
228 |
+
# Apply topk sampling
|
229 |
+
logits = logits.permute(0, 2, 1)
|
230 |
+
|
231 |
+
z = self.sample_from_logits(
|
232 |
+
logits,
|
233 |
+
tmpt,
|
234 |
+
top_k,
|
235 |
+
sample=sample,
|
236 |
+
typical_filtering=typical_filtering,
|
237 |
+
typical_mass=typical_mass,
|
238 |
+
typical_min_tokens=typical_min_tokens,
|
239 |
+
)
|
240 |
+
|
241 |
+
# add back in conditioning codebooks
|
242 |
+
z = self.embedding.unflatten(z, n_codebooks=self.n_predict_codebooks)
|
243 |
+
z = torch.cat(
|
244 |
+
[z_init[:, : self.n_conditioning_codebooks, :], z], dim=1
|
245 |
+
).int()
|
246 |
+
|
247 |
+
if i < renoise_steps:
|
248 |
+
if renoise_mode == "prev":
|
249 |
+
z, _ = self.add_noise(z, r[i + 1], random_x=z_prev)
|
250 |
+
elif renoise_mode == "start":
|
251 |
+
z, _ = self.add_noise(z, r[i + 1], random_x=z_init)
|
252 |
+
elif renoise_mode == "rand":
|
253 |
+
z, _ = self.add_noise(z, r[i + 1])
|
254 |
+
else:
|
255 |
+
raise ValueError(f"Invalid renoise_mode: {renoise_mode}")
|
256 |
+
|
257 |
+
if mask is not None:
|
258 |
+
z = start_tokens * (1 - mask) + z * mask
|
259 |
+
|
260 |
+
if return_signal:
|
261 |
+
return self.to_signal(z, vqvae)
|
262 |
+
else:
|
263 |
+
return z
|
264 |
+
|
265 |
+
def maskgit_sample(
|
266 |
+
self,
|
267 |
+
vqvae,
|
268 |
+
time_steps: int = 300,
|
269 |
+
sampling_steps: int = 24,
|
270 |
+
start_tokens: Optional[torch.Tensor] = None,
|
271 |
+
mask: Optional[torch.Tensor] = None,
|
272 |
+
temperature: Union[float, Tuple[float, float]] = 1.0,
|
273 |
+
top_k: int = None,
|
274 |
+
sample: str = "multinomial",
|
275 |
+
typical_filtering=False,
|
276 |
+
typical_mass=0.2,
|
277 |
+
typical_min_tokens=1,
|
278 |
+
return_signal=True,
|
279 |
+
):
|
280 |
+
if isinstance(temperature, float):
|
281 |
+
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
282 |
+
elif isinstance(temperature, tuple):
|
283 |
+
assert len(temperature) == 2
|
284 |
+
l, h = temperature
|
285 |
+
temperature = torch.linspace(l, h, sampling_steps)
|
286 |
+
else:
|
287 |
+
raise TypeError(f"invalid type for temperature")
|
288 |
+
|
289 |
+
def flatten(codes):
|
290 |
+
return rearrange(codes, "b c t -> b (t c)")
|
291 |
+
|
292 |
+
def unflatten(codes, c):
|
293 |
+
return rearrange(codes, "b (t c) -> b c t", c=c)
|
294 |
+
|
295 |
+
z = start_tokens
|
296 |
+
|
297 |
+
if z is None:
|
298 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
299 |
+
self.device
|
300 |
+
)
|
301 |
+
|
302 |
+
if mask is None:
|
303 |
+
mask = torch.ones_like(z).to(self.device).int()
|
304 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
305 |
+
if mask.ndim == 2:
|
306 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
307 |
+
|
308 |
+
# figure out which timesteps we're keeping
|
309 |
+
keep_mask = 1 - mask
|
310 |
+
|
311 |
+
# any conditioning codebook levels need to be in the keep mask
|
312 |
+
# if self.n_conditioning_codebooks > 0:
|
313 |
+
# cond_mask = torch.ones(z.shape[0], self.n_conditioning_codebooks, z.shape[-1]).to(z.device)
|
314 |
+
# keep_mask = torch.cat([cond_mask, keep_mask], dim=1)
|
315 |
+
|
316 |
+
# flatten
|
317 |
+
keep_mask = flatten(keep_mask)
|
318 |
+
|
319 |
+
# our r steps
|
320 |
+
r_steps = torch.linspace(0, 1, sampling_steps + 1)[1:].to(self.device)
|
321 |
+
|
322 |
+
# how many tokens did we keep on init?
|
323 |
+
num_kept_on_init = keep_mask.sum()
|
324 |
+
|
325 |
+
# how many codebooks are we inferring vs conditioning on?
|
326 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
327 |
+
|
328 |
+
for i in tqdm(range(sampling_steps)):
|
329 |
+
# our current temperature
|
330 |
+
tmpt = temperature[i]
|
331 |
+
|
332 |
+
# our current schedule step
|
333 |
+
r = r_steps[i : i + 1]
|
334 |
+
|
335 |
+
with torch.inference_mode():
|
336 |
+
# mask our z
|
337 |
+
keep_mask_unflat = unflatten(keep_mask, c=self.n_codebooks)
|
338 |
+
z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
|
339 |
+
|
340 |
+
# get latents
|
341 |
+
latents = self.embedding.from_codes(z_masked, vqvae)
|
342 |
+
|
343 |
+
# infer from latents
|
344 |
+
logits = self.forward(latents, r)
|
345 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
346 |
+
|
347 |
+
# the schedule determines how many samples to keep
|
348 |
+
num_tokens_to_infer = (z.shape[-1] * z.shape[-2]) - num_kept_on_init
|
349 |
+
num_to_keep = num_kept_on_init + int(
|
350 |
+
num_tokens_to_infer * (self.gamma(1 - r))
|
351 |
+
)
|
352 |
+
|
353 |
+
# figure out which logits we wanna keep
|
354 |
+
if num_to_keep > 0:
|
355 |
+
probs = logits.softmax(dim=-1)
|
356 |
+
|
357 |
+
keep_probs = F.one_hot(z, self.vocab_size)[:, :, :]
|
358 |
+
|
359 |
+
probs = rearrange(
|
360 |
+
probs, "b (t c) p -> b c t p", c=n_infer_codebooks
|
361 |
+
)
|
362 |
+
probs = torch.cat(
|
363 |
+
[keep_probs[:, : self.n_conditioning_codebooks, ...], probs],
|
364 |
+
dim=1,
|
365 |
+
)
|
366 |
+
|
367 |
+
keep_probs = rearrange(
|
368 |
+
keep_probs, "b c t p -> b (t c) p", c=self.n_codebooks
|
369 |
+
)
|
370 |
+
probs = rearrange(probs, "b c t p -> b (t c) p", c=self.n_codebooks)
|
371 |
+
|
372 |
+
keep_prob_mask = keep_mask.unsqueeze(-1).repeat(
|
373 |
+
1, 1, self.vocab_size
|
374 |
+
)
|
375 |
+
probs = (keep_prob_mask.long() * keep_probs) + (
|
376 |
+
1 - keep_prob_mask.long()
|
377 |
+
) * probs
|
378 |
+
|
379 |
+
highest_probs = probs.max(dim=-1, keepdim=False)[0]
|
380 |
+
v, _ = highest_probs.topk(num_to_keep, dim=-1)
|
381 |
+
|
382 |
+
keep_mask = torch.ones_like(keep_mask).bool().clone()
|
383 |
+
keep_mask[highest_probs < v[..., [-1]]] = 0
|
384 |
+
|
385 |
+
logits = torch.log(probs)
|
386 |
+
|
387 |
+
z_inferred = self.sample_from_logits(
|
388 |
+
logits=logits,
|
389 |
+
top_k=top_k,
|
390 |
+
temperature=tmpt,
|
391 |
+
sample=sample,
|
392 |
+
typical_filtering=typical_filtering,
|
393 |
+
typical_mass=typical_mass,
|
394 |
+
typical_min_tokens=typical_min_tokens,
|
395 |
+
)
|
396 |
+
|
397 |
+
z = rearrange(z_inferred, "b (t c) -> b c t", c=self.n_codebooks)
|
398 |
+
|
399 |
+
# add conditioning codebooks back
|
400 |
+
# z = torch.cat([z[:, :self.n_conditioning_codebooks, :], z_inferred], dim=1)
|
401 |
+
|
402 |
+
if return_signal:
|
403 |
+
return self.to_signal(z, vqvae)
|
404 |
+
else:
|
405 |
+
return z
|
406 |
+
|
407 |
+
def sample_from_logits(
|
408 |
+
self,
|
409 |
+
logits,
|
410 |
+
top_k: int = None,
|
411 |
+
temperature: float = 1.0,
|
412 |
+
sample: str = "multinomial",
|
413 |
+
typical_filtering=False,
|
414 |
+
typical_mass=0.2,
|
415 |
+
typical_min_tokens=1,
|
416 |
+
):
|
417 |
+
# add temperature
|
418 |
+
logits = logits / temperature
|
419 |
+
|
420 |
+
# add topk
|
421 |
+
if top_k is not None:
|
422 |
+
v, topk_idx = logits.topk(top_k)
|
423 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
424 |
+
|
425 |
+
if typical_filtering:
|
426 |
+
assert top_k is None
|
427 |
+
nb, nt, _ = logits.shape
|
428 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
429 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
430 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
431 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
432 |
+
|
433 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
434 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
435 |
+
x_flat_cumsum = (
|
436 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
437 |
+
)
|
438 |
+
|
439 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
440 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
441 |
+
1, last_ind.view(-1, 1)
|
442 |
+
)
|
443 |
+
if typical_min_tokens > 1:
|
444 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
445 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
446 |
+
1, x_flat_indices, sorted_indices_to_remove
|
447 |
+
)
|
448 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
449 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
450 |
+
|
451 |
+
if sample == "multinomial":
|
452 |
+
probs = torch.softmax(logits, dim=-1)
|
453 |
+
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
454 |
+
elif sample == "argmax":
|
455 |
+
inferred = torch.softmax(probs, dim=-1).argmax(dim=-1)
|
456 |
+
elif sample == "gumbel":
|
457 |
+
inferred = gumbel_sample(logits, dim=-1)
|
458 |
+
else:
|
459 |
+
raise ValueError(f"invalid sampling method: {sample}")
|
460 |
+
|
461 |
+
return inferred
|
vampnet/modules/modules.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Optional
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
def num_params(model):
|
13 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
14 |
+
|
15 |
+
|
16 |
+
def recurse_children(module, fn):
|
17 |
+
for child in module.children():
|
18 |
+
if isinstance(child, nn.ModuleList):
|
19 |
+
for c in child:
|
20 |
+
yield recurse_children(c, fn)
|
21 |
+
if isinstance(child, nn.ModuleDict):
|
22 |
+
for c in child.values():
|
23 |
+
yield recurse_children(c, fn)
|
24 |
+
|
25 |
+
yield recurse_children(child, fn)
|
26 |
+
yield fn(child)
|
27 |
+
|
28 |
+
|
29 |
+
# Scripting this brings model speed up 1.4x
|
30 |
+
@torch.jit.script
|
31 |
+
def snake(x, alpha):
|
32 |
+
shape = x.shape
|
33 |
+
x = x.reshape(shape[0], shape[1], -1)
|
34 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
35 |
+
x = x.reshape(shape)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class Snake1d(nn.Module):
|
40 |
+
def __init__(self, channels):
|
41 |
+
super().__init__()
|
42 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return snake(x, self.alpha)
|
46 |
+
|
47 |
+
|
48 |
+
def WNConv1d(*args, **kwargs):
|
49 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
50 |
+
|
51 |
+
|
52 |
+
def WNConvTranspose1d(*args, **kwargs):
|
53 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
54 |
+
|
55 |
+
|
56 |
+
class SequentialWithFiLM(nn.Module):
|
57 |
+
"""
|
58 |
+
handy wrapper for nn.Sequential that allows FiLM layers to be
|
59 |
+
inserted in between other layers.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, *layers):
|
63 |
+
super().__init__()
|
64 |
+
self.layers = nn.ModuleList(layers)
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def has_film(module):
|
68 |
+
mod_has_film = any(
|
69 |
+
[res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
|
70 |
+
)
|
71 |
+
return mod_has_film
|
72 |
+
|
73 |
+
def forward(self, x, cond):
|
74 |
+
for layer in self.layers:
|
75 |
+
if self.has_film(layer):
|
76 |
+
x = layer(x, cond)
|
77 |
+
else:
|
78 |
+
x = layer(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class FiLM(nn.Module):
|
83 |
+
def __init__(self, input_dim: int, output_dim: int):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
self.input_dim = input_dim
|
87 |
+
self.output_dim = output_dim
|
88 |
+
|
89 |
+
if input_dim > 0:
|
90 |
+
self.beta = nn.Linear(input_dim, output_dim)
|
91 |
+
self.gamma = nn.Linear(input_dim, output_dim)
|
92 |
+
|
93 |
+
def forward(self, x, r):
|
94 |
+
if self.input_dim == 0:
|
95 |
+
return x
|
96 |
+
else:
|
97 |
+
beta, gamma = self.beta(r), self.gamma(r)
|
98 |
+
beta, gamma = (
|
99 |
+
beta.view(x.size(0), self.output_dim, 1),
|
100 |
+
gamma.view(x.size(0), self.output_dim, 1),
|
101 |
+
)
|
102 |
+
x = x * (gamma + 1) + beta
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class CodebookEmbedding(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
vocab_size: int,
|
110 |
+
latent_dim: int,
|
111 |
+
n_codebooks: int,
|
112 |
+
emb_dim: int,
|
113 |
+
special_tokens: Optional[Tuple[str]] = None,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.n_codebooks = n_codebooks
|
117 |
+
self.emb_dim = emb_dim
|
118 |
+
self.latent_dim = latent_dim
|
119 |
+
self.vocab_size = vocab_size
|
120 |
+
|
121 |
+
if special_tokens is not None:
|
122 |
+
for tkn in special_tokens:
|
123 |
+
self.special = nn.ParameterDict(
|
124 |
+
{
|
125 |
+
tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
|
126 |
+
for tkn in special_tokens
|
127 |
+
}
|
128 |
+
)
|
129 |
+
self.special_idxs = {
|
130 |
+
tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
|
131 |
+
}
|
132 |
+
|
133 |
+
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
134 |
+
|
135 |
+
def from_codes(self, codes: torch.Tensor, vqvae):
|
136 |
+
n_codebooks = codes.shape[1]
|
137 |
+
latent = []
|
138 |
+
for i in range(n_codebooks):
|
139 |
+
c = codes[:, i, :]
|
140 |
+
|
141 |
+
lookup_table = vqvae.quantizer.quantizers[i].codebook.weight
|
142 |
+
if hasattr(self, "special"):
|
143 |
+
special_lookup = torch.cat(
|
144 |
+
[self.special[tkn][i : i + 1] for tkn in self.special], dim=0
|
145 |
+
)
|
146 |
+
lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
|
147 |
+
|
148 |
+
l = F.embedding(c, lookup_table).transpose(1, 2)
|
149 |
+
latent.append(l)
|
150 |
+
|
151 |
+
latent = torch.cat(latent, dim=1)
|
152 |
+
return latent
|
153 |
+
|
154 |
+
def forward(self, latents: torch.Tensor):
|
155 |
+
x = self.out_proj(latents)
|
156 |
+
return x
|
157 |
+
|
158 |
+
def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
|
159 |
+
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
160 |
+
return rearrange(tokens, "b c t -> b (t c)", c=n_c)
|
161 |
+
|
162 |
+
def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
|
163 |
+
nb, nt = flat_tokens.shape
|
164 |
+
|
165 |
+
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
166 |
+
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
|
167 |
+
|
168 |
+
return tokens
|
vampnet/modules/transformer.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
from .base import VampBase
|
10 |
+
from .modules import CodebookEmbedding
|
11 |
+
from .modules import FiLM
|
12 |
+
from .modules import SequentialWithFiLM
|
13 |
+
from .modules import WNConv1d
|
14 |
+
|
15 |
+
|
16 |
+
class RMSNorm(nn.Module):
|
17 |
+
def __init__(self, hidden_size: int, eps=1e-6):
|
18 |
+
super().__init__()
|
19 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
20 |
+
self.var_eps = eps
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
"""Returns root mean square normalized version of input `x`
|
24 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known
|
25 |
+
# as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
|
26 |
+
# thus varience is calculated w/o mean and there is no bias
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
x : Tensor[B x T x D]
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
Tensor[B x T x D]
|
33 |
+
"""
|
34 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
35 |
+
x = x * torch.rsqrt(var + self.var_eps)
|
36 |
+
|
37 |
+
return self.weight * x
|
38 |
+
|
39 |
+
|
40 |
+
def get_activation(name: str = "relu"):
|
41 |
+
if name == "relu":
|
42 |
+
return nn.ReLU
|
43 |
+
elif name == "gelu":
|
44 |
+
return NewGELU
|
45 |
+
elif name == "geglu":
|
46 |
+
return GatedGELU
|
47 |
+
elif name == "snake":
|
48 |
+
return Snake1d
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Unrecognized activation {name}")
|
51 |
+
|
52 |
+
|
53 |
+
class NewGELU(nn.Module):
|
54 |
+
"""
|
55 |
+
Implementation of the GELU activation function currently in Google BERT repo
|
56 |
+
(identical to OpenAI GPT). Also see the Gaussian Error Linear Units
|
57 |
+
paper: https://arxiv.org/abs/1606.08415
|
58 |
+
"""
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
return (
|
62 |
+
0.5
|
63 |
+
* x
|
64 |
+
* (
|
65 |
+
1.0
|
66 |
+
+ torch.tanh(
|
67 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
68 |
+
)
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
class GatedGELU(nn.Module):
|
74 |
+
def __init__(self):
|
75 |
+
super().__init__()
|
76 |
+
self.gelu = NewGELU()
|
77 |
+
|
78 |
+
def forward(self, x, dim: int = -1):
|
79 |
+
p1, p2 = x.chunk(2, dim=dim)
|
80 |
+
return p1 * self.gelu(p2)
|
81 |
+
|
82 |
+
|
83 |
+
class Snake1d(nn.Module):
|
84 |
+
def __init__(self, channels):
|
85 |
+
super().__init__()
|
86 |
+
self.alpha = nn.Parameter(torch.ones(channels))
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
|
90 |
+
|
91 |
+
|
92 |
+
class FeedForward(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
factor = 2 if activation == "geglu" else 1
|
98 |
+
self.w_1 = nn.Linear(d_model, d_model * 4, bias=False)
|
99 |
+
self.w_2 = nn.Linear(d_model * 4 // factor, d_model, bias=False)
|
100 |
+
self.drop = nn.Dropout(dropout)
|
101 |
+
self.act = get_activation(activation)()
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
"""Computes position-wise feed-forward layer
|
105 |
+
Parameters
|
106 |
+
----------
|
107 |
+
x : Tensor[B x T x D]
|
108 |
+
Returns
|
109 |
+
-------
|
110 |
+
Tensor[B x T x D]
|
111 |
+
"""
|
112 |
+
x = self.w_1(x)
|
113 |
+
x = self.act(x)
|
114 |
+
x = self.drop(x)
|
115 |
+
x = self.w_2(x)
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class MultiHeadRelativeAttention(nn.Module):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
n_head: int = 8,
|
123 |
+
d_model: int = 512,
|
124 |
+
dropout: float = 0.1,
|
125 |
+
bidirectional: bool = True,
|
126 |
+
has_relative_attention_bias: bool = True,
|
127 |
+
attention_num_buckets: int = 32,
|
128 |
+
attention_max_distance: int = 128,
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
d_head = d_model // n_head
|
132 |
+
self.n_head = n_head
|
133 |
+
self.d_head = d_head
|
134 |
+
self.bidirectional = bidirectional
|
135 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
136 |
+
self.attention_num_buckets = attention_num_buckets
|
137 |
+
self.attention_max_distance = attention_max_distance
|
138 |
+
|
139 |
+
# Create linear query, key, value projections
|
140 |
+
self.w_qs = nn.Linear(d_model, d_model, bias=False)
|
141 |
+
self.w_ks = nn.Linear(d_model, d_model, bias=False)
|
142 |
+
self.w_vs = nn.Linear(d_model, d_model, bias=False)
|
143 |
+
|
144 |
+
# Create linear final output projection
|
145 |
+
self.fc = nn.Linear(d_model, d_model, bias=False)
|
146 |
+
|
147 |
+
# Dropout for attention output weights
|
148 |
+
self.dropout = nn.Dropout(dropout)
|
149 |
+
|
150 |
+
# Create relative positional embeddings (if turned on)
|
151 |
+
if has_relative_attention_bias:
|
152 |
+
self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
|
153 |
+
|
154 |
+
def _relative_position_bucket(self, relative_position):
|
155 |
+
"""Converts unbounded relative position into bounded set of buckets
|
156 |
+
with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
|
157 |
+
buckets
|
158 |
+
Parameters
|
159 |
+
----------
|
160 |
+
relative_position : Tensor[T_q x T_kv]
|
161 |
+
Relative positions between queries and key_value items
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
Tensor[T_q x T_kv]
|
165 |
+
Input relative positions converted into buckets
|
166 |
+
"""
|
167 |
+
relative_buckets = 0
|
168 |
+
num_buckets = self.attention_num_buckets
|
169 |
+
max_distance = self.attention_max_distance
|
170 |
+
|
171 |
+
# Convert relative position for (-inf, inf) to [0, inf]
|
172 |
+
# Negative relative positions correspond to past
|
173 |
+
# Positive relative positions correspond to future
|
174 |
+
if self.bidirectional:
|
175 |
+
# use half buckets for each side (past / future)
|
176 |
+
num_buckets //= 2
|
177 |
+
|
178 |
+
# Shift the position positions by `num_buckets` to wrap around
|
179 |
+
# negative positions
|
180 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
181 |
+
relative_position = torch.abs(relative_position)
|
182 |
+
else:
|
183 |
+
# If not bidirectional, ignore positive positions and wrap
|
184 |
+
# negative positions to positive
|
185 |
+
relative_position = -torch.min(
|
186 |
+
relative_position, torch.zeros_like(relative_position)
|
187 |
+
)
|
188 |
+
|
189 |
+
# Allocate half of the buckets are for exact increments in positions
|
190 |
+
max_exact = num_buckets // 2
|
191 |
+
is_small = relative_position < max_exact
|
192 |
+
|
193 |
+
# The other half of the buckets are for logarithmically bigger bins in
|
194 |
+
# positions up to `max_distance`
|
195 |
+
relative_postion_if_large = max_exact + (
|
196 |
+
torch.log(relative_position.float() / max_exact)
|
197 |
+
/ math.log(max_distance / max_exact)
|
198 |
+
* (num_buckets - max_exact)
|
199 |
+
).to(torch.long)
|
200 |
+
|
201 |
+
# Clip the max relative position to `num_buckets - 1`
|
202 |
+
relative_postion_if_large = torch.min(
|
203 |
+
relative_postion_if_large,
|
204 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
205 |
+
)
|
206 |
+
|
207 |
+
# Choose relative buckets based on small or large positions
|
208 |
+
relative_buckets += torch.where(
|
209 |
+
is_small, relative_position, relative_postion_if_large
|
210 |
+
)
|
211 |
+
|
212 |
+
return relative_buckets
|
213 |
+
|
214 |
+
def compute_bias(self, query_length, key_length):
|
215 |
+
"""Computes a position bias scalar for each index in query_length x key_length
|
216 |
+
Parameters
|
217 |
+
----------
|
218 |
+
query_length : int
|
219 |
+
key_length : int
|
220 |
+
Returns
|
221 |
+
-------
|
222 |
+
Tensor[heads x 1 x T_q x T_kv]
|
223 |
+
Position bias to be applied on attention logits
|
224 |
+
"""
|
225 |
+
|
226 |
+
query_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
227 |
+
key_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
228 |
+
relative_position = key_position - query_position
|
229 |
+
|
230 |
+
# Convert relative position to buckets
|
231 |
+
relative_position_bucket = self._relative_position_bucket(relative_position)
|
232 |
+
relative_position_bucket = relative_position_bucket.to(
|
233 |
+
self.relative_attention_bias.weight.device
|
234 |
+
)
|
235 |
+
|
236 |
+
# Index attention bias values
|
237 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
238 |
+
values = rearrange(values, "q k h -> h 1 q k")
|
239 |
+
|
240 |
+
return values
|
241 |
+
|
242 |
+
def forward(self, q, k, v, mask=None, position_bias=None):
|
243 |
+
"""Computes attention over (keys, values) for every timestep in query
|
244 |
+
Parameters
|
245 |
+
----------
|
246 |
+
q : Tensor[B x T_q x d_model]
|
247 |
+
Query vectors
|
248 |
+
k : Tensor[B x T_kv x d_model]
|
249 |
+
Key vectors to compute attention over
|
250 |
+
v : Tensor[B x T_kv x d_model]
|
251 |
+
Value vectors corresponding to the keys
|
252 |
+
mask : Tensor[B x T_q x T_kv], optional
|
253 |
+
position_bias: Tensor[head x 1 x T_q x T_kv]
|
254 |
+
Returns
|
255 |
+
-------
|
256 |
+
Tensor[B x T_q x d_model]
|
257 |
+
Outputs after attending (key, value) using queries
|
258 |
+
"""
|
259 |
+
# Compute query, key, value projections
|
260 |
+
q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
|
261 |
+
k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
|
262 |
+
v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
|
263 |
+
|
264 |
+
# Compute attention matrix
|
265 |
+
attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
|
266 |
+
|
267 |
+
# Add relative position bias to attention scores
|
268 |
+
if position_bias is None:
|
269 |
+
if self.has_relative_attention_bias:
|
270 |
+
position_bias = self.compute_bias(q.size(-2), k.size(-2))
|
271 |
+
else:
|
272 |
+
position_bias = torch.zeros_like(attn)
|
273 |
+
attn += position_bias
|
274 |
+
|
275 |
+
# Apply mask to attention scores to prevent looking up invalid locations
|
276 |
+
if mask is not None:
|
277 |
+
attn = attn.masked_fill(mask[None] == 0, -1e9)
|
278 |
+
|
279 |
+
# Normalize attention scores and add dropout
|
280 |
+
attn = torch.softmax(attn, dim=3)
|
281 |
+
attn = self.dropout(attn)
|
282 |
+
|
283 |
+
# Compute attended outputs (product of attention matrix and values)
|
284 |
+
output = torch.einsum("hblt,hbtv->hblv", [attn, v])
|
285 |
+
output = rearrange(output, "head b l v -> b l (head v)")
|
286 |
+
output = self.fc(output)
|
287 |
+
|
288 |
+
return output, position_bias
|
289 |
+
|
290 |
+
|
291 |
+
class TransformerLayer(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
d_model: int = 512,
|
295 |
+
d_cond: int = 64,
|
296 |
+
n_heads: int = 8,
|
297 |
+
bidirectional: bool = True,
|
298 |
+
is_decoder: bool = False,
|
299 |
+
has_relative_attention_bias: bool = False,
|
300 |
+
flash_attn: bool = False,
|
301 |
+
dropout: float = 0.1,
|
302 |
+
):
|
303 |
+
super().__init__()
|
304 |
+
# Store args
|
305 |
+
self.is_decoder = is_decoder
|
306 |
+
|
307 |
+
# Create self-attention layer
|
308 |
+
self.norm_1 = RMSNorm(d_model)
|
309 |
+
self.film_1 = FiLM(d_cond, d_model)
|
310 |
+
self.flash_attn = flash_attn
|
311 |
+
|
312 |
+
if flash_attn:
|
313 |
+
from flash_attn.flash_attention import FlashMHA
|
314 |
+
self.self_attn = FlashMHA(
|
315 |
+
embed_dim=d_model,
|
316 |
+
num_heads=n_heads,
|
317 |
+
attention_dropout=dropout,
|
318 |
+
causal=False,
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
self.self_attn = MultiHeadRelativeAttention(
|
322 |
+
n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
|
323 |
+
)
|
324 |
+
|
325 |
+
# (Optional) Create cross-attention layer
|
326 |
+
if is_decoder:
|
327 |
+
self.norm_2 = RMSNorm(d_model)
|
328 |
+
self.film_2 = FiLM(d_cond, d_model)
|
329 |
+
self.cross_attn = MultiHeadRelativeAttention(
|
330 |
+
n_heads,
|
331 |
+
d_model,
|
332 |
+
dropout,
|
333 |
+
bidirectional=True,
|
334 |
+
has_relative_attention_bias=False,
|
335 |
+
)
|
336 |
+
|
337 |
+
# Create last feed-forward layer
|
338 |
+
self.norm_3 = RMSNorm(d_model)
|
339 |
+
self.film_3 = FiLM(d_cond, d_model)
|
340 |
+
self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
|
341 |
+
|
342 |
+
# Create dropout
|
343 |
+
self.dropout = nn.Dropout(dropout)
|
344 |
+
|
345 |
+
def forward(
|
346 |
+
self,
|
347 |
+
x,
|
348 |
+
x_mask,
|
349 |
+
cond,
|
350 |
+
src=None,
|
351 |
+
src_mask=None,
|
352 |
+
position_bias=None,
|
353 |
+
encoder_decoder_position_bias=None,
|
354 |
+
):
|
355 |
+
"""Computes one transformer layer consisting of self attention, (op) cross attention
|
356 |
+
and feedforward layer
|
357 |
+
Parameters
|
358 |
+
----------
|
359 |
+
x : Tensor[B x T_q x D]
|
360 |
+
x_mask : Tensor[B x T_q]
|
361 |
+
src : Tensor[B x T_kv x D], optional
|
362 |
+
src_mask : Tensor[B x T_kv x D], optional
|
363 |
+
position_bias : Tensor[heads x B x T_q x T_q], optional
|
364 |
+
Relative position bias for self attention layer
|
365 |
+
encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
|
366 |
+
Relative position bias for cross attention layer
|
367 |
+
Returns
|
368 |
+
-------
|
369 |
+
Tensor[B x T_q x D]
|
370 |
+
"""
|
371 |
+
y = self.norm_1(x)
|
372 |
+
y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
373 |
+
if self.flash_attn:
|
374 |
+
with torch.autocast(y.device.type, dtype=torch.bfloat16):
|
375 |
+
y = self.self_attn(y)[0]
|
376 |
+
else:
|
377 |
+
y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
|
378 |
+
x = x + self.dropout(y)
|
379 |
+
|
380 |
+
if self.is_decoder:
|
381 |
+
y = self.norm_2(x)
|
382 |
+
y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
383 |
+
y, encoder_decoder_position_bias = self.cross_attn(
|
384 |
+
y, src, src, src_mask, encoder_decoder_position_bias
|
385 |
+
)
|
386 |
+
x = x + self.dropout(y)
|
387 |
+
|
388 |
+
y = self.norm_3(x)
|
389 |
+
y = self.film_3(
|
390 |
+
y.permute(
|
391 |
+
0,
|
392 |
+
2,
|
393 |
+
1,
|
394 |
+
),
|
395 |
+
cond,
|
396 |
+
).permute(0, 2, 1)
|
397 |
+
y = self.feed_forward(y)
|
398 |
+
x = x + self.dropout(y)
|
399 |
+
|
400 |
+
return x, position_bias, encoder_decoder_position_bias
|
401 |
+
|
402 |
+
|
403 |
+
class TransformerStack(nn.Module):
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
d_model: int = 512,
|
407 |
+
d_cond: int = 64,
|
408 |
+
n_heads: int = 8,
|
409 |
+
n_layers: int = 8,
|
410 |
+
last_layer: bool = True,
|
411 |
+
bidirectional: bool = True,
|
412 |
+
flash_attn: bool = False,
|
413 |
+
is_decoder: bool = False,
|
414 |
+
dropout: float = 0.1,
|
415 |
+
):
|
416 |
+
super().__init__()
|
417 |
+
# Store args
|
418 |
+
self.bidirectional = bidirectional
|
419 |
+
self.is_decoder = is_decoder
|
420 |
+
|
421 |
+
# Create transformer layers
|
422 |
+
# In T5, relative attention bias is shared by all layers in the stack
|
423 |
+
self.layers = nn.ModuleList(
|
424 |
+
[
|
425 |
+
TransformerLayer(
|
426 |
+
d_model,
|
427 |
+
d_cond,
|
428 |
+
n_heads,
|
429 |
+
bidirectional,
|
430 |
+
is_decoder,
|
431 |
+
has_relative_attention_bias=(i == 0),
|
432 |
+
flash_attn=flash_attn,
|
433 |
+
dropout=dropout,
|
434 |
+
)
|
435 |
+
for i in range(n_layers)
|
436 |
+
]
|
437 |
+
)
|
438 |
+
|
439 |
+
# Perform last normalization
|
440 |
+
self.norm = RMSNorm(d_model) if last_layer else None
|
441 |
+
|
442 |
+
def subsequent_mask(self, size):
|
443 |
+
return torch.ones(1, size, size).tril().bool()
|
444 |
+
|
445 |
+
def forward(self, x, x_mask, cond=None, src=None, src_mask=None):
|
446 |
+
"""Computes a full transformer stack
|
447 |
+
Parameters
|
448 |
+
----------
|
449 |
+
x : Tensor[B x T_q x D]
|
450 |
+
x_mask : Tensor[B x T_q]
|
451 |
+
src : Tensor[B x T_kv x D], optional
|
452 |
+
src_mask : Tensor[B x T_kv], optional
|
453 |
+
Returns
|
454 |
+
-------
|
455 |
+
Tensor[B x T_q x D]
|
456 |
+
"""
|
457 |
+
|
458 |
+
# Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
|
459 |
+
if self.is_decoder:
|
460 |
+
src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
|
461 |
+
|
462 |
+
# Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
|
463 |
+
x_mask = x_mask.unsqueeze(-2)
|
464 |
+
if not self.bidirectional:
|
465 |
+
x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
|
466 |
+
|
467 |
+
# Initialize position biases
|
468 |
+
position_bias = None
|
469 |
+
encoder_decoder_position_bias = None
|
470 |
+
|
471 |
+
# Compute transformer layers
|
472 |
+
for layer in self.layers:
|
473 |
+
x, position_bias, encoder_decoder_position_bias = layer(
|
474 |
+
x=x,
|
475 |
+
x_mask=x_mask,
|
476 |
+
cond=cond,
|
477 |
+
src=src,
|
478 |
+
src_mask=src_mask,
|
479 |
+
position_bias=position_bias,
|
480 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
481 |
+
)
|
482 |
+
|
483 |
+
return self.norm(x) if self.norm is not None else x
|
484 |
+
|
485 |
+
|
486 |
+
class VampNet(VampBase):
|
487 |
+
def __init__(
|
488 |
+
self,
|
489 |
+
n_heads: int = 20,
|
490 |
+
n_layers: int = 16,
|
491 |
+
r_cond_dim: int = 64,
|
492 |
+
n_codebooks: int = 9,
|
493 |
+
n_conditioning_codebooks: int = 0,
|
494 |
+
latent_dim: int = 8,
|
495 |
+
embedding_dim: int = 1280,
|
496 |
+
vocab_size: int = 1024,
|
497 |
+
flash_attn: bool = True,
|
498 |
+
noise_mode: str = "mask",
|
499 |
+
seq_len: int = 313,
|
500 |
+
):
|
501 |
+
super().__init__()
|
502 |
+
self.n_heads = n_heads
|
503 |
+
self.n_layers = n_layers
|
504 |
+
self.r_cond_dim = r_cond_dim
|
505 |
+
self.n_codebooks = n_codebooks
|
506 |
+
self.n_conditioning_codebooks = n_conditioning_codebooks
|
507 |
+
self.embedding_dim = embedding_dim
|
508 |
+
self.vocab_size = vocab_size
|
509 |
+
self.latent_dim = latent_dim
|
510 |
+
self.flash_attn = flash_attn
|
511 |
+
self.noise_mode = noise_mode
|
512 |
+
self.seq_len = seq_len
|
513 |
+
|
514 |
+
if noise_mode == "mask":
|
515 |
+
special_tokens = ["MASK"]
|
516 |
+
elif noise_mode == "random":
|
517 |
+
special_tokens = None
|
518 |
+
else:
|
519 |
+
raise ValueError(f"Unknown noise mode: {noise_mode}")
|
520 |
+
|
521 |
+
self.embedding = CodebookEmbedding(
|
522 |
+
latent_dim=latent_dim,
|
523 |
+
n_codebooks=n_codebooks,
|
524 |
+
vocab_size=vocab_size,
|
525 |
+
emb_dim=embedding_dim,
|
526 |
+
special_tokens=special_tokens,
|
527 |
+
)
|
528 |
+
|
529 |
+
if noise_mode == "mask":
|
530 |
+
self.mask_token = self.embedding.special_idxs["MASK"]
|
531 |
+
|
532 |
+
self.transformer = TransformerStack(
|
533 |
+
d_model=embedding_dim,
|
534 |
+
d_cond=r_cond_dim,
|
535 |
+
n_heads=n_heads,
|
536 |
+
n_layers=n_layers,
|
537 |
+
last_layer=True,
|
538 |
+
bidirectional=True,
|
539 |
+
flash_attn=flash_attn,
|
540 |
+
is_decoder=False,
|
541 |
+
dropout=0.1,
|
542 |
+
)
|
543 |
+
|
544 |
+
# Add final conv layer
|
545 |
+
self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
|
546 |
+
self.classifier = SequentialWithFiLM(
|
547 |
+
WNConv1d(
|
548 |
+
embedding_dim,
|
549 |
+
vocab_size * self.n_predict_codebooks,
|
550 |
+
kernel_size=1,
|
551 |
+
padding="same",
|
552 |
+
# groups=self.n_predict_codebooks,
|
553 |
+
),
|
554 |
+
)
|
555 |
+
|
556 |
+
def forward(self, x, cond):
|
557 |
+
x = self.embedding(x)
|
558 |
+
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
559 |
+
|
560 |
+
cond = self.r_embed(cond)
|
561 |
+
|
562 |
+
x = rearrange(x, "b d n -> b n d")
|
563 |
+
out = self.transformer(x=x, x_mask=x_mask, cond=cond)
|
564 |
+
out = rearrange(out, "b n d -> b d n")
|
565 |
+
|
566 |
+
out = self.classifier(out, cond)
|
567 |
+
|
568 |
+
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
569 |
+
|
570 |
+
return out
|
571 |
+
|
572 |
+
|
573 |
+
if __name__ == "__main__":
|
574 |
+
# import argbind
|
575 |
+
from .modules import num_params
|
576 |
+
|
577 |
+
VampNet = argbind.bind(VampNet)
|
578 |
+
|
579 |
+
@argbind.bind(without_prefix=True)
|
580 |
+
def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
|
581 |
+
seq_len = int(32000 / 512 * seq_len_s)
|
582 |
+
|
583 |
+
model = VampNet().to(device)
|
584 |
+
|
585 |
+
z = torch.randint(
|
586 |
+
0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
|
587 |
+
).to(device)
|
588 |
+
|
589 |
+
r = torch.zeros(batch_size).to(device)
|
590 |
+
z_mask, mask = model.add_noise(z, r)
|
591 |
+
|
592 |
+
z_mask_latent = torch.rand(
|
593 |
+
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
594 |
+
).to(device)
|
595 |
+
z_hat = model(z_mask_latent, r)
|
596 |
+
|
597 |
+
pred = z_hat.argmax(dim=1)
|
598 |
+
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
599 |
+
|
600 |
+
print(f"model has {num_params(model)/1e6:<.3f}M parameters")
|
601 |
+
print(f"prediction has shape {pred.shape}")
|
602 |
+
breakpoint()
|
603 |
+
|
604 |
+
args = argbind.parse_args()
|
605 |
+
with argbind.scope(args):
|
606 |
+
try_model()
|
vampnet/scheduler.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class NoamScheduler:
|
7 |
+
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
|
8 |
+
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
optimizer: torch.optim.Optimizer,
|
14 |
+
d_model: int = 512,
|
15 |
+
factor: float = 1.0,
|
16 |
+
warmup: int = 4000,
|
17 |
+
):
|
18 |
+
# Store hparams
|
19 |
+
self.warmup = warmup
|
20 |
+
self.factor = factor
|
21 |
+
self.d_model = d_model
|
22 |
+
|
23 |
+
# Initialize variables `lr` and `steps`
|
24 |
+
self.lr = None
|
25 |
+
self.steps = 0
|
26 |
+
|
27 |
+
# Store the optimizer
|
28 |
+
self.optimizer = optimizer
|
29 |
+
|
30 |
+
def state_dict(self):
|
31 |
+
return {
|
32 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
33 |
+
}
|
34 |
+
|
35 |
+
def load_state_dict(self, state_dict):
|
36 |
+
self.__dict__.update(state_dict)
|
37 |
+
|
38 |
+
def step(self):
|
39 |
+
self.steps += 1
|
40 |
+
self.lr = self.factor * (
|
41 |
+
self.d_model ** (-0.5)
|
42 |
+
* min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
|
43 |
+
)
|
44 |
+
|
45 |
+
for p in self.optimizer.param_groups:
|
46 |
+
p["lr"] = self.lr
|
47 |
+
|