marcusj83 felixkreuk commited on
Commit
8deb5b2
0 Parent(s):

Duplicate from musicgen/MusicGen

Browse files

Co-authored-by: Felix Kreuk <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/actions/audiocraft_build/action.yml +29 -0
  2. .github/workflows/audiocraft_docs.yml +32 -0
  3. .github/workflows/audiocraft_linter.yml +17 -0
  4. .github/workflows/audiocraft_tests.yml +17 -0
  5. .gitignore +55 -0
  6. CHANGELOG.md +9 -0
  7. CODE_OF_CONDUCT.md +80 -0
  8. CONTRIBUTING.md +35 -0
  9. LICENSE +21 -0
  10. LICENSE_weights +157 -0
  11. MANIFEST.in +8 -0
  12. MODEL_CARD.md +81 -0
  13. Makefile +21 -0
  14. README.md +125 -0
  15. app.py +155 -0
  16. app_batched.py +130 -0
  17. assets/bach.mp3 +0 -0
  18. assets/bolero_ravel.mp3 +0 -0
  19. audiocraft/__init__.py +10 -0
  20. audiocraft/data/__init__.py +8 -0
  21. audiocraft/data/audio.py +213 -0
  22. audiocraft/data/audio_dataset.py +525 -0
  23. audiocraft/data/audio_utils.py +169 -0
  24. audiocraft/data/zip.py +74 -0
  25. audiocraft/models/__init__.py +10 -0
  26. audiocraft/models/builders.py +218 -0
  27. audiocraft/models/encodec.py +302 -0
  28. audiocraft/models/lm.py +526 -0
  29. audiocraft/models/loaders.py +92 -0
  30. audiocraft/models/musicgen.py +283 -0
  31. audiocraft/modules/__init__.py +20 -0
  32. audiocraft/modules/activations.py +96 -0
  33. audiocraft/modules/codebooks_patterns.py +539 -0
  34. audiocraft/modules/conditioners.py +986 -0
  35. audiocraft/modules/conv.py +245 -0
  36. audiocraft/modules/lstm.py +25 -0
  37. audiocraft/modules/rope.py +124 -0
  38. audiocraft/modules/seanet.py +258 -0
  39. audiocraft/modules/streaming.py +135 -0
  40. audiocraft/modules/transformer.py +704 -0
  41. audiocraft/py.typed +0 -0
  42. audiocraft/quantization/__init__.py +9 -0
  43. audiocraft/quantization/base.py +107 -0
  44. audiocraft/quantization/core_vq.py +400 -0
  45. audiocraft/quantization/vq.py +116 -0
  46. audiocraft/utils/__init__.py +5 -0
  47. audiocraft/utils/autocast.py +40 -0
  48. audiocraft/utils/export.py +56 -0
  49. audiocraft/utils/notebook.py +32 -0
  50. audiocraft/utils/utils.py +234 -0
.github/actions/audiocraft_build/action.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: audiocraft_build
2
+ description: 'Build audiocraft env.'
3
+ runs:
4
+ using: "composite"
5
+ steps:
6
+ - uses: actions/setup-python@v2
7
+ with:
8
+ python-version: 3.8
9
+ - uses: actions/cache@v2
10
+ id: cache
11
+ with:
12
+ path: env
13
+ key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
14
+
15
+ - if: ${{ steps.cache.outputs.cache-hit != 'true' }}
16
+ name: Install dependencies
17
+ shell: bash
18
+ run: |
19
+ sudo apt-get update
20
+ sudo apt-get install libsndfile1-dev ffmpeg
21
+ python3 -m venv env
22
+ . env/bin/activate
23
+ python -m pip install --upgrade pip
24
+ pip install -e '.[dev]'
25
+ - name: System Dependencies
26
+ shell: bash
27
+ run: |
28
+ sudo apt-get update
29
+ sudo apt-get install libsndfile1-dev ffmpeg
.github/workflows/audiocraft_docs.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: audiocraft_docs
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+
6
+ jobs:
7
+ run_docs:
8
+ name: Run docs
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v2
12
+ - uses: ./.github/actions/audiocraft_build
13
+ - name: Config git
14
+ run: |
15
+ git config --global user.email "[email protected]"
16
+ git config --global user.name "Alexandre Défossez (autodoc)"
17
+
18
+ - name: Reset branch
19
+ run: |
20
+ git branch -f gh-docs main
21
+ git checkout gh-docs
22
+
23
+ - name: Make docs
24
+ run: |
25
+ . env/bin/activate
26
+ make docs
27
+ git add -f docs
28
+ git commit -m docs
29
+
30
+ - name: Push branch
31
+ run: |
32
+ git push -f -u origin gh-docs
.github/workflows/audiocraft_linter.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: audiocraft_linter
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+ pull_request:
6
+ branches: [ main ]
7
+
8
+ jobs:
9
+ run_linter:
10
+ name: Run linter
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - uses: ./.github/actions/audiocraft_build
15
+ - run: |
16
+ . env/bin/activate
17
+ make linter
.github/workflows/audiocraft_tests.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: audiocraft_tests
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+ pull_request:
6
+ branches: [ main ]
7
+
8
+ jobs:
9
+ run_tests:
10
+ name: Run tests
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - uses: ./.github/actions/audiocraft_build
15
+ - run: |
16
+ . env/bin/activate
17
+ make tests
.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # macOS dir files
10
+ .DS_Store
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ env/
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ .ipynb_checkpoints
31
+
32
+ # Tests and linter
33
+ .pytest_cache/
34
+ .mypy_cache/
35
+ .coverage
36
+
37
+ # docs
38
+ /docs
39
+
40
+ # dotenv
41
+ .env
42
+ .envrc
43
+
44
+ # virtualenv
45
+ .venv
46
+ venv/
47
+ ENV/
48
+
49
+ # personal notebooks & scripts
50
+ */local_scripts
51
+ */notes
52
+ .vscode/
53
+ /notebooks
54
+ /local_scripts
55
+ /notes
CHANGELOG.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6
+
7
+ ## [0.0.1a] - TBD
8
+
9
+ Initial release, with model evaluation only.
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Audiocraft
2
+
3
+ We want to make contributing to this project as easy and transparent as
4
+ possible.
5
+
6
+ ## Pull Requests
7
+
8
+ Audiocraft is the implementation of a research paper.
9
+ Therefore, we do not plan on accepting many pull requests for new features.
10
+ We certainly welcome them for bug fixes.
11
+
12
+ 1. Fork the repo and create your branch from `main`.
13
+ 2. If you've added code that should be tested, add tests.
14
+ 3. If you've changed APIs, update the documentation.
15
+ 4. Ensure the test suite passes.
16
+ 5. Make sure your code lints.
17
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
18
+
19
+ ## Contributor License Agreement ("CLA")
20
+ In order to accept your pull request, we need you to submit a CLA. You only need
21
+ to do this once to work on any of Meta's open source projects.
22
+
23
+ Complete your CLA here: <https://code.facebook.com/cla>
24
+
25
+ ## Issues
26
+ We use GitHub issues to track public bugs. Please ensure your description is
27
+ clear and has sufficient instructions to be able to reproduce the issue.
28
+
29
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
30
+ disclosure of security bugs. In those cases, please go through the process
31
+ outlined on that page and do not file a public issue.
32
+
33
+ ## License
34
+ By contributing to encodec, you agree that your contributions will be licensed
35
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE_weights ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attribution-NonCommercial-NoDerivatives 4.0 International
2
+
3
+ > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
4
+ >
5
+ > ### Using Creative Commons Public Licenses
6
+ >
7
+ > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+ >
9
+ > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10
+ >
11
+ > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12
+
13
+ ## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ ### Section 1 – Definitions.
18
+
19
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+
21
+ b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
22
+
23
+ e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
24
+
25
+ f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
26
+
27
+ h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
28
+
29
+ i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
30
+
31
+ h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
32
+
33
+ i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
34
+
35
+ j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
36
+
37
+ k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
38
+
39
+ l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
40
+
41
+ ### Section 2 – Scope.
42
+
43
+ a. ___License grant.___
44
+
45
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
46
+
47
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
48
+
49
+ B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only.
50
+
51
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
52
+
53
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
54
+
55
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
56
+
57
+ 5. __Downstream recipients.__
58
+
59
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
60
+
61
+ B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
62
+
63
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
64
+
65
+ b. ___Other rights.___
66
+
67
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
68
+
69
+ 2. Patent and trademark rights are not licensed under this Public License.
70
+
71
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
72
+
73
+ ### Section 3 – License Conditions.
74
+
75
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
76
+
77
+ a. ___Attribution.___
78
+
79
+ 1. If You Share the Licensed Material, You must:
80
+
81
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
82
+
83
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
84
+
85
+ ii. a copyright notice;
86
+
87
+ iii. a notice that refers to this Public License;
88
+
89
+ iv. a notice that refers to the disclaimer of warranties;
90
+
91
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
92
+
93
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
94
+
95
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
96
+
97
+ For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material.
98
+
99
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
100
+
101
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
102
+
103
+ ### Section 4 – Sui Generis Database Rights.
104
+
105
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
106
+
107
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material;
108
+
109
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
110
+
111
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
112
+
113
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
114
+
115
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
116
+
117
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
118
+
119
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
120
+
121
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
122
+
123
+ ### Section 6 – Term and Termination.
124
+
125
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
126
+
127
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
128
+
129
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
130
+
131
+ 2. upon express reinstatement by the Licensor.
132
+
133
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
134
+
135
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
136
+
137
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
138
+
139
+ ### Section 7 – Other Terms and Conditions.
140
+
141
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
142
+
143
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
144
+
145
+ ### Section 8 – Interpretation.
146
+
147
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
148
+
149
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
150
+
151
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
152
+
153
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
154
+
155
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
156
+ >
157
+ > Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org).
MANIFEST.in ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ include Makefile
2
+ include LICENSE
3
+ include LICENSE_weights
4
+ include *.md
5
+ include *.ini
6
+ include requirements.txt
7
+ include audiocraft/py.typed
8
+ include assets/*.mp3
MODEL_CARD.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MusicGen Model Card
2
+
3
+ ## Model details
4
+
5
+ **Organization developing the model:** The FAIR team of Meta AI.
6
+
7
+ **Model date:** MusicGen was trained between April 2023 and May 2023.
8
+
9
+ **Model version:** This is the version 1 of the model.
10
+
11
+ **Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation.
12
+
13
+ **Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
14
+
15
+ **Citation details** See [our paper][arxiv]
16
+
17
+ **License** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
18
+
19
+ **Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [Github repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
20
+
21
+ ## Intended use
22
+ **Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
23
+
24
+ - Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
25
+ - Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs
26
+
27
+ **Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
28
+
29
+ **Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
30
+
31
+ ## Metrics
32
+
33
+ **Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark:
34
+
35
+ - Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
36
+ - Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
37
+ - CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model
38
+
39
+ Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
40
+
41
+ - Overall quality of the music samples;
42
+ - Text relevance to the provided text input;
43
+ - Adherence to the melody for melody-guided music generation.
44
+
45
+ More details on performance measures and human studies can be found in the paper.
46
+
47
+ **Decision thresholds:** Not applicable.
48
+
49
+ ## Evaluation datasets
50
+
51
+ The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set.
52
+
53
+ ## Training datasets
54
+
55
+ The model was trained using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
56
+
57
+ ## Quantitative analysis
58
+
59
+ More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Experimental Setup section.
60
+
61
+ ## Limitations and biases
62
+
63
+ **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
64
+
65
+ **Mitigations:** All vocals have been removed from the data source using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). The model is therefore not able to produce vocals.
66
+
67
+ **Limitations:**
68
+
69
+ - The model is not able to generate realistic vocals.
70
+ - The model has been trained with English descriptions and will not perform as well in other languages.
71
+ - The model does not perform equally well for all music styles and cultures.
72
+ - The model sometimes generates end of songs, collapsing to silence.
73
+ - It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
74
+
75
+ **Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive.
76
+
77
+ **Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
78
+
79
+ **Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
80
+
81
+ [arxiv]: https://arxiv.org/abs/2306.05284
Makefile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default: linter tests
2
+
3
+ install:
4
+ pip install -U pip
5
+ pip install -U -e '.[dev]'
6
+
7
+ linter:
8
+ flake8 audiocraft && mypy audiocraft
9
+ flake8 tests && mypy tests
10
+
11
+ tests:
12
+ coverage run -m pytest tests
13
+ coverage report --include 'audiocraft/*'
14
+
15
+ docs:
16
+ pdoc3 --html -o docs -f audiocraft
17
+
18
+ dist:
19
+ python setup.py sdist
20
+
21
+ .PHONY: linter tests docs dist
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MusicGen
3
+ python_version: '3.9'
4
+ tags:
5
+ - music generation
6
+ - language models
7
+ - LLMs
8
+ app_file: app.py
9
+ emoji: 🎵
10
+ colorFrom: white
11
+ colorTo: blue
12
+ sdk: gradio
13
+ sdk_version: 3.34.0
14
+ pinned: true
15
+ suggested_hardware: a10g-large
16
+ license: cc-by-nc-4.0
17
+ duplicated_from: musicgen/MusicGen
18
+ ---
19
+ # Audiocraft
20
+ ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
21
+ ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
22
+ ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
23
+
24
+ Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model.
25
+
26
+ ## MusicGen
27
+
28
+ Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
29
+ Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't not require a self-supervised semantic representation, and it generates
30
+ all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
31
+ them in parallel, thus having only 50 auto-regressive steps per second of audio.
32
+ Check out our [sample page][musicgen_samples] or test the available demo!
33
+
34
+ <a target="_blank" href="https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing">
35
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
36
+ </a>
37
+ <a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
38
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
39
+ </a>
40
+ <br>
41
+
42
+ ## Installation
43
+ Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
44
+
45
+ ```shell
46
+ # Best to make sure you have torch installed first, in particular before installing xformers.
47
+ # Don't run this if you already have PyTorch installed.
48
+ pip install 'torch>=2.0'
49
+ # Then proceed to one of the following
50
+ pip install -U audiocraft # stable release
51
+ pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
52
+ pip install -e . # or if you cloned the repo locally
53
+ ```
54
+
55
+ ## Usage
56
+ You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing). Finally, a demo is also available on the [`facebook/MusiGen` HugginFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
57
+
58
+ ## API
59
+
60
+ We provide a simple API and 4 pre-trained models. The pre trained models are:
61
+ - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
62
+ - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
63
+ - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
64
+ - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
65
+
66
+ We observe the best trade-off between quality and compute with the `medium` or `melody` model.
67
+ In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
68
+ GPUs will be able to generate short sequences, or longer sequences with the `small` model.
69
+
70
+ **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
71
+ You can install it with:
72
+ ```
73
+ apt get install ffmpeg
74
+ ```
75
+
76
+ See after a quick example for using the API.
77
+
78
+ ```python
79
+ import torchaudio
80
+ from audiocraft.models import MusicGen
81
+ from audiocraft.data.audio import audio_write
82
+
83
+ model = MusicGen.get_pretrained('melody')
84
+ model.set_generation_params(duration=8) # generate 8 seconds.
85
+ wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
86
+ descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
87
+ wav = model.generate(descriptions) # generates 3 samples.
88
+
89
+ melody, sr = torchaudio.load('./assets/bach.mp3')
90
+ # generates using the melody from the given audio and the provided descriptions.
91
+ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
92
+
93
+ for idx, one_wav in enumerate(wav):
94
+ # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
95
+ audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
96
+ ```
97
+
98
+
99
+ ## Model Card
100
+
101
+ See [the model card page](./MODEL_CARD.md).
102
+
103
+ ## FAQ
104
+
105
+ #### Will the training code be released?
106
+
107
+ Yes. We will soon release the training code for MusicGen and EnCodec.
108
+
109
+
110
+ ## Citation
111
+ ```
112
+ @article{copet2023simple,
113
+ title={Simple and Controllable Music Generation},
114
+ author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
115
+ year={2023},
116
+ journal={arXiv preprint arXiv:2306.05284},
117
+ }
118
+ ```
119
+
120
+ ## License
121
+ * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
122
+ * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
123
+
124
+ [arxiv]: https://arxiv.org/abs/2306.05284
125
+ [musicgen_samples]: https://ai.honu.io/papers/musicgen/
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+
5
+ This source code is licensed under the license found in the
6
+ LICENSE file in the root directory of this source tree.
7
+ """
8
+
9
+ from tempfile import NamedTemporaryFile
10
+ import torch
11
+ import gradio as gr
12
+ import os
13
+ from audiocraft.models import MusicGen
14
+
15
+ from audiocraft.data.audio import audio_write
16
+
17
+
18
+ MODEL = None
19
+ IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ['SPACE_ID']
20
+
21
+ def load_model(version):
22
+ print("Loading model", version)
23
+ return MusicGen.get_pretrained(version)
24
+
25
+
26
+ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
27
+ global MODEL
28
+ topk = int(topk)
29
+ if MODEL is None or MODEL.name != model:
30
+ MODEL = load_model(model)
31
+
32
+ if duration > MODEL.lm.cfg.dataset.segment_duration:
33
+ raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
34
+ MODEL.set_generation_params(
35
+ use_sampling=True,
36
+ top_k=topk,
37
+ top_p=topp,
38
+ temperature=temperature,
39
+ cfg_coef=cfg_coef,
40
+ duration=duration,
41
+ )
42
+
43
+ if melody:
44
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
45
+ print(melody.shape)
46
+ if melody.dim() == 2:
47
+ melody = melody[None]
48
+ melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
49
+ output = MODEL.generate_with_chroma(
50
+ descriptions=[text],
51
+ melody_wavs=melody,
52
+ melody_sample_rate=sr,
53
+ progress=False
54
+ )
55
+ else:
56
+ output = MODEL.generate(descriptions=[text], progress=False)
57
+
58
+ output = output.detach().cpu().float()[0]
59
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
+ audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
61
+ waveform_video = gr.make_waveform(file.name)
62
+ return waveform_video
63
+
64
+
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown(
67
+ """
68
+ # MusicGen
69
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
70
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
71
+ """
72
+ )
73
+ if IS_SHARED_SPACE:
74
+ gr.Markdown("""
75
+ ⚠ This Space doesn't work in this shared UI ⚠
76
+
77
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
78
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
79
+ to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
80
+ """
81
+ )
82
+ with gr.Row():
83
+ with gr.Column():
84
+ with gr.Row():
85
+ text = gr.Text(label="Input Text", interactive=True)
86
+ melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
87
+ with gr.Row():
88
+ submit = gr.Button("Submit" if not IS_SHARED_SPACE else "Duplicate the Space to generate", interactive=not IS_SHARED_SPACE)
89
+ with gr.Row():
90
+ model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
91
+ with gr.Row():
92
+ duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
93
+ with gr.Row():
94
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
95
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
96
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
97
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
98
+ with gr.Column():
99
+ output = gr.Video(label="Generated Music")
100
+ submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
101
+ gr.Examples(
102
+ fn=predict,
103
+ examples=[
104
+ [
105
+ "An 80s driving pop song with heavy drums and synth pads in the background",
106
+ "./assets/bach.mp3",
107
+ "melody"
108
+ ],
109
+ [
110
+ "A cheerful country song with acoustic guitars",
111
+ "./assets/bolero_ravel.mp3",
112
+ "melody"
113
+ ],
114
+ [
115
+ "90s rock song with electric guitar and heavy drums",
116
+ None,
117
+ "medium"
118
+ ],
119
+ [
120
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
121
+ "./assets/bach.mp3",
122
+ "melody"
123
+ ],
124
+ [
125
+ "lofi slow bpm electro chill with organic samples",
126
+ None,
127
+ "medium",
128
+ ],
129
+ ],
130
+ inputs=[text, melody, model],
131
+ outputs=[output]
132
+ )
133
+ gr.Markdown(
134
+ """
135
+ ### More details
136
+
137
+ The model will generate a short music extract based on the description you provided.
138
+ You can generate up to 30 seconds of audio.
139
+
140
+ We present 4 model variations:
141
+ 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
142
+ 2. Small -- a 300M transformer decoder conditioned on text only.
143
+ 3. Medium -- a 1.5B transformer decoder conditioned on text only.
144
+ 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
145
+
146
+ When using `melody`, ou can optionaly provide a reference audio from
147
+ which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
148
+
149
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
150
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
151
+ for more details.
152
+ """
153
+ )
154
+
155
+ demo.launch()
app_batched.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ All rights reserved.
4
+
5
+ This source code is licensed under the license found in the
6
+ LICENSE file in the root directory of this source tree.
7
+ """
8
+
9
+ from tempfile import NamedTemporaryFile
10
+ import torch
11
+ import gradio as gr
12
+ from audiocraft.data.audio_utils import convert_audio
13
+ from audiocraft.data.audio import audio_write
14
+ from audiocraft.models import MusicGen
15
+
16
+
17
+ MODEL = None
18
+
19
+
20
+ def load_model():
21
+ print("Loading model")
22
+ return MusicGen.get_pretrained("melody")
23
+
24
+
25
+ def predict(texts, melodies):
26
+ global MODEL
27
+ if MODEL is None:
28
+ MODEL = load_model()
29
+
30
+ duration = 12
31
+ MODEL.set_generation_params(duration=duration)
32
+
33
+ print(texts, melodies)
34
+ processed_melodies = []
35
+
36
+ target_sr = 32000
37
+ target_ac = 1
38
+ for melody in melodies:
39
+ if melody is None:
40
+ processed_melodies.append(None)
41
+ else:
42
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
43
+ if melody.dim() == 1:
44
+ melody = melody[None]
45
+ melody = melody[..., :int(sr * duration)]
46
+ melody = convert_audio(melody, sr, target_sr, target_ac)
47
+ processed_melodies.append(melody)
48
+
49
+ outputs = MODEL.generate_with_chroma(
50
+ descriptions=texts,
51
+ melody_wavs=processed_melodies,
52
+ melody_sample_rate=target_sr,
53
+ progress=False
54
+ )
55
+
56
+ outputs = outputs.detach().cpu().float()
57
+ out_files = []
58
+ for output in outputs:
59
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
+ audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
61
+ waveform_video = gr.make_waveform(file.name)
62
+ out_files.append(waveform_video)
63
+ return [out_files]
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown(
68
+ """
69
+ # MusicGen
70
+
71
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
72
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
73
+ <br/>
74
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
75
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
76
+ for longer sequences, more control and no queue.</p>
77
+ """
78
+ )
79
+ with gr.Row():
80
+ with gr.Column():
81
+ with gr.Row():
82
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
83
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
84
+ with gr.Row():
85
+ submit = gr.Button("Generate")
86
+ with gr.Column():
87
+ output = gr.Video(label="Generated Music")
88
+ submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
89
+ gr.Examples(
90
+ fn=predict,
91
+ examples=[
92
+ [
93
+ "An 80s driving pop song with heavy drums and synth pads in the background",
94
+ "./assets/bach.mp3",
95
+ ],
96
+ [
97
+ "A cheerful country song with acoustic guitars",
98
+ "./assets/bolero_ravel.mp3",
99
+ ],
100
+ [
101
+ "90s rock song with electric guitar and heavy drums",
102
+ None,
103
+ ],
104
+ [
105
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
106
+ "./assets/bach.mp3",
107
+ ],
108
+ [
109
+ "lofi slow bpm electro chill with organic samples",
110
+ None,
111
+ ],
112
+ ],
113
+ inputs=[text, melody],
114
+ outputs=[output]
115
+ )
116
+ gr.Markdown("""
117
+ ### More details
118
+
119
+ The model will generate 12 seconds of audio based on the description you provided.
120
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
121
+ The model will then try to follow both the description and melody provided.
122
+ All samples are generated with the `melody` model.
123
+
124
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
125
+
126
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
127
+ for more details.
128
+ """)
129
+
130
+ demo.queue(max_size=15).launch()
assets/bach.mp3 ADDED
Binary file (160 kB). View file
 
assets/bolero_ravel.mp3 ADDED
Binary file (161 kB). View file
 
audiocraft/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from . import data, modules, models
9
+
10
+ __version__ = '0.0.1'
audiocraft/data/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from . import audio, audio_dataset
audiocraft/data/audio.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Audio IO methods are defined in this module (info, read, write),
9
+ We rely on av library for faster read when possible, otherwise on torchaudio.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ import logging
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import soundfile
19
+ import torch
20
+ from torch.nn import functional as F
21
+ import torchaudio as ta
22
+
23
+ import av
24
+
25
+ from .audio_utils import f32_pcm, i16_pcm, normalize_audio
26
+
27
+
28
+ _av_initialized = False
29
+
30
+
31
+ def _init_av():
32
+ global _av_initialized
33
+ if _av_initialized:
34
+ return
35
+ logger = logging.getLogger('libav.mp3')
36
+ logger.setLevel(logging.ERROR)
37
+ _av_initialized = True
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class AudioFileInfo:
42
+ sample_rate: int
43
+ duration: float
44
+ channels: int
45
+
46
+
47
+ def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48
+ _init_av()
49
+ with av.open(str(filepath)) as af:
50
+ stream = af.streams.audio[0]
51
+ sample_rate = stream.codec_context.sample_rate
52
+ duration = float(stream.duration * stream.time_base)
53
+ channels = stream.channels
54
+ return AudioFileInfo(sample_rate, duration, channels)
55
+
56
+
57
+ def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
58
+ info = soundfile.info(filepath)
59
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
60
+
61
+
62
+ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
64
+ filepath = Path(filepath)
65
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
66
+ # ffmpeg has some weird issue with flac.
67
+ return _soundfile_info(filepath)
68
+ else:
69
+ return _av_info(filepath)
70
+
71
+
72
+ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
+ """FFMPEG-based audio file reading using PyAV bindings.
74
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75
+
76
+ Args:
77
+ filepath (str or Path): Path to audio file to read.
78
+ seek_time (float): Time at which to start reading in the file.
79
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
+ Returns:
81
+ Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
82
+ """
83
+ _init_av()
84
+ with av.open(str(filepath)) as af:
85
+ stream = af.streams.audio[0]
86
+ sr = stream.codec_context.sample_rate
87
+ num_frames = int(sr * duration) if duration >= 0 else -1
88
+ frame_offset = int(sr * seek_time)
89
+ # we need a small negative offset otherwise we get some edge artifact
90
+ # from the mp3 decoder.
91
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
92
+ frames = []
93
+ length = 0
94
+ for frame in af.decode(streams=stream.index):
95
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
96
+ strip = max(0, frame_offset - current_offset)
97
+ buf = torch.from_numpy(frame.to_ndarray())
98
+ if buf.shape[0] != stream.channels:
99
+ buf = buf.view(-1, stream.channels).t()
100
+ buf = buf[:, strip:]
101
+ frames.append(buf)
102
+ length += buf.shape[1]
103
+ if num_frames > 0 and length >= num_frames:
104
+ break
105
+ assert frames
106
+ # If the above assert fails, it is likely because we seeked past the end of file point,
107
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
108
+ # This will need proper debugging, in due time.
109
+ wav = torch.cat(frames, dim=1)
110
+ assert wav.shape[0] == stream.channels
111
+ if num_frames > 0:
112
+ wav = wav[:, :num_frames]
113
+ return f32_pcm(wav), sr
114
+
115
+
116
+ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
+ """Read audio by picking the most appropriate backend tool based on the audio format.
119
+
120
+ Args:
121
+ filepath (str or Path): Path to audio file to read.
122
+ seek_time (float): Time at which to start reading in the file.
123
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
+ pad (bool): Pad output audio if not reaching expected duration.
125
+ Returns:
126
+ Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
127
+ """
128
+ fp = Path(filepath)
129
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
130
+ # There is some bug with ffmpeg and reading flac
131
+ info = _soundfile_info(filepath)
132
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
133
+ frame_offset = int(seek_time * info.sample_rate)
134
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
135
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
136
+ wav = torch.from_numpy(wav).t().contiguous()
137
+ if len(wav.shape) == 1:
138
+ wav = torch.unsqueeze(wav, 0)
139
+ elif (
140
+ fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
141
+ and duration <= 0 and seek_time == 0
142
+ ):
143
+ # Torchaudio is faster if we load an entire file at once.
144
+ wav, sr = ta.load(fp)
145
+ else:
146
+ wav, sr = _av_read(filepath, seek_time, duration)
147
+ if pad and duration > 0:
148
+ expected_frames = int(duration * sr)
149
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
150
+ return wav, sr
151
+
152
+
153
+ def audio_write(stem_name: tp.Union[str, Path],
154
+ wav: torch.Tensor, sample_rate: int,
155
+ format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
156
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158
+ log_clipping: bool = True, make_parent_dir: bool = True,
159
+ add_suffix: bool = True) -> Path:
160
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
161
+
162
+ Args:
163
+ stem_name (str or Path): Filename without extension which will be added automatically.
164
+ format (str): Either "wav" or "mp3".
165
+ mp3_rate (int): kbps when using mp3s.
166
+ normalize (bool): if `True` (default), normalizes according to the prescribed
167
+ strategy (see after). If `False`, the strategy is only used in case clipping
168
+ would happen.
169
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
170
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
171
+ with extra headroom to avoid clipping. 'clip' just clips.
172
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
173
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
174
+ than the `peak_clip` one to avoid further clipping.
175
+ loudness_headroom_db (float): Target loudness for loudness normalization.
176
+ log_clipping (bool): If True, basic logging on stderr when clipping still
177
+ occurs despite strategy (only for 'rms').
178
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
179
+ Returns:
180
+ Path: Path of the saved audio.
181
+ """
182
+ assert wav.dtype.is_floating_point, "wav is not floating point"
183
+ if wav.dim() == 1:
184
+ wav = wav[None]
185
+ elif wav.dim() > 2:
186
+ raise ValueError("Input wav should be at most 2 dimension.")
187
+ assert wav.isfinite().all()
188
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
189
+ rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
190
+ sample_rate=sample_rate, stem_name=str(stem_name))
191
+ kwargs: dict = {}
192
+ if format == 'mp3':
193
+ suffix = '.mp3'
194
+ kwargs.update({"compression": mp3_rate})
195
+ elif format == 'wav':
196
+ wav = i16_pcm(wav)
197
+ suffix = '.wav'
198
+ kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
199
+ else:
200
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
201
+ if not add_suffix:
202
+ suffix = ''
203
+ path = Path(str(stem_name) + suffix)
204
+ if make_parent_dir:
205
+ path.parent.mkdir(exist_ok=True, parents=True)
206
+ try:
207
+ ta.save(path, wav, sample_rate, **kwargs)
208
+ except Exception:
209
+ if path.exists():
210
+ # we do not want to leave half written files around.
211
+ path.unlink()
212
+ raise
213
+ return path
audiocraft/data/audio_dataset.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import copy
9
+ from concurrent.futures import ThreadPoolExecutor, Future
10
+ from dataclasses import dataclass, fields
11
+ from contextlib import ExitStack
12
+ import gzip
13
+ import json
14
+ import logging
15
+ import os
16
+ from pathlib import Path
17
+ import random
18
+ import sys
19
+ import typing as tp
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ from .audio import audio_read, audio_info
25
+ from .audio_utils import convert_audio
26
+ from .zip import PathInZip
27
+
28
+ try:
29
+ import dora
30
+ except ImportError:
31
+ dora = None # type: ignore
32
+
33
+
34
+ @dataclass(order=True)
35
+ class BaseInfo:
36
+
37
+ @classmethod
38
+ def _dict2fields(cls, dictionary: dict):
39
+ return {
40
+ field.name: dictionary[field.name]
41
+ for field in fields(cls) if field.name in dictionary
42
+ }
43
+
44
+ @classmethod
45
+ def from_dict(cls, dictionary: dict):
46
+ _dictionary = cls._dict2fields(dictionary)
47
+ return cls(**_dictionary)
48
+
49
+ def to_dict(self):
50
+ return {
51
+ field.name: self.__getattribute__(field.name)
52
+ for field in fields(self)
53
+ }
54
+
55
+
56
+ @dataclass(order=True)
57
+ class AudioMeta(BaseInfo):
58
+ path: str
59
+ duration: float
60
+ sample_rate: int
61
+ amplitude: tp.Optional[float] = None
62
+ weight: tp.Optional[float] = None
63
+ # info_path is used to load additional information about the audio file that is stored in zip files.
64
+ info_path: tp.Optional[PathInZip] = None
65
+
66
+ @classmethod
67
+ def from_dict(cls, dictionary: dict):
68
+ base = cls._dict2fields(dictionary)
69
+ if 'info_path' in base and base['info_path'] is not None:
70
+ base['info_path'] = PathInZip(base['info_path'])
71
+ return cls(**base)
72
+
73
+ def to_dict(self):
74
+ d = super().to_dict()
75
+ if d['info_path'] is not None:
76
+ d['info_path'] = str(d['info_path'])
77
+ return d
78
+
79
+
80
+ @dataclass(order=True)
81
+ class SegmentInfo(BaseInfo):
82
+ meta: AudioMeta
83
+ seek_time: float
84
+ n_frames: int # actual number of frames without padding
85
+ total_frames: int # total number of frames, padding included
86
+ sample_rate: int # actual sample rate
87
+
88
+
89
+ DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
90
+
91
+ logger = logging.getLogger(__name__)
92
+
93
+
94
+ def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
95
+ """AudioMeta from a path to an audio file.
96
+
97
+ Args:
98
+ file_path (str): Resolved path of valid audio file.
99
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
100
+ Returns:
101
+ AudioMeta: Audio file path and its metadata.
102
+ """
103
+ info = audio_info(file_path)
104
+ amplitude: tp.Optional[float] = None
105
+ if not minimal:
106
+ wav, sr = audio_read(file_path)
107
+ amplitude = wav.abs().max().item()
108
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
109
+
110
+
111
+ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
112
+ """If Dora is available as a dependency, try to resolve potential relative paths
113
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
114
+
115
+ Args:
116
+ m (AudioMeta): Audio meta to resolve.
117
+ fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
118
+ Only valid on Linux/Mac.
119
+ Returns:
120
+ AudioMeta: Audio meta with resolved path.
121
+ """
122
+ def is_abs(m):
123
+ if fast:
124
+ return str(m)[0] == '/'
125
+ else:
126
+ os.path.isabs(str(m))
127
+
128
+ if not dora:
129
+ return m
130
+
131
+ if not is_abs(m.path):
132
+ m.path = dora.git_save.to_absolute_path(m.path)
133
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
134
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
135
+ return m
136
+
137
+
138
+ def find_audio_files(path: tp.Union[Path, str],
139
+ exts: tp.List[str] = DEFAULT_EXTS,
140
+ resolve: bool = True,
141
+ minimal: bool = True,
142
+ progress: bool = False,
143
+ workers: int = 0) -> tp.List[AudioMeta]:
144
+ """Build a list of AudioMeta from a given path,
145
+ collecting relevant audio files and fetching meta info.
146
+
147
+ Args:
148
+ path (str or Path): Path to folder containing audio files.
149
+ exts (list of str): List of file extensions to consider for audio files.
150
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
151
+ progress (bool): Whether to log progress on audio files collection.
152
+ workers (int): number of parallel workers, if 0, use only the current thread.
153
+ Returns:
154
+ List[AudioMeta]: List of audio file path and its metadata.
155
+ """
156
+ audio_files = []
157
+ futures: tp.List[Future] = []
158
+ pool: tp.Optional[ThreadPoolExecutor] = None
159
+ with ExitStack() as stack:
160
+ if workers > 0:
161
+ pool = ThreadPoolExecutor(workers)
162
+ stack.enter_context(pool)
163
+
164
+ if progress:
165
+ print("Finding audio files...")
166
+ for root, folders, files in os.walk(path, followlinks=True):
167
+ for file in files:
168
+ full_path = Path(root) / file
169
+ if full_path.suffix.lower() in exts:
170
+ audio_files.append(full_path)
171
+ if pool is not None:
172
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
173
+ if progress:
174
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
175
+
176
+ if progress:
177
+ print("Getting audio metadata...")
178
+ meta: tp.List[AudioMeta] = []
179
+ for idx, file_path in enumerate(audio_files):
180
+ try:
181
+ if pool is None:
182
+ m = _get_audio_meta(str(file_path), minimal)
183
+ else:
184
+ m = futures[idx].result()
185
+ if resolve:
186
+ m = _resolve_audio_meta(m)
187
+ except Exception as err:
188
+ print("Error with", str(file_path), err, file=sys.stderr)
189
+ continue
190
+ meta.append(m)
191
+ if progress:
192
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
193
+ meta.sort()
194
+ return meta
195
+
196
+
197
+ def load_audio_meta(path: tp.Union[str, Path],
198
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
199
+ """Load list of AudioMeta from an optionally compressed json file.
200
+
201
+ Args:
202
+ path (str or Path): Path to JSON file.
203
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
204
+ fast (bool): activates some tricks to make things faster.
205
+ Returns:
206
+ List[AudioMeta]: List of audio file path and its total duration.
207
+ """
208
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
209
+ with open_fn(path, 'rb') as fp: # type: ignore
210
+ lines = fp.readlines()
211
+ meta = []
212
+ for line in lines:
213
+ d = json.loads(line)
214
+ m = AudioMeta.from_dict(d)
215
+ if resolve:
216
+ m = _resolve_audio_meta(m, fast=fast)
217
+ meta.append(m)
218
+ return meta
219
+
220
+
221
+ def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
222
+ """Save the audio metadata to the file pointer as json.
223
+
224
+ Args:
225
+ path (str or Path): Path to JSON file.
226
+ metadata (list of BaseAudioMeta): List of audio meta to save.
227
+ """
228
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
229
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
230
+ with open_fn(path, 'wb') as fp: # type: ignore
231
+ for m in meta:
232
+ json_str = json.dumps(m.to_dict()) + '\n'
233
+ json_bytes = json_str.encode('utf-8')
234
+ fp.write(json_bytes)
235
+
236
+
237
+ class AudioDataset:
238
+ """Base audio dataset.
239
+
240
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
241
+ and potentially additional information, by creating random segments from the list of audio
242
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
243
+ mixing of channels, padding, etc.
244
+
245
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
246
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
247
+ duration, applying padding if required.
248
+
249
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
250
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
251
+ original audio meta.
252
+
253
+ Args:
254
+ meta (tp.List[AudioMeta]): List of audio files metadata.
255
+ segment_duration (float): Optional segment duration of audio to load.
256
+ If not specified, the dataset will load the full audio segment from the file.
257
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
258
+ sample_rate (int): Target sample rate of the loaded audio samples.
259
+ channels (int): Target number of channels of the loaded audio samples.
260
+ sample_on_duration (bool): Set to `True` to sample segments with probability
261
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
262
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
263
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
264
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
265
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
266
+ is shorter than the desired segment.
267
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
268
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
269
+ min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
270
+ audio shorter than this will be filtered out.
271
+ max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
272
+ audio longer than this will be filtered out.
273
+ """
274
+ def __init__(self,
275
+ meta: tp.List[AudioMeta],
276
+ segment_duration: tp.Optional[float] = None,
277
+ shuffle: bool = True,
278
+ num_samples: int = 10_000,
279
+ sample_rate: int = 48_000,
280
+ channels: int = 2,
281
+ pad: bool = True,
282
+ sample_on_duration: bool = True,
283
+ sample_on_weight: bool = True,
284
+ min_segment_ratio: float = 0.5,
285
+ max_read_retry: int = 10,
286
+ return_info: bool = False,
287
+ min_audio_duration: tp.Optional[float] = None,
288
+ max_audio_duration: tp.Optional[float] = None
289
+ ):
290
+ assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
291
+ assert segment_duration is None or segment_duration > 0
292
+ assert segment_duration is None or min_segment_ratio >= 0
293
+ logging.debug(f'sample_on_duration: {sample_on_duration}')
294
+ logging.debug(f'sample_on_weight: {sample_on_weight}')
295
+ logging.debug(f'pad: {pad}')
296
+ logging.debug(f'min_segment_ratio: {min_segment_ratio}')
297
+
298
+ self.segment_duration = segment_duration
299
+ self.min_segment_ratio = min_segment_ratio
300
+ self.max_audio_duration = max_audio_duration
301
+ self.min_audio_duration = min_audio_duration
302
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
303
+ assert self.min_audio_duration <= self.max_audio_duration
304
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
305
+ assert len(self.meta) # Fail fast if all data has been filtered.
306
+ self.total_duration = sum(d.duration for d in self.meta)
307
+
308
+ if segment_duration is None:
309
+ num_samples = len(self.meta)
310
+ self.num_samples = num_samples
311
+ self.shuffle = shuffle
312
+ self.sample_rate = sample_rate
313
+ self.channels = channels
314
+ self.pad = pad
315
+ self.sample_on_weight = sample_on_weight
316
+ self.sample_on_duration = sample_on_duration
317
+ self.sampling_probabilities = self._get_sampling_probabilities()
318
+ self.max_read_retry = max_read_retry
319
+ self.return_info = return_info
320
+
321
+ def __len__(self):
322
+ return self.num_samples
323
+
324
+ def _get_sampling_probabilities(self, normalized: bool = True):
325
+ """Return the sampling probabilities for each file inside `self.meta`.
326
+ """
327
+ scores: tp.List[float] = []
328
+ for file_meta in self.meta:
329
+ score = 1.
330
+ if self.sample_on_weight and file_meta.weight is not None:
331
+ score *= file_meta.weight
332
+ if self.sample_on_duration:
333
+ score *= file_meta.duration
334
+ scores.append(score)
335
+ probabilities = torch.tensor(scores)
336
+ if normalized:
337
+ probabilities /= probabilities.sum()
338
+ return probabilities
339
+
340
+ def sample_file(self, rng: torch.Generator) -> AudioMeta:
341
+ """Sample a given file from `self.meta`. Can be overriden in subclasses.
342
+ This is only called if `segment_duration` is not None.
343
+
344
+ You must use the provided random number generator `rng` for reproducibility.
345
+ """
346
+ if not self.sample_on_weight and not self.sample_on_duration:
347
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
348
+ else:
349
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
350
+
351
+ return self.meta[file_index]
352
+
353
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
354
+ if self.segment_duration is None:
355
+ file_meta = self.meta[index]
356
+ out, sr = audio_read(file_meta.path)
357
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
358
+ n_frames = out.shape[-1]
359
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
360
+ sample_rate=self.sample_rate)
361
+ else:
362
+ rng = torch.Generator()
363
+ if self.shuffle:
364
+ # We use index, plus extra randomness
365
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
366
+ else:
367
+ # We only use index
368
+ rng.manual_seed(index)
369
+
370
+ for retry in range(self.max_read_retry):
371
+ file_meta = self.sample_file(rng)
372
+ # We add some variance in the file position even if audio file is smaller than segment
373
+ # without ending up with empty segments
374
+ max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
375
+ seek_time = torch.rand(1, generator=rng).item() * max_seek
376
+ try:
377
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
378
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
379
+ n_frames = out.shape[-1]
380
+ target_frames = int(self.segment_duration * self.sample_rate)
381
+ if self.pad:
382
+ out = F.pad(out, (0, target_frames - n_frames))
383
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
384
+ sample_rate=self.sample_rate)
385
+ except Exception as exc:
386
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
387
+ if retry == self.max_read_retry - 1:
388
+ raise
389
+ else:
390
+ break
391
+
392
+ if self.return_info:
393
+ # Returns the wav and additional information on the wave segment
394
+ return out, segment_info
395
+ else:
396
+ return out
397
+
398
+ def collater(self, samples):
399
+ """The collater function has to be provided to the dataloader
400
+ if AudioDataset has return_info=True in order to properly collate
401
+ the samples of a batch.
402
+ """
403
+ if self.segment_duration is None and len(samples) > 1:
404
+ assert self.pad, "Must allow padding when batching examples of different durations."
405
+
406
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
407
+ to_pad = self.segment_duration is None and self.pad
408
+ if to_pad:
409
+ max_len = max([wav.shape[-1] for wav, _ in samples])
410
+
411
+ def _pad_wav(wav):
412
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
413
+
414
+ if self.return_info:
415
+ if len(samples) > 0:
416
+ assert len(samples[0]) == 2
417
+ assert isinstance(samples[0][0], torch.Tensor)
418
+ assert isinstance(samples[0][1], SegmentInfo)
419
+
420
+ wavs = [wav for wav, _ in samples]
421
+ segment_infos = [copy.deepcopy(info) for _, info in samples]
422
+
423
+ if to_pad:
424
+ # Each wav could be of a different duration as they are not segmented.
425
+ for i in range(len(samples)):
426
+ # Determines the total legth of the signal with padding, so we update here as we pad.
427
+ segment_infos[i].total_frames = max_len
428
+ wavs[i] = _pad_wav(wavs[i])
429
+
430
+ wav = torch.stack(wavs)
431
+ return wav, segment_infos
432
+ else:
433
+ assert isinstance(samples[0], torch.Tensor)
434
+ if to_pad:
435
+ samples = [_pad_wav(s) for s in samples]
436
+ return torch.stack(samples)
437
+
438
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
439
+ """Filters out audio files with short durations.
440
+ Removes from meta files that have durations that will not allow to samples examples from them.
441
+ """
442
+ orig_len = len(meta)
443
+
444
+ # Filter data that is too short.
445
+ if self.min_audio_duration is not None:
446
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
447
+
448
+ # Filter data that is too long.
449
+ if self.max_audio_duration is not None:
450
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
451
+
452
+ filtered_len = len(meta)
453
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
454
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
455
+ if removed_percentage < 10:
456
+ logging.debug(msg)
457
+ else:
458
+ logging.warning(msg)
459
+ return meta
460
+
461
+ @classmethod
462
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
463
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
464
+
465
+ Args:
466
+ root (str or Path): Path to root folder containing audio files.
467
+ kwargs: Additional keyword arguments for the AudioDataset.
468
+ """
469
+ root = Path(root)
470
+ if root.is_dir():
471
+ if (root / 'data.jsonl').exists():
472
+ root = root / 'data.jsonl'
473
+ elif (root / 'data.jsonl.gz').exists():
474
+ root = root / 'data.jsonl.gz'
475
+ else:
476
+ raise ValueError("Don't know where to read metadata from in the dir. "
477
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
478
+ meta = load_audio_meta(root)
479
+ return cls(meta, **kwargs)
480
+
481
+ @classmethod
482
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
483
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
484
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
485
+
486
+ Args:
487
+ root (str or Path): Path to root folder containing audio files.
488
+ minimal_meta (bool): Whether to only load minimal metadata or not.
489
+ exts (list of str): Extensions for audio files.
490
+ kwargs: Additional keyword arguments for the AudioDataset.
491
+ """
492
+ root = Path(root)
493
+ if root.is_file():
494
+ meta = load_audio_meta(root, resolve=True)
495
+ else:
496
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
497
+ return cls(meta, **kwargs)
498
+
499
+
500
+ def main():
501
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
502
+ parser = argparse.ArgumentParser(
503
+ prog='audio_dataset',
504
+ description='Generate .jsonl files by scanning a folder.')
505
+ parser.add_argument('root', help='Root folder with all the audio files')
506
+ parser.add_argument('output_meta_file',
507
+ help='Output file to store the metadata, ')
508
+ parser.add_argument('--complete',
509
+ action='store_false', dest='minimal', default=True,
510
+ help='Retrieve all metadata, even the one that are expansive '
511
+ 'to compute (e.g. normalization).')
512
+ parser.add_argument('--resolve',
513
+ action='store_true', default=False,
514
+ help='Resolve the paths to be absolute and with no symlinks.')
515
+ parser.add_argument('--workers',
516
+ default=10, type=int,
517
+ help='Number of workers.')
518
+ args = parser.parse_args()
519
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
520
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
521
+ save_audio_meta(args.output_meta_file, meta)
522
+
523
+
524
+ if __name__ == '__main__':
525
+ main()
audiocraft/data/audio_utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ import torchaudio
13
+
14
+
15
+ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
16
+ """Convert audio to the given number of channels.
17
+
18
+ Args:
19
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
20
+ channels (int): Expected number of channels as output.
21
+ Returns:
22
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
23
+ """
24
+ *shape, src_channels, length = wav.shape
25
+ if src_channels == channels:
26
+ pass
27
+ elif channels == 1:
28
+ # Case 1:
29
+ # The caller asked 1-channel audio, and the stream has multiple
30
+ # channels, downmix all channels.
31
+ wav = wav.mean(dim=-2, keepdim=True)
32
+ elif src_channels == 1:
33
+ # Case 2:
34
+ # The caller asked for multiple channels, but the input file has
35
+ # a single channel, replicate the audio over all channels.
36
+ wav = wav.expand(*shape, channels, length)
37
+ elif src_channels >= channels:
38
+ # Case 3:
39
+ # The caller asked for multiple channels, and the input file has
40
+ # more channels than requested. In that case return the first channels.
41
+ wav = wav[..., :channels, :]
42
+ else:
43
+ # Case 4: What is a reasonable choice here?
44
+ raise ValueError('The audio file has less channels than requested but is not mono.')
45
+ return wav
46
+
47
+
48
+ def convert_audio(wav: torch.Tensor, from_rate: float,
49
+ to_rate: float, to_channels: int) -> torch.Tensor:
50
+ """Convert audio to new sample rate and number of audio channels.
51
+ """
52
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
+ wav = convert_audio_channels(wav, to_channels)
54
+ return wav
55
+
56
+
57
+ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 12,
58
+ energy_floor: float = 2e-3):
59
+ """Normalize an input signal to a user loudness in dB LKFS.
60
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61
+
62
+ Args:
63
+ wav (torch.Tensor): Input multichannel audio data.
64
+ sample_rate (int): Sample rate.
65
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66
+ energy_floor (float): anything below that RMS level will not be rescaled.
67
+ Returns:
68
+ output (torch.Tensor): Loudness normalized output data.
69
+ """
70
+ energy = wav.pow(2).mean().sqrt().item()
71
+ if energy < energy_floor:
72
+ return wav
73
+ transform = torchaudio.transforms.Loudness(sample_rate)
74
+ input_loudness_db = transform(wav).item()
75
+ # calculate the gain needed to scale to the desired loudness level
76
+ delta_loudness = -loudness_headroom_db - input_loudness_db
77
+ gain = 10.0 ** (delta_loudness / 20.0)
78
+ output = gain * wav
79
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
80
+ return output
81
+
82
+
83
+ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
84
+ """Utility function to clip the audio with logging if specified."""
85
+ max_scale = wav.abs().max()
86
+ if log_clipping and max_scale > 1:
87
+ clamp_prob = (wav.abs() > 1).float().mean().item()
88
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
89
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
90
+ wav.clamp_(-1, 1)
91
+
92
+
93
+ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
94
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
95
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
96
+ log_clipping: bool = False, sample_rate: tp.Optional[int] = None,
97
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
98
+ """Normalize the audio according to the prescribed strategy (see after).
99
+
100
+ Args:
101
+ wav (torch.Tensor): Audio data.
102
+ normalize (bool): if `True` (default), normalizes according to the prescribed
103
+ strategy (see after). If `False`, the strategy is only used in case clipping
104
+ would happen.
105
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
106
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
107
+ with extra headroom to avoid clipping. 'clip' just clips.
108
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
109
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
110
+ than the `peak_clip` one to avoid further clipping.
111
+ loudness_headroom_db (float): Target loudness for loudness normalization.
112
+ log_clipping (bool): If True, basic logging on stderr when clipping still
113
+ occurs despite strategy (only for 'rms').
114
+ sample_rate (int): Sample rate for the audio data (required for loudness).
115
+ stem_name (Optional[str]): Stem name for clipping logging.
116
+ Returns:
117
+ torch.Tensor: Normalized audio.
118
+ """
119
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
120
+ scale_rms = 10 ** (-rms_headroom_db / 20)
121
+ if strategy == 'peak':
122
+ rescaling = (scale_peak / wav.abs().max())
123
+ if normalize or rescaling < 1:
124
+ wav = wav * rescaling
125
+ elif strategy == 'clip':
126
+ wav = wav.clamp(-scale_peak, scale_peak)
127
+ elif strategy == 'rms':
128
+ mono = wav.mean(dim=0)
129
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
130
+ if normalize or rescaling < 1:
131
+ wav = wav * rescaling
132
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
133
+ elif strategy == 'loudness':
134
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
135
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db)
136
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
137
+ else:
138
+ assert wav.abs().max() < 1
139
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
140
+ return wav
141
+
142
+
143
+ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
144
+ """Convert audio to float 32 bits PCM format.
145
+ """
146
+ if wav.dtype.is_floating_point:
147
+ return wav
148
+ else:
149
+ assert wav.dtype == torch.int16
150
+ return wav.float() / 2**15
151
+
152
+
153
+ def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
154
+ """Convert audio to int 16 bits PCM format.
155
+
156
+ ..Warning:: There exist many formula for doing this convertion. None are perfect
157
+ due to the asymetry of the int16 range. One either have possible clipping, DC offset,
158
+ or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
159
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
160
+ """
161
+ if wav.dtype.is_floating_point:
162
+ assert wav.abs().max() <= 1
163
+ candidate = (wav * 2 ** 15).round()
164
+ if candidate.max() >= 2 ** 15: # clipping would occur
165
+ candidate = (wav * (2 ** 15 - 1)).round()
166
+ return candidate.short()
167
+ else:
168
+ assert wav.dtype == torch.int16
169
+ return wav
audiocraft/data/zip.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing
8
+ import zipfile
9
+
10
+ from dataclasses import dataclass
11
+ from functools import lru_cache
12
+ from typing_extensions import Literal
13
+
14
+
15
+ DEFAULT_SIZE = 32
16
+ MODE = Literal['r', 'w', 'x', 'a']
17
+
18
+
19
+ @dataclass(order=True)
20
+ class PathInZip:
21
+ """Class for holding a path of file within a zip file.
22
+
23
+ Args:
24
+ path: The convention is <path_to_zip>:<relative_path_inside_zip>
25
+ Let's assume there is a zip file /some/location/foo.zip
26
+ and inside of it is a json file located at /data/file1.json,
27
+ Then we expect path = "/some/location/foo.zip:/data/file1.json"
28
+ """
29
+
30
+ INFO_PATH_SEP = ':'
31
+ zip_path: str
32
+ file_path: str
33
+
34
+ def __init__(self, path: str) -> None:
35
+ split_path = path.split(self.INFO_PATH_SEP)
36
+ assert len(split_path) == 2
37
+ self.zip_path, self.file_path = split_path
38
+
39
+ @classmethod
40
+ def from_paths(cls, zip_path: str, file_path: str):
41
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
42
+
43
+ def __str__(self) -> str:
44
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
45
+
46
+
47
+ def _open_zip(path: str, mode: MODE = 'r'):
48
+ return zipfile.ZipFile(path, mode)
49
+
50
+
51
+ _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
52
+
53
+
54
+ def set_zip_cache_size(max_size: int):
55
+ """Sets the maximal LRU caching for zip file opening.
56
+
57
+ Args:
58
+ max_size: the maximal LRU cache.
59
+ """
60
+ global _cached_open_zip
61
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
62
+
63
+
64
+ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
65
+ """Opens a file stored inside a zip and returns a file-like object.
66
+
67
+ Args:
68
+ path_in_zip: A PathInZip object representing the file to return a file-like object of.
69
+ mode: The mode in which to open the file with.
70
+ Returns:
71
+ A file-like object for PathInZip.
72
+ """
73
+ zf = _cached_open_zip(path_in_zip.zip_path)
74
+ return zf.open(path_in_zip.file_path)
audiocraft/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .musicgen import MusicGen
9
+ from .lm import LMModel
10
+ from .encodec import CompressionModel, EncodecModel
audiocraft/models/builders.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ All the functions to build the relevant models and modules
9
+ from the Hydra config.
10
+ """
11
+
12
+ import typing as tp
13
+ import warnings
14
+
15
+ import audiocraft
16
+ import omegaconf
17
+ import torch
18
+
19
+ from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
20
+ from .lm import LMModel
21
+ from ..modules.codebooks_patterns import (
22
+ CodebooksPatternProvider,
23
+ DelayedPatternProvider,
24
+ ParallelPatternProvider,
25
+ UnrolledPatternProvider,
26
+ VALLEPattern,
27
+ MusicLMPattern,
28
+ )
29
+ from ..modules.conditioners import (
30
+ BaseConditioner,
31
+ ConditioningProvider,
32
+ LUTConditioner,
33
+ T5Conditioner,
34
+ ConditionFuser,
35
+ ChromaStemConditioner,
36
+ )
37
+ from .. import quantization as qt
38
+ from ..utils.utils import dict_from_config
39
+
40
+
41
+ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
42
+ klass = {
43
+ 'no_quant': qt.DummyQuantizer,
44
+ 'rvq': qt.ResidualVectorQuantizer
45
+ }[quantizer]
46
+ kwargs = dict_from_config(getattr(cfg, quantizer))
47
+ if quantizer != 'no_quant':
48
+ kwargs['dimension'] = dimension
49
+ return klass(**kwargs)
50
+
51
+
52
+ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
53
+ if encoder_name == 'seanet':
54
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
55
+ encoder_override_kwargs = kwargs.pop('encoder')
56
+ decoder_override_kwargs = kwargs.pop('decoder')
57
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
58
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
59
+ encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
60
+ decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
61
+ return encoder, decoder
62
+ else:
63
+ raise KeyError(f'Unexpected compression model {cfg.compression_model}')
64
+
65
+
66
+ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
67
+ """Instantiate a compression model.
68
+ """
69
+ if cfg.compression_model == 'encodec':
70
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
71
+ encoder_name = kwargs.pop('autoencoder')
72
+ quantizer_name = kwargs.pop('quantizer')
73
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
74
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
75
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
76
+ renormalize = kwargs.pop('renormalize', None)
77
+ renorm = kwargs.pop('renorm')
78
+ if renormalize is None:
79
+ renormalize = renorm is not None
80
+ warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
81
+ return EncodecModel(encoder, decoder, quantizer,
82
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
83
+ else:
84
+ raise KeyError(f'Unexpected compression model {cfg.compression_model}')
85
+
86
+
87
+ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
88
+ """Instantiate a transformer LM.
89
+ """
90
+ if cfg.lm_model == 'transformer_lm':
91
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
92
+ n_q = kwargs['n_q']
93
+ q_modeling = kwargs.pop('q_modeling', None)
94
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
95
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
96
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
97
+ cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
98
+ fuser = get_condition_fuser(cfg)
99
+ condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
100
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
101
+ kwargs['cross_attention'] = True
102
+ if codebooks_pattern_cfg.modeling is None:
103
+ assert q_modeling is not None, \
104
+ 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
105
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
106
+ {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
107
+ )
108
+ pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
109
+ return LMModel(
110
+ pattern_provider=pattern_provider,
111
+ condition_provider=condition_provider,
112
+ fuser=fuser,
113
+ cfg_dropout=cfg_prob,
114
+ cfg_coef=cfg_coef,
115
+ attribute_dropout=attribute_dropout,
116
+ dtype=getattr(torch, cfg.dtype),
117
+ device=cfg.device,
118
+ **kwargs
119
+ ).to(cfg.device)
120
+ else:
121
+ raise KeyError(f'Unexpected LM model {cfg.lm_model}')
122
+
123
+
124
+ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
125
+ """Instantiate a conditioning model.
126
+ """
127
+ device = cfg.device
128
+ duration = cfg.dataset.segment_duration
129
+ cfg = getattr(cfg, "conditioners")
130
+ cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
131
+ conditioners: tp.Dict[str, BaseConditioner] = {}
132
+ with omegaconf.open_dict(cfg):
133
+ condition_provider_args = cfg.pop('args', {})
134
+ for cond, cond_cfg in cfg.items():
135
+ model_type = cond_cfg["model"]
136
+ model_args = cond_cfg[model_type]
137
+ if model_type == "t5":
138
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
139
+ elif model_type == "lut":
140
+ conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
141
+ elif model_type == "chroma_stem":
142
+ model_args.pop('cache_path', None)
143
+ conditioners[str(cond)] = ChromaStemConditioner(
144
+ output_dim=output_dim,
145
+ duration=duration,
146
+ device=device,
147
+ **model_args
148
+ )
149
+ else:
150
+ raise ValueError(f"unrecognized conditioning model: {model_type}")
151
+ conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
152
+ return conditioner
153
+
154
+
155
+ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
156
+ """Instantiate a condition fuser object.
157
+ """
158
+ fuser_cfg = getattr(cfg, "fuser")
159
+ fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
160
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
161
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
162
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
163
+ return fuser
164
+
165
+
166
+ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
167
+ """Instantiate a codebooks pattern provider object.
168
+ """
169
+ pattern_providers = {
170
+ 'parallel': ParallelPatternProvider,
171
+ 'delay': DelayedPatternProvider,
172
+ 'unroll': UnrolledPatternProvider,
173
+ 'valle': VALLEPattern,
174
+ 'musiclm': MusicLMPattern,
175
+ }
176
+ name = cfg.modeling
177
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
178
+ klass = pattern_providers[name]
179
+ return klass(n_q, **kwargs)
180
+
181
+
182
+ def get_debug_compression_model(device='cpu'):
183
+ """Instantiate a debug compression model to be used for unit tests.
184
+ """
185
+ seanet_kwargs = {
186
+ 'n_filters': 4,
187
+ 'n_residual_layers': 1,
188
+ 'dimension': 32,
189
+ 'ratios': [10, 8, 16] # 25 Hz at 32kHz
190
+ }
191
+ encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
192
+ decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
193
+ quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
194
+ init_x = torch.randn(8, 32, 128)
195
+ quantizer(init_x, 1) # initialize kmeans etc.
196
+ compression_model = EncodecModel(
197
+ encoder, decoder, quantizer,
198
+ frame_rate=25, sample_rate=32000, channels=1).to(device)
199
+ return compression_model.eval()
200
+
201
+
202
+ def get_debug_lm_model(device='cpu'):
203
+ """Instantiate a debug LM to be used for unit tests.
204
+ """
205
+ pattern = DelayedPatternProvider(n_q=4)
206
+ dim = 16
207
+ providers = {
208
+ 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
209
+ }
210
+ condition_provider = ConditioningProvider(providers)
211
+ fuser = ConditionFuser(
212
+ {'cross': ['description'], 'prepend': [],
213
+ 'sum': [], 'input_interpolate': []})
214
+ lm = LMModel(
215
+ pattern, condition_provider, fuser,
216
+ n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
217
+ cross_attention=True, causal=True)
218
+ return lm.to(device).eval()
audiocraft/models/encodec.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import typing as tp
9
+
10
+ from einops import rearrange
11
+ import torch
12
+ from torch import nn
13
+
14
+ from .. import quantization as qt
15
+
16
+
17
+ class CompressionModel(ABC, nn.Module):
18
+
19
+ @abstractmethod
20
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
21
+ ...
22
+
23
+ @abstractmethod
24
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
25
+ """See `EncodecModel.encode`"""
26
+ ...
27
+
28
+ @abstractmethod
29
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
30
+ """See `EncodecModel.decode`"""
31
+ ...
32
+
33
+ @property
34
+ @abstractmethod
35
+ def channels(self) -> int:
36
+ ...
37
+
38
+ @property
39
+ @abstractmethod
40
+ def frame_rate(self) -> int:
41
+ ...
42
+
43
+ @property
44
+ @abstractmethod
45
+ def sample_rate(self) -> int:
46
+ ...
47
+
48
+ @property
49
+ @abstractmethod
50
+ def cardinality(self) -> int:
51
+ ...
52
+
53
+ @property
54
+ @abstractmethod
55
+ def num_codebooks(self) -> int:
56
+ ...
57
+
58
+ @property
59
+ @abstractmethod
60
+ def total_codebooks(self) -> int:
61
+ ...
62
+
63
+ @abstractmethod
64
+ def set_num_codebooks(self, n: int):
65
+ """Set the active number of codebooks used by the quantizer.
66
+ """
67
+ ...
68
+
69
+
70
+ class EncodecModel(CompressionModel):
71
+ """Encodec model operating on the raw waveform.
72
+
73
+ Args:
74
+ encoder (nn.Module): Encoder network.
75
+ decoder (nn.Module): Decoder network.
76
+ quantizer (qt.BaseQuantizer): Quantizer network.
77
+ frame_rate (int): Frame rate for the latent representation.
78
+ sample_rate (int): Audio sample rate.
79
+ channels (int): Number of audio channels.
80
+ causal (bool): Whether to use a causal version of the model.
81
+ renormalize (bool): Whether to renormalize the audio before running the model.
82
+ """
83
+ # we need assignement to override the property in the abstract class,
84
+ # I couldn't find a better way...
85
+ frame_rate: int = 0
86
+ sample_rate: int = 0
87
+ channels: int = 0
88
+
89
+ def __init__(self,
90
+ encoder: nn.Module,
91
+ decoder: nn.Module,
92
+ quantizer: qt.BaseQuantizer,
93
+ frame_rate: int,
94
+ sample_rate: int,
95
+ channels: int,
96
+ causal: bool = False,
97
+ renormalize: bool = False):
98
+ super().__init__()
99
+ self.encoder = encoder
100
+ self.decoder = decoder
101
+ self.quantizer = quantizer
102
+ self.frame_rate = frame_rate
103
+ self.sample_rate = sample_rate
104
+ self.channels = channels
105
+ self.renormalize = renormalize
106
+ self.causal = causal
107
+ if self.causal:
108
+ # we force disabling here to avoid handling linear overlap of segments
109
+ # as supported in original EnCodec codebase.
110
+ assert not self.renormalize, 'Causal model does not support renormalize'
111
+
112
+ @property
113
+ def total_codebooks(self):
114
+ """Total number of quantizer codebooks available.
115
+ """
116
+ return self.quantizer.total_codebooks
117
+
118
+ @property
119
+ def num_codebooks(self):
120
+ """Active number of codebooks used by the quantizer.
121
+ """
122
+ return self.quantizer.num_codebooks
123
+
124
+ def set_num_codebooks(self, n: int):
125
+ """Set the active number of codebooks used by the quantizer.
126
+ """
127
+ self.quantizer.set_num_codebooks(n)
128
+
129
+ @property
130
+ def cardinality(self):
131
+ """Cardinality of each codebook.
132
+ """
133
+ return self.quantizer.bins
134
+
135
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
136
+ scale: tp.Optional[torch.Tensor]
137
+ if self.renormalize:
138
+ mono = x.mean(dim=1, keepdim=True)
139
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
140
+ scale = 1e-8 + volume
141
+ x = x / scale
142
+ scale = scale.view(-1, 1)
143
+ else:
144
+ scale = None
145
+ return x, scale
146
+
147
+ def postprocess(self,
148
+ x: torch.Tensor,
149
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
150
+ if scale is not None:
151
+ assert self.renormalize
152
+ x = x * scale.view(-1, 1, 1)
153
+ return x
154
+
155
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
156
+ assert x.dim() == 3
157
+ length = x.shape[-1]
158
+ x, scale = self.preprocess(x)
159
+
160
+ emb = self.encoder(x)
161
+ q_res = self.quantizer(emb, self.frame_rate)
162
+ out = self.decoder(q_res.x)
163
+
164
+ # remove extra padding added by the encoder and decoder
165
+ assert out.shape[-1] >= length, (out.shape[-1], length)
166
+ out = out[..., :length]
167
+
168
+ q_res.x = self.postprocess(out, scale)
169
+
170
+ return q_res
171
+
172
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
173
+ """Encode the given input tensor to quantized representation along with scale parameter.
174
+
175
+ Args:
176
+ x (torch.Tensor): Float tensor of shape [B, C, T]
177
+
178
+ Returns:
179
+ codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
180
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
181
+ scale a float tensor containing the scale for audio renormalizealization.
182
+ """
183
+ assert x.dim() == 3
184
+ x, scale = self.preprocess(x)
185
+ emb = self.encoder(x)
186
+ codes = self.quantizer.encode(emb)
187
+ return codes, scale
188
+
189
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
190
+ """Decode the given codes to a reconstructed representation, using the scale to perform
191
+ audio denormalization if needed.
192
+
193
+ Args:
194
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
195
+ scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
196
+
197
+ Returns:
198
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
199
+ """
200
+ emb = self.quantizer.decode(codes)
201
+ out = self.decoder(emb)
202
+ out = self.postprocess(out, scale)
203
+ # out contains extra padding added by the encoder and decoder
204
+ return out
205
+
206
+
207
+ class FlattenedCompressionModel(CompressionModel):
208
+ """Wraps a CompressionModel and flatten its codebooks, e.g.
209
+ instead of returning [B, K, T], return [B, S, T * (K // S)] with
210
+ S the number of codebooks per step, and `K // S` the number of 'virtual steps'
211
+ for each real time step.
212
+
213
+ Args:
214
+ model (CompressionModel): compression model to wrap.
215
+ codebooks_per_step (int): number of codebooks to keep per step,
216
+ this must divide the number of codebooks provided by the wrapped model.
217
+ extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
218
+ if each codebook has a cardinality N, then the first codebook will
219
+ use the range [0, N - 1], and the second [N, 2 N - 1] etc.
220
+ On decoding, this can lead to potentially invalid sequences.
221
+ Any invalid entry will be silently remapped to the proper range
222
+ with a modulo.
223
+ """
224
+ def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
225
+ extend_cardinality: bool = True):
226
+ super().__init__()
227
+ self.model = model
228
+ self.codebooks_per_step = codebooks_per_step
229
+ self.extend_cardinality = extend_cardinality
230
+
231
+ @property
232
+ def total_codebooks(self):
233
+ return self.model.total_codebooks
234
+
235
+ @property
236
+ def num_codebooks(self):
237
+ """Active number of codebooks used by the quantizer.
238
+
239
+ ..Warning:: this reports the number of codebooks after the flattening
240
+ of the codebooks!
241
+ """
242
+ assert self.model.num_codebooks % self.codebooks_per_step == 0
243
+ return self.codebooks_per_step
244
+
245
+ def set_num_codebooks(self, n: int):
246
+ """Set the active number of codebooks used by the quantizer.
247
+
248
+ ..Warning:: this sets the number of codebooks **before** the flattening
249
+ of the codebooks.
250
+ """
251
+ assert n % self.codebooks_per_step == 0
252
+ self.model.set_num_codebooks(n)
253
+
254
+ @property
255
+ def num_virtual_steps(self) -> int:
256
+ """Return the number of virtual steps, e.g. one real step
257
+ will be split into that many steps.
258
+ """
259
+ return self.model.num_codebooks // self.codebooks_per_step
260
+
261
+ @property
262
+ def frame_rate(self) -> int:
263
+ return self.model.frame_rate * self.num_virtual_steps
264
+
265
+ @property
266
+ def sample_rate(self) -> int:
267
+ return self.model.sample_rate
268
+
269
+ @property
270
+ def channels(self) -> int:
271
+ return self.model.channels
272
+
273
+ @property
274
+ def cardinality(self):
275
+ """Cardinality of each codebook.
276
+ """
277
+ if self.extend_cardinality:
278
+ return self.model.cardinality * self.num_virtual_steps
279
+ else:
280
+ return self.model.cardinality
281
+
282
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
283
+ raise NotImplementedError("Not supported, use encode and decode.")
284
+
285
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
286
+ indices, scales = self.model.encode(x)
287
+ B, K, T = indices.shape
288
+ indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
289
+ if self.extend_cardinality:
290
+ for virtual_step in range(1, self.num_virtual_steps):
291
+ indices[..., virtual_step] += self.model.cardinality * virtual_step
292
+ indices = rearrange(indices, 'b k t v -> b k (t v)')
293
+ return (indices, scales)
294
+
295
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
296
+ B, K, T = codes.shape
297
+ assert T % self.num_virtual_steps == 0
298
+ codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
299
+ # We silently ignore potential errors from the LM when
300
+ # using extend_cardinality.
301
+ codes = codes % self.model.cardinality
302
+ return self.model.decode(codes, scale)
audiocraft/models/lm.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ import logging
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from ..utils import utils
17
+ from ..modules.streaming import StreamingModule, State
18
+ from ..modules.transformer import StreamingTransformer, create_norm_fn
19
+ from ..modules.conditioners import (
20
+ ConditionFuser,
21
+ ClassifierFreeGuidanceDropout,
22
+ AttributeDropout,
23
+ ConditioningProvider,
24
+ ConditioningAttributes,
25
+ ConditionType,
26
+ )
27
+ from ..modules.codebooks_patterns import CodebooksPatternProvider
28
+ from ..modules.activations import get_activation_fn
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+ ConditionTensors = tp.Dict[str, ConditionType]
33
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
34
+
35
+
36
+ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
37
+ """LM layer initialization.
38
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
39
+
40
+ Args:
41
+ method (str): Method name for init function. Valid options are:
42
+ 'gaussian', 'uniform'.
43
+ input_dim (int): Input dimension of the initialized module.
44
+ init_depth (Optional[int]): Optional init depth value used to rescale
45
+ the standard deviation if defined.
46
+ """
47
+ # Compute std
48
+ std = 1 / math.sqrt(input_dim)
49
+ # Rescale with depth
50
+ if init_depth is not None:
51
+ std = std / math.sqrt(2 * init_depth)
52
+
53
+ if method == 'gaussian':
54
+ return partial(
55
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
56
+ )
57
+ elif method == 'uniform':
58
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
59
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
60
+ else:
61
+ raise ValueError("Unsupported layer initialization method")
62
+
63
+
64
+ def init_layer(m: nn.Module,
65
+ method: str,
66
+ init_depth: tp.Optional[int] = None,
67
+ zero_bias_init: bool = False):
68
+ """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
69
+
70
+ Args:
71
+ m (nn.Module): Module to initialize.
72
+ method (str): Method name for the init function.
73
+ init_depth (Optional[int]): Optional init depth value used to rescale
74
+ the standard deviation if defined.
75
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
76
+ """
77
+ if isinstance(m, nn.Linear):
78
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
79
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
80
+ weight = m.weight.float()
81
+ init_fn(weight)
82
+ m.weight.data[:] = weight.half()
83
+ else:
84
+ init_fn(m.weight)
85
+ if zero_bias_init and m.bias is not None:
86
+ nn.init.constant_(m.bias, 0)
87
+ elif isinstance(m, nn.Embedding):
88
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
89
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
90
+ weight = m.weight.float()
91
+ init_fn(weight)
92
+ m.weight.data[:] = weight.half()
93
+ else:
94
+ init_fn(m.weight)
95
+
96
+
97
+ class ScaledEmbedding(nn.Embedding):
98
+ """Boost learning rate for embeddings (with `scale`).
99
+ """
100
+ def __init__(self, *args, lr=None, **kwargs):
101
+ super().__init__(*args, **kwargs)
102
+ self.lr = lr
103
+
104
+ def make_optim_group(self):
105
+ group = {"params": list(self.parameters())}
106
+ if self.lr is not None:
107
+ group["lr"] = self.lr
108
+ return group
109
+
110
+
111
+ @dataclass
112
+ class LMOutput:
113
+ # The logits are already re-aligned with the input codes
114
+ # hence no extra shift is required, e.g. when computing CE
115
+ logits: torch.Tensor # [B, K, T, card]
116
+ mask: torch.Tensor # [B, K, T]
117
+
118
+
119
+ class LMModel(StreamingModule):
120
+ """Transformer-based language model on multiple streams of codes.
121
+
122
+ Args:
123
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
124
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
125
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
126
+ n_q (int): Number of parallel streams to model.
127
+ card (int): Cardinality, vocabulary size.
128
+ dim (int): Dimension of the transformer encoder.
129
+ num_heads (int): Number of heads for the transformer encoder.
130
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
131
+ norm (str): Normalization method.
132
+ norm_first (bool): Use pre-norm instead of post-norm.
133
+ emb_lr (Optional[float]): Embedding-specific learning rate.
134
+ bias_proj (bool): Use bias for output projections.
135
+ weight_init (Optional[str]): Method for weight initialization.
136
+ depthwise_init (Optional[str]): Method for depthwise weight initialization.
137
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
138
+ cfg_dropout (float): Classifier-free guidance dropout.
139
+ cfg_coef (float): Classifier-free guidance coefficient.
140
+ attribute_dropout (dict): Attribute dropout probabilities.
141
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
142
+ **kwargs: Additional parameters for the transformer encoder.
143
+ """
144
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
145
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
146
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
147
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
148
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
149
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
150
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
151
+ **kwargs):
152
+ super().__init__()
153
+ self.cfg_coef = cfg_coef
154
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
155
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
156
+ self.condition_provider = condition_provider
157
+ self.fuser = fuser
158
+ self.card = card
159
+ embed_dim = self.card + 1
160
+ self.n_q = n_q
161
+ self.dim = dim
162
+ self.pattern_provider = pattern_provider
163
+ self.two_step_cfg = two_step_cfg
164
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
165
+ if 'activation' in kwargs:
166
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
167
+ self.transformer = StreamingTransformer(
168
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
169
+ norm=norm, norm_first=norm_first, **kwargs)
170
+ self.out_norm: tp.Optional[nn.Module] = None
171
+ if norm_first:
172
+ self.out_norm = create_norm_fn(norm, dim)
173
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
174
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
175
+ self._fsdp: tp.Optional[nn.Module]
176
+ self.__dict__['_fsdp'] = None
177
+
178
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
179
+ """Initialization of the transformer module weights.
180
+
181
+ Args:
182
+ weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
183
+ depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
184
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
185
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
186
+ zero_bias_init (bool): Whether to initalize bias to zero or not.
187
+ """
188
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
189
+ assert depthwise_init is None or weight_init is not None, \
190
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
191
+ assert not zero_bias_init or weight_init is not None, \
192
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
193
+
194
+ if weight_init is None:
195
+ return
196
+
197
+ for emb_layer in self.emb:
198
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
199
+
200
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
201
+ depth = None
202
+ if depthwise_init == 'current':
203
+ depth = layer_idx + 1
204
+ elif depthwise_init == 'global':
205
+ depth = len(self.transformer.layers)
206
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
207
+ tr_layer.apply(init_fn)
208
+
209
+ for linear in self.linears:
210
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
211
+
212
+ @property
213
+ def special_token_id(self) -> int:
214
+ return self.card
215
+
216
+ @property
217
+ def num_codebooks(self) -> int:
218
+ return self.n_q
219
+
220
+ def forward(self, sequence: torch.Tensor,
221
+ conditions: tp.List[ConditioningAttributes],
222
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
223
+ """Apply language model on sequence and conditions.
224
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
225
+ S the sequence steps, return the logits with shape [B, card, K, S].
226
+
227
+ Args:
228
+ indices (torch.Tensor): indices of the codes to model.
229
+ conditions (list[ConditioningAttributes]): conditionings to use when modeling
230
+ the given codes. Note that when evaluating multiple time with the same conditioning
231
+ you should pre-compute those and pass them as `condition_tensors`.
232
+ condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
233
+ tensors, see `conditions`.
234
+ Returns:
235
+ torch.Tensor: Logits.
236
+ """
237
+ B, K, S = sequence.shape
238
+ assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
239
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
240
+ if condition_tensors is None:
241
+ assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
242
+ # apply dropout modules
243
+ conditions = self.cfg_dropout(conditions)
244
+ conditions = self.att_dropout(conditions)
245
+ tokenized = self.condition_provider.tokenize(conditions)
246
+ # encode conditions and fuse, both have a streaming cache to not recompute when generating.
247
+ condition_tensors = self.condition_provider(tokenized)
248
+ else:
249
+ assert not conditions, "Shouldn't pass both conditions and condition_tensors."
250
+
251
+ input_, cross_attention_input = self.fuser(input_, condition_tensors)
252
+
253
+ out = self.transformer(input_, cross_attention_src=cross_attention_input)
254
+ if self.out_norm:
255
+ out = self.out_norm(out)
256
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
257
+
258
+ # remove the prefix from the model outputs
259
+ if len(self.fuser.fuse2cond['prepend']) > 0:
260
+ logits = logits[:, :, -S:]
261
+
262
+ return logits # [B, K, S, card]
263
+
264
+ def compute_predictions(
265
+ self, codes: torch.Tensor,
266
+ conditions: tp.List[ConditioningAttributes],
267
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
268
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
269
+ forward using the specified codes interleaving pattern.
270
+
271
+ Args:
272
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
273
+ K the number of codebooks and T the number of timesteps.
274
+ conditions (list[ConditioningAttributes]): conditionings to use when modeling
275
+ the given codes. Note that when evaluating multiple time with the same conditioning
276
+ you should pre-compute those and pass them as `condition_tensors`.
277
+ condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
278
+ tensors, see `conditions`.
279
+ Returns:
280
+ LMOutput: Language model outputs
281
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
282
+ i.e. the first item corresponds to logits to predict the first code, meaning that
283
+ no additional shifting of codes and logits is required.
284
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
285
+ Given the specified interleaving strategies, parts of the logits and codes should
286
+ not be considered as valid predictions because of invalid context.
287
+ """
288
+ B, K, T = codes.shape
289
+ codes = codes.contiguous()
290
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
291
+ pattern = self.pattern_provider.get_pattern(T)
292
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
293
+ codes, self.special_token_id, keep_only_valid_steps=True
294
+ )
295
+ # apply model on pattern sequence
296
+ model = self if self._fsdp is None else self._fsdp
297
+ logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
298
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
299
+ # and provide the corresponding mask over invalid positions of tokens
300
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
301
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
302
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
303
+ logits, float('nan'), keep_only_valid_steps=True
304
+ )
305
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
306
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
307
+ return LMOutput(logits, logits_mask)
308
+
309
+ def _sample_next_token(self,
310
+ sequence: torch.Tensor,
311
+ cfg_conditions: CFGConditions,
312
+ unconditional_state: State,
313
+ use_sampling: bool = False,
314
+ temp: float = 1.0,
315
+ top_k: int = 0,
316
+ top_p: float = 0.0,
317
+ cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
318
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
319
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
320
+
321
+ Args:
322
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
323
+ with K corresponding to the number of codebooks and S the number of sequence steps.
324
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
325
+ condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
326
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
327
+ use_sampling (bool): Whether to use a sampling strategy or not.
328
+ temp (float): Sampling temperature.
329
+ top_k (int): K for "top-k" sampling.
330
+ top_p (float): P for "top-p" sampling.
331
+ cfg_coef (float): classifier free guidance coefficient
332
+ Returns:
333
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
334
+ """
335
+ B = sequence.shape[0]
336
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
337
+ model = self if self._fsdp is None else self._fsdp
338
+ if self.two_step_cfg and cfg_conditions != {}:
339
+ assert isinstance(cfg_conditions, tuple)
340
+ condition_tensors, null_condition_tensors = cfg_conditions
341
+ cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
342
+ state = self.get_streaming_state()
343
+ self.set_streaming_state(unconditional_state)
344
+ uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
345
+ unconditional_state.update(self.get_streaming_state())
346
+ self.set_streaming_state(state)
347
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
348
+ else:
349
+ assert isinstance(cfg_conditions, dict)
350
+ condition_tensors = cfg_conditions
351
+ if condition_tensors:
352
+ # Preparing for CFG, predicting both conditional and unconditional logits.
353
+ sequence = torch.cat([sequence, sequence], dim=0)
354
+ all_logits = model(
355
+ sequence,
356
+ conditions=[], condition_tensors=condition_tensors)
357
+ if condition_tensors:
358
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
359
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
360
+ else:
361
+ logits = all_logits
362
+
363
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
364
+ logits = logits[..., -1] # [B x K x card]
365
+
366
+ if use_sampling:
367
+ probs = torch.softmax(logits / temp, dim=-1)
368
+ if top_p > 0.0:
369
+ next_token = utils.sample_top_p(probs, p=top_p)
370
+ elif top_k > 0:
371
+ next_token = utils.sample_top_k(probs, k=top_k)
372
+ else:
373
+ next_token = utils.multinomial(probs, num_samples=1)
374
+ else:
375
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
376
+
377
+ return next_token
378
+
379
+ @torch.no_grad()
380
+ def generate(self,
381
+ prompt: tp.Optional[torch.Tensor] = None,
382
+ conditions: tp.List[ConditioningAttributes] = [],
383
+ num_samples: tp.Optional[int] = None,
384
+ max_gen_len: int = 256,
385
+ use_sampling: bool = True,
386
+ temp: float = 1.0,
387
+ top_k: int = 250,
388
+ top_p: float = 0.0,
389
+ cfg_coef: tp.Optional[float] = None,
390
+ two_step_cfg: bool = False,
391
+ remove_prompts: bool = False,
392
+ check: bool = False,
393
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
394
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
395
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
396
+
397
+ Args:
398
+ prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
399
+ conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
400
+ num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
401
+ max_gen_len (int): Maximum generation length.
402
+ use_sampling (bool): Whether to use a sampling strategy or not.
403
+ temp (float): Sampling temperature.
404
+ top_k (int): K for "top-k" sampling.
405
+ top_p (float): P for "top-p" sampling.
406
+ remove_prompts (bool): Whether to remove prompts from generation or not.
407
+ Returns:
408
+ torch.Tensor: Generated tokens.
409
+ """
410
+ assert not self.training, "generation shouldn't be used in training mode."
411
+ first_param = next(iter(self.parameters()))
412
+ device = first_param.device
413
+
414
+ # Checking all input shapes are consistents.
415
+ possible_num_samples = []
416
+ if num_samples is not None:
417
+ possible_num_samples.append(num_samples)
418
+ elif prompt is not None:
419
+ possible_num_samples.append(prompt.shape[0])
420
+ elif conditions:
421
+ possible_num_samples.append(len(conditions))
422
+ else:
423
+ possible_num_samples.append(1)
424
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
425
+ num_samples = possible_num_samples[0]
426
+
427
+ # below we create set of conditions: one conditional and one unconditional
428
+ # to do that we merge the regular condition together with the null condition
429
+ # we then do 1 forward pass instead of 2.
430
+ # the reason for that is two-fold:
431
+ # 1. it is about x2 faster than doing 2 forward passes
432
+ # 2. avoid the streaming API treating the 2 passes as part of different time steps
433
+ # We also support doing two different passes, in particular to ensure that
434
+ # the padding structure is exactly the same between train anf test.
435
+ # With a batch size of 1, this can be slower though.
436
+ cfg_conditions: CFGConditions
437
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
438
+ if conditions:
439
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
440
+ if two_step_cfg:
441
+ cfg_conditions = (
442
+ self.condition_provider(self.condition_provider.tokenize(conditions)),
443
+ self.condition_provider(self.condition_provider.tokenize(null_conditions)),
444
+ )
445
+ else:
446
+ conditions = conditions + null_conditions
447
+ tokenized = self.condition_provider.tokenize(conditions)
448
+ cfg_conditions = self.condition_provider(tokenized)
449
+ else:
450
+ cfg_conditions = {}
451
+
452
+ if prompt is None:
453
+ assert num_samples > 0
454
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
455
+
456
+ B, K, T = prompt.shape
457
+ start_offset = T
458
+ assert start_offset < max_gen_len
459
+
460
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
461
+ # this token is used as default value for codes that are not generated yet
462
+ unknown_token = -1
463
+
464
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
465
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
466
+ # filling the gen_codes with the prompt if needed
467
+ gen_codes[..., :start_offset] = prompt
468
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
469
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
470
+ # retrieve the start_offset in the sequence:
471
+ # it is the first sequence step that contains the `start_offset` timestep
472
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
473
+ assert start_offset_sequence is not None
474
+
475
+ with self.streaming():
476
+ unconditional_state = self.get_streaming_state()
477
+ prev_offset = 0
478
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
479
+ for offset in range(start_offset_sequence, gen_sequence_len):
480
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
481
+ curr_sequence = gen_sequence[..., prev_offset:offset]
482
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
483
+ if check:
484
+ # check coherence between mask and sequence
485
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
486
+ # should never happen as gen_sequence is filled progressively
487
+ assert not (curr_sequence == unknown_token).any()
488
+ # sample next token from the model, next token shape is [B, K, 1]
489
+ next_token = self._sample_next_token(
490
+ curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
491
+ cfg_coef=cfg_coef)
492
+ # ensure the tokens that should be masked are properly set to special_token_id
493
+ # as the model never output special_token_id
494
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
495
+ next_token[~valid_mask] = self.special_token_id
496
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
497
+ # (then mask tokens should be left as is as well, which is correct)
498
+ gen_sequence[..., offset:offset+1] = torch.where(
499
+ gen_sequence[..., offset:offset+1] == unknown_token,
500
+ next_token, gen_sequence[..., offset:offset+1]
501
+ )
502
+ prev_offset = offset
503
+ if callback is not None:
504
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
505
+ unconditional_state.clear()
506
+
507
+ # ensure sequence has been entirely filled
508
+ assert not (gen_sequence == unknown_token).any()
509
+ # ensure gen_sequence pattern and mask are matching
510
+ # which means the gen_sequence is valid according to the pattern
511
+ assert (
512
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
513
+ ).all()
514
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
515
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
516
+
517
+ # sanity checks over the returned codes and corresponding masks
518
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
519
+ assert (out_mask[..., :max_gen_len] == 1).all()
520
+
521
+ out_start_offset = start_offset if remove_prompts else 0
522
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
523
+
524
+ # ensure the returned codes are all valid
525
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
526
+ return out_codes
audiocraft/models/loaders.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility functions to load from the checkpoints.
9
+ Each checkpoint is a torch.saved dict with the following keys:
10
+ - 'xp.cfg': the hydra config as dumped during training. This should be used
11
+ to rebuild the object using the audiocraft.models.builders functions,
12
+ - 'model_best_state': a readily loadable best state for the model, including
13
+ the conditioner. The model obtained from `xp.cfg` should be compatible
14
+ with this state dict. In the case of a LM, the encodec model would not be
15
+ bundled along but instead provided separately.
16
+
17
+ Those functions also support loading from a remote location with the Torch Hub API.
18
+ They also support overriding some parameters, in particular the device and dtype
19
+ of the returned model.
20
+ """
21
+
22
+ from pathlib import Path
23
+ from huggingface_hub import hf_hub_download
24
+ import typing as tp
25
+ import os
26
+
27
+ from omegaconf import OmegaConf
28
+ import torch
29
+
30
+ from . import builders
31
+
32
+
33
+ HF_MODEL_CHECKPOINTS_MAP = {
34
+ "small": "facebook/musicgen-small",
35
+ "medium": "facebook/musicgen-medium",
36
+ "large": "facebook/musicgen-large",
37
+ "melody": "facebook/musicgen-melody",
38
+ }
39
+
40
+
41
+ def _get_state_dict(
42
+ file_or_url_or_id: tp.Union[Path, str],
43
+ filename: tp.Optional[str] = None,
44
+ device='cpu',
45
+ cache_dir: tp.Optional[str] = None,
46
+ ):
47
+ # Return the state dict either from a file or url
48
+ file_or_url_or_id = str(file_or_url_or_id)
49
+ assert isinstance(file_or_url_or_id, str)
50
+
51
+ if os.path.isfile(file_or_url_or_id):
52
+ return torch.load(file_or_url_or_id, map_location=device)
53
+
54
+ elif file_or_url_or_id.startswith('https://'):
55
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
56
+
57
+ elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
58
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
59
+
60
+ repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
61
+ file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
62
+ return torch.load(file, map_location=device)
63
+
64
+ else:
65
+ raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
66
+
67
+
68
+ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
69
+ pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
70
+ cfg = OmegaConf.create(pkg['xp.cfg'])
71
+ cfg.device = str(device)
72
+ model = builders.get_compression_model(cfg)
73
+ model.load_state_dict(pkg['best_state'])
74
+ model.eval()
75
+ return model
76
+
77
+
78
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
79
+ pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
80
+ cfg = OmegaConf.create(pkg['xp.cfg'])
81
+ cfg.device = str(device)
82
+ if cfg.device == 'cpu':
83
+ cfg.transformer_lm.memory_efficient = False
84
+ cfg.transformer_lm.custom = True
85
+ cfg.dtype = 'float32'
86
+ else:
87
+ cfg.dtype = 'float16'
88
+ model = builders.get_lm_model(cfg)
89
+ model.load_state_dict(pkg['best_state'])
90
+ model.eval()
91
+ model.cfg = cfg
92
+ return model
audiocraft/models/musicgen.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Main model for using MusicGen. This will combine all the required components
9
+ and provide easy access to the generation API.
10
+ """
11
+
12
+ import os
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .encodec import CompressionModel
18
+ from .lm import LMModel
19
+ from .builders import get_debug_compression_model, get_debug_lm_model
20
+ from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
21
+ from ..data.audio_utils import convert_audio
22
+ from ..modules.conditioners import ConditioningAttributes, WavCondition
23
+ from ..utils.autocast import TorchAutocast
24
+
25
+
26
+ MelodyList = tp.List[tp.Optional[torch.Tensor]]
27
+ MelodyType = tp.Union[torch.Tensor, MelodyList]
28
+
29
+
30
+ class MusicGen:
31
+ """MusicGen main model with convenient generation API.
32
+
33
+ Args:
34
+ name (str): name of the model.
35
+ compression_model (CompressionModel): Compression model
36
+ used to map audio to invertible discrete representations.
37
+ lm (LMModel): Language model over discrete representations.
38
+ """
39
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel):
40
+ self.name = name
41
+ self.compression_model = compression_model
42
+ self.lm = lm
43
+ self.device = next(iter(lm.parameters())).device
44
+ self.generation_params: dict = {}
45
+ self.set_generation_params(duration=15) # 15 seconds by default
46
+ if self.device.type == 'cpu':
47
+ self.autocast = TorchAutocast(enabled=False)
48
+ else:
49
+ self.autocast = TorchAutocast(
50
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
51
+
52
+ @property
53
+ def frame_rate(self) -> int:
54
+ """Roughly the number of AR steps per seconds."""
55
+ return self.compression_model.frame_rate
56
+
57
+ @property
58
+ def sample_rate(self) -> int:
59
+ """Sample rate of the generated audio."""
60
+ return self.compression_model.sample_rate
61
+
62
+ @property
63
+ def audio_channels(self) -> int:
64
+ """Audio channels of the generated audio."""
65
+ return self.compression_model.channels
66
+
67
+ @staticmethod
68
+ def get_pretrained(name: str = 'melody', device='cuda'):
69
+ """Return pretrained model, we provide four models:
70
+ - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
71
+ - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
72
+ - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
73
+ - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
74
+ """
75
+
76
+ if name == 'debug':
77
+ # used only for unit tests
78
+ compression_model = get_debug_compression_model(device)
79
+ lm = get_debug_lm_model(device)
80
+ return MusicGen(name, compression_model, lm)
81
+
82
+ if name not in HF_MODEL_CHECKPOINTS_MAP:
83
+ raise ValueError(
84
+ f"{name} is not a valid checkpoint name. "
85
+ f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
86
+ )
87
+
88
+ cache_dir = os.environ.get('MUSICGEN_ROOT', None)
89
+ compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
90
+ lm = load_lm_model(name, device=device, cache_dir=cache_dir)
91
+
92
+ return MusicGen(name, compression_model, lm)
93
+
94
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
95
+ top_p: float = 0.0, temperature: float = 1.0,
96
+ duration: float = 30.0, cfg_coef: float = 3.0,
97
+ two_step_cfg: bool = False):
98
+ """Set the generation parameters for MusicGen.
99
+
100
+ Args:
101
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
102
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
103
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
104
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
105
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
106
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
107
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
108
+ instead of batching together the two. This has some impact on how things
109
+ are padded but seems to have little impact in practice.
110
+ """
111
+ assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
112
+ self.generation_params = {
113
+ 'max_gen_len': int(duration * self.frame_rate),
114
+ 'use_sampling': use_sampling,
115
+ 'temp': temperature,
116
+ 'top_k': top_k,
117
+ 'top_p': top_p,
118
+ 'cfg_coef': cfg_coef,
119
+ 'two_step_cfg': two_step_cfg,
120
+ }
121
+
122
+ def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
123
+ """Generate samples in an unconditional manner.
124
+
125
+ Args:
126
+ num_samples (int): Number of samples to be generated.
127
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
128
+ """
129
+ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
130
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
131
+ return self._generate_tokens(attributes, prompt_tokens, progress)
132
+
133
+ def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
134
+ """Generate samples conditioned on text.
135
+
136
+ Args:
137
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
138
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
139
+ """
140
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
141
+ assert prompt_tokens is None
142
+ return self._generate_tokens(attributes, prompt_tokens, progress)
143
+
144
+ def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
145
+ melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
146
+ """Generate samples conditioned on text and melody.
147
+
148
+ Args:
149
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
150
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
151
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
152
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
153
+ a list of [C, T] tensors.
154
+ melody_sample_rate: (int): Sample rate of the melody waveforms.
155
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
156
+ """
157
+ if isinstance(melody_wavs, torch.Tensor):
158
+ if melody_wavs.dim() == 2:
159
+ melody_wavs = melody_wavs[None]
160
+ if melody_wavs.dim() != 3:
161
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
162
+ melody_wavs = list(melody_wavs)
163
+ else:
164
+ for melody in melody_wavs:
165
+ if melody is not None:
166
+ assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
167
+
168
+ melody_wavs = [
169
+ convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
170
+ if wav is not None else None
171
+ for wav in melody_wavs]
172
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
173
+ melody_wavs=melody_wavs)
174
+ assert prompt_tokens is None
175
+ return self._generate_tokens(attributes, prompt_tokens, progress)
176
+
177
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
178
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
179
+ progress: bool = False) -> torch.Tensor:
180
+ """Generate samples conditioned on audio prompts.
181
+
182
+ Args:
183
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
184
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
185
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
186
+ descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
187
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
188
+ """
189
+ if prompt.dim() == 2:
190
+ prompt = prompt[None]
191
+ if prompt.dim() != 3:
192
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
193
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
194
+ if descriptions is None:
195
+ descriptions = [None] * len(prompt)
196
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
197
+ assert prompt_tokens is not None
198
+ return self._generate_tokens(attributes, prompt_tokens, progress)
199
+
200
+ @torch.no_grad()
201
+ def _prepare_tokens_and_attributes(
202
+ self,
203
+ descriptions: tp.Sequence[tp.Optional[str]],
204
+ prompt: tp.Optional[torch.Tensor],
205
+ melody_wavs: tp.Optional[MelodyList] = None,
206
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
207
+ """Prepare model inputs.
208
+
209
+ Args:
210
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
211
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
212
+ melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
213
+ used as melody conditioning. Defaults to None.
214
+ """
215
+ attributes = [
216
+ ConditioningAttributes(text={'description': description})
217
+ for description in descriptions]
218
+
219
+ if melody_wavs is None:
220
+ for attr in attributes:
221
+ attr.wav['self_wav'] = WavCondition(
222
+ torch.zeros((1, 1), device=self.device),
223
+ torch.tensor([0], device=self.device),
224
+ path='null_wav') # type: ignore
225
+ else:
226
+ if self.name != "melody":
227
+ raise RuntimeError("This model doesn't support melody conditioning. "
228
+ "Use the `melody` model.")
229
+ assert len(melody_wavs) == len(descriptions), \
230
+ f"number of melody wavs must match number of descriptions! " \
231
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
232
+ for attr, melody in zip(attributes, melody_wavs):
233
+ if melody is None:
234
+ attr.wav['self_wav'] = WavCondition(
235
+ torch.zeros((1, 1), device=self.device),
236
+ torch.tensor([0], device=self.device),
237
+ path='null_wav') # type: ignore
238
+ else:
239
+ attr.wav['self_wav'] = WavCondition(
240
+ melody.to(device=self.device),
241
+ torch.tensor([melody.shape[-1]], device=self.device))
242
+
243
+ if prompt is not None:
244
+ if descriptions is not None:
245
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
246
+ prompt = prompt.to(self.device)
247
+ prompt_tokens, scale = self.compression_model.encode(prompt)
248
+ assert scale is None
249
+ else:
250
+ prompt_tokens = None
251
+ return attributes, prompt_tokens
252
+
253
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
254
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
255
+ """Generate discrete audio tokens given audio prompt and/or conditions.
256
+
257
+ Args:
258
+ attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
259
+ prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
260
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
261
+ Returns:
262
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
263
+ """
264
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
265
+ print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
266
+
267
+ if prompt_tokens is not None:
268
+ assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
269
+ "Prompt is longer than audio to generate"
270
+
271
+ callback = None
272
+ if progress:
273
+ callback = _progress_callback
274
+
275
+ # generate by sampling from LM
276
+ with self.autocast:
277
+ gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
278
+
279
+ # generate audio
280
+ assert gen_tokens.dim() == 3
281
+ with torch.no_grad():
282
+ gen_audio = self.compression_model.decode(gen_tokens, None)
283
+ return gen_audio
audiocraft/modules/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .conv import (
9
+ NormConv1d,
10
+ NormConv2d,
11
+ NormConvTranspose1d,
12
+ NormConvTranspose2d,
13
+ StreamableConv1d,
14
+ StreamableConvTranspose1d,
15
+ pad_for_conv1d,
16
+ pad1d,
17
+ unpad1d,
18
+ )
19
+ from .lstm import StreamableLSTM
20
+ from .seanet import SEANetEncoder, SEANetDecoder
audiocraft/modules/activations.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from typing import Union, Callable
11
+
12
+
13
+ class CustomGLU(nn.Module):
14
+ """Custom Gated Linear Unit activation.
15
+ Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
16
+ of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
17
+ function (i.e. sigmoid, swish, etc.).
18
+
19
+ Args:
20
+ activation (nn.Module): The custom activation to apply in the Gated Linear Unit
21
+ dim (int): the dimension on which to split the input. Default: -1
22
+
23
+ Shape:
24
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
25
+ dimensions
26
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
27
+
28
+ Examples::
29
+ >>> m = CustomGLU(nn.Sigmoid())
30
+ >>> input = torch.randn(4, 2)
31
+ >>> output = m(input)
32
+ """
33
+ def __init__(self, activation: nn.Module, dim: int = -1):
34
+ super(CustomGLU, self).__init__()
35
+ self.dim = dim
36
+ self.activation = activation
37
+
38
+ def forward(self, x: Tensor):
39
+ assert x.shape[self.dim] % 2 == 0 # M = N / 2
40
+ a, b = torch.chunk(x, 2, dim=self.dim)
41
+ return a * self.activation(b)
42
+
43
+
44
+ class SwiGLU(CustomGLU):
45
+ """SiLU Gated Linear Unit activation.
46
+ Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
47
+ the first half of the input matrices, :math:`b` is the second half.
48
+
49
+ Args:
50
+ dim (int): the dimension on which to split the input. Default: -1
51
+ """
52
+ def __init__(self, dim: int = -1):
53
+ super(SwiGLU, self).__init__(nn.SiLU(), dim)
54
+
55
+
56
+ class GeGLU(CustomGLU):
57
+ """GeLU Gated Linear Unit activation.
58
+ Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
59
+ the first half of the input matrices, :math:`b` is the second half.
60
+
61
+ Args:
62
+ dim (int): the dimension on which to split the input. Default: -1
63
+ """
64
+ def __init__(self, dim: int = -1):
65
+ super(GeGLU, self).__init__(nn.GELU(), dim)
66
+
67
+
68
+ class ReGLU(CustomGLU):
69
+ """ReLU Gated Linear Unit activation.
70
+ Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
71
+ the first half of the input matrices, :math:`b` is the second half.
72
+
73
+ Args:
74
+ dim (int): the dimension on which to split the input. Default: -1
75
+ """
76
+ def __init__(self, dim: int = -1):
77
+ super(ReGLU, self).__init__(nn.ReLU(), dim)
78
+
79
+
80
+ def get_activation_fn(
81
+ activation: Union[str, Callable[[Tensor], Tensor]]
82
+ ) -> Union[str, Callable[[Tensor], Tensor]]:
83
+ """Helper function to map an activation string to the activation class.
84
+ If the supplied activation is not a string that is recognized, the activation is passed back.
85
+
86
+ Args:
87
+ activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
88
+ """
89
+ if isinstance(activation, str):
90
+ if activation == "reglu":
91
+ return ReGLU()
92
+ elif activation == "geglu":
93
+ return GeGLU()
94
+ elif activation == "swiglu":
95
+ return SwiGLU()
96
+ return activation
audiocraft/modules/codebooks_patterns.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ import logging
11
+ import typing as tp
12
+
13
+ from abc import ABC, abstractmethod
14
+ import torch
15
+
16
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
17
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class Pattern:
23
+ """Base implementation of a pattern over a sequence with multiple codebooks.
24
+
25
+ The codebook pattern consists in a layout, defining for each sequence step
26
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
27
+ The first item of the pattern is always an empty list in order to properly insert a special token
28
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
29
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
30
+
31
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
32
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
33
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
34
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
35
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
36
+ is returned along with a mask indicating valid tokens.
37
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
38
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
39
+ to fill and specify invalid positions if needed.
40
+ See the dedicated methods for more details.
41
+ """
42
+ # Pattern layout, for each sequence step, we have a list of coordinates
43
+ # corresponding to the original codebook timestep and position.
44
+ # The first list is always an empty list in order to properly insert
45
+ # a special token to start with.
46
+ layout: PatternLayout
47
+ timesteps: int
48
+ n_q: int
49
+
50
+ def __post_init__(self):
51
+ assert len(self.layout) > 0
52
+ assert self.layout[0] == []
53
+ self._validate_layout()
54
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
55
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
56
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
57
+
58
+ def _validate_layout(self):
59
+ """Runs checks on the layout to ensure a valid pattern is defined.
60
+ A pattern is considered invalid if:
61
+ - Multiple timesteps for a same codebook are defined in the same sequence step
62
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
63
+ (this would mean that we have future timesteps before past timesteps).
64
+ """
65
+ q_timesteps = {q: 0 for q in range(self.n_q)}
66
+ for s, seq_coords in enumerate(self.layout):
67
+ if len(seq_coords) > 0:
68
+ qs = set()
69
+ for coord in seq_coords:
70
+ qs.add(coord.q)
71
+ last_q_timestep = q_timesteps[coord.q]
72
+ assert coord.t >= last_q_timestep, \
73
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
74
+ q_timesteps[coord.q] = coord.t
75
+ # each sequence step contains at max 1 coordinate per codebook
76
+ assert len(qs) == len(seq_coords), \
77
+ f"Multiple entries for a same codebook are found at step {s}"
78
+
79
+ @property
80
+ def num_sequence_steps(self):
81
+ return len(self.layout) - 1
82
+
83
+ @property
84
+ def max_delay(self):
85
+ max_t_in_seq_coords = 0
86
+ for seq_coords in self.layout[1:]:
87
+ for coords in seq_coords:
88
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
89
+ return max_t_in_seq_coords - self.timesteps
90
+
91
+ @property
92
+ def valid_layout(self):
93
+ valid_step = len(self.layout) - self.max_delay
94
+ return self.layout[:valid_step]
95
+
96
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
97
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
98
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
99
+ and the actual codebook coordinates.
100
+ """
101
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
102
+ if q is not None:
103
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
104
+ coords = []
105
+ for s, seq_codes in enumerate(self.layout):
106
+ for code in seq_codes:
107
+ if code.t == t and (q is None or code.q == q):
108
+ coords.append((s, code))
109
+ return coords
110
+
111
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
112
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
113
+
114
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
115
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
116
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
117
+
118
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
119
+ device: tp.Union[torch.device, str] = 'cpu'):
120
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
121
+
122
+ Args:
123
+ timesteps (int): Maximum number of timesteps steps to consider.
124
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
125
+ device (Union[torch.device, str]): Device for created tensors.
126
+ Returns:
127
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
128
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
129
+ """
130
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
131
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
132
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
133
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
134
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
135
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
136
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
137
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
138
+ # fill indexes with last sequence step value that will correspond to our special token
139
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
140
+ # which will correspond to the index: n_q * timesteps
141
+ indexes[:] = n_q * timesteps
142
+ # iterate over the pattern and fill scattered indexes and mask
143
+ for s, sequence_coords in enumerate(ref_layout):
144
+ for coords in sequence_coords:
145
+ if coords.t < timesteps:
146
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
147
+ mask[coords.q, s] = 1
148
+ indexes = torch.from_numpy(indexes).to(device)
149
+ mask = torch.from_numpy(mask).to(device)
150
+ return indexes, mask
151
+
152
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
153
+ """Build sequence corresponding to the pattern from the input tensor z.
154
+ The sequence is built using up to sequence_steps if specified, and non-pattern
155
+ coordinates are filled with the special token.
156
+
157
+ Args:
158
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
159
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
160
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
161
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
162
+ Returns:
163
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
164
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
165
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
166
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
167
+ """
168
+ B, K, T = z.shape
169
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
170
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
171
+ )
172
+ z = z.view(B, -1)
173
+ # we append the special token as the last index of our flattened z tensor
174
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
175
+ values = z[:, indexes.view(-1)]
176
+ values = values.view(B, K, indexes.shape[-1])
177
+ return values, indexes, mask
178
+
179
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
180
+ keep_only_valid_steps: bool = False,
181
+ is_model_output: bool = False,
182
+ device: tp.Union[torch.device, str] = 'cpu'):
183
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
184
+ from interleaving pattern.
185
+
186
+ Args:
187
+ sequence_steps (int): Sequence steps.
188
+ n_q (int): Number of codebooks.
189
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
190
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
191
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
192
+ device (Union[torch.device, str]): Device for created tensors.
193
+ Returns:
194
+ torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
195
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
196
+ """
197
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
198
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
199
+ timesteps = self.timesteps
200
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
201
+ assert sequence_steps <= len(ref_layout), \
202
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
203
+
204
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
205
+ if is_model_output:
206
+ ref_layout = ref_layout[1:]
207
+
208
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
209
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
210
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
211
+ # fill indexes with last sequence step value that will correspond to our special token
212
+ indexes[:] = n_q * sequence_steps
213
+ for s, sequence_codes in enumerate(ref_layout):
214
+ if s < sequence_steps:
215
+ for code in sequence_codes:
216
+ if code.t < timesteps:
217
+ indexes[code.q, code.t] = s + code.q * sequence_steps
218
+ mask[code.q, code.t] = 1
219
+ indexes = torch.from_numpy(indexes).to(device)
220
+ mask = torch.from_numpy(mask).to(device)
221
+ return indexes, mask
222
+
223
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
224
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
225
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
226
+ are filled with the special token.
227
+
228
+ Args:
229
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
230
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
231
+ Returns:
232
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
233
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
234
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
235
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
236
+ """
237
+ B, K, S = s.shape
238
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
239
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
240
+ )
241
+ s = s.view(B, -1)
242
+ # we append the special token as the last index of our flattened z tensor
243
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
244
+ values = s[:, indexes.view(-1)]
245
+ values = values.view(B, K, indexes.shape[-1])
246
+ return values, indexes, mask
247
+
248
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
249
+ """Revert model logits obtained on a sequence built from the pattern
250
+ back to a tensor matching the original sequence.
251
+
252
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
253
+ 1. It is designed to work with the extra cardinality dimension
254
+ 2. We return the logits for the first sequence item that matches the special_token and
255
+ which matching target in the original sequence is the first item of the sequence,
256
+ while we skip the last logits as there is no matching target
257
+ """
258
+ B, card, K, S = logits.shape
259
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
260
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
261
+ )
262
+ logits = logits.reshape(B, card, -1)
263
+ # we append the special token as the last index of our flattened z tensor
264
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
265
+ values = logits[:, :, indexes.view(-1)]
266
+ values = values.view(B, card, K, indexes.shape[-1])
267
+ return values, indexes, mask
268
+
269
+
270
+ class CodebooksPatternProvider(ABC):
271
+ """Abstraction around providing pattern for interleaving codebooks.
272
+
273
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
274
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
275
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
276
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
277
+ can be used to construct a new sequence from the original codes respecting the specified
278
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
279
+ being a tuple with the original timestep and codebook to build the new sequence.
280
+ Note that all patterns must start with an empty list that is then used to insert a first
281
+ sequence step of special tokens in the newly generated sequence.
282
+
283
+ Args:
284
+ n_q (int): number of codebooks.
285
+ cached (bool): if True, patterns for a given length are cached. In general
286
+ that should be true for efficiency reason to avoid synchronization points.
287
+ """
288
+ def __init__(self, n_q: int, cached: bool = True):
289
+ assert n_q > 0
290
+ self.n_q = n_q
291
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
292
+
293
+ @abstractmethod
294
+ def get_pattern(self, timesteps: int) -> Pattern:
295
+ """Builds pattern with specific interleaving between codebooks.
296
+
297
+ Args:
298
+ timesteps (int): Total numer of timesteps.
299
+ """
300
+ raise NotImplementedError()
301
+
302
+
303
+ class DelayedPatternProvider(CodebooksPatternProvider):
304
+ """Provider for delayed pattern across delayed codebooks.
305
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
306
+ from different timesteps.
307
+
308
+ Example:
309
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
310
+ [[1, 2, 3, 4],
311
+ [1, 2, 3, 4],
312
+ [1, 2, 3, 4]]
313
+ The resulting sequence obtained from the returned pattern is:
314
+ [[S, 1, 2, 3, 4],
315
+ [S, S, 1, 2, 3],
316
+ [S, S, S, 1, 2]]
317
+ (with S being a special token)
318
+
319
+ Args:
320
+ n_q (int): Number of codebooks.
321
+ delays (Optional[List[int]]): Delay for each of the codebooks.
322
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
323
+ flatten_first (int): Flatten the first N timesteps.
324
+ empty_initial (int): Prepend with N empty list of coordinates.
325
+ """
326
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
327
+ flatten_first: int = 0, empty_initial: int = 0):
328
+ super().__init__(n_q)
329
+ if delays is None:
330
+ delays = list(range(n_q))
331
+ self.delays = delays
332
+ self.flatten_first = flatten_first
333
+ self.empty_initial = empty_initial
334
+ assert len(self.delays) == self.n_q
335
+ assert sorted(self.delays) == self.delays
336
+
337
+ def get_pattern(self, timesteps: int) -> Pattern:
338
+ out: PatternLayout = [[]]
339
+ max_delay = max(self.delays)
340
+ if self.empty_initial:
341
+ out += [[] for _ in range(self.empty_initial)]
342
+ if self.flatten_first:
343
+ for t in range(min(timesteps, self.flatten_first)):
344
+ for q in range(self.n_q):
345
+ out.append([LayoutCoord(t, q)])
346
+ for t in range(self.flatten_first, timesteps + max_delay):
347
+ v = []
348
+ for q, delay in enumerate(self.delays):
349
+ t_for_q = t - delay
350
+ if t_for_q >= self.flatten_first:
351
+ v.append(LayoutCoord(t_for_q, q))
352
+ out.append(v)
353
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
354
+
355
+
356
+ class ParallelPatternProvider(DelayedPatternProvider):
357
+ """Provider for parallel pattern across codebooks.
358
+ This pattern provider is a special case of the delayed pattern with actually no delay,
359
+ hence delays=repeat(0, n_q).
360
+
361
+ Args:
362
+ n_q (int): Number of codebooks.
363
+ """
364
+ def __init__(self, n_q: int):
365
+ super().__init__(n_q, [0] * n_q)
366
+
367
+
368
+ class UnrolledPatternProvider(CodebooksPatternProvider):
369
+ """Provider for unrolling codebooks pattern.
370
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
371
+ while also specifying a given delay between the flattened codebooks representation, allowing to
372
+ unroll the codebooks in the sequence.
373
+
374
+ Example:
375
+ 1. Flattening of the codebooks.
376
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
377
+ taking n_q = 3 and timesteps = 4:
378
+ [[1, 2, 3, 4],
379
+ [1, 2, 3, 4],
380
+ [1, 2, 3, 4]]
381
+ will result into:
382
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
383
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
384
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
385
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
386
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
387
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
388
+ [[1, 2, 3, 4],
389
+ [1, 2, 3, 4],
390
+ [1, 2, 3, 4]]
391
+ will result into:
392
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
393
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
394
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
395
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
396
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
397
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
398
+ and delays = [0, 3, 3]:
399
+ [[1, 2, 3, 4],
400
+ [1, 2, 3, 4],
401
+ [1, 2, 3, 4]]
402
+ will result into:
403
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
404
+ [S, S, S, 1, S, 2, S, 3, S, 4],
405
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
406
+
407
+ Args:
408
+ n_q (int): Number of codebooks.
409
+ flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
410
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
411
+ have n_q extra steps for each timestep.
412
+ delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
413
+ no delay is added and therefore will default to [0] * ``n_q``.
414
+ Note that two codebooks that will be flattened to the same inner step
415
+ should have the same delay, otherwise the pattern is considered as invalid.
416
+ """
417
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
418
+
419
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
420
+ delays: tp.Optional[tp.List[int]] = None):
421
+ super().__init__(n_q)
422
+ if flattening is None:
423
+ flattening = list(range(n_q))
424
+ if delays is None:
425
+ delays = [0] * n_q
426
+ assert len(flattening) == n_q
427
+ assert len(delays) == n_q
428
+ assert sorted(flattening) == flattening
429
+ assert sorted(delays) == delays
430
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
431
+ self.max_delay = max(delays)
432
+
433
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
434
+ """Build a flattened codebooks representation as a dictionary of inner step
435
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
436
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
437
+ """
438
+ flattened_codebooks: dict = {}
439
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
440
+ if inner_step not in flattened_codebooks:
441
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
442
+ else:
443
+ flat_codebook = flattened_codebooks[inner_step]
444
+ assert flat_codebook.delay == delay, (
445
+ "Delay and flattening between codebooks is inconsistent: ",
446
+ "two codebooks flattened to the same position should have the same delay."
447
+ )
448
+ flat_codebook.codebooks.append(q)
449
+ flattened_codebooks[inner_step] = flat_codebook
450
+ return flattened_codebooks
451
+
452
+ @property
453
+ def _num_inner_steps(self):
454
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
455
+ """
456
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
457
+
458
+ def num_virtual_steps(self, timesteps: int) -> int:
459
+ return timesteps * self._num_inner_steps + 1
460
+
461
+ def get_pattern(self, timesteps: int) -> Pattern:
462
+ """Builds pattern for delay across codebooks.
463
+
464
+ Args:
465
+ timesteps (int): Total numer of timesteps.
466
+ """
467
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
468
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
469
+ indexed_out: list = [(-1, [])]
470
+ max_timesteps = timesteps + self.max_delay
471
+ for t in range(max_timesteps):
472
+ # for each timestep, we unroll the flattened codebooks,
473
+ # emitting the sequence step with the corresponding delay
474
+ for step in range(self._num_inner_steps):
475
+ if step in self._flattened_codebooks:
476
+ # we have codebooks at this virtual step to emit
477
+ step_codebooks = self._flattened_codebooks[step]
478
+ t_for_q = t + step_codebooks.delay
479
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
480
+ if t_for_q < max_timesteps and t < max_timesteps:
481
+ indexed_out.append((t_for_q, coords))
482
+ else:
483
+ # there is no codebook in this virtual step so we emit an empty list
484
+ indexed_out.append((t, []))
485
+ out = [coords for _, coords in sorted(indexed_out)]
486
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
487
+
488
+
489
+ class VALLEPattern(CodebooksPatternProvider):
490
+ """Almost VALL-E style pattern. We futher allow some delays for the
491
+ codebooks other than the first one.
492
+
493
+ Args:
494
+ n_q (int): Number of codebooks.
495
+ delays (Optional[List[int]]): Delay for each of the codebooks.
496
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
497
+ """
498
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
499
+ super().__init__(n_q)
500
+ if delays is None:
501
+ delays = [0] * (n_q - 1)
502
+ self.delays = delays
503
+ assert len(self.delays) == self.n_q - 1
504
+ assert sorted(self.delays) == self.delays
505
+
506
+ def get_pattern(self, timesteps: int) -> Pattern:
507
+ out: PatternLayout = [[]]
508
+ for t in range(timesteps):
509
+ out.append([LayoutCoord(t, 0)])
510
+ max_delay = max(self.delays)
511
+ for t in range(timesteps + max_delay):
512
+ v = []
513
+ for q, delay in enumerate(self.delays):
514
+ t_for_q = t - delay
515
+ if t_for_q >= 0:
516
+ v.append(LayoutCoord(t_for_q, q + 1))
517
+ out.append(v)
518
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
519
+
520
+
521
+ class MusicLMPattern(CodebooksPatternProvider):
522
+ """Almost MusicLM style pattern. This is equivalent to full flattening
523
+ but in a different order.
524
+
525
+ Args:
526
+ n_q (int): Number of codebooks.
527
+ group_by (int): Number of codebooks to group together.
528
+ """
529
+ def __init__(self, n_q: int, group_by: int = 2):
530
+ super().__init__(n_q)
531
+ self.group_by = group_by
532
+
533
+ def get_pattern(self, timesteps: int) -> Pattern:
534
+ out: PatternLayout = [[]]
535
+ for offset in range(0, self.n_q, self.group_by):
536
+ for t in range(timesteps):
537
+ for q in range(offset, offset + self.group_by):
538
+ out.append([LayoutCoord(t, q)])
539
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
audiocraft/modules/conditioners.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from copy import deepcopy
9
+ from dataclasses import dataclass, field
10
+ from itertools import chain
11
+ import logging
12
+ import random
13
+ import re
14
+ import typing as tp
15
+ import warnings
16
+
17
+ from einops import rearrange
18
+ from num2words import num2words
19
+ import spacy
20
+ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
21
+ import torchaudio
22
+ import torch
23
+ from torch import nn
24
+ from torch import Tensor
25
+ import torch.nn.functional as F
26
+ from torch.nn.utils.rnn import pad_sequence
27
+
28
+ from .streaming import StreamingModule
29
+ from .transformer import create_sin_embedding
30
+ from ..data.audio_dataset import SegmentInfo
31
+ from ..utils.autocast import TorchAutocast
32
+ from ..utils.utils import hash_trick, length_to_mask, collate
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+ TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
37
+ ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask
38
+
39
+
40
+ class WavCondition(tp.NamedTuple):
41
+ wav: Tensor
42
+ length: Tensor
43
+ path: tp.List[tp.Optional[str]] = []
44
+
45
+
46
+ def nullify_condition(condition: ConditionType, dim: int = 1):
47
+ """This function transforms an input condition to a null condition.
48
+ The way it is done by converting it to a single zero vector similarly
49
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
50
+
51
+ Args:
52
+ condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
53
+ dim (int): the dimension that will be truncated (should be the time dimension)
54
+ WARNING!: dim should not be the batch dimension!
55
+ Returns:
56
+ ConditionType: a tuple of null condition and mask
57
+ """
58
+ assert dim != 0, "dim cannot be the batch dimension!"
59
+ assert type(condition) == tuple and \
60
+ type(condition[0]) == Tensor and \
61
+ type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
62
+ cond, mask = condition
63
+ B = cond.shape[0]
64
+ last_dim = cond.dim() - 1
65
+ out = cond.transpose(dim, last_dim)
66
+ out = 0. * out[..., :1]
67
+ out = out.transpose(dim, last_dim)
68
+ mask = torch.zeros((B, 1), device=out.device).int()
69
+ assert cond.dim() == out.dim()
70
+ return out, mask
71
+
72
+
73
+ def nullify_wav(wav: Tensor) -> WavCondition:
74
+ """Create a nullified WavCondition from a wav tensor with appropriate shape.
75
+
76
+ Args:
77
+ wav (Tensor): tensor of shape [B, T]
78
+ Returns:
79
+ WavCondition: wav condition with nullified wav.
80
+ """
81
+ null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
82
+ return WavCondition(
83
+ wav=null_wav,
84
+ length=torch.tensor([0] * wav.shape[0], device=wav.device),
85
+ path=['null_wav'] * wav.shape[0]
86
+ )
87
+
88
+
89
+ @dataclass
90
+ class ConditioningAttributes:
91
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
92
+ wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
93
+
94
+ def __getitem__(self, item):
95
+ return getattr(self, item)
96
+
97
+ @property
98
+ def text_attributes(self):
99
+ return self.text.keys()
100
+
101
+ @property
102
+ def wav_attributes(self):
103
+ return self.wav.keys()
104
+
105
+ @property
106
+ def attributes(self):
107
+ return {"text": self.text_attributes, "wav": self.wav_attributes}
108
+
109
+ def to_flat_dict(self):
110
+ return {
111
+ **{f"text.{k}": v for k, v in self.text.items()},
112
+ **{f"wav.{k}": v for k, v in self.wav.items()},
113
+ }
114
+
115
+ @classmethod
116
+ def from_flat_dict(cls, x):
117
+ out = cls()
118
+ for k, v in x.items():
119
+ kind, att = k.split(".")
120
+ out[kind][att] = v
121
+ return out
122
+
123
+
124
+ class SegmentWithAttributes(SegmentInfo):
125
+ """Base class for all dataclasses that are used for conditioning.
126
+ All child classes should implement `to_condition_attributes` that converts
127
+ the existing attributes to a dataclass of type ConditioningAttributes.
128
+ """
129
+ def to_condition_attributes(self) -> ConditioningAttributes:
130
+ raise NotImplementedError()
131
+
132
+
133
+ class Tokenizer:
134
+ """Base class for all tokenizers
135
+ (in case we want to introduce more advances tokenizers in the future).
136
+ """
137
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
138
+ raise NotImplementedError()
139
+
140
+
141
+ class WhiteSpaceTokenizer(Tokenizer):
142
+ """This tokenizer should be used for natural language descriptions.
143
+ For example:
144
+ ["he didn't, know he's going home.", 'shorter sentence'] =>
145
+ [[78, 62, 31, 4, 78, 25, 19, 34],
146
+ [59, 77, 0, 0, 0, 0, 0, 0]]
147
+ """
148
+ PUNCTUATIONS = "?:!.,;"
149
+
150
+ def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
151
+ lemma: bool = True, stopwords: bool = True) -> None:
152
+ self.n_bins = n_bins
153
+ self.pad_idx = pad_idx
154
+ self.lemma = lemma
155
+ self.stopwords = stopwords
156
+ try:
157
+ self.nlp = spacy.load(language)
158
+ except IOError:
159
+ spacy.cli.download(language) # type: ignore
160
+ self.nlp = spacy.load(language)
161
+
162
+ @tp.no_type_check
163
+ def __call__(
164
+ self,
165
+ texts: tp.List[tp.Optional[str]],
166
+ return_text: bool = False
167
+ ) -> tp.Tuple[Tensor, Tensor]:
168
+ """Take a list of strings and convert them to a tensor of indices.
169
+
170
+ Args:
171
+ texts (tp.List[str]): List of strings.
172
+ return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
173
+ Returns:
174
+ tp.Tuple[Tensor, Tensor]:
175
+ - Indices of words in the LUT.
176
+ - And a mask indicating where the padding tokens are
177
+ """
178
+ output, lengths = [], []
179
+ texts = deepcopy(texts)
180
+ for i, text in enumerate(texts):
181
+ # if current sample doesn't have a certain attribute, replace with pad token
182
+ if text is None:
183
+ output.append(Tensor([self.pad_idx]))
184
+ lengths.append(0)
185
+ continue
186
+
187
+ # convert numbers to words
188
+ text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
189
+ # normalize text
190
+ text = self.nlp(text) # type: ignore
191
+ # remove stopwords
192
+ if self.stopwords:
193
+ text = [w for w in text if not w.is_stop] # type: ignore
194
+ # remove punctuations
195
+ text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore
196
+ # lemmatize if needed
197
+ text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
198
+
199
+ texts[i] = " ".join(text)
200
+ lengths.append(len(text))
201
+ # convert to tensor
202
+ tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
203
+ output.append(tokens)
204
+
205
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
206
+ padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
207
+ if return_text:
208
+ return padded_output, mask, texts # type: ignore
209
+ return padded_output, mask
210
+
211
+
212
+ class NoopTokenizer(Tokenizer):
213
+ """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
214
+ The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
215
+ strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
216
+ split it to ["Jeff", "Buckley"] and return an index per word.
217
+
218
+ For example:
219
+ ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
220
+ ["Metal", "Rock", "Classical"] => [0, 223, 51]
221
+ """
222
+ def __init__(self, n_bins: int, pad_idx: int = 0):
223
+ self.n_bins = n_bins
224
+ self.pad_idx = pad_idx
225
+
226
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
227
+ output, lengths = [], []
228
+ for text in texts:
229
+ # if current sample doesn't have a certain attribute, replace with pad token
230
+ if text is None:
231
+ output.append(self.pad_idx)
232
+ lengths.append(0)
233
+ else:
234
+ output.append(hash_trick(text, self.n_bins))
235
+ lengths.append(1)
236
+
237
+ tokens = torch.LongTensor(output).unsqueeze(1)
238
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
239
+ return tokens, mask
240
+
241
+
242
+ class BaseConditioner(nn.Module):
243
+ """Base model for all conditioner modules. We allow the output dim to be different
244
+ than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
245
+ 2) make all condition dims consistent.
246
+
247
+ Args:
248
+ dim (int): Hidden dim of the model (text-encoder/LUT).
249
+ output_dim (int): Output dim of the conditioner.
250
+ """
251
+ def __init__(self, dim, output_dim):
252
+ super().__init__()
253
+ self.dim = dim
254
+ self.output_dim = output_dim
255
+ self.output_proj = nn.Linear(dim, output_dim)
256
+
257
+ def tokenize(self, *args, **kwargs) -> tp.Any:
258
+ """Should be any part of the processing that will lead to a synchronization
259
+ point, e.g. BPE tokenization with transfer to the GPU.
260
+
261
+ The returned value will be saved and return later when calling forward().
262
+ """
263
+ raise NotImplementedError()
264
+
265
+ def forward(self, inputs: tp.Any) -> ConditionType:
266
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
267
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
268
+
269
+ Returns:
270
+ ConditionType:
271
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
272
+ output embedding and D is the dimension of the embedding.
273
+ - And a mask indicating where the padding tokens.
274
+ """
275
+ raise NotImplementedError()
276
+
277
+
278
+ class TextConditioner(BaseConditioner):
279
+ ...
280
+
281
+
282
+ class LUTConditioner(TextConditioner):
283
+ """Lookup table TextConditioner.
284
+
285
+ Args:
286
+ n_bins (int): Number of bins.
287
+ dim (int): Hidden dim of the model (text-encoder/LUT).
288
+ output_dim (int): Output dim of the conditioner.
289
+ tokenizer (str): Name of the tokenizer.
290
+ pad_idx (int, optional): Index for padding token. Defaults to 0.
291
+ """
292
+ def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
293
+ super().__init__(dim, output_dim)
294
+ self.embed = nn.Embedding(n_bins, dim)
295
+ self.tokenizer: Tokenizer
296
+ if tokenizer == "whitespace":
297
+ self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
298
+ elif tokenizer == "noop":
299
+ self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
300
+ else:
301
+ raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
302
+
303
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
304
+ device = self.embed.weight.device
305
+ tokens, mask = self.tokenizer(x)
306
+ tokens, mask = tokens.to(device), mask.to(device)
307
+ return tokens, mask
308
+
309
+ def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
310
+ tokens, mask = inputs
311
+ embeds = self.embed(tokens)
312
+ embeds = self.output_proj(embeds)
313
+ embeds = (embeds * mask.unsqueeze(-1))
314
+ return embeds, mask
315
+
316
+
317
+ class T5Conditioner(TextConditioner):
318
+ """T5-based TextConditioner.
319
+
320
+ Args:
321
+ name (str): Name of the T5 model.
322
+ output_dim (int): Output dim of the conditioner.
323
+ finetune (bool): Whether to fine-tune T5 at train time.
324
+ device (str): Device for T5 Conditioner.
325
+ autocast_dtype (tp.Optional[str], optional): Autocast dtype.
326
+ word_dropout (float, optional): Word dropout probability.
327
+ normalize_text (bool, optional): Whether to apply text normalization.
328
+ """
329
+ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
330
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
331
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
332
+ MODELS_DIMS = {
333
+ "t5-small": 512,
334
+ "t5-base": 768,
335
+ "t5-large": 1024,
336
+ "t5-3b": 1024,
337
+ "t5-11b": 1024,
338
+ "google/flan-t5-small": 512,
339
+ "google/flan-t5-base": 768,
340
+ "google/flan-t5-large": 1024,
341
+ "google/flan-t5-3b": 1024,
342
+ "google/flan-t5-11b": 1024,
343
+ }
344
+
345
+ def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
346
+ autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
347
+ normalize_text: bool = False):
348
+ assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
349
+ super().__init__(self.MODELS_DIMS[name], output_dim)
350
+ self.device = device
351
+ self.name = name
352
+ self.finetune = finetune
353
+ self.word_dropout = word_dropout
354
+
355
+ if autocast_dtype is None or self.device == 'cpu':
356
+ self.autocast = TorchAutocast(enabled=False)
357
+ if self.device != 'cpu':
358
+ logger.warning("T5 has no autocast, this might lead to NaN")
359
+ else:
360
+ dtype = getattr(torch, autocast_dtype)
361
+ assert isinstance(dtype, torch.dtype)
362
+ logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
363
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
364
+ # Let's disable logging temporarily because T5 will vomit some errors otherwise.
365
+ # thanks https://gist.github.com/simon-weber/7853144
366
+ previous_level = logging.root.manager.disable
367
+ logging.disable(logging.ERROR)
368
+ with warnings.catch_warnings():
369
+ warnings.simplefilter("ignore")
370
+ try:
371
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
372
+ t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
373
+ finally:
374
+ logging.disable(previous_level)
375
+ if finetune:
376
+ self.t5 = t5
377
+ else:
378
+ # this makes sure that the t5 models is not part
379
+ # of the saved checkpoint
380
+ self.__dict__["t5"] = t5.to(device)
381
+
382
+ self.normalize_text = normalize_text
383
+ if normalize_text:
384
+ self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
385
+
386
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
387
+ # if current sample doesn't have a certain attribute, replace with empty string
388
+ entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
389
+ if self.normalize_text:
390
+ _, _, entries = self.text_normalizer(entries, return_text=True)
391
+ if self.word_dropout > 0. and self.training:
392
+ new_entries = []
393
+ for entry in entries:
394
+ words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
395
+ new_entries.append(" ".join(words))
396
+ entries = new_entries
397
+
398
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
399
+
400
+ inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
401
+ mask = inputs["attention_mask"]
402
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
403
+ return inputs
404
+
405
+ def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
406
+ mask = inputs["attention_mask"]
407
+ with torch.set_grad_enabled(self.finetune), self.autocast:
408
+ embeds = self.t5(**inputs).last_hidden_state
409
+ embeds = self.output_proj(embeds.to(self.output_proj.weight))
410
+ embeds = (embeds * mask.unsqueeze(-1))
411
+ return embeds, mask
412
+
413
+
414
+ class WaveformConditioner(BaseConditioner):
415
+ """Base class for all conditioners that take a waveform as input.
416
+ Classes that inherit must implement `_get_wav_embedding` that outputs
417
+ a continuous tensor, and `_downsampling_factor` that returns the down-sampling
418
+ factor of the embedding model.
419
+
420
+ Args:
421
+ dim (int): The internal representation dimension.
422
+ output_dim (int): Output dimension.
423
+ device (tp.Union[torch.device, str]): Device.
424
+ """
425
+ def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
426
+ super().__init__(dim, output_dim)
427
+ self.device = device
428
+
429
+ def tokenize(self, wav_length: WavCondition) -> WavCondition:
430
+ wav, length, path = wav_length
431
+ assert length is not None
432
+ return WavCondition(wav.to(self.device), length.to(self.device), path)
433
+
434
+ def _get_wav_embedding(self, wav: Tensor) -> Tensor:
435
+ """Gets as input a wav and returns a dense vector of conditions."""
436
+ raise NotImplementedError()
437
+
438
+ def _downsampling_factor(self):
439
+ """Returns the downsampling factor of the embedding model."""
440
+ raise NotImplementedError()
441
+
442
+ def forward(self, inputs: WavCondition) -> ConditionType:
443
+ """
444
+ Args:
445
+ input (WavCondition): Tuple of (waveform, lengths).
446
+ Returns:
447
+ ConditionType: Dense vector representing the conditioning along with its' mask.
448
+ """
449
+ wav, lengths, path = inputs
450
+ with torch.no_grad():
451
+ embeds = self._get_wav_embedding(wav)
452
+ embeds = embeds.to(self.output_proj.weight)
453
+ embeds = self.output_proj(embeds)
454
+
455
+ if lengths is not None:
456
+ lengths = lengths / self._downsampling_factor()
457
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
458
+ else:
459
+ mask = torch.ones_like(embeds)
460
+ embeds = (embeds * mask.unsqueeze(2).to(self.device))
461
+
462
+ return embeds, mask
463
+
464
+
465
+ class ChromaStemConditioner(WaveformConditioner):
466
+ """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
467
+ the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
468
+ information about melody.
469
+
470
+ Args:
471
+ output_dim (int): Output dimension for the conditioner.
472
+ sample_rate (int): Sample rate for the chroma extractor.
473
+ n_chroma (int): Number of chroma for the chroma extractor.
474
+ radix2_exp (int): Radix2 exponent for the chroma extractor.
475
+ duration (float): Duration used during training. This is later used for correct padding
476
+ in case we are using chroma as prefix.
477
+ match_len_on_eval (bool, optional): If True then all chromas are padded to the training
478
+ duration. Defaults to False.
479
+ eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
480
+ conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
481
+ Defaults to None.
482
+ n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
483
+ device (tp.Union[torch.device, str], optional): Device for the conditioner.
484
+ **kwargs: Additional parameters for the chroma extractor.
485
+ """
486
+ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
487
+ duration: float, match_len_on_eval: bool = False, eval_wavs: tp.Optional[str] = None,
488
+ n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
489
+ from demucs import pretrained
490
+ super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
491
+ self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
492
+ self.sample_rate = sample_rate
493
+ self.match_len_on_eval = match_len_on_eval
494
+ self.duration = duration
495
+ self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
496
+ self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
497
+ self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
498
+ self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
499
+ device=device, **kwargs)
500
+ self.chroma_len = self._get_chroma_len()
501
+
502
+ def _downsampling_factor(self):
503
+ return self.chroma.winhop
504
+
505
+ def _get_chroma_len(self):
506
+ """Get length of chroma during training"""
507
+ dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
508
+ dummy_chr = self.chroma(dummy_wav)
509
+ return dummy_chr.shape[1]
510
+
511
+ @torch.no_grad()
512
+ def _get_filtered_wav(self, wav):
513
+ from demucs.apply import apply_model
514
+ from demucs.audio import convert_audio
515
+ with self.autocast:
516
+ wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
517
+ stems = apply_model(self.demucs, wav, device=self.device)
518
+ stems = stems[:, self.stem_idx] # extract stem
519
+ stems = stems.sum(1) # merge extracted stems
520
+ stems = stems.mean(1, keepdim=True) # mono
521
+ stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
522
+ return stems
523
+
524
+ @torch.no_grad()
525
+ def _get_wav_embedding(self, wav):
526
+ # avoid 0-size tensors when we are working with null conds
527
+ if wav.shape[-1] == 1:
528
+ return self.chroma(wav)
529
+ stems = self._get_filtered_wav(wav)
530
+ chroma = self.chroma(stems)
531
+
532
+ if self.match_len_on_eval:
533
+ b, t, c = chroma.shape
534
+ if t > self.chroma_len:
535
+ chroma = chroma[:, :self.chroma_len]
536
+ logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
537
+ elif t < self.chroma_len:
538
+ chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
539
+ logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
540
+ return chroma
541
+
542
+
543
+ class ChromaExtractor(nn.Module):
544
+ """Chroma extraction class, handles chroma extraction and quantization.
545
+
546
+ Args:
547
+ sample_rate (int): Sample rate.
548
+ n_chroma (int): Number of chroma to consider.
549
+ radix2_exp (int): Radix2 exponent.
550
+ nfft (tp.Optional[int], optional): Number of FFT.
551
+ winlen (tp.Optional[int], optional): Window length.
552
+ winhop (tp.Optional[int], optional): Window hop size.
553
+ argmax (bool, optional): Whether to use argmax. Defaults to False.
554
+ norm (float, optional): Norm for chroma normalization. Defaults to inf.
555
+ device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
556
+ """
557
+ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
558
+ nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
559
+ argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
560
+ super().__init__()
561
+ from librosa import filters
562
+ self.device = device
563
+ self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
564
+ self.winlen = winlen or 2 ** radix2_exp
565
+ self.nfft = nfft or self.winlen
566
+ self.winhop = winhop or (self.winlen // 4)
567
+ self.sr = sample_rate
568
+ self.n_chroma = n_chroma
569
+ self.norm = norm
570
+ self.argmax = argmax
571
+ self.window = torch.hann_window(self.winlen).to(device)
572
+ self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
573
+ n_chroma=self.n_chroma)).to(device)
574
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
575
+ hop_length=self.winhop, power=2, center=True,
576
+ pad=0, normalized=True).to(device)
577
+
578
+ def forward(self, wav):
579
+ with self.autocast:
580
+ T = wav.shape[-1]
581
+ # in case we are getting a wav that was dropped out (nullified)
582
+ # make sure wav length is no less that nfft
583
+ if T < self.nfft:
584
+ pad = self.nfft - T
585
+ r = 0 if pad % 2 == 0 else 1
586
+ wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
587
+ assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
588
+ spec = self.spec(wav).squeeze(1)
589
+ raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
590
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
591
+ norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
592
+
593
+ if self.argmax:
594
+ idx = norm_chroma.argmax(-1, keepdims=True)
595
+ norm_chroma[:] = 0
596
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
597
+
598
+ return norm_chroma
599
+
600
+
601
+ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
602
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
603
+ If the condition is of type "wav", then nullify it using "nullify_condition".
604
+ If the condition is of any other type, set its' value to None.
605
+ Works in-place.
606
+ """
607
+ if condition_type not in ["text", "wav"]:
608
+ raise ValueError(
609
+ "dropout_condition got an unexpected condition type!"
610
+ f" expected 'wav' or 'text' but got '{condition_type}'"
611
+ )
612
+
613
+ if condition not in getattr(sample, condition_type):
614
+ raise ValueError(
615
+ "dropout_condition received an unexpected condition!"
616
+ f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
617
+ f"but got '{condition}' of type '{condition_type}'!"
618
+ )
619
+
620
+ if condition_type == "wav":
621
+ wav, length, path = sample.wav[condition]
622
+ sample.wav[condition] = nullify_wav(wav)
623
+ else:
624
+ sample.text[condition] = None
625
+
626
+ return sample
627
+
628
+
629
+ class DropoutModule(nn.Module):
630
+ """Base class for all dropout modules."""
631
+ def __init__(self, seed: int = 1234):
632
+ super().__init__()
633
+ self.rng = torch.Generator()
634
+ self.rng.manual_seed(seed)
635
+
636
+
637
+ class AttributeDropout(DropoutModule):
638
+ """Applies dropout with a given probability per attribute. This is different from the behavior of
639
+ ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
640
+ "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
641
+ where if "artist" is dropped "genre" must also be dropped.
642
+
643
+ Args:
644
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
645
+ ...
646
+ "genre": 0.1,
647
+ "artist": 0.5,
648
+ "wav": 0.25,
649
+ ...
650
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
651
+ seed (int, optional): Random seed.
652
+ """
653
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
654
+ super().__init__(seed=seed)
655
+ self.active_on_eval = active_on_eval
656
+ # construct dict that return the values from p otherwise 0
657
+ self.p = {}
658
+ for condition_type, probs in p.items():
659
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
660
+
661
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
662
+ """
663
+ Args:
664
+ samples (tp.List[ConditioningAttributes]): List of conditions.
665
+ Returns:
666
+ tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
667
+ """
668
+ if not self.training and not self.active_on_eval:
669
+ return samples
670
+
671
+ samples = deepcopy(samples)
672
+
673
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
674
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
675
+ if torch.rand(1, generator=self.rng).item() < p:
676
+ for sample in samples:
677
+ dropout_condition(sample, condition_type, condition)
678
+
679
+ return samples
680
+
681
+ def __repr__(self):
682
+ return f"AttributeDropout({dict(self.p)})"
683
+
684
+
685
+ class ClassifierFreeGuidanceDropout(DropoutModule):
686
+ """Applies Classifier Free Guidance dropout, meaning all attributes
687
+ are dropped with the same probability.
688
+
689
+ Args:
690
+ p (float): Probability to apply condition dropout during training.
691
+ seed (int): Random seed.
692
+ """
693
+ def __init__(self, p: float, seed: int = 1234):
694
+ super().__init__(seed=seed)
695
+ self.p = p
696
+
697
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
698
+ """
699
+ Args:
700
+ samples (tp.List[ConditioningAttributes]): List of conditions.
701
+ Returns:
702
+ tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
703
+ """
704
+ if not self.training:
705
+ return samples
706
+
707
+ # decide on which attributes to drop in a batched fashion
708
+ drop = torch.rand(1, generator=self.rng).item() < self.p
709
+ if not drop:
710
+ return samples
711
+
712
+ # nullify conditions of all attributes
713
+ samples = deepcopy(samples)
714
+
715
+ for condition_type in ["wav", "text"]:
716
+ for sample in samples:
717
+ for condition in sample.attributes[condition_type]:
718
+ dropout_condition(sample, condition_type, condition)
719
+
720
+ return samples
721
+
722
+ def __repr__(self):
723
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
724
+
725
+
726
+ class ConditioningProvider(nn.Module):
727
+ """Main class to provide conditions given all the supported conditioners.
728
+
729
+ Args:
730
+ conditioners (dict): Dictionary of conditioners.
731
+ merge_text_conditions_p (float, optional): Probability to merge all text sources
732
+ into a single text condition. Defaults to 0.
733
+ drop_desc_p (float, optional): Probability to drop the original description
734
+ when merging all text sources into a single text condition. Defaults to 0.
735
+ device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
736
+ """
737
+ def __init__(
738
+ self,
739
+ conditioners: tp.Dict[str, BaseConditioner],
740
+ merge_text_conditions_p: float = 0,
741
+ drop_desc_p: float = 0,
742
+ device: tp.Union[torch.device, str] = "cpu",
743
+ ):
744
+ super().__init__()
745
+ self.device = device
746
+ self.merge_text_conditions_p = merge_text_conditions_p
747
+ self.drop_desc_p = drop_desc_p
748
+ self.conditioners = nn.ModuleDict(conditioners)
749
+
750
+ @property
751
+ def text_conditions(self):
752
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
753
+
754
+ @property
755
+ def wav_conditions(self):
756
+ return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
757
+
758
+ @property
759
+ def has_wav_condition(self):
760
+ return len(self.wav_conditions) > 0
761
+
762
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
763
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
764
+ This should be called before starting any real GPU work to avoid synchronization points.
765
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
766
+
767
+ Args:
768
+ inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
769
+ text and wav conditions.
770
+ """
771
+ assert all([type(x) == ConditioningAttributes for x in inputs]), \
772
+ "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
773
+ f" but types were {set([type(x) for x in inputs])}"
774
+
775
+ output = {}
776
+ text = self._collate_text(inputs)
777
+ wavs = self._collate_wavs(inputs)
778
+
779
+ assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
780
+ f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
781
+
782
+ for attribute, batch in chain(text.items(), wavs.items()):
783
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
784
+ return output
785
+
786
+ def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
787
+ """Compute pairs of `(embedding, mask)` using the configured conditioners
788
+ and the tokenized representations. The output is for example:
789
+
790
+ {
791
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
792
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
793
+ ...
794
+ }
795
+
796
+ Args:
797
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
798
+ """
799
+ output = {}
800
+ for attribute, inputs in tokenized.items():
801
+ condition, mask = self.conditioners[attribute](inputs)
802
+ output[attribute] = (condition, mask)
803
+ return output
804
+
805
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
806
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
807
+ are the attributes and the values are the aggregated input per attribute.
808
+ For example:
809
+ Input:
810
+ [
811
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
812
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
813
+ ]
814
+ Output:
815
+ {
816
+ "genre": ["Rock", "Hip-hop"],
817
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
818
+ }
819
+ """
820
+ batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
821
+
822
+ def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
823
+ def is_valid(k, v):
824
+ k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
825
+ v_valid = v is not None and isinstance(v, (int, float, str, list))
826
+ return k_valid and v_valid
827
+
828
+ def process_value(v):
829
+ if isinstance(v, (int, float, str)):
830
+ return v
831
+ if isinstance(v, list):
832
+ return ", ".join(v)
833
+ else:
834
+ RuntimeError(f"unknown type for text value! ({type(v), v})")
835
+
836
+ desc = cond.text['description']
837
+ meta_data = ""
838
+ if random.uniform(0, 1) < merge_text_conditions_p:
839
+ meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
840
+ random.shuffle(meta_pairs)
841
+ meta_data = ". ".join(meta_pairs)
842
+ desc = desc if not random.uniform(0, 1) < drop_desc_p else None
843
+
844
+ if desc is None:
845
+ desc = meta_data if len(meta_data) > 1 else None
846
+ else:
847
+ desc = desc.rstrip('.') + ". " + meta_data
848
+ cond.text['description'] = desc.strip() if desc else None
849
+
850
+ if self.training and self.merge_text_conditions_p:
851
+ for sample in samples:
852
+ _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
853
+
854
+ texts = [x.text for x in samples]
855
+ for text in texts:
856
+ for condition in self.text_conditions:
857
+ batch_per_attribute[condition].append(text[condition])
858
+
859
+ return batch_per_attribute
860
+
861
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
862
+ """Generate a dict where the keys are attributes by which we fetch similar wavs,
863
+ and the values are Tensors of wavs according to said attribtues.
864
+
865
+ *Note*: by the time the samples reach this function, each sample should have some waveform
866
+ inside the "wav" attribute. It should be either:
867
+ 1. A real waveform
868
+ 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
869
+ 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
870
+
871
+ Args:
872
+ samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
873
+ Returns:
874
+ dict: A dicionary mapping an attribute name to wavs.
875
+ """
876
+ wavs = defaultdict(list)
877
+ lens = defaultdict(list)
878
+ paths = defaultdict(list)
879
+ out = {}
880
+
881
+ for sample in samples:
882
+ for attribute in self.wav_conditions:
883
+ wav, length, path = sample.wav[attribute]
884
+ wavs[attribute].append(wav.flatten())
885
+ lens[attribute].append(length)
886
+ paths[attribute].append(path)
887
+
888
+ # stack all wavs to a single tensor
889
+ for attribute in self.wav_conditions:
890
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
891
+ out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
892
+ torch.cat(lens['self_wav']), paths[attribute]) # type: ignore
893
+
894
+ return out
895
+
896
+
897
+ class ConditionFuser(StreamingModule):
898
+ """Condition fuser handles the logic to combine the different conditions
899
+ to the actual model input.
900
+
901
+ Args:
902
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
903
+ each condition. For example:
904
+ {
905
+ "prepend": ["description"],
906
+ "sum": ["genre", "bpm"],
907
+ "cross": ["description"],
908
+ }
909
+ cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
910
+ cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
911
+ """
912
+ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
913
+
914
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
915
+ cross_attention_pos_emb_scale: float = 1.0):
916
+ super().__init__()
917
+ assert all(
918
+ [k in self.FUSING_METHODS for k in fuse2cond.keys()]
919
+ ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
920
+ self.cross_attention_pos_emb = cross_attention_pos_emb
921
+ self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
922
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
923
+ self.cond2fuse: tp.Dict[str, str] = {}
924
+ for fuse_method, conditions in fuse2cond.items():
925
+ for condition in conditions:
926
+ self.cond2fuse[condition] = fuse_method
927
+
928
+ def forward(
929
+ self,
930
+ input: Tensor,
931
+ conditions: tp.Dict[str, ConditionType]
932
+ ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
933
+ """Fuse the conditions to the provided model input.
934
+
935
+ Args:
936
+ input (Tensor): Transformer input.
937
+ conditions (tp.Dict[str, ConditionType]): Dict of conditions.
938
+ Returns:
939
+ tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
940
+ after the conditions have been fused. The second output tensor is the tensor
941
+ used for cross-attention or None if no cross attention inputs exist.
942
+ """
943
+ B, T, _ = input.shape
944
+
945
+ if 'offsets' in self._streaming_state:
946
+ first_step = False
947
+ offsets = self._streaming_state['offsets']
948
+ else:
949
+ first_step = True
950
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
951
+
952
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
953
+ f"given conditions contain unknown attributes for fuser, " \
954
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
955
+ cross_attention_output = None
956
+ for cond_type, (cond, cond_mask) in conditions.items():
957
+ op = self.cond2fuse[cond_type]
958
+ if op == "sum":
959
+ input += cond
960
+ elif op == "input_interpolate":
961
+ cond = rearrange(cond, "b t d -> b d t")
962
+ cond = F.interpolate(cond, size=input.shape[1])
963
+ input += rearrange(cond, "b d t -> b t d")
964
+ elif op == "prepend":
965
+ if first_step:
966
+ input = torch.cat([cond, input], dim=1)
967
+ elif op == "cross":
968
+ if cross_attention_output is not None:
969
+ cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
970
+ else:
971
+ cross_attention_output = cond
972
+ else:
973
+ raise ValueError(f"unknown op ({op})")
974
+
975
+ if self.cross_attention_pos_emb and cross_attention_output is not None:
976
+ positions = torch.arange(
977
+ cross_attention_output.shape[1],
978
+ device=cross_attention_output.device
979
+ ).view(1, -1, 1)
980
+ pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
981
+ cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
982
+
983
+ if self._is_streaming:
984
+ self._streaming_state['offsets'] = offsets + T
985
+
986
+ return input, cross_attention_output
audiocraft/modules/conv.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+ import warnings
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from torch.nn.utils import spectral_norm, weight_norm
15
+
16
+
17
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
18
+ 'time_group_norm'])
19
+
20
+
21
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
22
+ assert norm in CONV_NORMALIZATIONS
23
+ if norm == 'weight_norm':
24
+ return weight_norm(module)
25
+ elif norm == 'spectral_norm':
26
+ return spectral_norm(module)
27
+ else:
28
+ # We already check was in CONV_NORMALIZATION, so any other choice
29
+ # doesn't need reparametrization.
30
+ return module
31
+
32
+
33
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
34
+ """Return the proper normalization module. If causal is True, this will ensure the returned
35
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
36
+ """
37
+ assert norm in CONV_NORMALIZATIONS
38
+ if norm == 'time_group_norm':
39
+ if causal:
40
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
41
+ assert isinstance(module, nn.modules.conv._ConvNd)
42
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
43
+ else:
44
+ return nn.Identity()
45
+
46
+
47
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
48
+ padding_total: int = 0) -> int:
49
+ """See `pad_for_conv1d`.
50
+ """
51
+ length = x.shape[-1]
52
+ n_frames = (length - kernel_size + padding_total) / stride + 1
53
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
54
+ return ideal_length - length
55
+
56
+
57
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
58
+ """Pad for a convolution to make sure that the last window is full.
59
+ Extra padding is added at the end. This is required to ensure that we can rebuild
60
+ an output of the same length, as otherwise, even with padding, some time steps
61
+ might get removed.
62
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
63
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
64
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
65
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
66
+ 1 2 3 4 # once you removed padding, we are missing one time step !
67
+ """
68
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
69
+ return F.pad(x, (0, extra_padding))
70
+
71
+
72
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
73
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
74
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
75
+ """
76
+ length = x.shape[-1]
77
+ padding_left, padding_right = paddings
78
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
79
+ if mode == 'reflect':
80
+ max_pad = max(padding_left, padding_right)
81
+ extra_pad = 0
82
+ if length <= max_pad:
83
+ extra_pad = max_pad - length + 1
84
+ x = F.pad(x, (0, extra_pad))
85
+ padded = F.pad(x, paddings, mode, value)
86
+ end = padded.shape[-1] - extra_pad
87
+ return padded[..., :end]
88
+ else:
89
+ return F.pad(x, paddings, mode, value)
90
+
91
+
92
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
93
+ """Remove padding from x, handling properly zero padding. Only for 1d!
94
+ """
95
+ padding_left, padding_right = paddings
96
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
97
+ assert (padding_left + padding_right) <= x.shape[-1]
98
+ end = x.shape[-1] - padding_right
99
+ return x[..., padding_left: end]
100
+
101
+
102
+ class NormConv1d(nn.Module):
103
+ """Wrapper around Conv1d and normalization applied to this conv
104
+ to provide a uniform interface across normalization approaches.
105
+ """
106
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
107
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
108
+ super().__init__()
109
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
110
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
111
+ self.norm_type = norm
112
+
113
+ def forward(self, x):
114
+ x = self.conv(x)
115
+ x = self.norm(x)
116
+ return x
117
+
118
+
119
+ class NormConv2d(nn.Module):
120
+ """Wrapper around Conv2d and normalization applied to this conv
121
+ to provide a uniform interface across normalization approaches.
122
+ """
123
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
124
+ super().__init__()
125
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
126
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
127
+ self.norm_type = norm
128
+
129
+ def forward(self, x):
130
+ x = self.conv(x)
131
+ x = self.norm(x)
132
+ return x
133
+
134
+
135
+ class NormConvTranspose1d(nn.Module):
136
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
137
+ to provide a uniform interface across normalization approaches.
138
+ """
139
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
140
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
141
+ super().__init__()
142
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
143
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
144
+ self.norm_type = norm
145
+
146
+ def forward(self, x):
147
+ x = self.convtr(x)
148
+ x = self.norm(x)
149
+ return x
150
+
151
+
152
+ class NormConvTranspose2d(nn.Module):
153
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
154
+ to provide a uniform interface across normalization approaches.
155
+ """
156
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
157
+ super().__init__()
158
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
159
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
160
+
161
+ def forward(self, x):
162
+ x = self.convtr(x)
163
+ x = self.norm(x)
164
+ return x
165
+
166
+
167
+ class StreamableConv1d(nn.Module):
168
+ """Conv1d with some builtin handling of asymmetric or causal padding
169
+ and normalization.
170
+ """
171
+ def __init__(self, in_channels: int, out_channels: int,
172
+ kernel_size: int, stride: int = 1, dilation: int = 1,
173
+ groups: int = 1, bias: bool = True, causal: bool = False,
174
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
175
+ pad_mode: str = 'reflect'):
176
+ super().__init__()
177
+ # warn user on unusual setup between dilation and stride
178
+ if stride > 1 and dilation > 1:
179
+ warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
180
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
181
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
182
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
183
+ norm=norm, norm_kwargs=norm_kwargs)
184
+ self.causal = causal
185
+ self.pad_mode = pad_mode
186
+
187
+ def forward(self, x):
188
+ B, C, T = x.shape
189
+ kernel_size = self.conv.conv.kernel_size[0]
190
+ stride = self.conv.conv.stride[0]
191
+ dilation = self.conv.conv.dilation[0]
192
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
193
+ padding_total = kernel_size - stride
194
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
195
+ if self.causal:
196
+ # Left padding for causal
197
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
198
+ else:
199
+ # Asymmetric padding required for odd strides
200
+ padding_right = padding_total // 2
201
+ padding_left = padding_total - padding_right
202
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
203
+ return self.conv(x)
204
+
205
+
206
+ class StreamableConvTranspose1d(nn.Module):
207
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
208
+ and normalization.
209
+ """
210
+ def __init__(self, in_channels: int, out_channels: int,
211
+ kernel_size: int, stride: int = 1, causal: bool = False,
212
+ norm: str = 'none', trim_right_ratio: float = 1.,
213
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
214
+ super().__init__()
215
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
216
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
217
+ self.causal = causal
218
+ self.trim_right_ratio = trim_right_ratio
219
+ assert self.causal or self.trim_right_ratio == 1., \
220
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
221
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
222
+
223
+ def forward(self, x):
224
+ kernel_size = self.convtr.convtr.kernel_size[0]
225
+ stride = self.convtr.convtr.stride[0]
226
+ padding_total = kernel_size - stride
227
+
228
+ y = self.convtr(x)
229
+
230
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
231
+ # removed at the very end, when keeping only the right length for the output,
232
+ # as removing it here would require also passing the length at the matching layer
233
+ # in the encoder.
234
+ if self.causal:
235
+ # Trim the padding on the right according to the specified ratio
236
+ # if trim_right_ratio = 1.0, trim everything from right
237
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
238
+ padding_left = padding_total - padding_right
239
+ y = unpad1d(y, (padding_left, padding_right))
240
+ else:
241
+ # Asymmetric padding required for odd strides
242
+ padding_right = padding_total // 2
243
+ padding_left = padding_total - padding_right
244
+ y = unpad1d(y, (padding_left, padding_right))
245
+ return y
audiocraft/modules/lstm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+
9
+
10
+ class StreamableLSTM(nn.Module):
11
+ """LSTM without worrying about the hidden state, nor the layout of the data.
12
+ Expects input as convolutional layout.
13
+ """
14
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
15
+ super().__init__()
16
+ self.skip = skip
17
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
18
+
19
+ def forward(self, x):
20
+ x = x.permute(2, 0, 1)
21
+ y, _ = self.lstm(x)
22
+ if self.skip:
23
+ y = y + x
24
+ y = y.permute(1, 2, 0)
25
+ return y
audiocraft/modules/rope.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ from torch import nn
10
+ import torch
11
+
12
+
13
+ class XPos(nn.Module):
14
+ """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
15
+ This applies an exponential decay to the RoPE rotation matrix.
16
+
17
+ Args:
18
+ dim (int): Embedding dimension.
19
+ smoothing (float): Smoothing factor applied to the decay rates.
20
+ base_scale (int): Base decay rate, given in terms of scaling time.
21
+ device (torch.device or None): Device on which to initialize the module.
22
+ dtype (torch.dtype): dtype to use to generate the embedding.
23
+ """
24
+ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
25
+ device=None, dtype: torch.dtype = torch.float32):
26
+ super().__init__()
27
+ assert dim % 2 == 0
28
+ assert dtype in [torch.float64, torch.float32]
29
+ self.dtype = dtype
30
+ self.base_scale = base_scale
31
+
32
+ half_dim = dim // 2
33
+ adim = torch.arange(half_dim, device=device, dtype=dtype)
34
+ decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
35
+ self.register_buffer("decay_rates", decay_rates)
36
+ self.decay: tp.Optional[torch.Tensor] = None
37
+
38
+ def get_decay(self, start: int, end: int):
39
+ """Create complex decay tensor, cache values for fast computation.
40
+ """
41
+ if self.decay is None or end > self.decay.shape[0]:
42
+ assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
43
+ idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
44
+ power = idx / self.base_scale
45
+ scale = self.decay_rates ** power.unsqueeze(-1)
46
+ self.decay = torch.polar(scale, torch.zeros_like(scale))
47
+ return self.decay[start:end] # [T, C/2]
48
+
49
+
50
+ class RotaryEmbedding(nn.Module):
51
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
52
+
53
+ Args:
54
+ dim (int): Embedding dimension (twice the number of frequencies).
55
+ max_period (float): Maximum period of the rotation frequencies.
56
+ xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
57
+ scale (float): Scale of positional embedding, set to 0 to deactivate.
58
+ device (torch.device or None): Device on which to initialize the module.
59
+ dtype (torch.dtype): dtype to use to generate the embedding.
60
+ """
61
+ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
62
+ scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
63
+ super().__init__()
64
+ assert dim % 2 == 0
65
+ self.scale = scale
66
+ assert dtype in [torch.float64, torch.float32]
67
+ self.dtype = dtype
68
+
69
+ adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
70
+ frequencies = 1.0 / (max_period ** (adim / dim))
71
+ self.register_buffer("frequencies", frequencies)
72
+ self.rotation: tp.Optional[torch.Tensor] = None
73
+
74
+ self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
75
+
76
+ def get_rotation(self, start: int, end: int):
77
+ """Create complex rotation tensor, cache values for fast computation.
78
+ """
79
+ if self.rotation is None or end > self.rotation.shape[0]:
80
+ assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
81
+ idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
82
+ angles = torch.outer(idx, self.frequencies)
83
+ self.rotation = torch.polar(torch.ones_like(angles), angles)
84
+ return self.rotation[start:end]
85
+
86
+ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
87
+ """Apply rope rotation to query or key tensor.
88
+ """
89
+ T = x.shape[1]
90
+ rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
91
+
92
+ if self.xpos:
93
+ decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
94
+ else:
95
+ decay = 1.0
96
+
97
+ if invert_decay:
98
+ decay = decay ** -1
99
+
100
+ x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
101
+ scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
102
+ x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
103
+
104
+ return x_out.type_as(x)
105
+
106
+ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
107
+ """ Apply rope rotation to both query and key tensors.
108
+ Supports streaming mode, in which query and key are not expected to have the same shape.
109
+ In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
110
+ query will be [C] (typically C == 1).
111
+
112
+ Args:
113
+ query (torch.Tensor): Query to rotate.
114
+ key (torch.Tensor): Key to rotate.
115
+ start (int): Start index of the sequence for time offset.
116
+ """
117
+ query_timesteps = query.shape[1]
118
+ key_timesteps = key.shape[1]
119
+ streaming_offset = key_timesteps - query_timesteps
120
+
121
+ query_out = self.rotate(query, start + streaming_offset)
122
+ key_out = self.rotate(key, start, invert_decay=True)
123
+
124
+ return query_out, key_out
audiocraft/modules/seanet.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch.nn as nn
11
+
12
+ from .conv import StreamableConv1d, StreamableConvTranspose1d
13
+ from .lstm import StreamableLSTM
14
+
15
+
16
+ class SEANetResnetBlock(nn.Module):
17
+ """Residual block from SEANet model.
18
+
19
+ Args:
20
+ dim (int): Dimension of the input/output.
21
+ kernel_sizes (list): List of kernel sizes for the convolutions.
22
+ dilations (list): List of dilations for the convolutions.
23
+ activation (str): Activation function.
24
+ activation_params (dict): Parameters to provide to the activation function.
25
+ norm (str): Normalization method.
26
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
27
+ causal (bool): Whether to use fully causal convolution.
28
+ pad_mode (str): Padding mode for the convolutions.
29
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
30
+ true_skip (bool): Whether to use true skip connection or a simple
31
+ (streamable) convolution as the skip connection.
32
+ """
33
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
34
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
35
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
36
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
37
+ super().__init__()
38
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
39
+ act = getattr(nn, activation)
40
+ hidden = dim // compress
41
+ block = []
42
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
43
+ in_chs = dim if i == 0 else hidden
44
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
45
+ block += [
46
+ act(**activation_params),
47
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
48
+ norm=norm, norm_kwargs=norm_params,
49
+ causal=causal, pad_mode=pad_mode),
50
+ ]
51
+ self.block = nn.Sequential(*block)
52
+ self.shortcut: nn.Module
53
+ if true_skip:
54
+ self.shortcut = nn.Identity()
55
+ else:
56
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
57
+ causal=causal, pad_mode=pad_mode)
58
+
59
+ def forward(self, x):
60
+ return self.shortcut(x) + self.block(x)
61
+
62
+
63
+ class SEANetEncoder(nn.Module):
64
+ """SEANet encoder.
65
+
66
+ Args:
67
+ channels (int): Audio channels.
68
+ dimension (int): Intermediate representation dimension.
69
+ n_filters (int): Base width for the model.
70
+ n_residual_layers (int): nb of residual layers.
71
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
72
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
73
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
74
+ activation (str): Activation function.
75
+ activation_params (dict): Parameters to provide to the activation function.
76
+ norm (str): Normalization method.
77
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
78
+ kernel_size (int): Kernel size for the initial convolution.
79
+ last_kernel_size (int): Kernel size for the initial convolution.
80
+ residual_kernel_size (int): Kernel size for the residual layers.
81
+ dilation_base (int): How much to increase the dilation with each layer.
82
+ causal (bool): Whether to use fully causal convolution.
83
+ pad_mode (str): Padding mode for the convolutions.
84
+ true_skip (bool): Whether to use true skip connection or a simple
85
+ (streamable) convolution as the skip connection in the residual network blocks.
86
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
87
+ lstm (int): Number of LSTM layers at the end of the encoder.
88
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
89
+ For the encoder, it corresponds to the N first blocks.
90
+ """
91
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
92
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
93
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
94
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
95
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
96
+ disable_norm_outer_blocks: int = 0):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.dimension = dimension
100
+ self.n_filters = n_filters
101
+ self.ratios = list(reversed(ratios))
102
+ del ratios
103
+ self.n_residual_layers = n_residual_layers
104
+ self.hop_length = np.prod(self.ratios)
105
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
106
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
107
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
108
+ "Number of blocks for which to disable norm is invalid." \
109
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
110
+
111
+ act = getattr(nn, activation)
112
+ mult = 1
113
+ model: tp.List[nn.Module] = [
114
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
115
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
116
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
117
+ ]
118
+ # Downsample to raw audio scale
119
+ for i, ratio in enumerate(self.ratios):
120
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
121
+ # Add residual layers
122
+ for j in range(n_residual_layers):
123
+ model += [
124
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
125
+ dilations=[dilation_base ** j, 1],
126
+ norm=block_norm, norm_params=norm_params,
127
+ activation=activation, activation_params=activation_params,
128
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
129
+
130
+ # Add downsampling layers
131
+ model += [
132
+ act(**activation_params),
133
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
134
+ kernel_size=ratio * 2, stride=ratio,
135
+ norm=block_norm, norm_kwargs=norm_params,
136
+ causal=causal, pad_mode=pad_mode),
137
+ ]
138
+ mult *= 2
139
+
140
+ if lstm:
141
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
142
+
143
+ model += [
144
+ act(**activation_params),
145
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
146
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
147
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
148
+ ]
149
+
150
+ self.model = nn.Sequential(*model)
151
+
152
+ def forward(self, x):
153
+ return self.model(x)
154
+
155
+
156
+ class SEANetDecoder(nn.Module):
157
+ """SEANet decoder.
158
+
159
+ Args:
160
+ channels (int): Audio channels.
161
+ dimension (int): Intermediate representation dimension.
162
+ n_filters (int): Base width for the model.
163
+ n_residual_layers (int): nb of residual layers.
164
+ ratios (Sequence[int]): kernel size and stride ratios.
165
+ activation (str): Activation function.
166
+ activation_params (dict): Parameters to provide to the activation function.
167
+ final_activation (str): Final activation function after all convolutions.
168
+ final_activation_params (dict): Parameters to provide to the activation function.
169
+ norm (str): Normalization method.
170
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
171
+ kernel_size (int): Kernel size for the initial convolution.
172
+ last_kernel_size (int): Kernel size for the initial convolution.
173
+ residual_kernel_size (int): Kernel size for the residual layers.
174
+ dilation_base (int): How much to increase the dilation with each layer.
175
+ causal (bool): Whether to use fully causal convolution.
176
+ pad_mode (str): Padding mode for the convolutions.
177
+ true_skip (bool): Whether to use true skip connection or a simple.
178
+ (streamable) convolution as the skip connection in the residual network blocks.
179
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
180
+ lstm (int): Number of LSTM layers at the end of the encoder.
181
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
182
+ For the decoder, it corresponds to the N last blocks.
183
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
184
+ If equal to 1.0, it means that all the trimming is done at the right.
185
+ """
186
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
187
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
188
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
189
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
190
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
191
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
192
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
193
+ super().__init__()
194
+ self.dimension = dimension
195
+ self.channels = channels
196
+ self.n_filters = n_filters
197
+ self.ratios = ratios
198
+ del ratios
199
+ self.n_residual_layers = n_residual_layers
200
+ self.hop_length = np.prod(self.ratios)
201
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
202
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
203
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
204
+ "Number of blocks for which to disable norm is invalid." \
205
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
206
+
207
+ act = getattr(nn, activation)
208
+ mult = int(2 ** len(self.ratios))
209
+ model: tp.List[nn.Module] = [
210
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
211
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
212
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
213
+ ]
214
+
215
+ if lstm:
216
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
217
+
218
+ # Upsample to raw audio scale
219
+ for i, ratio in enumerate(self.ratios):
220
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
221
+ # Add upsampling layers
222
+ model += [
223
+ act(**activation_params),
224
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
225
+ kernel_size=ratio * 2, stride=ratio,
226
+ norm=block_norm, norm_kwargs=norm_params,
227
+ causal=causal, trim_right_ratio=trim_right_ratio),
228
+ ]
229
+ # Add residual layers
230
+ for j in range(n_residual_layers):
231
+ model += [
232
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
233
+ dilations=[dilation_base ** j, 1],
234
+ activation=activation, activation_params=activation_params,
235
+ norm=block_norm, norm_params=norm_params, causal=causal,
236
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
237
+
238
+ mult //= 2
239
+
240
+ # Add final layers
241
+ model += [
242
+ act(**activation_params),
243
+ StreamableConv1d(n_filters, channels, last_kernel_size,
244
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
245
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
246
+ ]
247
+ # Add optional final activation to decoder (eg. tanh)
248
+ if final_activation is not None:
249
+ final_act = getattr(nn, final_activation)
250
+ final_activation_params = final_activation_params or {}
251
+ model += [
252
+ final_act(**final_activation_params)
253
+ ]
254
+ self.model = nn.Sequential(*model)
255
+
256
+ def forward(self, z):
257
+ y = self.model(z)
258
+ return y
audiocraft/modules/streaming.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Streaming module API that should be implemented by all Streaming components,
9
+ """
10
+
11
+ from contextlib import contextmanager
12
+ import typing as tp
13
+ from torch import nn
14
+ import torch
15
+
16
+
17
+ State = tp.Dict[str, torch.Tensor]
18
+
19
+
20
+ class StreamingModule(nn.Module):
21
+ """Common API for streaming components.
22
+
23
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
24
+ By convention, the first dim of each tensor must be the batch size.
25
+ Don't use dots in the key names, as this would clash with submodules
26
+ (like in state_dict).
27
+
28
+ If `self._is_streaming` is True, the component should use and remember
29
+ the proper state inside `self._streaming_state`.
30
+
31
+ To set a streaming component in streaming state, use
32
+
33
+ with module.streaming():
34
+ ...
35
+
36
+ This will automatically reset the streaming state when exiting the context manager.
37
+ This also automatically propagates to all streaming children module.
38
+
39
+ Some module might also implement the `StreamingModule.flush` method, although
40
+ this one is trickier, as all parents module must be StreamingModule and implement
41
+ it as well for it to work properly. See `StreamingSequential` after.
42
+ """
43
+ def __init__(self) -> None:
44
+ super().__init__()
45
+ self._streaming_state: State = {}
46
+ self._is_streaming = False
47
+
48
+ def _apply_named_streaming(self, fn: tp.Any):
49
+ for name, module in self.named_modules():
50
+ if isinstance(module, StreamingModule):
51
+ fn(name, module)
52
+
53
+ def _set_streaming(self, streaming: bool):
54
+ def _set_streaming(name, module):
55
+ module._is_streaming = streaming
56
+ self._apply_named_streaming(_set_streaming)
57
+
58
+ @contextmanager
59
+ def streaming(self):
60
+ """Context manager to enter streaming mode. Reset streaming state on exit.
61
+ """
62
+ self._set_streaming(True)
63
+ try:
64
+ yield
65
+ finally:
66
+ self._set_streaming(False)
67
+ self.reset_streaming()
68
+
69
+ def reset_streaming(self):
70
+ """Reset the streaming state.
71
+ """
72
+ def _reset(name: str, module: StreamingModule):
73
+ module._streaming_state.clear()
74
+
75
+ self._apply_named_streaming(_reset)
76
+
77
+ def get_streaming_state(self) -> State:
78
+ """Return the streaming state, including that of sub-modules.
79
+ """
80
+ state: State = {}
81
+
82
+ def _add(name: str, module: StreamingModule):
83
+ if name:
84
+ name += "."
85
+ for key, value in module._streaming_state.items():
86
+ state[name + key] = value
87
+
88
+ self._apply_named_streaming(_add)
89
+ return state
90
+
91
+ def set_streaming_state(self, state: State):
92
+ """Set the streaming state, including that of sub-modules.
93
+ """
94
+ state = dict(state)
95
+
96
+ def _set(name: str, module: StreamingModule):
97
+ if name:
98
+ name += "."
99
+ module._streaming_state.clear()
100
+ for key, value in list(state.items()):
101
+ # complexity is not ideal here, but probably fine.
102
+ if key.startswith(name):
103
+ local_key = key[len(name):]
104
+ if '.' not in local_key:
105
+ module._streaming_state[local_key] = value
106
+ del state[key]
107
+
108
+ self._apply_named_streaming(_set)
109
+ assert len(state) == 0, list(state.keys())
110
+
111
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
112
+ """Flush any remaining outputs that were waiting for completion.
113
+ Typically, for convolutions, this will add the final padding
114
+ and process the last buffer.
115
+
116
+ This should take an optional argument `x`, which will be provided
117
+ if a module before this one in the streaming pipeline has already
118
+ spitted out a flushed out buffer.
119
+ """
120
+ if x is None:
121
+ return None
122
+ else:
123
+ return self(x)
124
+
125
+
126
+ class StreamingSequential(StreamingModule, nn.Sequential):
127
+ """A streaming compatible alternative of `nn.Sequential`.
128
+ """
129
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
130
+ for module in self:
131
+ if isinstance(module, StreamingModule):
132
+ x = module.flush(x)
133
+ elif x is not None:
134
+ x = module(x)
135
+ return x
audiocraft/modules/transformer.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Transformer model, with streaming support, xformer attention support
9
+ and easy causal attention with a potentially finite receptive field.
10
+
11
+ See `StreamingTransformer` for more information.
12
+
13
+ Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
14
+ """
15
+
16
+ import typing as tp
17
+
18
+ from einops import rearrange
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
23
+ from xformers import ops
24
+
25
+ from .rope import RotaryEmbedding
26
+ from .streaming import StreamingModule
27
+
28
+
29
+ def _is_profiled() -> bool:
30
+ # Return true if we are currently running with a xformers profiler activated.
31
+ try:
32
+ from xformers.profiler import profiler
33
+ except ImportError:
34
+ return False
35
+ return profiler._Profiler._CURRENT_PROFILER is not None
36
+
37
+
38
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
39
+ """Create normalization module for transformer encoder layer.
40
+
41
+ Args:
42
+ norm_type (str): Normalization method.
43
+ dim (int): Dimension of the normalized layer.
44
+ **kwargs (dict): Additional parameters for normalization layer.
45
+ Returns:
46
+ nn.Module: Normalization module.
47
+ """
48
+ if norm_type == 'layer_norm':
49
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
50
+ else:
51
+ raise ValueError(f"Unknown norm type: {norm_type}")
52
+
53
+
54
+ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
55
+ dtype: torch.dtype = torch.float32) -> torch.Tensor:
56
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
57
+
58
+ Args:
59
+ positions (torch.Tensor): LongTensor of positions.
60
+ dim (int): Dimension of the embedding.
61
+ max_period (float): Maximum period of the cosine/sine functions.
62
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
63
+ Returns:
64
+ torch.Tensor: Sinusoidal positional embedding.
65
+ """
66
+ # We aim for BTC format
67
+ assert dim % 2 == 0
68
+ half_dim = dim // 2
69
+ positions = positions.to(dtype)
70
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
71
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
72
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
73
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
74
+
75
+
76
+ def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
77
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
78
+ bs, slen, n_kv_heads, head_dim = x.shape
79
+ if n_rep == 1:
80
+ return x
81
+ return (
82
+ x[:, :, :, None, :]
83
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
84
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
85
+ )
86
+
87
+
88
+ class LayerScale(nn.Module):
89
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
90
+ This rescales diagonaly the residual outputs close to 0, with a learnt scale.
91
+
92
+ Args:
93
+ channels (int): Number of channels.
94
+ init (float): Initial scale.
95
+ channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
96
+ device (torch.device or None): Device on which to initialize the module.
97
+ dtype (torch.dtype or None): dtype to use to initialize the module.
98
+ """
99
+ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
100
+ device=None, dtype=None):
101
+ super().__init__()
102
+ self.channel_last = channel_last
103
+ self.scale = nn.Parameter(
104
+ torch.full((channels,), init,
105
+ requires_grad=True, device=device, dtype=dtype))
106
+
107
+ def forward(self, x: torch.Tensor):
108
+ if self.channel_last:
109
+ return self.scale * x
110
+ else:
111
+ return self.scale[:, None] * x
112
+
113
+
114
+ class StreamingMultiheadAttention(StreamingModule):
115
+ """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
116
+
117
+ Args:
118
+ embed_dim (int): Dimension to project to.
119
+ num_heads (int): Number of heads.
120
+ dropout (float): Dropout level.
121
+ bias (bool): Use bias in projections.
122
+ causal (bool): Causal mask applied automatically.
123
+ past_context (int or None): Receptive field for the causal mask, infinite if None.
124
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
125
+ memory_efficient (bool): Use xformers based memory efficient attention.
126
+ attention_as_float32 (bool): Perform the attention as float32
127
+ (especially important with memory_efficient as autocast won't do this automatically).
128
+ rope (`RotaryEmbedding` or None): Rope embedding to use.
129
+ cross_attention: Should be true when used as a cross attention.
130
+ All keys and values must be available at once, streaming is only for the queries.
131
+ Cannot be used with `causal` or `rope` (as it wouldn't make sens to
132
+ intepret the time steps in the keys relative to those in the queries).
133
+ safe_streaming (bool): Bug fix, will go away with xformers update.
134
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
135
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
136
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
137
+ device (torch.device or None): Sevice on which to initialize.
138
+ dtype (torch.dtype or None): dtype to use.
139
+ """
140
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
141
+ causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
142
+ memory_efficient: bool = False, attention_as_float32: bool = False,
143
+ rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
144
+ safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
145
+ device=None, dtype=None):
146
+ super().__init__()
147
+ factory_kwargs = {'device': device, 'dtype': dtype}
148
+ if past_context is not None:
149
+ assert causal
150
+
151
+ self.embed_dim = embed_dim
152
+ self.causal = causal
153
+ self.past_context = past_context
154
+ self.memory_efficient = memory_efficient
155
+ self.attention_as_float32 = attention_as_float32
156
+ self.rope = rope
157
+ self.cross_attention = cross_attention
158
+ self.safe_streaming = safe_streaming
159
+ self.num_heads = num_heads
160
+ self.dropout = dropout
161
+ self.kv_repeat = kv_repeat
162
+ if cross_attention:
163
+ assert not causal, "Causal cannot work with cross attention."
164
+ assert rope is None, "Rope cannot work with cross attention."
165
+
166
+ if memory_efficient:
167
+ _verify_xformers_memory_efficient_compat()
168
+
169
+ self.custom = _is_custom(custom, memory_efficient)
170
+ if self.custom:
171
+ out_dim = embed_dim
172
+ assert num_heads % kv_repeat == 0
173
+ assert not cross_attention or kv_repeat == 1
174
+ num_kv = num_heads // kv_repeat
175
+ kv_dim = (embed_dim // num_heads) * num_kv
176
+ out_dim += 2 * kv_dim
177
+ in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
178
+ # We try to follow the default PyTorch MHA convention, to easily compare results.
179
+ self.in_proj_weight = in_proj.weight
180
+ self.in_proj_bias = in_proj.bias
181
+ if bias:
182
+ self.in_proj_bias.data.zero_() # Following Pytorch convention
183
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
184
+ if bias:
185
+ self.out_proj.bias.data.zero_()
186
+ else:
187
+ assert not qk_layer_norm
188
+ assert kv_repeat == 1
189
+ self.mha = nn.MultiheadAttention(
190
+ embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
191
+ **factory_kwargs)
192
+ self.qk_layer_norm = qk_layer_norm
193
+ if qk_layer_norm:
194
+ assert self.custom
195
+ assert kv_repeat == 1
196
+ ln_dim = embed_dim
197
+ self.q_layer_norm = nn.LayerNorm(ln_dim)
198
+ self.k_layer_norm = nn.LayerNorm(ln_dim)
199
+
200
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
201
+ if not self.custom:
202
+ # Support compat with regular MHA
203
+ keys = [n for n, _ in self.mha.named_parameters()]
204
+ for key in keys:
205
+ if prefix + key in state_dict:
206
+ state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
207
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
208
+
209
+ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
210
+ # Return a causal mask, accounting for potentially stored past keys/values
211
+ # We actually return a bias for the attention score, as this has the same
212
+ # convention both in the builtin MHA in Pytorch, and Xformers functions.
213
+ if self.memory_efficient:
214
+ from xformers.ops import LowerTriangularMask
215
+ if current_steps == 1:
216
+ # If we only have one step, then we do not need a mask.
217
+ return None
218
+ elif 'past_keys' in self._streaming_state:
219
+ raise RuntimeError('Not supported at the moment')
220
+ else:
221
+ # Then we can safely use a lower triangular mask
222
+ return LowerTriangularMask()
223
+ if self._streaming_state:
224
+ past_keys = self._streaming_state['past_keys']
225
+ past_steps = past_keys.shape[1]
226
+ else:
227
+ past_steps = 0
228
+
229
+ queries_pos = torch.arange(
230
+ past_steps, current_steps + past_steps, device=device).view(-1, 1)
231
+ keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
232
+ delta = queries_pos - keys_pos
233
+ valid = delta >= 0
234
+ if self.past_context is not None:
235
+ valid &= (delta <= self.past_context)
236
+ return torch.where(
237
+ valid,
238
+ torch.zeros([], device=device, dtype=dtype),
239
+ torch.full([], float('-inf'), device=device, dtype=dtype))
240
+
241
+ def _complete_kv(self, k, v):
242
+ if self.cross_attention:
243
+ # With cross attention we assume all keys and values
244
+ # are already available, and streaming is with respect
245
+ # to the queries only.
246
+ return k, v
247
+ # Complete the key/value pair using the streaming state.
248
+ if self._streaming_state:
249
+ pk = self._streaming_state['past_keys']
250
+ nk = torch.cat([pk, k], dim=1)
251
+ if v is k:
252
+ nv = nk
253
+ else:
254
+ pv = self._streaming_state['past_values']
255
+ nv = torch.cat([pv, v], dim=1)
256
+ else:
257
+ nk = k
258
+ nv = v
259
+
260
+ assert nk.shape[1] == nv.shape[1]
261
+ offset = 0
262
+ if self.past_context is not None:
263
+ offset = max(0, nk.shape[1] - self.past_context)
264
+ if self._is_streaming:
265
+ self._streaming_state['past_keys'] = nk[:, offset:]
266
+ if v is not k:
267
+ self._streaming_state['past_values'] = nv[:, offset:]
268
+ if 'offset' in self._streaming_state:
269
+ self._streaming_state['offset'] += offset
270
+ else:
271
+ self._streaming_state['offset'] = torch.tensor(0)
272
+ return nk, nv
273
+
274
+ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
275
+ # Apply rope embeddings to query and key tensors.
276
+ assert self.rope is not None
277
+ if 'past_keys' in self._streaming_state:
278
+ past_keys_offset = self._streaming_state['past_keys'].shape[1]
279
+ else:
280
+ past_keys_offset = 0
281
+ if 'offset' in self._streaming_state:
282
+ past_context_offset = int(self._streaming_state['offset'].item())
283
+ else:
284
+ past_context_offset = 0
285
+ streaming_offset = past_context_offset + past_keys_offset
286
+ return self.rope.rotate_qk(query, key, start=streaming_offset)
287
+
288
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
289
+ key_padding_mask=None, need_weights=False, attn_mask=None,
290
+ average_attn_weights=True, is_causal=False):
291
+ assert attn_mask is None
292
+ assert not is_causal, ("new param added in torch 2.0.1 not supported, "
293
+ "use the causal args in the constructor.")
294
+
295
+ dtype = query.dtype
296
+ if self._is_streaming:
297
+ assert self.causal or self.cross_attention, \
298
+ "Streaming only available for causal or cross attention"
299
+
300
+ if self.causal:
301
+ # At the moment we specialize only for the self-attention case.
302
+ assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
303
+ assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
304
+ attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
305
+
306
+ if self.custom:
307
+ # custom implementation
308
+ assert need_weights is False
309
+ assert key_padding_mask is None
310
+ if self.cross_attention:
311
+ # Different queries, keys, values, we have to spit manually the weights
312
+ # before applying the linear.
313
+ dim = self.in_proj_weight.shape[0] // 3
314
+ if self.in_proj_bias is None:
315
+ bias_q, bias_k, bias_v = None, None, None
316
+ else:
317
+ bias_q = self.in_proj_bias[:dim]
318
+ bias_k = self.in_proj_bias[dim: 2 * dim]
319
+ bias_v = self.in_proj_bias[2 * dim:]
320
+ q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
321
+ # todo: when streaming, we could actually save k, v and check the shape actually match.
322
+ k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
323
+ v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
324
+ if self.qk_layer_norm is True:
325
+ q = self.q_layer_norm(q)
326
+ k = self.k_layer_norm(k)
327
+ # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
328
+ q, k, v = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v]]
329
+ else:
330
+ if not _is_profiled():
331
+ # profiling breaks that propertysomehow.
332
+ assert query is key, "specialized implementation"
333
+ assert value is key, "specialized implementation"
334
+ projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
335
+ if self.kv_repeat == 1:
336
+ packed = rearrange(projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads)
337
+ q, k, v = ops.unbind(packed, dim=2)
338
+ else:
339
+ embed_dim = self.embed_dim
340
+ per_head_dim = (embed_dim // self.num_heads)
341
+ kv_heads = self.num_heads // self.kv_repeat
342
+ q = projected[:, :, :embed_dim]
343
+ start = embed_dim
344
+ end = start + per_head_dim * kv_heads
345
+ k = projected[:, :, start: end]
346
+ v = projected[:, :, end:]
347
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.num_heads)
348
+ k = rearrange(k, "b t (h d) -> b t h d", h=kv_heads)
349
+ v = rearrange(v, "b t (h d) -> b t h d", h=kv_heads)
350
+
351
+ if self.qk_layer_norm is True:
352
+ assert self.kv_repeat == 1
353
+ q, k = [rearrange(x, "b t h d -> b t (h d)") for x in [q, k]]
354
+ q = self.q_layer_norm(q)
355
+ k = self.k_layer_norm(k)
356
+ q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
357
+ if self.rope:
358
+ q, k = self._apply_rope(q, k)
359
+ k, v = self._complete_kv(k, v)
360
+ if self.kv_repeat > 1:
361
+ k = expand_repeated_kv(k, self.kv_repeat)
362
+ v = expand_repeated_kv(v, self.kv_repeat)
363
+ if self.attention_as_float32:
364
+ q, k, v = [x.float() for x in [q, k, v]]
365
+ if self.memory_efficient:
366
+ p = self.dropout if self.training else 0
367
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
368
+ else:
369
+ # We include the dot product as float32, for consistency
370
+ # with the other implementations that include that step
371
+ # as part of the attention. Note that when using `autocast`,
372
+ # the einsums would be done as bfloat16, but the softmax
373
+ # would be done as bfloat16, so `attention_as_float32` will
374
+ # extend a bit the range of operations done in float32,
375
+ # although this should make no difference.
376
+ q = q / q.shape[-1] ** 0.5
377
+ if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
378
+ with torch.autocast(device_type=q.device.type, dtype=torch.float32):
379
+ pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
380
+ else:
381
+ pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
382
+ if attn_mask is not None:
383
+ pre_w = pre_w + attn_mask
384
+ w = torch.softmax(pre_w, dim=-1)
385
+ w = F.dropout(w, self.dropout, training=self.training).to(v)
386
+ x = torch.einsum("bhqk,bkhc->bqhc", w, v)
387
+ x = x.to(dtype)
388
+ x = rearrange(x, "b t h d -> b t (h d)", h=self.num_heads)
389
+ x = self.out_proj(x)
390
+ else:
391
+ key, value = self._complete_kv(key, value)
392
+ if self.attention_as_float32:
393
+ query, key, value = [x.float() for x in [query, key, value]]
394
+ x, _ = self.mha(
395
+ query, key, value, key_padding_mask,
396
+ need_weights, attn_mask, average_attn_weights)
397
+ x = x.to(dtype)
398
+
399
+ return x, None
400
+
401
+
402
+ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
403
+ """TransformerLayer with Streaming / Causal support.
404
+ This also integrates cross_attention, when passing `cross_attention=True`,
405
+ rather than having two separate classes like in PyTorch.
406
+
407
+ Args:
408
+ d_model (int): Dimension of the data.
409
+ num_heads (int): Number of heads.
410
+ dim_feedforward (int): Intermediate dimension of FF module.
411
+ dropout (float): Dropout both for MHA and FF.
412
+ bias_ff (bool): Use bias for FF.
413
+ bias_attn (bool): Use bias for MHA.
414
+ causal (bool): Causal mask applied automatically.
415
+ past_context (int or None): Receptive field for the causal mask, infinite if None.
416
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
417
+ memory_efficient (bool): Use xformers based memory efficient attention.
418
+ attention_as_float32 (bool): Perform the attention as float32
419
+ (especially important with memory_efficient as autocast won't do this automatically).
420
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
421
+ qk_layer_norm_cross (bool): Same for the cross attention.
422
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
423
+ Cross attention will use the default MHA, as it typically won't require
424
+ special treatment.
425
+ layer_scale (float or None): If not None, LayerScale will be used with
426
+ the given value as initial scale.
427
+ rope (`RotaryEmbedding` or None): Rope embedding to use.
428
+ attention_dropout (float or None): If not None, separate the value of the dimension dropout
429
+ in FFN and of the attention dropout.
430
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
431
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
432
+ device (torch.device or None): Device on which to initialize.
433
+ dtype (torch.dtype or None): dtype to use.
434
+ **kwargs: See `nn.TransformerEncoderLayer`.
435
+ """
436
+ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
437
+ bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
438
+ past_context: tp.Optional[int] = None, custom: bool = False,
439
+ memory_efficient: bool = False, attention_as_float32: bool = False,
440
+ qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
441
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
442
+ rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
443
+ kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
444
+ super().__init__(d_model, num_heads, dim_feedforward, dropout,
445
+ device=device, dtype=dtype, batch_first=True, **kwargs)
446
+ factory_kwargs = {'device': device, 'dtype': dtype}
447
+ # Redefine self_attn to our streaming multi-head attention
448
+ attn_kwargs: tp.Dict[str, tp.Any] = {
449
+ 'embed_dim': d_model,
450
+ 'num_heads': num_heads,
451
+ 'dropout': dropout if attention_dropout is None else attention_dropout,
452
+ 'bias': bias_attn,
453
+ 'custom': custom,
454
+ 'memory_efficient': memory_efficient,
455
+ 'attention_as_float32': attention_as_float32,
456
+ }
457
+ self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
458
+ causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
459
+ kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
460
+ # Redefine feedforward layers to expose bias parameter
461
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
462
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
463
+
464
+ self.layer_scale_1: nn.Module
465
+ self.layer_scale_2: nn.Module
466
+ if layer_scale is None:
467
+ self.layer_scale_1 = nn.Identity()
468
+ self.layer_scale_2 = nn.Identity()
469
+ else:
470
+ self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
471
+ self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
472
+
473
+ self.cross_attention: tp.Optional[nn.Module] = None
474
+ if cross_attention:
475
+ self.cross_attention = StreamingMultiheadAttention(
476
+ cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
477
+ **attn_kwargs, **factory_kwargs)
478
+ # Norm and dropout
479
+ self.dropout_cross = nn.Dropout(dropout)
480
+ # eps value matching that used in PyTorch reference implementation.
481
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
482
+ self.layer_scale_cross: nn.Module
483
+ if layer_scale is None:
484
+ self.layer_scale_cross = nn.Identity()
485
+ else:
486
+ self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
487
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
488
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
489
+
490
+ def _cross_attention_block(self, src: torch.Tensor,
491
+ cross_attention_src: torch.Tensor) -> torch.Tensor:
492
+ assert self.cross_attention is not None
493
+ # queries are from src, keys and values from cross_attention_src.
494
+ x = self.cross_attention(
495
+ src, cross_attention_src, cross_attention_src, need_weights=False)[0]
496
+ return self.dropout_cross(x) # type: ignore
497
+
498
+ def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
499
+ src_key_padding_mask: tp.Optional[torch.Tensor] = None,
500
+ cross_attention_src: tp.Optional[torch.Tensor] = None):
501
+ if self.cross_attention is None:
502
+ assert cross_attention_src is None
503
+ else:
504
+ assert cross_attention_src is not None
505
+ x = src
506
+ if self.norm_first:
507
+ x = x + self.layer_scale_1(
508
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
509
+ if cross_attention_src is not None:
510
+ x = x + self.layer_scale_cross(
511
+ self._cross_attention_block(
512
+ self.norm_cross(x), cross_attention_src))
513
+ x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
514
+ else:
515
+ x = self.norm1(x + self.layer_scale_1(
516
+ self._sa_block(x, src_mask, src_key_padding_mask)))
517
+ if cross_attention_src is not None:
518
+ x = self.norm_cross(
519
+ x + self.layer_scale_cross(
520
+ self._cross_attention_block(src, cross_attention_src)))
521
+ x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
522
+ return x
523
+
524
+
525
+ class StreamingTransformer(StreamingModule):
526
+ """Transformer with Streaming / Causal support.
527
+
528
+ Args:
529
+ d_model (int): Dimension of the data.
530
+ num_heads (int): Number of heads.
531
+ dim_feedforward (int): Intermediate dimension of FF module.
532
+ dropout (float): Dropout both for MHA and FF.
533
+ bias_ff (bool): Use bias for FF.
534
+ bias_attn (bool): Use bias for MHA.
535
+ causal (bool): Causal mask applied automatically.
536
+ past_context (int or None): Receptive field for the causal mask, infinite if None.
537
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
538
+ memory_efficient (bool): Use xformers based memory efficient attention.
539
+ attention_as_float32 (bool): Perform the attention as float32
540
+ (especially important with memory_efficient as autocast won't do this automatically).
541
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
542
+ layer_scale (float or None): If not None, LayerScale will be used
543
+ with the given value as initial scale.
544
+ positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
545
+ max_period (float): Maximum period of the time embedding.
546
+ positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
547
+ xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
548
+ lr (float or None): learning rate override through the `make_optim_group` API.
549
+ weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
550
+ layer_class: (subclass of `StreamingTransformerLayer): class to use
551
+ to initialize the layers, allowing further customization outside of Audiocraft.
552
+ checkpointing (str): Checkpointing strategy to reduce memory usage.
553
+ No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
554
+ if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
555
+ minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
556
+ a policy for opting-out some operations of the checkpointing like
557
+ linear layers and attention, providing a middle ground between speed and memory.
558
+ device (torch.device or None): Device on which to initialize.
559
+ dtype (torch.dtype or None): dtype to use.
560
+ **kwargs: See `nn.TransformerEncoderLayer`.
561
+ """
562
+ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
563
+ dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
564
+ causal: bool = False, past_context: tp.Optional[int] = None,
565
+ custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
566
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
567
+ positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
568
+ xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
569
+ layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
570
+ checkpointing: str = 'none', device=None, dtype=None, **kwargs):
571
+ super().__init__()
572
+ assert d_model % num_heads == 0
573
+
574
+ self.positional_embedding = positional_embedding
575
+ self.max_period = max_period
576
+ self.positional_scale = positional_scale
577
+ self.weight_decay = weight_decay
578
+ self.lr = lr
579
+
580
+ assert positional_embedding in ['sin', 'rope', 'sin_rope']
581
+ self.rope: tp.Optional[RotaryEmbedding] = None
582
+ if self.positional_embedding in ['rope', 'sin_rope']:
583
+ assert _is_custom(custom, memory_efficient)
584
+ self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
585
+ xpos=xpos, scale=positional_scale, device=device)
586
+
587
+ self.checkpointing = checkpointing
588
+
589
+ assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
590
+ if self.checkpointing.startswith('xformers'):
591
+ _verify_xformers_internal_compat()
592
+
593
+ self.layers = nn.ModuleList()
594
+ for idx in range(num_layers):
595
+ self.layers.append(
596
+ layer_class(
597
+ d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
598
+ dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
599
+ causal=causal, past_context=past_context, custom=custom,
600
+ memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
601
+ cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
602
+ device=device, dtype=dtype, **kwargs))
603
+
604
+ if self.checkpointing != 'none':
605
+ for layer in self.layers:
606
+ # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
607
+ # backward hook inside of FSDP...
608
+ layer._magma_checkpointed = True # type: ignore
609
+ assert layer.layer_drop == 0., "Need further checking" # type: ignore
610
+
611
+ def _apply_layer(self, layer, *args, **kwargs):
612
+ method = self.checkpointing
613
+ if method == 'none':
614
+ return layer(*args, **kwargs)
615
+ elif method == 'torch':
616
+ return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
617
+ elif method.startswith('xformers'):
618
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
619
+ if method == 'xformers_default':
620
+ # those operations will be saved, and not recomputed.
621
+ # According to Francisco we can get smarter policies but this is a good start.
622
+ allow_list = [
623
+ "xformers.efficient_attention_forward_cutlass.default",
624
+ "xformers_flash.flash_fwd.default",
625
+ "aten.addmm.default",
626
+ "aten.mm.default",
627
+ ]
628
+ elif method == 'xformers_mm':
629
+ # those operations will be saved, and not recomputed.
630
+ # According to Francisco we can get smarter policies but this is a good start.
631
+ allow_list = [
632
+ "aten.addmm.default",
633
+ "aten.mm.default",
634
+ ]
635
+ else:
636
+ raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
637
+ policy_fn = _get_default_policy(allow_list)
638
+ return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
639
+ else:
640
+ raise ValueError(f"Checkpointing method {method} is unknown.")
641
+
642
+ def forward(self, x: torch.Tensor, *args, **kwargs):
643
+ B, T, C = x.shape
644
+
645
+ if 'offsets' in self._streaming_state:
646
+ offsets = self._streaming_state['offsets']
647
+ else:
648
+ offsets = torch.zeros(B, dtype=torch.long, device=x.device)
649
+
650
+ if self.positional_embedding in ['sin', 'sin_rope']:
651
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
652
+ positions = positions + offsets.view(-1, 1, 1)
653
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
654
+ x = x + self.positional_scale * pos_emb
655
+
656
+ for layer in self.layers:
657
+ x = self._apply_layer(layer, x, *args, **kwargs)
658
+
659
+ if self._is_streaming:
660
+ self._streaming_state['offsets'] = offsets + T
661
+
662
+ return x
663
+
664
+ def make_optim_group(self):
665
+ group = {"params": list(self.parameters())}
666
+ if self.lr is not None:
667
+ group["lr"] = self.lr
668
+ if self.weight_decay is not None:
669
+ group["weight_decay"] = self.weight_decay
670
+ return group
671
+
672
+
673
+ # special attention attention related function
674
+
675
+ def _verify_xformers_memory_efficient_compat():
676
+ try:
677
+ from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
678
+ except ImportError:
679
+ raise ImportError(
680
+ "xformers is not installed. Please install it and try again.\n"
681
+ "To install on AWS and Azure, run \n"
682
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
683
+ "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n"
684
+ "To install on FAIR Cluster, run \n"
685
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
686
+ "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n")
687
+
688
+
689
+ def _verify_xformers_internal_compat():
690
+ try:
691
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
692
+ except ImportError:
693
+ raise ImportError(
694
+ "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
695
+ "To install on AWS and Azure, run \n"
696
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
697
+ "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n"
698
+ "To install on FAIR Cluster, run \n"
699
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
700
+ "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n")
701
+
702
+
703
+ def _is_custom(custom: bool, memory_efficient: bool):
704
+ return custom or memory_efficient
audiocraft/py.typed ADDED
File without changes
audiocraft/quantization/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import ResidualVectorQuantizer
9
+ from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
audiocraft/quantization/base.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Base class for all quantizers.
9
+ """
10
+
11
+ from dataclasses import dataclass, field
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+
18
+ @dataclass
19
+ class QuantizedResult:
20
+ x: torch.Tensor
21
+ codes: torch.Tensor
22
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
23
+ penalty: tp.Optional[torch.Tensor] = None
24
+ metrics: dict = field(default_factory=dict)
25
+
26
+
27
+ class BaseQuantizer(nn.Module):
28
+ """Base class for quantizers.
29
+ """
30
+
31
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
32
+ """
33
+ Given input tensor x, returns first the quantized (or approximately quantized)
34
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
35
+ Finally, this returns a dict of metrics to update logging etc.
36
+ Frame rate must be passed so that the bandwidth is properly computed.
37
+ """
38
+ raise NotImplementedError()
39
+
40
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
41
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
42
+ """
43
+ raise NotImplementedError()
44
+
45
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
46
+ """Decode the given codes to the quantized representation.
47
+ """
48
+ raise NotImplementedError()
49
+
50
+ @property
51
+ def total_codebooks(self):
52
+ """Total number of codebooks.
53
+ """
54
+ raise NotImplementedError()
55
+
56
+ @property
57
+ def num_codebooks(self):
58
+ """Number of active codebooks.
59
+ """
60
+ raise NotImplementedError()
61
+
62
+ def set_num_codebooks(self, n: int):
63
+ """Set the number of active codebooks.
64
+ """
65
+ raise NotImplementedError()
66
+
67
+
68
+ class DummyQuantizer(BaseQuantizer):
69
+ """Fake quantizer that actually does not perform any quantization.
70
+ """
71
+ def __init__(self):
72
+ super().__init__()
73
+
74
+ def forward(self, x: torch.Tensor, frame_rate: int):
75
+ q = x.unsqueeze(1)
76
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
77
+
78
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
79
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
80
+ In the case of the DummyQuantizer, the codes are actually identical
81
+ to the input and resulting quantized representation as no quantization is done.
82
+ """
83
+ return x.unsqueeze(1)
84
+
85
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
86
+ """Decode the given codes to the quantized representation.
87
+ In the case of the DummyQuantizer, the codes are actually identical
88
+ to the input and resulting quantized representation as no quantization is done.
89
+ """
90
+ return codes.squeeze(1)
91
+
92
+ @property
93
+ def total_codebooks(self):
94
+ """Total number of codebooks.
95
+ """
96
+ return 1
97
+
98
+ @property
99
+ def num_codebooks(self):
100
+ """Total number of codebooks.
101
+ """
102
+ return self.total_codebooks
103
+
104
+ def set_num_codebooks(self, n: int):
105
+ """Set the number of active codebooks.
106
+ """
107
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
audiocraft/quantization/core_vq.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ from einops import rearrange, repeat
10
+ import flashy
11
+ import torch
12
+ from torch import nn, einsum
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def exists(val: tp.Optional[tp.Any]) -> bool:
17
+ return val is not None
18
+
19
+
20
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
21
+ return val if exists(val) else d
22
+
23
+
24
+ def l2norm(t):
25
+ return F.normalize(t, p=2, dim=-1)
26
+
27
+
28
+ def ema_inplace(moving_avg, new, decay: float):
29
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
30
+
31
+
32
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
33
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
34
+
35
+
36
+ def uniform_init(*shape: int):
37
+ t = torch.empty(shape)
38
+ nn.init.kaiming_uniform_(t)
39
+ return t
40
+
41
+
42
+ def sample_vectors(samples, num: int):
43
+ num_samples, device = samples.shape[0], samples.device
44
+
45
+ if num_samples >= num:
46
+ indices = torch.randperm(num_samples, device=device)[:num]
47
+ else:
48
+ indices = torch.randint(0, num_samples, (num,), device=device)
49
+
50
+ return samples[indices]
51
+
52
+
53
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
54
+ dim, dtype = samples.shape[-1], samples.dtype
55
+
56
+ means = sample_vectors(samples, num_clusters)
57
+
58
+ for _ in range(num_iters):
59
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
60
+ means, "c d -> () c d"
61
+ )
62
+ dists = -(diffs ** 2).sum(dim=-1)
63
+
64
+ buckets = dists.max(dim=-1).indices
65
+ bins = torch.bincount(buckets, minlength=num_clusters)
66
+ zero_mask = bins == 0
67
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
68
+
69
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
70
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
71
+ new_means = new_means / bins_min_clamped[..., None]
72
+
73
+ means = torch.where(zero_mask[..., None], means, new_means)
74
+
75
+ return means, bins
76
+
77
+
78
+ def orthgonal_loss_fn(t):
79
+ # eq (2) from https://arxiv.org/abs/2112.00384
80
+ n = t.shape[0]
81
+ normed_codes = l2norm(t)
82
+ identity = torch.eye(n, device=t.device)
83
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
84
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
85
+
86
+
87
+ class EuclideanCodebook(nn.Module):
88
+ """Codebook with Euclidean distance.
89
+
90
+ Args:
91
+ dim (int): Dimension.
92
+ codebook_size (int): Codebook size.
93
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
94
+ If set to true, run the k-means algorithm on the first training batch and use
95
+ the learned centroids as initialization.
96
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
97
+ decay (float): Decay for exponential moving average over the codebooks.
98
+ epsilon (float): Epsilon value for numerical stability.
99
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
100
+ that have an exponential moving average cluster size less than the specified threshold with
101
+ randomly selected vector from the current batch.
102
+ """
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ codebook_size: int,
107
+ kmeans_init: int = False,
108
+ kmeans_iters: int = 10,
109
+ decay: float = 0.8,
110
+ epsilon: float = 1e-5,
111
+ threshold_ema_dead_code: int = 2,
112
+ ):
113
+ super().__init__()
114
+ self.decay = decay
115
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
116
+ embed = init_fn(codebook_size, dim)
117
+
118
+ self.codebook_size = codebook_size
119
+
120
+ self.kmeans_iters = kmeans_iters
121
+ self.epsilon = epsilon
122
+ self.threshold_ema_dead_code = threshold_ema_dead_code
123
+
124
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
125
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
126
+ self.register_buffer("embed", embed)
127
+ self.register_buffer("embed_avg", embed.clone())
128
+
129
+ @torch.jit.ignore
130
+ def init_embed_(self, data):
131
+ if self.inited:
132
+ return
133
+
134
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
135
+ self.embed.data.copy_(embed)
136
+ self.embed_avg.data.copy_(embed.clone())
137
+ self.cluster_size.data.copy_(cluster_size)
138
+ self.inited.data.copy_(torch.Tensor([True]))
139
+ # Make sure all buffers across workers are in sync after initialization
140
+ flashy.distrib.broadcast_tensors(self.buffers())
141
+
142
+ def replace_(self, samples, mask):
143
+ modified_codebook = torch.where(
144
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
145
+ )
146
+ self.embed.data.copy_(modified_codebook)
147
+
148
+ def expire_codes_(self, batch_samples):
149
+ if self.threshold_ema_dead_code == 0:
150
+ return
151
+
152
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
153
+ if not torch.any(expired_codes):
154
+ return
155
+
156
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
157
+ self.replace_(batch_samples, mask=expired_codes)
158
+ flashy.distrib.broadcast_tensors(self.buffers())
159
+
160
+ def preprocess(self, x):
161
+ x = rearrange(x, "... d -> (...) d")
162
+ return x
163
+
164
+ def quantize(self, x):
165
+ embed = self.embed.t()
166
+ dist = -(
167
+ x.pow(2).sum(1, keepdim=True)
168
+ - 2 * x @ embed
169
+ + embed.pow(2).sum(0, keepdim=True)
170
+ )
171
+ embed_ind = dist.max(dim=-1).indices
172
+ return embed_ind
173
+
174
+ def postprocess_emb(self, embed_ind, shape):
175
+ return embed_ind.view(*shape[:-1])
176
+
177
+ def dequantize(self, embed_ind):
178
+ quantize = F.embedding(embed_ind, self.embed)
179
+ return quantize
180
+
181
+ def encode(self, x):
182
+ shape = x.shape
183
+ # pre-process
184
+ x = self.preprocess(x)
185
+ # quantize
186
+ embed_ind = self.quantize(x)
187
+ # post-process
188
+ embed_ind = self.postprocess_emb(embed_ind, shape)
189
+ return embed_ind
190
+
191
+ def decode(self, embed_ind):
192
+ quantize = self.dequantize(embed_ind)
193
+ return quantize
194
+
195
+ def forward(self, x):
196
+ shape, dtype = x.shape, x.dtype
197
+ x = self.preprocess(x)
198
+ self.init_embed_(x)
199
+
200
+ embed_ind = self.quantize(x)
201
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
202
+ embed_ind = self.postprocess_emb(embed_ind, shape)
203
+ quantize = self.dequantize(embed_ind)
204
+
205
+ if self.training:
206
+ # We do the expiry of code at that point as buffers are in sync
207
+ # and all the workers will take the same decision.
208
+ self.expire_codes_(x)
209
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
210
+ embed_sum = x.t() @ embed_onehot
211
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
212
+ cluster_size = (
213
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
214
+ * self.cluster_size.sum()
215
+ )
216
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
217
+ self.embed.data.copy_(embed_normalized)
218
+
219
+ return quantize, embed_ind
220
+
221
+
222
+ class VectorQuantization(nn.Module):
223
+ """Vector quantization implementation.
224
+ Currently supports only euclidean distance.
225
+
226
+ Args:
227
+ dim (int): Dimension
228
+ codebook_size (int): Codebook size
229
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
230
+ decay (float): Decay for exponential moving average over the codebooks.
231
+ epsilon (float): Epsilon value for numerical stability.
232
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
233
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
234
+ threshold_ema_dead_code (int):
235
+ channels_last (bool): Channels are the last dimension in the input tensors.
236
+ commitment_weight (float): Weight for commitment loss.
237
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
238
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
239
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
240
+ for orthogonal regulariation.
241
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
242
+ that have an exponential moving average cluster size less than the specified threshold with
243
+ randomly selected vector from the current batch.
244
+ """
245
+ def __init__(
246
+ self,
247
+ dim: int,
248
+ codebook_size: int,
249
+ codebook_dim: tp.Optional[int] = None,
250
+ decay: float = 0.8,
251
+ epsilon: float = 1e-5,
252
+ kmeans_init: bool = False,
253
+ kmeans_iters: int = 10,
254
+ threshold_ema_dead_code: int = 2,
255
+ channels_last: bool = False,
256
+ commitment_weight: float = 1.,
257
+ orthogonal_reg_weight: float = 0.0,
258
+ orthogonal_reg_active_codes_only: bool = False,
259
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
260
+ ):
261
+ super().__init__()
262
+ _codebook_dim: int = default(codebook_dim, dim)
263
+
264
+ requires_projection = _codebook_dim != dim
265
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
266
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
267
+
268
+ self.epsilon = epsilon
269
+ self.commitment_weight = commitment_weight
270
+
271
+ self.orthogonal_reg_weight = orthogonal_reg_weight
272
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
273
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
274
+
275
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
276
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
277
+ decay=decay, epsilon=epsilon,
278
+ threshold_ema_dead_code=threshold_ema_dead_code)
279
+ self.codebook_size = codebook_size
280
+
281
+ self.channels_last = channels_last
282
+
283
+ @property
284
+ def codebook(self):
285
+ return self._codebook.embed
286
+
287
+ @property
288
+ def inited(self):
289
+ return self._codebook.inited
290
+
291
+ def _preprocess(self, x):
292
+ if not self.channels_last:
293
+ x = rearrange(x, "b d n -> b n d")
294
+ return x
295
+
296
+ def _postprocess(self, quantize):
297
+ if not self.channels_last:
298
+ quantize = rearrange(quantize, "b n d -> b d n")
299
+ return quantize
300
+
301
+ def encode(self, x):
302
+ x = self._preprocess(x)
303
+ x = self.project_in(x)
304
+ embed_in = self._codebook.encode(x)
305
+ return embed_in
306
+
307
+ def decode(self, embed_ind):
308
+ quantize = self._codebook.decode(embed_ind)
309
+ quantize = self.project_out(quantize)
310
+ quantize = self._postprocess(quantize)
311
+ return quantize
312
+
313
+ def forward(self, x):
314
+ device = x.device
315
+ x = self._preprocess(x)
316
+
317
+ x = self.project_in(x)
318
+ quantize, embed_ind = self._codebook(x)
319
+
320
+ if self.training:
321
+ quantize = x + (quantize - x).detach()
322
+
323
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
324
+
325
+ if self.training:
326
+ if self.commitment_weight > 0:
327
+ commit_loss = F.mse_loss(quantize.detach(), x)
328
+ loss = loss + commit_loss * self.commitment_weight
329
+
330
+ if self.orthogonal_reg_weight > 0:
331
+ codebook = self.codebook
332
+
333
+ if self.orthogonal_reg_active_codes_only:
334
+ # only calculate orthogonal loss for the activated codes for this batch
335
+ unique_code_ids = torch.unique(embed_ind)
336
+ codebook = codebook[unique_code_ids]
337
+
338
+ num_codes = codebook.shape[0]
339
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
340
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
341
+ codebook = codebook[rand_ids]
342
+
343
+ orthogonal_reg_loss = orthgonal_loss_fn(codebook)
344
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
345
+
346
+ quantize = self.project_out(quantize)
347
+ quantize = self._postprocess(quantize)
348
+
349
+ return quantize, embed_ind, loss
350
+
351
+
352
+ class ResidualVectorQuantization(nn.Module):
353
+ """Residual vector quantization implementation.
354
+
355
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
356
+ """
357
+ def __init__(self, *, num_quantizers, **kwargs):
358
+ super().__init__()
359
+ self.layers = nn.ModuleList(
360
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
361
+ )
362
+
363
+ def forward(self, x, n_q: tp.Optional[int] = None):
364
+ quantized_out = 0.0
365
+ residual = x
366
+
367
+ all_losses = []
368
+ all_indices = []
369
+
370
+ n_q = n_q or len(self.layers)
371
+
372
+ for i, layer in enumerate(self.layers[:n_q]):
373
+ quantized, indices, loss = layer(residual)
374
+ residual = residual - quantized
375
+ quantized_out = quantized_out + quantized
376
+ all_indices.append(indices)
377
+ all_losses.append(loss)
378
+
379
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
380
+ return quantized_out, out_indices, out_losses
381
+
382
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
383
+ residual = x
384
+ all_indices = []
385
+ n_q = n_q or len(self.layers)
386
+ for layer in self.layers[:n_q]:
387
+ indices = layer.encode(residual)
388
+ quantized = layer.decode(indices)
389
+ residual = residual - quantized
390
+ all_indices.append(indices)
391
+ out_indices = torch.stack(all_indices)
392
+ return out_indices
393
+
394
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
395
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
396
+ for i, indices in enumerate(q_indices):
397
+ layer = self.layers[i]
398
+ quantized = layer.decode(indices)
399
+ quantized_out = quantized_out + quantized
400
+ return quantized_out
audiocraft/quantization/vq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import torch
11
+
12
+ from .base import BaseQuantizer, QuantizedResult
13
+ from .core_vq import ResidualVectorQuantization
14
+
15
+
16
+ class ResidualVectorQuantizer(BaseQuantizer):
17
+ """Residual Vector Quantizer.
18
+
19
+ Args:
20
+ dimension (int): Dimension of the codebooks.
21
+ n_q (int): Number of residual vector quantizers used.
22
+ q_dropout (bool): Random quantizer drop out at train time.
23
+ bins (int): Codebook size.
24
+ decay (float): Decay for exponential moving average over the codebooks.
25
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
26
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
27
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
28
+ that have an exponential moving average cluster size less than the specified threshold with
29
+ randomly selected vector from the current batch.
30
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
31
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
32
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
33
+ for orthogonal regulariation.
34
+ """
35
+ def __init__(
36
+ self,
37
+ dimension: int = 256,
38
+ n_q: int = 8,
39
+ q_dropout: bool = False,
40
+ bins: int = 1024,
41
+ decay: float = 0.99,
42
+ kmeans_init: bool = True,
43
+ kmeans_iters: int = 10,
44
+ threshold_ema_dead_code: int = 2,
45
+ orthogonal_reg_weight: float = 0.0,
46
+ orthogonal_reg_active_codes_only: bool = False,
47
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
48
+ ):
49
+ super().__init__()
50
+ self.max_n_q = n_q
51
+ self.n_q = n_q
52
+ self.q_dropout = q_dropout
53
+ self.dimension = dimension
54
+ self.bins = bins
55
+ self.decay = decay
56
+ self.kmeans_init = kmeans_init
57
+ self.kmeans_iters = kmeans_iters
58
+ self.threshold_ema_dead_code = threshold_ema_dead_code
59
+ self.orthogonal_reg_weight = orthogonal_reg_weight
60
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
61
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
62
+ self.vq = ResidualVectorQuantization(
63
+ dim=self.dimension,
64
+ codebook_size=self.bins,
65
+ num_quantizers=self.n_q,
66
+ decay=self.decay,
67
+ kmeans_init=self.kmeans_init,
68
+ kmeans_iters=self.kmeans_iters,
69
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
70
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
71
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
72
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
73
+ channels_last=False
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor, frame_rate: int):
77
+ n_q = self.n_q
78
+ if self.training and self.q_dropout:
79
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
80
+ bw_per_q = math.log2(self.bins) * frame_rate / 1000
81
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
82
+ codes = codes.transpose(0, 1)
83
+ # codes is [B, K, T], with T frames, K nb of codebooks.
84
+ bw = torch.tensor(n_q * bw_per_q).to(x)
85
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
86
+
87
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
88
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
89
+ The RVQ encode method sets the appropriate number of quantizer to use
90
+ and returns indices for each quantizer.
91
+ """
92
+ n_q = self.n_q
93
+ codes = self.vq.encode(x, n_q=n_q)
94
+ codes = codes.transpose(0, 1)
95
+ # codes is [B, K, T], with T frames, K nb of codebooks.
96
+ return codes
97
+
98
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
99
+ """Decode the given codes to the quantized representation.
100
+ """
101
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
102
+ codes = codes.transpose(0, 1)
103
+ quantized = self.vq.decode(codes)
104
+ return quantized
105
+
106
+ @property
107
+ def total_codebooks(self):
108
+ return self.max_n_q
109
+
110
+ @property
111
+ def num_codebooks(self):
112
+ return self.n_q
113
+
114
+ def set_num_codebooks(self, n: int):
115
+ assert n > 0 and n <= self.max_n_q
116
+ self.n_q = n
audiocraft/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
audiocraft/utils/autocast.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+
10
+ class TorchAutocast:
11
+ """TorchAutocast utility class.
12
+ Allows you to enable and disable autocast. This is specially useful
13
+ when dealing with different architectures and clusters with different
14
+ levels of support.
15
+
16
+ Args:
17
+ enabled (bool): Whether to enable torch.autocast or not.
18
+ args: Additional args for torch.autocast.
19
+ kwargs: Additional kwargs for torch.autocast
20
+ """
21
+ def __init__(self, enabled: bool, *args, **kwargs):
22
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
23
+
24
+ def __enter__(self):
25
+ if self.autocast is None:
26
+ return
27
+ try:
28
+ self.autocast.__enter__()
29
+ except RuntimeError:
30
+ device = self.autocast.device
31
+ dtype = self.autocast.fast_dtype
32
+ raise RuntimeError(
33
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
34
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
35
+ )
36
+
37
+ def __exit__(self, *args, **kwargs):
38
+ if self.autocast is None:
39
+ return
40
+ self.autocast.__exit__(*args, **kwargs)
audiocraft/utils/export.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility to export a training checkpoint to a lightweight release checkpoint.
9
+ """
10
+
11
+ from pathlib import Path
12
+ import typing as tp
13
+
14
+ from omegaconf import OmegaConf, DictConfig
15
+ import torch
16
+
17
+
18
+ def _clean_lm_cfg(cfg: DictConfig):
19
+ OmegaConf.set_struct(cfg, False)
20
+ # This used to be set automatically in the LM solver, need a more robust solution
21
+ # for the future.
22
+ cfg['transformer_lm']['card'] = 2048
23
+ cfg['transformer_lm']['n_q'] = 4
24
+ # Experimental params no longer supported.
25
+ bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
26
+ 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
27
+ for name in bad_params:
28
+ del cfg['transformer_lm'][name]
29
+ OmegaConf.set_struct(cfg, True)
30
+ return cfg
31
+
32
+
33
+ def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
34
+ sig = Path(checkpoint_path).parent.name
35
+ assert len(sig) == 8, "Not a valid Dora signature"
36
+ pkg = torch.load(checkpoint_path, 'cpu')
37
+ new_pkg = {
38
+ 'best_state': pkg['ema']['state']['model'],
39
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
40
+ }
41
+ out_file = Path(out_folder) / f'{sig}.th'
42
+ torch.save(new_pkg, out_file)
43
+ return out_file
44
+
45
+
46
+ def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
47
+ sig = Path(checkpoint_path).parent.name
48
+ assert len(sig) == 8, "Not a valid Dora signature"
49
+ pkg = torch.load(checkpoint_path, 'cpu')
50
+ new_pkg = {
51
+ 'best_state': pkg['fsdp_best_state']['model'],
52
+ 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
53
+ }
54
+ out_file = Path(out_folder) / f'{sig}.th'
55
+ torch.save(new_pkg, out_file)
56
+ return out_file
audiocraft/utils/notebook.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ try:
8
+ import IPython.display as ipd # type: ignore
9
+ except ImportError:
10
+ # Note in a notebook...
11
+ pass
12
+
13
+
14
+ import torch
15
+
16
+
17
+ def display_audio(samples: torch.Tensor, sample_rate: int):
18
+ """Renders an audio player for the given audio samples.
19
+
20
+ Args:
21
+ samples (torch.Tensor): a Tensor of decoded audio samples
22
+ with shapes [B, C, T] or [C, T]
23
+ sample_rate (int): sample rate audio should be displayed with.
24
+ """
25
+ assert samples.dim() == 2 or samples.dim() == 3
26
+
27
+ samples = samples.detach().cpu()
28
+ if samples.dim() == 2:
29
+ samples = samples[None, ...]
30
+
31
+ for audio in samples:
32
+ ipd.display(ipd.Audio(audio, rate=sample_rate))
audiocraft/utils/utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from concurrent.futures import ProcessPoolExecutor
8
+ from functools import wraps
9
+ import hashlib
10
+ import logging
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import flashy.distrib
15
+ import omegaconf
16
+ import torch
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
24
+ """Convenience function to map an omegaconf configuration to a dictionary.
25
+
26
+ Args:
27
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
28
+ Returns:
29
+ dict: Config as dictionary object.
30
+ """
31
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
32
+ assert isinstance(dct, dict)
33
+ return dct
34
+
35
+
36
+ def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
37
+ if max_samples >= len(dataset):
38
+ return dataset
39
+
40
+ generator = torch.Generator().manual_seed(seed)
41
+ perm = torch.randperm(len(dataset), generator=generator)
42
+ return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
43
+
44
+
45
+ def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
46
+ num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
47
+ """Convenience function to load dataset into a dataloader with optional subset sampling.
48
+
49
+ Args:
50
+ dataset: Dataset to load.
51
+ num_samples (Optional[int]): Number of samples to limit subset size.
52
+ batch_size (int): Batch size.
53
+ num_workers (int): Number of workers for data loading.
54
+ seed (int): Random seed.
55
+ """
56
+ if num_samples is not None:
57
+ dataset = random_subset(dataset, num_samples, seed)
58
+
59
+ dataloader = flashy.distrib.loader(
60
+ dataset,
61
+ batch_size=batch_size,
62
+ num_workers=num_workers,
63
+ **kwargs
64
+ )
65
+ return dataloader
66
+
67
+
68
+ def get_dataset_from_loader(dataloader):
69
+ dataset = dataloader.dataset
70
+ if isinstance(dataset, torch.utils.data.Subset):
71
+ return dataset.dataset
72
+ else:
73
+ return dataset
74
+
75
+
76
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
77
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
78
+
79
+ Args:
80
+ input (torch.Tensor): The input tensor containing probabilities.
81
+ num_samples (int): Number of samples to draw.
82
+ replacement (bool): Whether to draw with replacement or not.
83
+ Keywords args:
84
+ generator (torch.Generator): A pseudorandom number generator for sampling.
85
+ Returns:
86
+ torch.Tensor: Last dimension contains num_samples indices
87
+ sampled from the multinomial probability distribution
88
+ located in the last dimension of tensor input.
89
+ """
90
+ input_ = input.reshape(-1, input.shape[-1])
91
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
92
+ output = output_.reshape(*list(input.shape[:-1]), -1)
93
+ return output
94
+
95
+
96
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
97
+ """Sample next token from top K values along the last dimension of the input probs tensor.
98
+
99
+ Args:
100
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
101
+ k (int): The k in “top-k”.
102
+ Returns:
103
+ torch.Tensor: Sampled tokens.
104
+ """
105
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
106
+ min_value_top_k = top_k_value[..., [-1]]
107
+ probs *= (probs >= min_value_top_k).float()
108
+ probs.div_(probs.sum(dim=-1, keepdim=True))
109
+ next_token = multinomial(probs, num_samples=1)
110
+ return next_token
111
+
112
+
113
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
114
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
115
+
116
+ Args:
117
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
118
+ p (int): The p in “top-p”.
119
+ Returns:
120
+ torch.Tensor: Sampled tokens.
121
+ """
122
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
124
+ mask = probs_sum - probs_sort > p
125
+ probs_sort *= (~mask).float()
126
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
+ next_token = multinomial(probs_sort, num_samples=1)
128
+ next_token = torch.gather(probs_idx, -1, next_token)
129
+ return next_token
130
+
131
+
132
+ class DummyPoolExecutor:
133
+ """Dummy pool executor to use when we actually have only 1 worker.
134
+ (e.g. instead of ProcessPoolExecutor).
135
+ """
136
+ class DummyResult:
137
+ def __init__(self, func, *args, **kwargs):
138
+ self.func = func
139
+ self.args = args
140
+ self.kwargs = kwargs
141
+
142
+ def result(self):
143
+ return self.func(*self.args, **self.kwargs)
144
+
145
+ def __init__(self, workers, mp_context=None):
146
+ pass
147
+
148
+ def submit(self, func, *args, **kwargs):
149
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
150
+
151
+ def __enter__(self):
152
+ return self
153
+
154
+ def __exit__(self, exc_type, exc_value, exc_tb):
155
+ return
156
+
157
+
158
+ def get_pool_executor(num_workers: int, mp_context=None):
159
+ return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
160
+
161
+
162
+ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
163
+ """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
164
+ For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
165
+
166
+ Args:
167
+ lengths (torch.Tensor): tensor with lengths
168
+ max_len (int): can set the max length manually. Defaults to None.
169
+ Returns:
170
+ torch.Tensor: mask with 0s where there is pad tokens else 1s
171
+ """
172
+ assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
173
+ final_length = lengths.max().item() if not max_len else max_len
174
+ final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
175
+ return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
176
+
177
+
178
+ def hash_trick(word: str, vocab_size: int) -> int:
179
+ """Hash trick to pair each word with an index
180
+
181
+ Args:
182
+ word (str): word we wish to convert to an index
183
+ vocab_size (int): size of the vocabulary
184
+ Returns:
185
+ int: index of the word in the embedding LUT
186
+ """
187
+ hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
188
+ return hash % vocab_size
189
+
190
+
191
+ def with_rank_rng(base_seed: int = 1234):
192
+ """Decorator for a function so that the function will use a Random Number Generator
193
+ whose state depend on the GPU rank. The original RNG state is restored upon returning.
194
+
195
+ Args:
196
+ base_seed (int): Random seed.
197
+ """
198
+ def _decorator(fun: tp.Callable):
199
+ @wraps(fun)
200
+ def _decorated(*args, **kwargs):
201
+ state = torch.get_rng_state()
202
+ seed = base_seed ^ flashy.distrib.rank()
203
+ torch.manual_seed(seed)
204
+ logger.debug('Rank dependent seed set to %d', seed)
205
+ try:
206
+ return fun(*args, **kwargs)
207
+ finally:
208
+ torch.set_rng_state(state)
209
+ logger.debug('RNG state restored.')
210
+ return _decorated
211
+ return _decorator
212
+
213
+
214
+ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
215
+ """Get a list of tensors and collate them to a single tensor. according to the following logic:
216
+ - `dim` specifies the time dimension which will be stacked and padded.
217
+ - The output will contain 1 new dimension (dimension index 0) which will be the size of
218
+ of the original list.
219
+
220
+ Args:
221
+ tensors (tp.List[torch.Tensor]): List of tensors to collate.
222
+ dim (int): Dimension which will be stacked and padded.
223
+ Returns:
224
+ tp.Tuple[torch.Tensor, torch.Tensor]:
225
+ torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
226
+ (dimension index 0) which will be the size of the original list.
227
+ torch.Tensor: Tensor containing length of original tensor sizes (without padding).
228
+ """
229
+ tensors = [x.transpose(0, dim) for x in tensors]
230
+ lens = torch.LongTensor([len(x) for x in tensors])
231
+ padded_tensors = pad_sequence(tensors)
232
+ padded_tensors = padded_tensors.transpose(0, 1)
233
+ padded_tensors = padded_tensors.transpose(1, dim + 1)
234
+ return padded_tensors, lens