Hugo Flores commited on
Commit
50f034f
0 Parent(s):

first commit

Browse files
.dockerignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.wav
2
+ runs/
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
2
+ tests/assets/cpu_test_file.pt filter=lfs diff=lfs merge=lfs -text
3
+ tests/assets/quick.pth filter=lfs diff=lfs merge=lfs -text
4
+ tests/assets/slow.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/env.sh
108
+ venv/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # Files created by experiments
131
+ output/
132
+ snapshot/
133
+ *.m4a
134
+ *.wav
135
+ notebooks/scratch.ipynb
136
+ notebooks/inspect.ipynb
137
+ notebooks/effects.ipynb
138
+ notebooks/*.ipynb
139
+ notebooks/*.gif
140
+ notebooks/*.wav
141
+ notebooks/*.mp4
142
+ *runs/
143
+ boards/
144
+ samples/
145
+ *.ipynb
146
+
147
+ results.json
148
+ metrics.csv
149
+ mprofile_*
150
+ mem.png
151
+
152
+ results/
153
+ mprofile*
154
+ *.png
155
+ # do not ignore the test wav file
156
+ !tests/audio/short_test_audio.wav
157
+ !tests/audio/output.wav
158
+ */.DS_Store
159
+ .DS_Store
160
+ env.sh
161
+ _codebraid/
162
+ **/*.html
163
+ **/*.exec.md
164
+ flagged/
165
+ log.txt
166
+ ckpt/
167
+ .syncthing*
168
+ tests/assets/
169
+ archived/
170
+
171
+ scratch/
172
+
173
+ runs-archive
.pre-commit-config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/asottile/reorder_python_imports
3
+ rev: v2.5.0
4
+ hooks:
5
+ - id: reorder-python-imports
6
+ - repo: https://github.com/psf/black
7
+ rev: 23.1.0
8
+ hooks:
9
+ - id: black
10
+ language_version: python3
11
+ - repo: https://github.com/pre-commit/pre-commit-hooks
12
+ rev: v4.0.1
13
+ hooks:
14
+ - id: end-of-file-fixer
15
+ - id: trailing-whitespace
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM us.gcr.io/lyrebird-research/research-image/audio:beta
2
+
3
+ COPY requirements.txt requirements.txt
4
+ ARG GITHUB_TOKEN
5
+ RUN echo machine github.com login ${GITHUB_TOKEN} > ~/.netrc
6
+
7
+ COPY env/alias.sh /alias.sh
8
+ COPY env/entry_script.sh /entry_script.sh
9
+ RUN cat /alias.sh >> ~/.zshrc
10
+
11
+ # USER researcher
12
+ RUN pip install --upgrade -r requirements.txt
13
+ RUN pip install --upgrade tensorflow
14
+ RUN pip install --upgrade librosa
15
+ RUN pip install --upgrade numba
16
+ ENV PYTHONPATH "$PYTHONPATH:/u/home/src"
17
+ ENV NUMBA_CACHE_DIR=/tmp/
18
+
19
+ USER root
20
+ RUN wget https://github.com/jgm/pandoc/releases/download/2.18/pandoc-2.18-1-amd64.deb
21
+ RUN dpkg -i pandoc-2.18-1-amd64.deb
22
+ RUN apt-get update && apt-get install task-spooler
23
+
24
+ RUN head -n -1 /entry_script.sh > /entry_script_jupyter.sh
25
+ RUN head -n -1 /entry_script.sh > /entry_script_tensorboard.sh
26
+ RUN head -n -1 /entry_script.sh > /entry_script_gradio.sh
27
+
28
+ RUN echo \
29
+ 'su -p ${USER} -c "source ~/.zshrc && jupyter lab --ip=0.0.0.0"' >> \
30
+ /entry_script_jupyter.sh
31
+ RUN echo \
32
+ 'su -p ${USER} -c "source ~/.zshrc && tensorboard --logdir=$TENSORBOARD_PATH --samples_per_plugin audio=500 --bind_all"' >> \
33
+ /entry_script_tensorboard.sh
34
+ RUN echo \
35
+ 'su -p ${USER} -c "source ~/.zshrc && python app.py --args.load=conf/app.yml"' >> \
36
+ /entry_script_gradio.sh
README.md ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lyrebird Wav2Wav
2
+
3
+ This repository contains recipes for training Wav2Wav models.
4
+
5
+ ## Install hooks
6
+
7
+ First install the pre-commit util:
8
+
9
+ https://pre-commit.com/#install
10
+
11
+ pip install pre-commit # with pip
12
+ brew install pre-commit # on Mac
13
+
14
+ Then install the git hooks
15
+
16
+ pre-commit install
17
+ # check .pre-commit-config.yaml for details of hooks
18
+
19
+ Upon `git commit`, the pre-commit hooks will be run automatically on the stage files (i.e. added by `git add`)
20
+
21
+ **N.B. By default, pre-commit checks only run on staged files**
22
+
23
+ If you need to run it on all files:
24
+
25
+ pre-commit run --all-files
26
+
27
+ ## Usage & model zoo
28
+
29
+ To download the model, one must be authenticated to the `lyrebird-research` project on Google Cloud.
30
+ To see all available models, run
31
+
32
+ ```bash
33
+ python -m wav2wav.list_models
34
+ ```
35
+
36
+ which outputs something like this:
37
+
38
+ ```
39
+ gs://research-models/wav2wav
40
+ └── prod
41
+ └── v3
42
+ └── ckpt
43
+ ├── best
44
+ │ └── generator
45
+ │ ├── ❌ model.onnx
46
+ │ ├── ❌ nvidia_geforce_rtx_2080_ti_11_7.trt
47
+ │ ├── ✅ package.pth
48
+ │ ├── ❌ tesla_t4_11_7.trt
49
+ │ └── ✅ weights.pth
50
+ └── latest
51
+ └── generator
52
+ ├── ❌ package.pth
53
+ └── ❌ weights.pth
54
+ └── v2
55
+ ...
56
+ └── dev
57
+ ...
58
+ ```
59
+
60
+ This will show all the models that are available on GCP. Models that are available locally are marked with a ✅, while those not available locally
61
+ are marked with ❌. `.onnx` indicates a model that must be run with
62
+ the `ONNX` runtime, while `.trt` indicate models that have been optimized
63
+ with TensorRT. Note that TensorRT models are specific to GPU and CUDA
64
+ runtime, and their file names indicate what to use to run them.
65
+
66
+ `package.pth` is a version of the model that is saved using `torch.package`,
67
+ and contains a copy of the model code within it, which allow it to work
68
+ even if the model code in `wav2wav/modules/generator.py` changes. `weights.pth`
69
+ contains the model weights, and the code must match the code used
70
+ to create the model.
71
+
72
+ To use a model from this list, simply write its path and give it to the `enhance` script,
73
+ like so:
74
+
75
+ ```
76
+ python -m wav2wav.interface \
77
+ [input_path]
78
+ --model_path=prod/v3/ckpt/best/generator/weights.pth
79
+ --output_path [output_path]
80
+ ```
81
+
82
+ Models are downloaded to the location set by the environment variable `MODEL_LOCAL_PATH`, and defaults to `~/.wav2wav/models`. Similarly,
83
+ The model bucket is determined by `MODEL_GCS_PATH` and defaults to
84
+ `gs://research-models/wav2wav/`.
85
+
86
+ ## Development
87
+ ### Setting everything up
88
+
89
+ Run the setup script to set up your environment via:
90
+
91
+ ```bash
92
+ python env/setup.py
93
+ ```
94
+
95
+ The setup script does not require any dependencies beyond just Python.
96
+ Once run, follow the instructions it prints out to create your
97
+ environment file, which will be at `env/env.sh`.
98
+
99
+ Note that if this is a new machine, and
100
+ the data is not downloaded somewhere on it already, it will ask you
101
+ for a directory to download the data to.
102
+
103
+ For Github setup, if you don't have a .netrc token, create one by going to your Github profile -> Developer settings -> Personal access tokens -> Generate new token. Copy the token and [keep it secret, keep it safe](https://www.youtube.com/watch?v=iThtELZvfPs).
104
+
105
+ When complete, run:
106
+
107
+ ```bash
108
+ source env/env.sh
109
+ ```
110
+
111
+ Now build and launch the Docker containers:
112
+
113
+ ```bash
114
+ docker compose up -d
115
+ ```
116
+
117
+ This builds and runs a Jupyter notebook and Tensorboard
118
+ in the background, which points to your `TENSORBOARD_PATH`
119
+ env. variable.
120
+
121
+ Now, launch your development environment via:
122
+
123
+ ```bash
124
+ docker compose run dev
125
+ ```
126
+
127
+ To tear down your development environment, just do
128
+
129
+ ```bash
130
+ docker compose down
131
+ ```
132
+
133
+ ### Downloading data and pre-processing
134
+ Next, from within the Docker environment (or an appropriately configured Conda environment with environment variables set as above), do the following:
135
+
136
+ ```
137
+ python -m wav2wav.preprocess.download
138
+ ```
139
+
140
+ This will download all the necessary data, which are referenced by
141
+ the CSV files in `conf/audio/*`. These CSVs were generated via
142
+ `python -m wav2wav.preprocess.organize`.
143
+
144
+ ### Launching an experiment
145
+
146
+ Experiments are first _staged_ by running the `stage` command (which corresponds to the script `scripts/exp/stage.py`).
147
+
148
+ `stage` creates a directory with a copy of all of the Git-tracked files in the root repository.`stage` launches a shell into said directory, so all commands are run on the
149
+ copy of the original repository code. This is useful for rewinding to an old experiment
150
+ and resuming it, for example. Even if the repository code changes, the snapshot in the experiment directory is unchanged from the original run, so it can be re-used.
151
+
152
+ Then, the experiment can be run via:
153
+
154
+ ```bash
155
+ torchrun --nproc_per_node gpu \
156
+ scripts/exp/train.py \
157
+ --args.load=conf/args.yml \
158
+ ```
159
+
160
+ The full settings are in [conf/daps/train.yml](conf/daps/train.yml).
161
+
162
+ ### Evaluating an experiment
163
+
164
+ There are two ways to evaluate an experiment: quantitative and qualitative.
165
+ For the first, we can use the `scripts/exp/evaluate.py` script. This script evaluates the model over the `val_data` and `test_data`, defined in your
166
+ `train` script, and takes as input an experiment directory. The metrics
167
+ computed by this script are saved to the same folder.
168
+
169
+ The other way is via a preference test. Let's say we want to compare
170
+ the v3 prod model against the v2 prod model. to do this, we use the
171
+ `scripts/exp/qa.py` script. This script creates a zip file containing all
172
+ the samples and an HTML page for easy viewing. It also creates a Papaya
173
+ preference test. Use it like this:
174
+
175
+ ```bash
176
+ WAV2WAV_MODELS=a,b python scripts/exp/qa.py \
177
+ --a/model_path prod/v3/ckpt/best/generator/package.pth \
178
+ --b/model_path prod/v2/ckpt/best/generator/package.pth \
179
+ --a/name "v3" --b/name "v2" \
180
+ --device cuda:0 \
181
+ --n_samples 20 \
182
+ --zip_path "samples/out.zip"
183
+ ```
184
+
185
+ ### Useful commands
186
+
187
+ #### Monitoring the machine
188
+
189
+ There's a useful `tmux` workspace that you can launch via:
190
+
191
+ ```bash
192
+ tmuxp load ./workspace.yml
193
+ ```
194
+
195
+ which will have a split pane with a shell to launch commands on the left,
196
+ and GPU monitoring, `htop`, and a script that watches for changes in your
197
+ directory on the right, in three split panes.
198
+
199
+ #### Cleaning up after a run
200
+
201
+ Sometimes DDP runs fail to clear themselves out of the machine. To fix this, run
202
+
203
+ ```bash
204
+ cleanup
205
+ ```
206
+
207
+ ### Deploying a new model to production
208
+
209
+ Okay, so you ran a model and it seems promising and you want to upload it
210
+ to GCS so it can be QA'd fully, and then shipped. First, upload
211
+ your experiment to the `dev` bucket on GCS via:
212
+
213
+ ```bash
214
+ gsutil cp -r /path/to/{exp_name} gs://research-models/wav2wav/dev/{exp_name}
215
+ ```
216
+
217
+ Once uploaded, QA can access the models by specifying
218
+ `model_path=dev/{exp_name}/ckpt/{best,latest}/generator/package.pth` when using the
219
+ `wav2wav.interface.enhance` function. If it passes QA, and is scheduled to
220
+ ship to production, then next we have to generate the TensorRT model file,
221
+ which requires us to have a machine that matches that of a production machine.
222
+
223
+ There is a script that automates this procedure, that does not require any
224
+ fiddling from our end. Navigate to the repository root and run:
225
+
226
+ ```
227
+ python scripts/utils/convert_on_gcp.py dev/{exp_name}/ckpt/{best,latest}//generator/weights.pth
228
+ ```
229
+
230
+ This will provision the machine, download the relevant model from GCS, optimize it on
231
+ the production GPU with the correct CUDA runtime, and then upload the generated `.trt`
232
+ and `.onnx` models back to the bucket.
233
+
234
+ Finally, copy the model to the `prod` bucket, incrementing the version number by one:
235
+
236
+ ```bash
237
+ gsutil cp -r gs://research-models/wav2wav/dev/{exp_name} gs://research-models/wav2wav/prod/v{N}
238
+ ```
239
+
240
+ where `N` is the next version (e.g. if v3 is the latest, the new one is v4). Then, update
241
+ the model table in [Notion](https://www.notion.so/descript/fc04de4b46e6417eba1d06bdc8de6c75?v=e56db4e6b37c4d9b9eca8d9be15c826a) with the new model.
242
+
243
+ Once the above is all done, we update the code in two places:
244
+
245
+ 1. In `interface.py`, we update `PROD_MODEL_PATH` to point to the `weights.pth`
246
+ for whichever tag ended up shipping (either `best` or `latest`).
247
+ 2. In `interface.py`, we update `PROD_TRT_PATH` to point the generated
248
+ TensorRT checkpoint generated by the script above.
249
+
250
+ After merging to master, a new Docker image will be created, and one can update the relevant lines
251
+ in descript-workflows like in this [PR](https://github.com/descriptinc/descript-workflows/pull/477/files).
252
+
253
+ We have Github action workflows in [.github/workflows/deploy.yml](.github/workflows/deploy.yml) to build and deploy new docker images. Two images are built - one for staging and another for production.
254
+ To deploy a new release version, follow the instructions in [this coda doc](https://coda.io/d/Research-Engineering_dOABAWL46p-/Deploying-Services_su1am#_lu7E8).
255
+
256
+ Coda doc with informations about deploying speech-enhance worker is [here](https://coda.io/d/Research-Engineering_dOABAWL46p-/Deploying-Services_su1am#_lu7E8).
257
+
258
+ And that's it! Once the new staging is built, you're done.
259
+
260
+ ## Testing
261
+
262
+ ### Profiling and Regression testing
263
+
264
+ - The [profiling script](tests/profile_inference.py) profiles the `wav2wav.interface.enhance` function.
265
+ - NOTE: ALWAYS run the profiler on a T4 GPU. ALWAYS run the profiling in isolation i.e kill all other processes on the GPU. Recommended vm size on GCP is `n1-standard-32` as the stress test of six hours of audio requires ~35GB of system memory.
266
+ - To run profiling use the [profiling script](tests/profile_inference.py) via command `python3 -m tests.profile_inference`. Results will be printed after `1` run.
267
+ - Use the [test_regression.py](tests/test_regression.py) script to run tests that
268
+ - compare performance stats of current model with known best model
269
+ - test for output deviation from the last model
270
+ - Run `git lfs checkout` to checkout input file and model weights required for testing the model.
271
+ - To launch these tests, run `python3 -m pytest tests/test_regression.py -v`.
272
+ - As a side effect, this will update the `tests/stat.csv` file if the current model performs better than last best known model as per `tests/stat.csv`.
273
+ - NOTE: In case of architecture change, purge the weights files : `tests/assets/{quick|slow}.pth` and reference stat file : `tests/assets/baseline.json` file. Running the [test_regression.py](tests/test_regression.py) script in absence of reference stat file, will generate new baseline referece stats as well as append new performance stats to stats file. In the absence of saved weights, new weights are generated and saved on disk. Make sure to commit these files (stat.csv, baseline.json, *.pth) when the model architecture changes.
274
+
275
+ ### Unit tests
276
+ Regular unit tests that test functionality such as training resume etc. These are run on CPU. Update them when new features are added.
277
+
278
+ ### Profiling tests
279
+ These tests profile the model's resource consumption. They are run on T4 GPU with 32 cores and >35GB memory. Their usage is reported in the above sections.
280
+
281
+ ### Functional tests
282
+ These tests detect deviation from known baseline model. A category of these tests ensure that a new pytorch model doesn't deviate from the previous one. Another category ensures that the TensorRT version of the current pytorch model doens't deviate from it. These tests are marked with the marker `output_qa` and can be run via the command line `python3 -m pytest -v -m output_qa`. Some of these tests require a GPU.
283
+
284
+ ### CI tests
285
+ - The tests are divided into two categories depending on the platform requirement - CPU tests and GPU tests.
286
+ - The CPU tests contains unit tests.
287
+ - The GPU tests contain a subset of functional tests. These tests can be run by the command `python3 -m pytest -v -m gpu_ci_test`.
changelog.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v0.3.4
2
+ - Upgraded nussl, removing unneeded code from interface, and improving interaction.
3
+
4
+ # v0.3.1
5
+ - Rich interaction with the model.
6
+
7
+ # v0.2.0
8
+ - Release with script to process zipfiles for QA + gsutil sync for model weights.
9
+
10
+ # v0.1.0
11
+ - Initial release.
conf/vampnet-c2f.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ VampNet.n_codebooks: 9
5
+ VampNet.n_conditioning_codebooks: 3
6
+
7
+ train/AudioDataset.duration: 3
8
+ val/AudioDataset.duration: 3
9
+ test/AudioDataset.duration: 3
conf/vampnet.yml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ wav2wav_ckpt: /u/home/src/runs/codec-ckpt/codec.pth
3
+ save_path: ckpt
4
+ max_epochs: 1000000
5
+ epoch_length: 1000
6
+ save_audio_epochs: 2
7
+ val_idx: [0,1,2,3,4,5,6,7,8,9]
8
+
9
+ prefix_amt: 0.0
10
+ suffix_amt: 0.0
11
+ prefix_dropout: 0.1
12
+ suffix_dropout: 0.1
13
+
14
+ batch_size: 120
15
+ num_workers: 80
16
+
17
+ # Optimization
18
+ detect_anomaly: false
19
+ amp: false
20
+
21
+ CrossEntropyLoss.label_smoothing: 0.1
22
+
23
+ AdamW.lr: 0.001
24
+
25
+ NoamScheduler.factor: 2.0
26
+ NoamScheduler.warmup: 10000
27
+
28
+ VampNet.vocab_size: 1024
29
+ VampNet.n_codebooks: 3
30
+ VampNet.n_conditioning_codebooks: 0
31
+ VampNet.r_cond_dim: 64
32
+ VampNet.embedding_dim: 1280
33
+ VampNet.n_layers: 16
34
+ VampNet.n_heads: 20
35
+ VampNet.flash_attn: false
36
+ VampNet.dropout: 0.05
37
+
38
+ AudioLoader.relative_path: /data/
39
+ AudioDataset.loudness_cutoff: -30.0
40
+ AudioDataset.without_replacement: true
41
+ AudioLoader.shuffle: true
42
+
43
+ train/AudioDataset.duration: 5.0
44
+ train/AudioDataset.n_examples: 10000000
45
+ train/AudioLoader.sources:
46
+ - /data/spotdl/audio/train
47
+
48
+ val/AudioDataset.duration: 5.0
49
+ val/AudioDataset.n_examples: 2000
50
+ val/AudioLoader.sources:
51
+ - /data/spotdl/audio/val
52
+
53
+ test/AudioDataset.duration: 5.0
54
+ test/AudioDataset.n_examples: 1000
55
+ test/AudioLoader.sources:
56
+ - /data/spotdl/audio/test
docker-compose.yml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ version: "3.5"
3
+ services:
4
+ tensorrt:
5
+ build:
6
+ context: .
7
+ dockerfile: ./deployment_build/dockerfile
8
+ args:
9
+ GITHUB_TOKEN: ${GITHUB_TOKEN}
10
+ profiles:
11
+ - tensorrt
12
+ volumes:
13
+ - ./:/u/home/src
14
+ - ~/.config/gcloud:/root/.config/gcloud
15
+ deploy:
16
+ resources:
17
+ limits:
18
+ # match production limits
19
+ cpus: '7'
20
+ memory: 25000M
21
+ reservations:
22
+ devices:
23
+ - driver: nvidia
24
+ count: 1
25
+ capabilities: [gpu]
26
+ working_dir: /u/home/src
27
+ entrypoint:
28
+ - python
29
+ - -m
30
+ - wav2wav.converter
31
+ base:
32
+ build:
33
+ context: .
34
+ dockerfile: ./Dockerfile
35
+ args:
36
+ GITHUB_TOKEN: ${GITHUB_TOKEN}
37
+ volumes:
38
+ - .:/u/home/src
39
+ - ~/.wav2wav:/u/home/.wav2wav
40
+ - ${PATH_TO_DATA}:/data
41
+ - ${PATH_TO_RUNS}:/runs
42
+ - ~/.config/gcloud:/u/home/.config/gcloud
43
+ - ~/.zsh_history:/u/home/.zsh_history
44
+ environment:
45
+ - GITHUB_TOKEN
46
+ - DISCOURSE_API_USERNAME
47
+ - DISCOURSE_SERVER
48
+ - DISCOURSE_API_KEY
49
+ - HOST_USER_ID
50
+ - HOST_USER_GID
51
+ - JUPYTER_TOKEN
52
+ - PATH_TO_DATA=/data
53
+ - PATH_TO_RUNS=/runs
54
+ - TENSORBOARD_PATH
55
+ - MPLCONFIGDIR=/u/home/.mplconfig
56
+ shm_size: 32G
57
+ working_dir: /u/home/src
58
+ deploy:
59
+ resources:
60
+ reservations:
61
+ devices:
62
+ - driver: nvidia
63
+ capabilities: [gpu]
64
+ dev:
65
+ extends: base
66
+ profiles:
67
+ - interactive
68
+ stdin_open: true
69
+ tty: true
70
+ jupyter:
71
+ extends: base
72
+ ports:
73
+ - ${JUPYTER_PORT}:8888
74
+ entrypoint:
75
+ - /bin/bash
76
+ - /entry_script_jupyter.sh
77
+ tensorboard:
78
+ extends: base
79
+ ports:
80
+ - ${TENSORBOARD_PORT}:6006
81
+ entrypoint:
82
+ - /bin/bash
83
+ - /entry_script_tensorboard.sh
84
+ gradio:
85
+ extends: base
86
+ ports:
87
+ - 7860:7860
88
+ entrypoint:
89
+ - /bin/bash
90
+ - /entry_script_gradio.sh
env/alias.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ alias cleanup="pkill python && echo -en '\e[?25h'"
2
+ alias stage="python ./scripts/utils/stage.py"
3
+ alias fix_cursor="echo -en '\e[?25h'"
env/data.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PATH_TO_DATA=~/data
2
+
3
+ if [[ $(hostname) == "oon17" ]]; then
4
+ export PATH_TO_DATA=/home/prem/shared/data/
5
+ fi
6
+
7
+ if [[ $(hostname) == "oon19" ]]; then
8
+ export PATH_TO_DATA=/home/prem/shared/data/
9
+ fi
10
+
11
+ if [[ $(hostname) == "lucas-ssound-trt-vm" ]]; then
12
+ export PATH_TO_DATA=~/data
13
+ fi
14
+
15
+ if [[ $(hostname) == "a100-ssound" ]]; then
16
+ export PATH_TO_DATA=~/data
17
+ fi
18
+
19
+ if [[ $(hostname) == "oon25" ]]; then
20
+ export PATH_TO_DATA=/data
21
+ fi
22
+
23
+ if [[ $(hostname) == "macbook-pro-2.lan" ]]; then
24
+ export PATH_TO_DATA=~/data
25
+ fi
26
+
27
+ if [[ $(hostname) == "oon11" ]]; then
28
+ export PATH_TO_DATA=/data2/syncthing_lucas/data
29
+ fi
30
+
31
+ if [[ $(hostname) == "oon12" ]]; then
32
+ export PATH_TO_DATA=/data
33
+ fi
34
+ if [[ $(hostname) == "oon26" ]]; then
35
+ export PATH_TO_DATA=/data
36
+ fi
env/entry_script.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ if [ -z "${USER}" ]; then
5
+ echo "We need USER to be set!"; exit 100
6
+ fi
7
+
8
+ # check if host uid and gid are set
9
+ if [ -z "${HOST_USER_ID}" ]; then
10
+ echo "Please set HOST_USER_ID env. variables to continue." ; exit 0
11
+ fi
12
+
13
+ if [ -z "${HOST_USER_GID}" ]; then
14
+ echo "Please set HOST_USER_GID env. variables to continue." ; exit 0
15
+ fi
16
+
17
+ USER_ID=$HOST_USER_ID
18
+ USER_GID=$HOST_USER_GID
19
+ USER_HOME=/u/home
20
+
21
+ # modify uid and gid to match host
22
+ sed -i -e "s/^${USER}:\([^:]*\):[0-9]*:[0-9]*/${USER}:\1:${USER_ID}:${USER_GID}/" /etc/passwd
23
+
24
+ # create a group for host gid
25
+ groupadd -f --gid "${USER_GID}" "host_group"
26
+
27
+ chown $USER_ID $USER_HOME
28
+ chown $USER_ID /u/home/.zshrc
29
+ chown $USER_ID /u/home/.oh-my-zsh
30
+
31
+ mkdir -p /u/home/.cache
32
+ chown -R $USER_ID:$USER_GID /u/home/.cache/
33
+
34
+ _term() {
35
+ echo "Caught SIGTERM signal!"
36
+ kill -TERM "$child" 2>/dev/null
37
+ }
38
+
39
+ trap _term SIGTERM
40
+
41
+ su -p "${USER}"
env/setup.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script guides the user through setting up their env.sh
2
+ # if env.sh does not exist. Should have no dependencies other
3
+ # than Python standard library.
4
+ import shlex
5
+ import socket
6
+ import subprocess
7
+ import textwrap
8
+
9
+
10
+ def run(cmd):
11
+ return subprocess.check_output(shlex.split(cmd)).decode("utf-8")
12
+
13
+
14
+ print("1. Setting up Google Cloud access")
15
+ print("---------------------------------")
16
+ gcloud_authorized = "gs://research-data-raw" in run("gsutil ls")
17
+ if not gcloud_authorized:
18
+ run("gcloud auth login")
19
+
20
+ run("gcloud config set project lyrebird-research")
21
+ run("gcloud auth configure-docker")
22
+
23
+ print()
24
+ print("2. Setting up Github access")
25
+ print("---------------------------")
26
+
27
+ lines = textwrap.wrap(
28
+ "First, let's get your Github token, so all "
29
+ "packages can be installed. Create one by going to your "
30
+ "Github profile -> Developer settings -> Personal access tokens -> "
31
+ "Generate new token. Copy the token below."
32
+ )
33
+ [print(l) for l in lines]
34
+
35
+ GITHUB_TOKEN = input("\nGithub token: ") or "undefined"
36
+
37
+ print()
38
+ print("3. Setting up Jupyter and Tensorboard")
39
+ print("-------------------------------------")
40
+
41
+ JUPYTER_TOKEN = input("Password for Jupyter server (default:password): ") or "password"
42
+ JUPYTER_PORT = input("Jupyter port to run on (default:8888): ") or "8888"
43
+ TENSORBOARD_PORT = input("Tensorboard port to run on (default:6006): ") or "6006"
44
+
45
+ print()
46
+ print("4. Setting up paths.")
47
+ print("--------------------")
48
+
49
+ PATH_TO_RUNS = input("Where runs should go (default:./runs/): ") or "./runs/"
50
+ TENSORBOARD_PATH = (
51
+ input("Bucket/dir for tensorboard logs (default=PATH_TO_RUNS): ") or PATH_TO_RUNS
52
+ )
53
+
54
+ with open("env/data.sh") as f:
55
+ data_script = f.read()
56
+
57
+ write_to_data_sh = False
58
+ if socket.gethostname() not in data_script:
59
+ print("Looks like the data path for this machine is not setup.")
60
+ PATH_TO_DATA = input(f"Path to data on {socket.gethostname()}: ") or "~/data"
61
+
62
+ data_command = f"""
63
+ if [[ $(hostname) == "{socket.gethostname()}" ]]; then
64
+ export PATH_TO_DATA={PATH_TO_DATA}
65
+ fi
66
+ """
67
+ write_to_data_sh = True
68
+
69
+
70
+ print()
71
+ print("5. Setting up Papaya")
72
+ print("-----------------------------------------")
73
+
74
+ PAPAYA_USER_TOKEN = input("Papaya user token: ") or "undefined"
75
+
76
+ env_script = f"""
77
+ source env/alias.sh
78
+ source env/data.sh
79
+ export GITHUB_TOKEN={GITHUB_TOKEN}
80
+
81
+ export PAPAYA_USER_TOKEN={PAPAYA_USER_TOKEN}
82
+
83
+ export HOST_USER_ID=$(id -u)
84
+ export HOST_USER_GID=$(id -g)
85
+
86
+ export JUPYTER_TOKEN={JUPYTER_TOKEN}
87
+ export JUPYTER_PORT={JUPYTER_PORT}
88
+ export TENSORBOARD_PORT={TENSORBOARD_PORT}
89
+
90
+ export PATH_TO_RUNS={PATH_TO_RUNS}
91
+ export TENSORBOARD_PATH={TENSORBOARD_PATH}
92
+ """
93
+
94
+ print()
95
+ print("6. Potential file contents.")
96
+ print("---------------------------")
97
+
98
+ print("env/env.sh: \n")
99
+ print("##################")
100
+ print(env_script)
101
+ print("##################")
102
+
103
+ if write_to_data_sh:
104
+ data_script += data_command
105
+
106
+ print("env/data.sh:")
107
+ print("##################")
108
+ print(data_script)
109
+ print("##################")
110
+
111
+ print()
112
+ write_to_files = input("Write to file [yn]? ") or "n"
113
+ if write_to_files == "y":
114
+ with open("env/env.sh", "w") as f:
115
+ f.write(env_script.strip())
116
+ with open("env/data.sh", "w") as f:
117
+ f.write(data_script.strip())
118
+
119
+ print()
120
+ print("8. Finalize setup.")
121
+ print("------------------")
122
+ print("Run the following command to complete setup.")
123
+ print("source env/env.sh")
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ argbind>=0.3.1
2
+ pytorch-ignite
3
+ rich
4
+ audiotools @ git+https://github.com/descriptinc/[email protected]
5
+ tqdm
6
+ tensorboard
7
+ google-cloud-logging==2.2.0
8
+ pytest
9
+ pytest-cov
10
+ papaya_client @ git+https://github.com/descriptinc/lyrebird-papaya.git@master
11
+ pynvml
12
+ psutil
13
+ pandas
14
+ onnx
15
+ onnx-simplifier
16
+ seaborn
17
+ jupyterlab
18
+ jupyterlab-link-share
19
+ pandas
20
+ watchdog
21
+ pesq
22
+ tabulate
23
+ torchmetrics
24
+ codebraid==0.5.0
25
+ jupyter-client==6.1.12
26
+ tensorboardX
27
+ gradio
28
+ einops
29
+ flash-attn
scripts/generative/eval.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import imp
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import argbind
7
+ import audiotools
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from flatten_dict import flatten
12
+ from rich.progress import track
13
+ from torch.utils.tensorboard import SummaryWriter
14
+
15
+ import wav2wav
16
+
17
+ train = imp.load_source("train", str(Path(__file__).absolute().parent / "train.py"))
18
+
19
+
20
+ @argbind.bind(without_prefix=True)
21
+ def evaluate(
22
+ args,
23
+ model_tag: str = "ckpt/best",
24
+ device: str = "cuda",
25
+ exp: str = None,
26
+ overwrite: bool = False,
27
+ ):
28
+ assert exp is not None
29
+
30
+ sisdr_loss = audiotools.metrics.distance.SISDRLoss()
31
+ stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
32
+ mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
33
+
34
+ with audiotools.util.chdir(exp):
35
+ vampnet = wav2wav.modules.vampnet.transformer.VampNet.load(
36
+ f"{model_tag}/vampnet/package.pth"
37
+ )
38
+ vampnet = vampnet.to(device)
39
+ if vampnet.cond_dim > 0:
40
+ condnet = wav2wav.modules.condnet.transformer.CondNet.load(
41
+ f"{model_tag}/condnet/package.pth"
42
+ )
43
+ condnet = condnet.to(device)
44
+ else:
45
+ condnet = None
46
+
47
+ vqvae = wav2wav.modules.generator.Generator.load(
48
+ f"{model_tag}/vqvae/package.pth"
49
+ )
50
+
51
+ _, _, test_data = train.build_datasets(args, vqvae.sample_rate)
52
+
53
+ with audiotools.util.chdir(exp):
54
+ datasets = {
55
+ "test": test_data,
56
+ }
57
+
58
+ metrics_path = Path(f"{model_tag}/metrics")
59
+ metrics_path.mkdir(parents=True, exist_ok=True)
60
+
61
+ for key, dataset in datasets.items():
62
+ csv_path = metrics_path / f"{key}.csv"
63
+ if csv_path.exists() and not overwrite:
64
+ break
65
+ metrics = []
66
+ for i in track(range(len(dataset))):
67
+ # TODO: for coarse2fine
68
+ # grab the signal
69
+ # mask all the codebooks except the conditioning ones
70
+ # and infer
71
+ # then compute metrics
72
+ # for a baseline, just use the coarsest codebook
73
+
74
+ try:
75
+ visqol = audiotools.metrics.quality.visqol(
76
+ enhanced, clean, "audio"
77
+ ).item()
78
+ except:
79
+ visqol = None
80
+
81
+ sisdr = sisdr_loss(enhanced, clean)
82
+ stft = stft_loss(enhanced, clean)
83
+ mel = mel_loss(enhanced, clean)
84
+
85
+ metrics.append(
86
+ {
87
+ "visqol": visqol,
88
+ "sisdr": sisdr.item(),
89
+ "stft": stft.item(),
90
+ "mel": mel.item(),
91
+ "dataset": key,
92
+ "condition": exp,
93
+ }
94
+ )
95
+ print(metrics[-1])
96
+
97
+ transform_args = flatten(item["transform_args"], "dot")
98
+ for k, v in transform_args.items():
99
+ if torch.is_tensor(v):
100
+ if len(v.shape) == 0:
101
+ metrics[-1][k] = v.item()
102
+
103
+ metrics = pd.DataFrame.from_dict(metrics)
104
+ with open(csv_path, "w") as f:
105
+ metrics.to_csv(f)
106
+
107
+ data = summary(model_tag).to_dict()
108
+ metrics = {}
109
+ for k1, v1 in data.items():
110
+ for k2, v2 in v1.items():
111
+ metrics[f"metrics/{k2}/{k1}"] = v2
112
+
113
+ # Number of steps to record
114
+ writer = SummaryWriter(log_dir=metrics_path)
115
+ num_steps = 10
116
+ for k, v in metrics.items():
117
+ for i in range(num_steps):
118
+ writer.add_scalar(k, v, i)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ args = argbind.parse_args()
123
+ with argbind.scope(args):
124
+ evaluate(args)
scripts/generative/train.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import time
4
+ import warnings
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import argbind
9
+ import audiotools as at
10
+ import torch
11
+ import torch.nn as nn
12
+ from audiotools import AudioSignal
13
+ from audiotools.data import transforms
14
+ from einops import rearrange
15
+ from rich import pretty
16
+ from rich.traceback import install
17
+ from tensorboardX import SummaryWriter
18
+
19
+ import vampnet
20
+ from vampnet.modules.transformer import VampNet
21
+ from lac.model.lac import LAC
22
+
23
+
24
+ # Enable cudnn autotuner to speed up training
25
+ # (can be altered by the funcs.seed function)
26
+ torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
27
+ # Uncomment to trade memory for speed.
28
+
29
+ # Install to make things look nice
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+ pretty.install()
32
+ install()
33
+
34
+ # optim
35
+ Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
36
+ CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
37
+ AdamW = argbind.bind(torch.optim.AdamW)
38
+ NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
39
+
40
+ # transforms
41
+ filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
42
+ "BaseTransform",
43
+ "Compose",
44
+ "Choose",
45
+ ]
46
+ tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
47
+
48
+ # model
49
+ VampNet = argbind.bind(VampNet)
50
+
51
+
52
+ # data
53
+ AudioLoader = argbind.bind(at.datasets.AudioLoader)
54
+ AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
55
+
56
+ IGNORE_INDEX = -100
57
+
58
+
59
+ @argbind.bind("train", "val", without_prefix=True)
60
+ def build_transform():
61
+ transform = transforms.Compose(
62
+ tfm.VolumeNorm(("uniform", -32, -14)),
63
+ tfm.VolumeChange(("uniform", -6, 3)),
64
+ tfm.RescaleAudio(),
65
+ )
66
+ return transform
67
+
68
+
69
+ @torch.no_grad()
70
+ def apply_transform(transform_fn, batch):
71
+ sig: AudioSignal = batch["signal"]
72
+ kwargs = batch["transform_args"]
73
+
74
+ sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
75
+ return sig
76
+
77
+
78
+ def build_datasets(args, sample_rate: int):
79
+ with argbind.scope(args, "train"):
80
+ train_data = AudioDataset(
81
+ AudioLoader(), sample_rate, transform=build_transform()
82
+ )
83
+ with argbind.scope(args, "val"):
84
+ val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
85
+ with argbind.scope(args, "test"):
86
+ test_data = AudioDataset(
87
+ AudioLoader(), sample_rate, transform=build_transform()
88
+ )
89
+ return train_data, val_data, test_data
90
+
91
+
92
+ def rand_float(shape, low, high, rng):
93
+ return rng.draw(shape)[:, 0] * (high - low) + low
94
+
95
+
96
+ def flip_coin(shape, p, rng):
97
+ return rng.draw(shape)[:, 0] < p
98
+
99
+
100
+ @argbind.bind(without_prefix=True)
101
+ def load(
102
+ args,
103
+ accel: at.ml.Accelerator,
104
+ save_path: str,
105
+ resume: bool = False,
106
+ tag: str = "latest",
107
+ load_weights: bool = False,
108
+ ):
109
+ model, v_extra = None, {}
110
+
111
+ if resume:
112
+ kwargs = {
113
+ "folder": f"{save_path}/{tag}",
114
+ "map_location": "cpu",
115
+ "package": not load_weights,
116
+ }
117
+ if (Path(kwargs["folder"]) / "model").exists():
118
+ model, v_extra = model.load_from_folder(**kwargs)
119
+
120
+ codec = LAC.load(args["codec_ckpt"], map_location="cpu")
121
+ codec.eval()
122
+ model = VampNet() if model is None else model
123
+
124
+ model = accel.prepare_model(model)
125
+
126
+ # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
127
+ assert (
128
+ accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
129
+ )
130
+
131
+ optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
132
+ scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
133
+ scheduler.step()
134
+
135
+ trainer_state = {"state_dict": None, "start_idx": 0}
136
+
137
+ if "optimizer.pth" in v_extra:
138
+ optimizer.load_state_dict(v_extra["optimizer.pth"])
139
+ if "scheduler.pth" in v_extra:
140
+ scheduler.load_state_dict(v_extra["scheduler.pth"])
141
+ if "trainer.pth" in v_extra:
142
+ trainer_state = v_extra["trainer.pth"]
143
+
144
+ return {
145
+ "model": model,
146
+ "codec": codec,
147
+ "optimizer": optimizer,
148
+ "scheduler": scheduler,
149
+ "trainer_state": trainer_state,
150
+ }
151
+
152
+
153
+ def get_gpu_memory_map():
154
+ """Get the current gpu usage.
155
+
156
+ Returns
157
+ -------
158
+ usage: dict
159
+ Keys are device ids as integers.
160
+ Values are memory usage as integers in MB.
161
+ """
162
+ result = subprocess.check_output(
163
+ ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
164
+ encoding="utf-8",
165
+ )
166
+ # Convert lines into a dictionary
167
+ gpu_memory = [int(x) for x in result.strip().split("\n")]
168
+ gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
169
+ gpu_memory_map = {f"gpu/{k}": v / 1024 for k, v in gpu_memory_map.items()}
170
+ return gpu_memory_map
171
+
172
+
173
+ def num_params_hook(o, p):
174
+ return o + f" {p/1e6:<.3f}M params."
175
+
176
+
177
+ def add_num_params_repr_hook(model):
178
+ import numpy as np
179
+ from functools import partial
180
+
181
+ for n, m in model.named_modules():
182
+ o = m.extra_repr()
183
+ p = sum([np.prod(p.size()) for p in m.parameters()])
184
+
185
+ setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
186
+
187
+
188
+ def accuracy(
189
+ preds: torch.Tensor,
190
+ target: torch.Tensor,
191
+ top_k: int = 1,
192
+ ignore_index: Optional[int] = None,
193
+ **kwargs,
194
+ ) -> torch.Tensor:
195
+ # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
196
+ preds = rearrange(preds, "b p s -> (b s) p")
197
+ target = rearrange(target, "b s -> (b s)")
198
+
199
+ # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
200
+ if ignore_index is not None:
201
+ # Create a mask for the ignored index
202
+ mask = target != ignore_index
203
+ # Apply the mask to the target and predictions
204
+ preds = preds[mask]
205
+ target = target[mask]
206
+
207
+ # Get the top-k predicted classes and their indices
208
+ _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
209
+
210
+ # Determine if the true target is in the top-k predicted classes
211
+ correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
212
+
213
+ # Calculate the accuracy
214
+ accuracy = torch.mean(correct.float())
215
+
216
+ return accuracy
217
+
218
+
219
+ @argbind.bind(without_prefix=True)
220
+ def train(
221
+ args,
222
+ accel: at.ml.Accelerator,
223
+ codec_ckpt: str = None,
224
+ seed: int = 0,
225
+ save_path: str = "ckpt",
226
+ max_epochs: int = int(100e3),
227
+ epoch_length: int = 1000,
228
+ save_audio_epochs: int = 10,
229
+ batch_size: int = 48,
230
+ grad_acc_steps: int = 1,
231
+ val_idx: list = [0, 1, 2, 3, 4],
232
+ num_workers: int = 20,
233
+ detect_anomaly: bool = False,
234
+ grad_clip_val: float = 5.0,
235
+ prefix_amt: float = 0.0,
236
+ suffix_amt: float = 0.0,
237
+ prefix_dropout: float = 0.1,
238
+ suffix_dropout: float = 0.1,
239
+ quiet: bool = False,
240
+ ):
241
+ assert codec_ckpt is not None, "codec_ckpt is required"
242
+
243
+ at.util.seed(seed)
244
+ writer = None
245
+
246
+ if accel.local_rank == 0:
247
+ writer = SummaryWriter(log_dir=f"{save_path}/logs/")
248
+ argbind.dump_args(args, f"{save_path}/args.yml")
249
+
250
+ # load the codec model
251
+ loaded = load(args, accel, save_path)
252
+ model = loaded["model"]
253
+ codec = loaded["codec"]
254
+ optimizer = loaded["optimizer"]
255
+ scheduler = loaded["scheduler"]
256
+ trainer_state = loaded["trainer_state"]
257
+
258
+ sample_rate = codec.sample_rate
259
+
260
+ # a better rng for sampling from our schedule
261
+ rng = torch.quasirandom.SobolEngine(1, scramble=True)
262
+
263
+ # log a model summary w/ num params
264
+ if accel.local_rank == 0:
265
+ add_num_params_repr_hook(accel.unwrap(model))
266
+ with open(f"{save_path}/model.txt", "w") as f:
267
+ f.write(repr(accel.unwrap(model)))
268
+
269
+ # load the datasets
270
+ train_data, val_data, _ = build_datasets(args, sample_rate)
271
+ train_dataloader = accel.prepare_dataloader(
272
+ train_data,
273
+ start_idx=trainer_state["start_idx"],
274
+ num_workers=num_workers,
275
+ batch_size=batch_size,
276
+ collate_fn=train_data.collate,
277
+ )
278
+ val_dataloader = accel.prepare_dataloader(
279
+ val_data,
280
+ start_idx=0,
281
+ num_workers=num_workers,
282
+ batch_size=batch_size,
283
+ collate_fn=val_data.collate,
284
+ )
285
+
286
+ criterion = CrossEntropyLoss()
287
+
288
+ class Trainer(at.ml.BaseTrainer):
289
+ _last_grad_norm = 0.0
290
+
291
+ def metrics(self, vn, z_hat, r, target, flat_mask, output):
292
+ for r_range in [(0, 0.5), (0.5, 1.0)]:
293
+ unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
294
+ masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
295
+
296
+ assert target.shape[0] == r.shape[0]
297
+ # grab the indices of the r values that are in the range
298
+ r_idx = (r >= r_range[0]) & (r < r_range[1])
299
+
300
+ # grab the target and z_hat values that are in the range
301
+ r_unmasked_target = unmasked_target[r_idx]
302
+ r_masked_target = masked_target[r_idx]
303
+ r_z_hat = z_hat[r_idx]
304
+
305
+ for topk in (1, 25):
306
+ s, e = r_range
307
+ tag = f"accuracy-{s}-{e}/top{topk}"
308
+
309
+ output[f"{tag}/unmasked"] = accuracy(
310
+ preds=r_z_hat,
311
+ target=r_unmasked_target,
312
+ ignore_index=IGNORE_INDEX,
313
+ top_k=topk,
314
+ task="multiclass",
315
+ num_classes=vn.vocab_size,
316
+ )
317
+ output[f"{tag}/masked"] = accuracy(
318
+ preds=r_z_hat,
319
+ target=r_masked_target,
320
+ ignore_index=IGNORE_INDEX,
321
+ top_k=topk,
322
+ task="multiclass",
323
+ num_classes=vn.vocab_size,
324
+ )
325
+
326
+ def train_loop(self, engine, batch):
327
+
328
+ model.train()
329
+ batch = at.util.prepare_batch(batch, accel.device)
330
+ signal = apply_transform(train_data.transform, batch)
331
+
332
+ output = {}
333
+ vn = accel.unwrap(model)
334
+ with accel.autocast():
335
+ with torch.inference_mode():
336
+ z = codec.encode(signal.samples, signal.sample_rate)["codes"]
337
+ z = z[:, : vn.n_codebooks, :]
338
+
339
+ n_batch = z.shape[0]
340
+ r = rng.draw(n_batch)[:, 0].to(accel.device)
341
+
342
+ if prefix_amt > 0.0:
343
+ prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
344
+ n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
345
+ else:
346
+ n_prefix = None
347
+ if suffix_amt > 0.0:
348
+ suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
349
+ n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
350
+ else:
351
+ n_suffix = None
352
+
353
+ z_mask, mask = vn.add_noise(
354
+ z, r, n_prefix=n_prefix, n_suffix=n_suffix
355
+ )
356
+ z_mask_latent = vn.embedding.from_codes(z_mask, codec)
357
+
358
+ dtype = torch.bfloat16 if accel.amp else None
359
+ with accel.autocast(dtype=dtype):
360
+ z_hat = model(z_mask_latent, r)
361
+ # for mask mode
362
+ z_hat = vn.add_truth_to_logits(z, z_hat, mask)
363
+
364
+ target = vn.embedding.flatten(
365
+ z[:, vn.n_conditioning_codebooks :, :],
366
+ n_codebooks=vn.n_predict_codebooks,
367
+ )
368
+
369
+ flat_mask = vn.embedding.flatten(
370
+ mask[:, vn.n_conditioning_codebooks :, :],
371
+ n_codebooks=vn.n_predict_codebooks,
372
+ )
373
+
374
+ if vn.noise_mode == "mask":
375
+ # replace target with ignore index for masked tokens
376
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
377
+ output["loss"] = criterion(z_hat, t_masked)
378
+ else:
379
+ output["loss"] = criterion(z_hat, target)
380
+
381
+ self.metrics(
382
+ vn=vn,
383
+ r=r,
384
+ z_hat=z_hat,
385
+ target=target,
386
+ flat_mask=flat_mask,
387
+ output=output,
388
+ )
389
+
390
+
391
+ accel.backward(output["loss"] / grad_acc_steps)
392
+
393
+ output["other/learning_rate"] = optimizer.param_groups[0]["lr"]
394
+ output["other/batch_size"] = z.shape[0]
395
+
396
+ output.update(get_gpu_memory_map())
397
+
398
+ if (
399
+ (engine.state.iteration % grad_acc_steps == 0)
400
+ or (engine.state.iteration % epoch_length == 0)
401
+ or (engine.state.iteration % epoch_length == 1)
402
+ ): # (or we reached the end of the epoch)
403
+ accel.scaler.unscale_(optimizer)
404
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
405
+ model.parameters(), grad_clip_val
406
+ )
407
+ self._last_grad_norm = output["other/grad_norm"]
408
+
409
+ accel.step(optimizer)
410
+ optimizer.zero_grad()
411
+
412
+ scheduler.step()
413
+ accel.update()
414
+ else:
415
+ output["other/grad_norm"] = self._last_grad_norm
416
+
417
+ return {k: v for k, v in sorted(output.items())}
418
+
419
+ @torch.no_grad()
420
+ def val_loop(self, engine, batch):
421
+ model.eval()
422
+ codec.eval()
423
+ batch = at.util.prepare_batch(batch, accel.device)
424
+ signal = apply_transform(val_data.transform, batch)
425
+
426
+ vn = accel.unwrap(model)
427
+ z = codec.encode(signal.samples, signal.sample_rate)["codes"]
428
+ z = z[:, : vn.n_codebooks, :]
429
+
430
+ n_batch = z.shape[0]
431
+ r = rng.draw(n_batch)[:, 0].to(accel.device)
432
+
433
+ if prefix_amt > 0.0:
434
+ prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
435
+ n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
436
+ else:
437
+ n_prefix = None
438
+ if suffix_amt > 0.0:
439
+ suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
440
+ n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
441
+ else:
442
+ n_suffix = None
443
+
444
+ z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
445
+ z_mask_latent = vn.embedding.from_codes(z_mask, codec)
446
+
447
+ z_hat = model(z_mask_latent, r)
448
+ # for mask mode
449
+ z_hat = vn.add_truth_to_logits(z, z_hat, mask)
450
+
451
+ target = vn.embedding.flatten(
452
+ z[:, vn.n_conditioning_codebooks :, :],
453
+ n_codebooks=vn.n_predict_codebooks,
454
+ )
455
+
456
+ flat_mask = vn.embedding.flatten(
457
+ mask[:, vn.n_conditioning_codebooks :, :],
458
+ n_codebooks=vn.n_predict_codebooks,
459
+ )
460
+
461
+ output = {}
462
+ if vn.noise_mode == "mask":
463
+ # replace target with ignore index for masked tokens
464
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
465
+ output["loss"] = criterion(z_hat, t_masked)
466
+ else:
467
+ output["loss"] = criterion(z_hat, target)
468
+
469
+ self.metrics(
470
+ vn=vn,
471
+ r=r,
472
+ z_hat=z_hat,
473
+ target=target,
474
+ flat_mask=flat_mask,
475
+ output=output,
476
+ )
477
+
478
+ return output
479
+
480
+ def checkpoint(self, engine):
481
+ if accel.local_rank != 0:
482
+ print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
483
+ return
484
+
485
+ metadata = {"logs": dict(engine.state.logs["epoch"])}
486
+
487
+ if self.state.epoch % save_audio_epochs == 0:
488
+ self.save_samples()
489
+
490
+ tags = ["latest"]
491
+ loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
492
+ self.print(f"Saving to {str(Path('.').absolute())}")
493
+
494
+ if self.is_best(engine, loss_key):
495
+ self.print(f"Best model so far")
496
+ tags.append("best")
497
+
498
+ for tag in tags:
499
+ model_extra = {
500
+ "optimizer.pth": optimizer.state_dict(),
501
+ "scheduler.pth": scheduler.state_dict(),
502
+ "trainer.pth": {
503
+ "start_idx": self.state.iteration * batch_size,
504
+ "state_dict": self.state_dict(),
505
+ },
506
+ "metadata.pth": metadata,
507
+ }
508
+
509
+ accel.unwrap(model).metadata = metadata
510
+ accel.unwrap(model).save_to_folder(
511
+ f"{save_path}/{tag}", model_extra
512
+ )
513
+
514
+ def save_sampled(self, z):
515
+ num_samples = z.shape[0]
516
+
517
+ for i in range(num_samples):
518
+ sampled = accel.unwrap(model).sample(
519
+ codec,
520
+ time_steps=z.shape[-1],
521
+ start_tokens=z[i : i + 1],
522
+ )
523
+ sampled.cpu().write_audio_to_tb(
524
+ f"sampled/{i}",
525
+ self.writer,
526
+ step=self.state.epoch,
527
+ plot_fn=None,
528
+ )
529
+
530
+ def save_imputation(self, z: torch.Tensor):
531
+ # imputations
532
+ mask_begin = z.shape[-1] // 4
533
+ mask_end = (z.shape[-1] * 3) // 4
534
+
535
+ imp_mask = torch.zeros(z.shape[0], z.shape[-1]).to(accel.device).int()
536
+ imp_mask[:, mask_begin:mask_end] = 1
537
+
538
+ imp_noisy = (
539
+ z * (1 - imp_mask[:, None, :])
540
+ + torch.randint_like(z, 0, accel.unwrap(model).vocab_size)
541
+ * imp_mask[:, None, :]
542
+ )
543
+ imputed_noisy = accel.unwrap(model).to_signal(imp_noisy, codec)
544
+ imputed_true = accel.unwrap(model).to_signal(z, codec)
545
+
546
+ imputed = []
547
+ for i in range(len(z)):
548
+ imputed.append(
549
+ accel.unwrap(model).sample(
550
+ codec,
551
+ time_steps=z.shape[-1],
552
+ start_tokens=z[i][None, ...],
553
+ mask=imp_mask[i][None, ...],
554
+ )
555
+ )
556
+ imputed = AudioSignal.batch(imputed)
557
+
558
+ for i in range(len(val_idx)):
559
+ imputed_noisy[i].cpu().write_audio_to_tb(
560
+ f"imputed_noisy/{i}",
561
+ self.writer,
562
+ step=self.state.epoch,
563
+ plot_fn=None,
564
+ )
565
+ imputed[i].cpu().write_audio_to_tb(
566
+ f"imputed/{i}",
567
+ self.writer,
568
+ step=self.state.epoch,
569
+ plot_fn=None,
570
+ )
571
+ imputed_true[i].cpu().write_audio_to_tb(
572
+ f"imputed_true/{i}",
573
+ self.writer,
574
+ step=self.state.epoch,
575
+ plot_fn=None,
576
+ )
577
+
578
+ @torch.no_grad()
579
+ def save_samples(self):
580
+ model.eval()
581
+ codec.eval()
582
+ vn = accel.unwrap(model)
583
+
584
+ batch = [val_data[i] for i in val_idx]
585
+ batch = at.util.prepare_batch(val_data.collate(batch), accel.device)
586
+
587
+ signal = apply_transform(val_data.transform, batch)
588
+
589
+ z = codec.encode(signal.samples, signal.sample_rate)["codes"]
590
+ z = z[:, : vn.n_codebooks, :]
591
+
592
+ r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
593
+
594
+ n_batch = z.shape[0]
595
+
596
+ if prefix_amt > 0.0:
597
+ prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
598
+ n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
599
+ else:
600
+ n_prefix = None
601
+ if suffix_amt > 0.0:
602
+ suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
603
+ n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
604
+ else:
605
+ n_suffix = None
606
+
607
+ z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
608
+ z_mask_latent = vn.embedding.from_codes(z_mask, codec)
609
+
610
+ z_hat = model(z_mask_latent, r)
611
+ # for mask mode
612
+ z_hat = vn.add_truth_to_logits(z, z_hat, mask)
613
+
614
+ z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
615
+ z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
616
+ z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
617
+
618
+ print("z_mask", z_mask.shape)
619
+ generated = vn.to_signal(z_pred, codec)
620
+ reconstructed = vn.to_signal(z, codec)
621
+ masked = vn.to_signal(z_mask.squeeze(1), codec)
622
+
623
+ for i in range(generated.batch_size):
624
+ audio_dict = {
625
+ "original": signal[i],
626
+ "masked": masked[i],
627
+ "generated": generated[i],
628
+ "reconstructed": reconstructed[i],
629
+ }
630
+ for k, v in audio_dict.items():
631
+ v.cpu().write_audio_to_tb(
632
+ f"samples/_{i}.r={r[i]:0.2f}/{k}",
633
+ self.writer,
634
+ step=self.state.epoch,
635
+ plot_fn=None,
636
+ )
637
+
638
+ self.save_sampled(z)
639
+ self.save_imputation(z)
640
+
641
+ trainer = Trainer(writer=writer, quiet=quiet)
642
+
643
+ if trainer_state["state_dict"] is not None:
644
+ trainer.load_state_dict(trainer_state["state_dict"])
645
+ if hasattr(train_dataloader.sampler, "set_epoch"):
646
+ train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch)
647
+
648
+ trainer.run(
649
+ train_dataloader,
650
+ val_dataloader,
651
+ num_epochs=max_epochs,
652
+ epoch_length=epoch_length,
653
+ detect_anomaly=detect_anomaly,
654
+ )
655
+
656
+
657
+ if __name__ == "__main__":
658
+ args = argbind.parse_args()
659
+ args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
660
+ with argbind.scope(args):
661
+ with Accelerator() as accel:
662
+ train(args, accel)
scripts/utils/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scripts
2
+
3
+ ## process_zip.py
4
+
5
+ Some requirements that may not be installed in the docker image:
6
+ * argbind
7
+ * wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
8
+
9
+ ### zip folder structure
10
+
11
+ The zip folder should have the following internal structure:
12
+
13
+ ```
14
+ base_folder/
15
+ test_case_1/
16
+ before.wav
17
+ test_case_2/
18
+ before.wav
19
+ ...
20
+ test_case_n/
21
+ before.wav
22
+ ```
23
+
24
+ Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
25
+ https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
26
+
27
+ ### Execution
28
+ `python process_zip.py <path/to/zip> -tag <string>`
scripts/utils/stage.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import rich
7
+ from audiotools.ml import Experiment
8
+
9
+
10
+ @argbind.bind(without_prefix=True)
11
+ def run(
12
+ run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
13
+ name: str = None,
14
+ recent: bool = False,
15
+ ):
16
+ if recent:
17
+ paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
18
+ paths = [p.name for p in paths if p.is_dir()]
19
+ if paths:
20
+ name = paths[-1]
21
+
22
+ with Experiment(run_dir, name) as exp:
23
+ exp.snapshot()
24
+ rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
25
+
26
+
27
+ if __name__ == "__main__":
28
+ args = argbind.parse_args()
29
+ with argbind.scope(args):
30
+ run()
setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages
2
+ from setuptools import setup
3
+
4
+ with open("README.md") as f:
5
+ long_description = f.read()
6
+
7
+ setup(
8
+ name="vampnet",
9
+ version="0.0.1",
10
+ classifiers=[
11
+ "Intended Audience :: Developers",
12
+ "Natural Language :: English",
13
+ "Programming Language :: Python :: 3.7",
14
+ "Topic :: Artistic Software",
15
+ "Topic :: Multimedia",
16
+ "Topic :: Multimedia :: Sound/Audio",
17
+ "Topic :: Multimedia :: Sound/Audio :: Editors",
18
+ "Topic :: Software Development :: Libraries",
19
+ ],
20
+ description="Generative Music Modeling.",
21
+ long_description=long_description,
22
+ long_description_content_type="text/markdown",
23
+ author="Hugo Flores García",
24
+ author_email="[email protected]",
25
+ url="https://github.com/descriptinc/lyrebird-vampnet",
26
+ license="MIT",
27
+ packages=find_packages(),
28
+ install_requires=[
29
+ "torch<=1.11.0",
30
+ "argbind>=0.3.2",
31
+ "pytorch-ignite",
32
+ "rich",
33
+ "audiotools @ git+https://github.com/descriptinc/[email protected]",
34
+ "tqdm",
35
+ "tensorboard",
36
+ "google-cloud-logging==2.2.0",
37
+ "torchmetrics>=0.7.3",
38
+ "einops",
39
+ ],
40
+ )
vampnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from . import modules
3
+ from . import scheduler
4
+ from . import enchilada
5
+
6
+ __version__ = "0.0.1"
vampnet/enchilada.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from audiotools import AudioSignal
6
+
7
+ from .modules.transformer import VampNet
8
+ from lac.model.lac import LAC
9
+
10
+
11
+ class TheWholeEnchilada:
12
+ def __init__(
13
+ self,
14
+ coarse_ckpt: str,
15
+ coarse2fine_ckpt: str,
16
+ codec_ckpt: str,
17
+ device: str = "cpu",
18
+ ):
19
+ self.codec = LAC.load(Path(codec_ckpt))
20
+ self.codec.eval()
21
+ self.codec.to(device)
22
+
23
+ self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
24
+ self.coarse.to(device)
25
+ self.coarse.eval()
26
+
27
+ self.coarse2fine = VampNet.load(
28
+ location=Path(coarse2fine_ckpt), map_location="cpu"
29
+ )
30
+ # FIXME
31
+ print(
32
+ f"WARNING: PATCHING coarse2fine seq_len to 288, for backwards compatibility with a specific jazzpop model. it used to be {self.coarse2fine.seq_len}"
33
+ )
34
+ self.coarse2fine.seq_len = 288
35
+
36
+ self.coarse2fine.to(device)
37
+ self.coarse2fine.eval()
38
+
39
+ self.device = device
40
+
41
+ def seconds_to_tokens(self, seconds: float):
42
+ return int(seconds * self.codec.sample_rate / self.codec.hop_length)
43
+
44
+ def to(self, device):
45
+ self.device = device
46
+ self.coarse.to(device)
47
+ self.coarse2fine.to(device)
48
+ self.codec.to(device)
49
+ return self
50
+
51
+ def encode(self, signal: AudioSignal):
52
+ with torch.inference_mode():
53
+ # coarse z
54
+ cz = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
55
+
56
+ return cz
57
+
58
+ def vamp(
59
+ self,
60
+ signal,
61
+ prefix_dur_s: float = 1.25,
62
+ suffix_dur_s: float = 1.25,
63
+ downsample_hint: bool = True,
64
+ downsample_factor: int = 4,
65
+ num_loops: int = 3,
66
+ **kwargs,
67
+ ):
68
+ """
69
+ Loop imputation of a signal.
70
+ """
71
+ signal.to(self.device).resample(self.codec.sample_rate).to_mono()
72
+
73
+ z = self.encode(signal)
74
+
75
+ cz = z[:, : self.coarse.n_codebooks, :].clone()
76
+ original_cz = cz.clone()
77
+ seq_len = original_cz.shape[-1]
78
+ assert (
79
+ seq_len == self.coarse.seq_len
80
+ ), f"expected seq_len {self.coarse.seq_len}, got {seq_len} for token sequence length. Is your signal the same duration as the model was trained with? "
81
+
82
+ vamp_hop_s = prefix_dur_s
83
+ vamp_hop = self.seconds_to_tokens(vamp_hop_s)
84
+
85
+ cmask = torch.ones_like(cz)
86
+
87
+ if downsample_hint:
88
+ # downsample by factor of 4
89
+ for i in range(cmask.shape[-1]):
90
+ if i % downsample_factor == 0:
91
+ cmask[:, :, i] = 0
92
+
93
+ if prefix_dur_s > 0:
94
+ prefix_len = self.seconds_to_tokens(prefix_dur_s)
95
+ cmask[:, :, :prefix_len] = 0
96
+ print(f"prefix_len: {prefix_len}")
97
+ else:
98
+ prefix_len = 0
99
+
100
+ if suffix_dur_s > 0:
101
+ suffix_len = self.seconds_to_tokens(suffix_dur_s)
102
+ cmask[:, :, -suffix_len:] = 0
103
+ print(f"suffix_len: {suffix_len}")
104
+ else:
105
+ suffix_len = 0
106
+
107
+ prefix_z = cz[:, :, :prefix_len]
108
+
109
+ coarse_vamp = [prefix_z.clone()]
110
+ for i in range(num_loops):
111
+ sampled_cz = self.coarse.sample(
112
+ codec=self.codec,
113
+ time_steps=seq_len,
114
+ mask=cmask,
115
+ start_tokens=cz,
116
+ return_signal=False,
117
+ **kwargs,
118
+ )
119
+
120
+ new_prefix = sampled_cz[:, :, prefix_len : prefix_len + vamp_hop]
121
+ coarse_vamp.append(new_prefix.clone())
122
+
123
+ # replace the prefix in cz with the new prefix
124
+ # don't worry about a copy of the prefix still being
125
+ # in the mask area, since that will be masked out
126
+ cz[:, :, :vamp_hop] = new_prefix.clone()
127
+ print("to append and to prefix")
128
+
129
+ # we're done, so add the suffix
130
+ coarse_vamp.append(sampled_cz[:, :, prefix_len + vamp_hop :])
131
+
132
+ # concatenate the vamps
133
+ coarse_vamp = torch.cat(coarse_vamp, dim=-1)
134
+
135
+ # add a layer of
136
+ fine_prefix = z[:, self.coarse.n_codebooks :, :prefix_len]
137
+ fine_suffix = z[:, self.coarse.n_codebooks :, -suffix_len:]
138
+ fine_vamp = torch.randint(
139
+ 0,
140
+ self.coarse2fine.vocab_size,
141
+ (
142
+ coarse_vamp.shape[0],
143
+ self.coarse2fine.n_predict_codebooks,
144
+ coarse_vamp.shape[-1],
145
+ ),
146
+ ).to(self.device)
147
+ fine_vamp[:, :, :prefix_len] = fine_prefix
148
+ fine_vamp[:, :, -suffix_len:] = fine_suffix
149
+
150
+ vamp_z = torch.cat([coarse_vamp, fine_vamp], dim=1)
151
+
152
+ # now we sample from the coarse2fine model
153
+ # to get the fine details
154
+ start_pos = 0
155
+
156
+ c2f_vamp = []
157
+ while start_pos < vamp_z.shape[-1]:
158
+ end_pos = min(start_pos + self.coarse2fine.seq_len, vamp_z.shape[-1])
159
+
160
+ c2fz = vamp_z[:, :, start_pos:end_pos]
161
+ self.coarse2fine: VampNet
162
+ sampled_c2fz = self.coarse2fine.sample(
163
+ codec=self.codec,
164
+ start_tokens=c2fz,
165
+ return_signal=False,
166
+ mask=None,
167
+ )
168
+ c2f_vamp.append(sampled_c2fz)
169
+ start_pos += self.coarse2fine.seq_len
170
+
171
+ c2f_vamp = torch.cat(c2f_vamp, dim=-1)
172
+
173
+ # make it a signal
174
+ vamp_signal = self.coarse2fine.to_signal(c2f_vamp, self.codec)
175
+
176
+ return {
177
+ "full": vamp_signal,
178
+ "coarse": self.coarse.to_signal(coarse_vamp, self.codec),
179
+ }
vampnet/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import audiotools
2
+
3
+ audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
+ audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention"]
vampnet/modules/base.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import audiotools as at
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from tqdm import tqdm
12
+
13
+
14
+ def log(t, eps=1e-20):
15
+ return torch.log(t + eps)
16
+
17
+
18
+ def gumbel_noise(t):
19
+ noise = torch.zeros_like(t).uniform_(0, 1)
20
+ return -log(-log(noise))
21
+
22
+
23
+ def gumbel_sample(t, temperature=1.0, dim=-1):
24
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
25
+
26
+
27
+ class VampBase(at.ml.BaseModel):
28
+ def forward(self, x: torch.Tensor, r: torch.Tensor):
29
+ raise NotImplementedError
30
+
31
+ def add_noise(
32
+ self,
33
+ x: torch.Tensor,
34
+ r: torch.Tensor,
35
+ random_x: Optional[torch.Tensor] = None,
36
+ mask: Optional[torch.Tensor] = None,
37
+ n_prefix: Optional[torch.Tensor] = None,
38
+ n_suffix: Optional[torch.Tensor] = None,
39
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
40
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
41
+
42
+ if mask is None:
43
+ r = self.gamma(r)[:, None, None]
44
+ probs = torch.ones_like(x) * r
45
+
46
+ # if we have a prefix or suffix, set their mask prob to 0
47
+ if n_prefix is not None:
48
+ for i, n in enumerate(n_prefix):
49
+ probs[i, :, :n] = 0.0
50
+ if n_suffix is not None:
51
+ for i, n in enumerate(n_suffix):
52
+ probs[i, :, -n:] = 0.0
53
+
54
+ mask = torch.bernoulli(probs)
55
+ mask = mask.round().long()
56
+
57
+ # if we have any conditioning codebooks, set their mask to 0
58
+ mask[:, : self.n_conditioning_codebooks, :] = 0
59
+ else:
60
+ assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
61
+ assert mask.shape == x.shape, "mask must be same shape as x"
62
+
63
+ if random_x is None:
64
+ random_x = torch.randint_like(x, 0, self.vocab_size)
65
+
66
+ if self.noise_mode == "mask":
67
+ random_x = torch.full_like(x, self.mask_token)
68
+ elif self.noise_mode == "random":
69
+ if random_x is None:
70
+ random_x = torch.randint_like(x, 0, self.vocab_size)
71
+ else:
72
+ raise ValueError(f"invalid noise mode {self.noise_mode}")
73
+
74
+ x = x * (1 - mask) + random_x * mask
75
+ return x, mask
76
+
77
+ def add_truth_to_logits(
78
+ self,
79
+ z_true,
80
+ z_hat,
81
+ mask,
82
+ ):
83
+ if self.noise_mode == "mask":
84
+ z_true = z_true[:, self.n_conditioning_codebooks :, :]
85
+ mask = mask[:, self.n_conditioning_codebooks :, :]
86
+
87
+ truth = F.one_hot(z_true, self.vocab_size)
88
+ print(truth.shape)
89
+ # truth = rearrange(truth, "b c t p -> b p (t c)")
90
+ mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
91
+ z_hat = rearrange(
92
+ z_hat,
93
+ "b p (t c) -> b c t p",
94
+ c=self.n_codebooks - self.n_conditioning_codebooks,
95
+ )
96
+
97
+ z_hat = z_hat * mask + truth * (1 - mask)
98
+
99
+ z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
100
+
101
+ return z_hat
102
+
103
+ def gamma(self, r):
104
+ return (r * torch.pi / 2).cos()
105
+
106
+ def r_embed(self, r, max_positions=10000):
107
+ """ """
108
+ assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
109
+
110
+ if self.r_cond_dim > 0:
111
+ dtype = r.dtype
112
+
113
+ r = self.gamma(r) * max_positions
114
+ half_dim = self.r_cond_dim // 2
115
+
116
+ emb = math.log(max_positions) / (half_dim - 1)
117
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
118
+
119
+ emb = r[:, None] * emb[None, :]
120
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
121
+
122
+ if self.r_cond_dim % 2 == 1: # zero pad
123
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
124
+
125
+ return emb.to(dtype)
126
+ else:
127
+ return r
128
+
129
+ @torch.no_grad()
130
+ def to_signal(self, z, vqvae):
131
+ if z.ndim == 2:
132
+ z = self.embedding.unflatten(z)
133
+ assert z.ndim == 3
134
+
135
+ signal = at.AudioSignal(
136
+ vqvae.decode(
137
+ vqvae.quantizer.from_latents(self.embedding.from_codes(z, vqvae))[0]
138
+ )["audio"],
139
+ vqvae.sample_rate,
140
+ )
141
+
142
+ return signal
143
+
144
+ @torch.no_grad()
145
+ def sample(self, **kwargs):
146
+ if self.noise_mode == "mask":
147
+ return self.maskgit_sample(**kwargs)
148
+ else:
149
+ return self.paella_sample(**kwargs)
150
+
151
+ def paella_sample(
152
+ self,
153
+ vqvae,
154
+ time_steps: int = 400,
155
+ sampling_steps: int = 12,
156
+ start_tokens: Optional[torch.Tensor] = None,
157
+ mask: Optional[torch.Tensor] = None,
158
+ device: str = "cpu",
159
+ temperature: Union[float, Tuple[float, float]] = 1.0,
160
+ top_k: int = None,
161
+ sample: str = "gumbel",
162
+ renoise_mode: str = "start",
163
+ renoise_steps=None,
164
+ typical_filtering=True,
165
+ typical_mass=0.2,
166
+ typical_min_tokens=1,
167
+ return_signal=True,
168
+ ):
169
+ r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(device)
170
+ if renoise_steps == None:
171
+ renoise_steps = sampling_steps - 1
172
+
173
+ if isinstance(temperature, float):
174
+ temperature = torch.tensor(temperature).repeat(sampling_steps)
175
+ elif isinstance(temperature, tuple):
176
+ assert len(temperature) == 2
177
+ l, h = temperature
178
+ temperature = torch.linspace(l, h, sampling_steps)
179
+ else:
180
+ raise TypeError(f"invalid type for temperature")
181
+
182
+ if self.n_conditioning_codebooks > 0:
183
+ assert (
184
+ start_tokens is not None
185
+ ), "must provide start_tokens if n_conditioning_codebooks > 0"
186
+
187
+ if start_tokens is None:
188
+ if self.noise_mode == "noise":
189
+ z = torch.randint(
190
+ 0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
191
+ ).to(device)
192
+ elif self.noise_mode == "mask":
193
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
194
+ else:
195
+ z = start_tokens
196
+ assert (
197
+ z.ndim == 3
198
+ ), f"start_tokens must be shape (batch, n_codebooks, seq_len), got {z.shape}"
199
+ assert z.shape[0] == 1, f"batch size must be 1"
200
+
201
+ if mask is None:
202
+ mask = torch.ones(z.shape[0], z.shape[-1]).to(device).int()
203
+
204
+ # apply mask
205
+ assert mask.shape == (
206
+ z.shape[0],
207
+ z.shape[-1],
208
+ ), f"mask must be shape (batch, seq_len), got {mask.shape}"
209
+ mask = mask[:, None, :]
210
+ mask = mask.repeat(1, z.shape[1], 1)
211
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
212
+
213
+ if self.noise_mode == "mask":
214
+ z_true = z.clone()
215
+
216
+ z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
217
+ z_init = z.clone()
218
+ for i, tmpt in enumerate(temperature):
219
+ if renoise_mode == "prev":
220
+ z_prev = z.clone()
221
+
222
+ latents = self.embedding.from_codes(z, vqvae)
223
+ logits = self.forward(latents, r[i])
224
+
225
+ # for mask mode
226
+ logits = self.add_truth_to_logits(z_true, logits, mask)
227
+
228
+ # Apply topk sampling
229
+ logits = logits.permute(0, 2, 1)
230
+
231
+ z = self.sample_from_logits(
232
+ logits,
233
+ tmpt,
234
+ top_k,
235
+ sample=sample,
236
+ typical_filtering=typical_filtering,
237
+ typical_mass=typical_mass,
238
+ typical_min_tokens=typical_min_tokens,
239
+ )
240
+
241
+ # add back in conditioning codebooks
242
+ z = self.embedding.unflatten(z, n_codebooks=self.n_predict_codebooks)
243
+ z = torch.cat(
244
+ [z_init[:, : self.n_conditioning_codebooks, :], z], dim=1
245
+ ).int()
246
+
247
+ if i < renoise_steps:
248
+ if renoise_mode == "prev":
249
+ z, _ = self.add_noise(z, r[i + 1], random_x=z_prev)
250
+ elif renoise_mode == "start":
251
+ z, _ = self.add_noise(z, r[i + 1], random_x=z_init)
252
+ elif renoise_mode == "rand":
253
+ z, _ = self.add_noise(z, r[i + 1])
254
+ else:
255
+ raise ValueError(f"Invalid renoise_mode: {renoise_mode}")
256
+
257
+ if mask is not None:
258
+ z = start_tokens * (1 - mask) + z * mask
259
+
260
+ if return_signal:
261
+ return self.to_signal(z, vqvae)
262
+ else:
263
+ return z
264
+
265
+ def maskgit_sample(
266
+ self,
267
+ vqvae,
268
+ time_steps: int = 300,
269
+ sampling_steps: int = 24,
270
+ start_tokens: Optional[torch.Tensor] = None,
271
+ mask: Optional[torch.Tensor] = None,
272
+ temperature: Union[float, Tuple[float, float]] = 1.0,
273
+ top_k: int = None,
274
+ sample: str = "multinomial",
275
+ typical_filtering=False,
276
+ typical_mass=0.2,
277
+ typical_min_tokens=1,
278
+ return_signal=True,
279
+ ):
280
+ if isinstance(temperature, float):
281
+ temperature = torch.tensor(temperature).repeat(sampling_steps)
282
+ elif isinstance(temperature, tuple):
283
+ assert len(temperature) == 2
284
+ l, h = temperature
285
+ temperature = torch.linspace(l, h, sampling_steps)
286
+ else:
287
+ raise TypeError(f"invalid type for temperature")
288
+
289
+ def flatten(codes):
290
+ return rearrange(codes, "b c t -> b (t c)")
291
+
292
+ def unflatten(codes, c):
293
+ return rearrange(codes, "b (t c) -> b c t", c=c)
294
+
295
+ z = start_tokens
296
+
297
+ if z is None:
298
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
299
+ self.device
300
+ )
301
+
302
+ if mask is None:
303
+ mask = torch.ones_like(z).to(self.device).int()
304
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
305
+ if mask.ndim == 2:
306
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
307
+
308
+ # figure out which timesteps we're keeping
309
+ keep_mask = 1 - mask
310
+
311
+ # any conditioning codebook levels need to be in the keep mask
312
+ # if self.n_conditioning_codebooks > 0:
313
+ # cond_mask = torch.ones(z.shape[0], self.n_conditioning_codebooks, z.shape[-1]).to(z.device)
314
+ # keep_mask = torch.cat([cond_mask, keep_mask], dim=1)
315
+
316
+ # flatten
317
+ keep_mask = flatten(keep_mask)
318
+
319
+ # our r steps
320
+ r_steps = torch.linspace(0, 1, sampling_steps + 1)[1:].to(self.device)
321
+
322
+ # how many tokens did we keep on init?
323
+ num_kept_on_init = keep_mask.sum()
324
+
325
+ # how many codebooks are we inferring vs conditioning on?
326
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
327
+
328
+ for i in tqdm(range(sampling_steps)):
329
+ # our current temperature
330
+ tmpt = temperature[i]
331
+
332
+ # our current schedule step
333
+ r = r_steps[i : i + 1]
334
+
335
+ with torch.inference_mode():
336
+ # mask our z
337
+ keep_mask_unflat = unflatten(keep_mask, c=self.n_codebooks)
338
+ z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
339
+
340
+ # get latents
341
+ latents = self.embedding.from_codes(z_masked, vqvae)
342
+
343
+ # infer from latents
344
+ logits = self.forward(latents, r)
345
+ logits = logits.permute(0, 2, 1) # b, seq, prob
346
+
347
+ # the schedule determines how many samples to keep
348
+ num_tokens_to_infer = (z.shape[-1] * z.shape[-2]) - num_kept_on_init
349
+ num_to_keep = num_kept_on_init + int(
350
+ num_tokens_to_infer * (self.gamma(1 - r))
351
+ )
352
+
353
+ # figure out which logits we wanna keep
354
+ if num_to_keep > 0:
355
+ probs = logits.softmax(dim=-1)
356
+
357
+ keep_probs = F.one_hot(z, self.vocab_size)[:, :, :]
358
+
359
+ probs = rearrange(
360
+ probs, "b (t c) p -> b c t p", c=n_infer_codebooks
361
+ )
362
+ probs = torch.cat(
363
+ [keep_probs[:, : self.n_conditioning_codebooks, ...], probs],
364
+ dim=1,
365
+ )
366
+
367
+ keep_probs = rearrange(
368
+ keep_probs, "b c t p -> b (t c) p", c=self.n_codebooks
369
+ )
370
+ probs = rearrange(probs, "b c t p -> b (t c) p", c=self.n_codebooks)
371
+
372
+ keep_prob_mask = keep_mask.unsqueeze(-1).repeat(
373
+ 1, 1, self.vocab_size
374
+ )
375
+ probs = (keep_prob_mask.long() * keep_probs) + (
376
+ 1 - keep_prob_mask.long()
377
+ ) * probs
378
+
379
+ highest_probs = probs.max(dim=-1, keepdim=False)[0]
380
+ v, _ = highest_probs.topk(num_to_keep, dim=-1)
381
+
382
+ keep_mask = torch.ones_like(keep_mask).bool().clone()
383
+ keep_mask[highest_probs < v[..., [-1]]] = 0
384
+
385
+ logits = torch.log(probs)
386
+
387
+ z_inferred = self.sample_from_logits(
388
+ logits=logits,
389
+ top_k=top_k,
390
+ temperature=tmpt,
391
+ sample=sample,
392
+ typical_filtering=typical_filtering,
393
+ typical_mass=typical_mass,
394
+ typical_min_tokens=typical_min_tokens,
395
+ )
396
+
397
+ z = rearrange(z_inferred, "b (t c) -> b c t", c=self.n_codebooks)
398
+
399
+ # add conditioning codebooks back
400
+ # z = torch.cat([z[:, :self.n_conditioning_codebooks, :], z_inferred], dim=1)
401
+
402
+ if return_signal:
403
+ return self.to_signal(z, vqvae)
404
+ else:
405
+ return z
406
+
407
+ def sample_from_logits(
408
+ self,
409
+ logits,
410
+ top_k: int = None,
411
+ temperature: float = 1.0,
412
+ sample: str = "multinomial",
413
+ typical_filtering=False,
414
+ typical_mass=0.2,
415
+ typical_min_tokens=1,
416
+ ):
417
+ # add temperature
418
+ logits = logits / temperature
419
+
420
+ # add topk
421
+ if top_k is not None:
422
+ v, topk_idx = logits.topk(top_k)
423
+ logits[logits < v[..., [-1]]] = -float("inf")
424
+
425
+ if typical_filtering:
426
+ assert top_k is None
427
+ nb, nt, _ = logits.shape
428
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
429
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
430
+ x_flat_norm_p = torch.exp(x_flat_norm)
431
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
432
+
433
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
434
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
435
+ x_flat_cumsum = (
436
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
437
+ )
438
+
439
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
440
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
441
+ 1, last_ind.view(-1, 1)
442
+ )
443
+ if typical_min_tokens > 1:
444
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
445
+ indices_to_remove = sorted_indices_to_remove.scatter(
446
+ 1, x_flat_indices, sorted_indices_to_remove
447
+ )
448
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
449
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
450
+
451
+ if sample == "multinomial":
452
+ probs = torch.softmax(logits, dim=-1)
453
+ inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
454
+ elif sample == "argmax":
455
+ inferred = torch.softmax(probs, dim=-1).argmax(dim=-1)
456
+ elif sample == "gumbel":
457
+ inferred = gumbel_sample(logits, dim=-1)
458
+ else:
459
+ raise ValueError(f"invalid sampling method: {sample}")
460
+
461
+ return inferred
vampnet/modules/modules.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ def num_params(model):
13
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
14
+
15
+
16
+ def recurse_children(module, fn):
17
+ for child in module.children():
18
+ if isinstance(child, nn.ModuleList):
19
+ for c in child:
20
+ yield recurse_children(c, fn)
21
+ if isinstance(child, nn.ModuleDict):
22
+ for c in child.values():
23
+ yield recurse_children(c, fn)
24
+
25
+ yield recurse_children(child, fn)
26
+ yield fn(child)
27
+
28
+
29
+ # Scripting this brings model speed up 1.4x
30
+ @torch.jit.script
31
+ def snake(x, alpha):
32
+ shape = x.shape
33
+ x = x.reshape(shape[0], shape[1], -1)
34
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
35
+ x = x.reshape(shape)
36
+ return x
37
+
38
+
39
+ class Snake1d(nn.Module):
40
+ def __init__(self, channels):
41
+ super().__init__()
42
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
43
+
44
+ def forward(self, x):
45
+ return snake(x, self.alpha)
46
+
47
+
48
+ def WNConv1d(*args, **kwargs):
49
+ return weight_norm(nn.Conv1d(*args, **kwargs))
50
+
51
+
52
+ def WNConvTranspose1d(*args, **kwargs):
53
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
54
+
55
+
56
+ class SequentialWithFiLM(nn.Module):
57
+ """
58
+ handy wrapper for nn.Sequential that allows FiLM layers to be
59
+ inserted in between other layers.
60
+ """
61
+
62
+ def __init__(self, *layers):
63
+ super().__init__()
64
+ self.layers = nn.ModuleList(layers)
65
+
66
+ @staticmethod
67
+ def has_film(module):
68
+ mod_has_film = any(
69
+ [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
70
+ )
71
+ return mod_has_film
72
+
73
+ def forward(self, x, cond):
74
+ for layer in self.layers:
75
+ if self.has_film(layer):
76
+ x = layer(x, cond)
77
+ else:
78
+ x = layer(x)
79
+ return x
80
+
81
+
82
+ class FiLM(nn.Module):
83
+ def __init__(self, input_dim: int, output_dim: int):
84
+ super().__init__()
85
+
86
+ self.input_dim = input_dim
87
+ self.output_dim = output_dim
88
+
89
+ if input_dim > 0:
90
+ self.beta = nn.Linear(input_dim, output_dim)
91
+ self.gamma = nn.Linear(input_dim, output_dim)
92
+
93
+ def forward(self, x, r):
94
+ if self.input_dim == 0:
95
+ return x
96
+ else:
97
+ beta, gamma = self.beta(r), self.gamma(r)
98
+ beta, gamma = (
99
+ beta.view(x.size(0), self.output_dim, 1),
100
+ gamma.view(x.size(0), self.output_dim, 1),
101
+ )
102
+ x = x * (gamma + 1) + beta
103
+ return x
104
+
105
+
106
+ class CodebookEmbedding(nn.Module):
107
+ def __init__(
108
+ self,
109
+ vocab_size: int,
110
+ latent_dim: int,
111
+ n_codebooks: int,
112
+ emb_dim: int,
113
+ special_tokens: Optional[Tuple[str]] = None,
114
+ ):
115
+ super().__init__()
116
+ self.n_codebooks = n_codebooks
117
+ self.emb_dim = emb_dim
118
+ self.latent_dim = latent_dim
119
+ self.vocab_size = vocab_size
120
+
121
+ if special_tokens is not None:
122
+ for tkn in special_tokens:
123
+ self.special = nn.ParameterDict(
124
+ {
125
+ tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
126
+ for tkn in special_tokens
127
+ }
128
+ )
129
+ self.special_idxs = {
130
+ tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
131
+ }
132
+
133
+ self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
134
+
135
+ def from_codes(self, codes: torch.Tensor, vqvae):
136
+ n_codebooks = codes.shape[1]
137
+ latent = []
138
+ for i in range(n_codebooks):
139
+ c = codes[:, i, :]
140
+
141
+ lookup_table = vqvae.quantizer.quantizers[i].codebook.weight
142
+ if hasattr(self, "special"):
143
+ special_lookup = torch.cat(
144
+ [self.special[tkn][i : i + 1] for tkn in self.special], dim=0
145
+ )
146
+ lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
147
+
148
+ l = F.embedding(c, lookup_table).transpose(1, 2)
149
+ latent.append(l)
150
+
151
+ latent = torch.cat(latent, dim=1)
152
+ return latent
153
+
154
+ def forward(self, latents: torch.Tensor):
155
+ x = self.out_proj(latents)
156
+ return x
157
+
158
+ def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
159
+ n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
160
+ return rearrange(tokens, "b c t -> b (t c)", c=n_c)
161
+
162
+ def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
163
+ nb, nt = flat_tokens.shape
164
+
165
+ n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
166
+ tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
167
+
168
+ return tokens
vampnet/modules/transformer.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+ from .base import VampBase
10
+ from .modules import CodebookEmbedding
11
+ from .modules import FiLM
12
+ from .modules import SequentialWithFiLM
13
+ from .modules import WNConv1d
14
+
15
+
16
+ class RMSNorm(nn.Module):
17
+ def __init__(self, hidden_size: int, eps=1e-6):
18
+ super().__init__()
19
+ self.weight = nn.Parameter(torch.ones(hidden_size))
20
+ self.var_eps = eps
21
+
22
+ def forward(self, x):
23
+ """Returns root mean square normalized version of input `x`
24
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known
25
+ # as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
26
+ # thus varience is calculated w/o mean and there is no bias
27
+ Parameters
28
+ ----------
29
+ x : Tensor[B x T x D]
30
+ Returns
31
+ -------
32
+ Tensor[B x T x D]
33
+ """
34
+ var = x.pow(2).mean(-1, keepdim=True)
35
+ x = x * torch.rsqrt(var + self.var_eps)
36
+
37
+ return self.weight * x
38
+
39
+
40
+ def get_activation(name: str = "relu"):
41
+ if name == "relu":
42
+ return nn.ReLU
43
+ elif name == "gelu":
44
+ return NewGELU
45
+ elif name == "geglu":
46
+ return GatedGELU
47
+ elif name == "snake":
48
+ return Snake1d
49
+ else:
50
+ raise ValueError(f"Unrecognized activation {name}")
51
+
52
+
53
+ class NewGELU(nn.Module):
54
+ """
55
+ Implementation of the GELU activation function currently in Google BERT repo
56
+ (identical to OpenAI GPT). Also see the Gaussian Error Linear Units
57
+ paper: https://arxiv.org/abs/1606.08415
58
+ """
59
+
60
+ def forward(self, x):
61
+ return (
62
+ 0.5
63
+ * x
64
+ * (
65
+ 1.0
66
+ + torch.tanh(
67
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
68
+ )
69
+ )
70
+ )
71
+
72
+
73
+ class GatedGELU(nn.Module):
74
+ def __init__(self):
75
+ super().__init__()
76
+ self.gelu = NewGELU()
77
+
78
+ def forward(self, x, dim: int = -1):
79
+ p1, p2 = x.chunk(2, dim=dim)
80
+ return p1 * self.gelu(p2)
81
+
82
+
83
+ class Snake1d(nn.Module):
84
+ def __init__(self, channels):
85
+ super().__init__()
86
+ self.alpha = nn.Parameter(torch.ones(channels))
87
+
88
+ def forward(self, x):
89
+ return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(
94
+ self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
95
+ ):
96
+ super().__init__()
97
+ factor = 2 if activation == "geglu" else 1
98
+ self.w_1 = nn.Linear(d_model, d_model * 4, bias=False)
99
+ self.w_2 = nn.Linear(d_model * 4 // factor, d_model, bias=False)
100
+ self.drop = nn.Dropout(dropout)
101
+ self.act = get_activation(activation)()
102
+
103
+ def forward(self, x):
104
+ """Computes position-wise feed-forward layer
105
+ Parameters
106
+ ----------
107
+ x : Tensor[B x T x D]
108
+ Returns
109
+ -------
110
+ Tensor[B x T x D]
111
+ """
112
+ x = self.w_1(x)
113
+ x = self.act(x)
114
+ x = self.drop(x)
115
+ x = self.w_2(x)
116
+ return x
117
+
118
+
119
+ class MultiHeadRelativeAttention(nn.Module):
120
+ def __init__(
121
+ self,
122
+ n_head: int = 8,
123
+ d_model: int = 512,
124
+ dropout: float = 0.1,
125
+ bidirectional: bool = True,
126
+ has_relative_attention_bias: bool = True,
127
+ attention_num_buckets: int = 32,
128
+ attention_max_distance: int = 128,
129
+ ):
130
+ super().__init__()
131
+ d_head = d_model // n_head
132
+ self.n_head = n_head
133
+ self.d_head = d_head
134
+ self.bidirectional = bidirectional
135
+ self.has_relative_attention_bias = has_relative_attention_bias
136
+ self.attention_num_buckets = attention_num_buckets
137
+ self.attention_max_distance = attention_max_distance
138
+
139
+ # Create linear query, key, value projections
140
+ self.w_qs = nn.Linear(d_model, d_model, bias=False)
141
+ self.w_ks = nn.Linear(d_model, d_model, bias=False)
142
+ self.w_vs = nn.Linear(d_model, d_model, bias=False)
143
+
144
+ # Create linear final output projection
145
+ self.fc = nn.Linear(d_model, d_model, bias=False)
146
+
147
+ # Dropout for attention output weights
148
+ self.dropout = nn.Dropout(dropout)
149
+
150
+ # Create relative positional embeddings (if turned on)
151
+ if has_relative_attention_bias:
152
+ self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
153
+
154
+ def _relative_position_bucket(self, relative_position):
155
+ """Converts unbounded relative position into bounded set of buckets
156
+ with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
157
+ buckets
158
+ Parameters
159
+ ----------
160
+ relative_position : Tensor[T_q x T_kv]
161
+ Relative positions between queries and key_value items
162
+ Returns
163
+ -------
164
+ Tensor[T_q x T_kv]
165
+ Input relative positions converted into buckets
166
+ """
167
+ relative_buckets = 0
168
+ num_buckets = self.attention_num_buckets
169
+ max_distance = self.attention_max_distance
170
+
171
+ # Convert relative position for (-inf, inf) to [0, inf]
172
+ # Negative relative positions correspond to past
173
+ # Positive relative positions correspond to future
174
+ if self.bidirectional:
175
+ # use half buckets for each side (past / future)
176
+ num_buckets //= 2
177
+
178
+ # Shift the position positions by `num_buckets` to wrap around
179
+ # negative positions
180
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
181
+ relative_position = torch.abs(relative_position)
182
+ else:
183
+ # If not bidirectional, ignore positive positions and wrap
184
+ # negative positions to positive
185
+ relative_position = -torch.min(
186
+ relative_position, torch.zeros_like(relative_position)
187
+ )
188
+
189
+ # Allocate half of the buckets are for exact increments in positions
190
+ max_exact = num_buckets // 2
191
+ is_small = relative_position < max_exact
192
+
193
+ # The other half of the buckets are for logarithmically bigger bins in
194
+ # positions up to `max_distance`
195
+ relative_postion_if_large = max_exact + (
196
+ torch.log(relative_position.float() / max_exact)
197
+ / math.log(max_distance / max_exact)
198
+ * (num_buckets - max_exact)
199
+ ).to(torch.long)
200
+
201
+ # Clip the max relative position to `num_buckets - 1`
202
+ relative_postion_if_large = torch.min(
203
+ relative_postion_if_large,
204
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
205
+ )
206
+
207
+ # Choose relative buckets based on small or large positions
208
+ relative_buckets += torch.where(
209
+ is_small, relative_position, relative_postion_if_large
210
+ )
211
+
212
+ return relative_buckets
213
+
214
+ def compute_bias(self, query_length, key_length):
215
+ """Computes a position bias scalar for each index in query_length x key_length
216
+ Parameters
217
+ ----------
218
+ query_length : int
219
+ key_length : int
220
+ Returns
221
+ -------
222
+ Tensor[heads x 1 x T_q x T_kv]
223
+ Position bias to be applied on attention logits
224
+ """
225
+
226
+ query_position = torch.arange(query_length, dtype=torch.long)[:, None]
227
+ key_position = torch.arange(key_length, dtype=torch.long)[None, :]
228
+ relative_position = key_position - query_position
229
+
230
+ # Convert relative position to buckets
231
+ relative_position_bucket = self._relative_position_bucket(relative_position)
232
+ relative_position_bucket = relative_position_bucket.to(
233
+ self.relative_attention_bias.weight.device
234
+ )
235
+
236
+ # Index attention bias values
237
+ values = self.relative_attention_bias(relative_position_bucket)
238
+ values = rearrange(values, "q k h -> h 1 q k")
239
+
240
+ return values
241
+
242
+ def forward(self, q, k, v, mask=None, position_bias=None):
243
+ """Computes attention over (keys, values) for every timestep in query
244
+ Parameters
245
+ ----------
246
+ q : Tensor[B x T_q x d_model]
247
+ Query vectors
248
+ k : Tensor[B x T_kv x d_model]
249
+ Key vectors to compute attention over
250
+ v : Tensor[B x T_kv x d_model]
251
+ Value vectors corresponding to the keys
252
+ mask : Tensor[B x T_q x T_kv], optional
253
+ position_bias: Tensor[head x 1 x T_q x T_kv]
254
+ Returns
255
+ -------
256
+ Tensor[B x T_q x d_model]
257
+ Outputs after attending (key, value) using queries
258
+ """
259
+ # Compute query, key, value projections
260
+ q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
261
+ k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
262
+ v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
263
+
264
+ # Compute attention matrix
265
+ attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
266
+
267
+ # Add relative position bias to attention scores
268
+ if position_bias is None:
269
+ if self.has_relative_attention_bias:
270
+ position_bias = self.compute_bias(q.size(-2), k.size(-2))
271
+ else:
272
+ position_bias = torch.zeros_like(attn)
273
+ attn += position_bias
274
+
275
+ # Apply mask to attention scores to prevent looking up invalid locations
276
+ if mask is not None:
277
+ attn = attn.masked_fill(mask[None] == 0, -1e9)
278
+
279
+ # Normalize attention scores and add dropout
280
+ attn = torch.softmax(attn, dim=3)
281
+ attn = self.dropout(attn)
282
+
283
+ # Compute attended outputs (product of attention matrix and values)
284
+ output = torch.einsum("hblt,hbtv->hblv", [attn, v])
285
+ output = rearrange(output, "head b l v -> b l (head v)")
286
+ output = self.fc(output)
287
+
288
+ return output, position_bias
289
+
290
+
291
+ class TransformerLayer(nn.Module):
292
+ def __init__(
293
+ self,
294
+ d_model: int = 512,
295
+ d_cond: int = 64,
296
+ n_heads: int = 8,
297
+ bidirectional: bool = True,
298
+ is_decoder: bool = False,
299
+ has_relative_attention_bias: bool = False,
300
+ flash_attn: bool = False,
301
+ dropout: float = 0.1,
302
+ ):
303
+ super().__init__()
304
+ # Store args
305
+ self.is_decoder = is_decoder
306
+
307
+ # Create self-attention layer
308
+ self.norm_1 = RMSNorm(d_model)
309
+ self.film_1 = FiLM(d_cond, d_model)
310
+ self.flash_attn = flash_attn
311
+
312
+ if flash_attn:
313
+ from flash_attn.flash_attention import FlashMHA
314
+ self.self_attn = FlashMHA(
315
+ embed_dim=d_model,
316
+ num_heads=n_heads,
317
+ attention_dropout=dropout,
318
+ causal=False,
319
+ )
320
+ else:
321
+ self.self_attn = MultiHeadRelativeAttention(
322
+ n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
323
+ )
324
+
325
+ # (Optional) Create cross-attention layer
326
+ if is_decoder:
327
+ self.norm_2 = RMSNorm(d_model)
328
+ self.film_2 = FiLM(d_cond, d_model)
329
+ self.cross_attn = MultiHeadRelativeAttention(
330
+ n_heads,
331
+ d_model,
332
+ dropout,
333
+ bidirectional=True,
334
+ has_relative_attention_bias=False,
335
+ )
336
+
337
+ # Create last feed-forward layer
338
+ self.norm_3 = RMSNorm(d_model)
339
+ self.film_3 = FiLM(d_cond, d_model)
340
+ self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
341
+
342
+ # Create dropout
343
+ self.dropout = nn.Dropout(dropout)
344
+
345
+ def forward(
346
+ self,
347
+ x,
348
+ x_mask,
349
+ cond,
350
+ src=None,
351
+ src_mask=None,
352
+ position_bias=None,
353
+ encoder_decoder_position_bias=None,
354
+ ):
355
+ """Computes one transformer layer consisting of self attention, (op) cross attention
356
+ and feedforward layer
357
+ Parameters
358
+ ----------
359
+ x : Tensor[B x T_q x D]
360
+ x_mask : Tensor[B x T_q]
361
+ src : Tensor[B x T_kv x D], optional
362
+ src_mask : Tensor[B x T_kv x D], optional
363
+ position_bias : Tensor[heads x B x T_q x T_q], optional
364
+ Relative position bias for self attention layer
365
+ encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
366
+ Relative position bias for cross attention layer
367
+ Returns
368
+ -------
369
+ Tensor[B x T_q x D]
370
+ """
371
+ y = self.norm_1(x)
372
+ y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
373
+ if self.flash_attn:
374
+ with torch.autocast(y.device.type, dtype=torch.bfloat16):
375
+ y = self.self_attn(y)[0]
376
+ else:
377
+ y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
378
+ x = x + self.dropout(y)
379
+
380
+ if self.is_decoder:
381
+ y = self.norm_2(x)
382
+ y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
383
+ y, encoder_decoder_position_bias = self.cross_attn(
384
+ y, src, src, src_mask, encoder_decoder_position_bias
385
+ )
386
+ x = x + self.dropout(y)
387
+
388
+ y = self.norm_3(x)
389
+ y = self.film_3(
390
+ y.permute(
391
+ 0,
392
+ 2,
393
+ 1,
394
+ ),
395
+ cond,
396
+ ).permute(0, 2, 1)
397
+ y = self.feed_forward(y)
398
+ x = x + self.dropout(y)
399
+
400
+ return x, position_bias, encoder_decoder_position_bias
401
+
402
+
403
+ class TransformerStack(nn.Module):
404
+ def __init__(
405
+ self,
406
+ d_model: int = 512,
407
+ d_cond: int = 64,
408
+ n_heads: int = 8,
409
+ n_layers: int = 8,
410
+ last_layer: bool = True,
411
+ bidirectional: bool = True,
412
+ flash_attn: bool = False,
413
+ is_decoder: bool = False,
414
+ dropout: float = 0.1,
415
+ ):
416
+ super().__init__()
417
+ # Store args
418
+ self.bidirectional = bidirectional
419
+ self.is_decoder = is_decoder
420
+
421
+ # Create transformer layers
422
+ # In T5, relative attention bias is shared by all layers in the stack
423
+ self.layers = nn.ModuleList(
424
+ [
425
+ TransformerLayer(
426
+ d_model,
427
+ d_cond,
428
+ n_heads,
429
+ bidirectional,
430
+ is_decoder,
431
+ has_relative_attention_bias=(i == 0),
432
+ flash_attn=flash_attn,
433
+ dropout=dropout,
434
+ )
435
+ for i in range(n_layers)
436
+ ]
437
+ )
438
+
439
+ # Perform last normalization
440
+ self.norm = RMSNorm(d_model) if last_layer else None
441
+
442
+ def subsequent_mask(self, size):
443
+ return torch.ones(1, size, size).tril().bool()
444
+
445
+ def forward(self, x, x_mask, cond=None, src=None, src_mask=None):
446
+ """Computes a full transformer stack
447
+ Parameters
448
+ ----------
449
+ x : Tensor[B x T_q x D]
450
+ x_mask : Tensor[B x T_q]
451
+ src : Tensor[B x T_kv x D], optional
452
+ src_mask : Tensor[B x T_kv], optional
453
+ Returns
454
+ -------
455
+ Tensor[B x T_q x D]
456
+ """
457
+
458
+ # Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
459
+ if self.is_decoder:
460
+ src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
461
+
462
+ # Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
463
+ x_mask = x_mask.unsqueeze(-2)
464
+ if not self.bidirectional:
465
+ x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
466
+
467
+ # Initialize position biases
468
+ position_bias = None
469
+ encoder_decoder_position_bias = None
470
+
471
+ # Compute transformer layers
472
+ for layer in self.layers:
473
+ x, position_bias, encoder_decoder_position_bias = layer(
474
+ x=x,
475
+ x_mask=x_mask,
476
+ cond=cond,
477
+ src=src,
478
+ src_mask=src_mask,
479
+ position_bias=position_bias,
480
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
481
+ )
482
+
483
+ return self.norm(x) if self.norm is not None else x
484
+
485
+
486
+ class VampNet(VampBase):
487
+ def __init__(
488
+ self,
489
+ n_heads: int = 20,
490
+ n_layers: int = 16,
491
+ r_cond_dim: int = 64,
492
+ n_codebooks: int = 9,
493
+ n_conditioning_codebooks: int = 0,
494
+ latent_dim: int = 8,
495
+ embedding_dim: int = 1280,
496
+ vocab_size: int = 1024,
497
+ flash_attn: bool = True,
498
+ noise_mode: str = "mask",
499
+ seq_len: int = 313,
500
+ ):
501
+ super().__init__()
502
+ self.n_heads = n_heads
503
+ self.n_layers = n_layers
504
+ self.r_cond_dim = r_cond_dim
505
+ self.n_codebooks = n_codebooks
506
+ self.n_conditioning_codebooks = n_conditioning_codebooks
507
+ self.embedding_dim = embedding_dim
508
+ self.vocab_size = vocab_size
509
+ self.latent_dim = latent_dim
510
+ self.flash_attn = flash_attn
511
+ self.noise_mode = noise_mode
512
+ self.seq_len = seq_len
513
+
514
+ if noise_mode == "mask":
515
+ special_tokens = ["MASK"]
516
+ elif noise_mode == "random":
517
+ special_tokens = None
518
+ else:
519
+ raise ValueError(f"Unknown noise mode: {noise_mode}")
520
+
521
+ self.embedding = CodebookEmbedding(
522
+ latent_dim=latent_dim,
523
+ n_codebooks=n_codebooks,
524
+ vocab_size=vocab_size,
525
+ emb_dim=embedding_dim,
526
+ special_tokens=special_tokens,
527
+ )
528
+
529
+ if noise_mode == "mask":
530
+ self.mask_token = self.embedding.special_idxs["MASK"]
531
+
532
+ self.transformer = TransformerStack(
533
+ d_model=embedding_dim,
534
+ d_cond=r_cond_dim,
535
+ n_heads=n_heads,
536
+ n_layers=n_layers,
537
+ last_layer=True,
538
+ bidirectional=True,
539
+ flash_attn=flash_attn,
540
+ is_decoder=False,
541
+ dropout=0.1,
542
+ )
543
+
544
+ # Add final conv layer
545
+ self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
546
+ self.classifier = SequentialWithFiLM(
547
+ WNConv1d(
548
+ embedding_dim,
549
+ vocab_size * self.n_predict_codebooks,
550
+ kernel_size=1,
551
+ padding="same",
552
+ # groups=self.n_predict_codebooks,
553
+ ),
554
+ )
555
+
556
+ def forward(self, x, cond):
557
+ x = self.embedding(x)
558
+ x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
559
+
560
+ cond = self.r_embed(cond)
561
+
562
+ x = rearrange(x, "b d n -> b n d")
563
+ out = self.transformer(x=x, x_mask=x_mask, cond=cond)
564
+ out = rearrange(out, "b n d -> b d n")
565
+
566
+ out = self.classifier(out, cond)
567
+
568
+ out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
569
+
570
+ return out
571
+
572
+
573
+ if __name__ == "__main__":
574
+ # import argbind
575
+ from .modules import num_params
576
+
577
+ VampNet = argbind.bind(VampNet)
578
+
579
+ @argbind.bind(without_prefix=True)
580
+ def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
581
+ seq_len = int(32000 / 512 * seq_len_s)
582
+
583
+ model = VampNet().to(device)
584
+
585
+ z = torch.randint(
586
+ 0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
587
+ ).to(device)
588
+
589
+ r = torch.zeros(batch_size).to(device)
590
+ z_mask, mask = model.add_noise(z, r)
591
+
592
+ z_mask_latent = torch.rand(
593
+ batch_size, model.latent_dim * model.n_codebooks, seq_len
594
+ ).to(device)
595
+ z_hat = model(z_mask_latent, r)
596
+
597
+ pred = z_hat.argmax(dim=1)
598
+ pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
599
+
600
+ print(f"model has {num_params(model)/1e6:<.3f}M parameters")
601
+ print(f"prediction has shape {pred.shape}")
602
+ breakpoint()
603
+
604
+ args = argbind.parse_args()
605
+ with argbind.scope(args):
606
+ try_model()
vampnet/scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+ class NoamScheduler:
7
+ """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
8
+ Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ optimizer: torch.optim.Optimizer,
14
+ d_model: int = 512,
15
+ factor: float = 1.0,
16
+ warmup: int = 4000,
17
+ ):
18
+ # Store hparams
19
+ self.warmup = warmup
20
+ self.factor = factor
21
+ self.d_model = d_model
22
+
23
+ # Initialize variables `lr` and `steps`
24
+ self.lr = None
25
+ self.steps = 0
26
+
27
+ # Store the optimizer
28
+ self.optimizer = optimizer
29
+
30
+ def state_dict(self):
31
+ return {
32
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
33
+ }
34
+
35
+ def load_state_dict(self, state_dict):
36
+ self.__dict__.update(state_dict)
37
+
38
+ def step(self):
39
+ self.steps += 1
40
+ self.lr = self.factor * (
41
+ self.d_model ** (-0.5)
42
+ * min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
43
+ )
44
+
45
+ for p in self.optimizer.param_groups:
46
+ p["lr"] = self.lr
47
+