diff --git a/.github/actions/audiocraft_build/action.yml b/.github/actions/audiocraft_build/action.yml
new file mode 100644
index 0000000000000000000000000000000000000000..be5dae26afef4c5e756135cbddfab034db6a016e
--- /dev/null
+++ b/.github/actions/audiocraft_build/action.yml
@@ -0,0 +1,29 @@
+name: audiocraft_build
+description: 'Build audiocraft env.'
+runs:
+ using: "composite"
+ steps:
+ - uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
+ - uses: actions/cache@v2
+ id: cache
+ with:
+ path: env
+ key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
+
+ - if: ${{ steps.cache.outputs.cache-hit != 'true' }}
+ name: Install dependencies
+ shell: bash
+ run: |
+ sudo apt-get update
+ sudo apt-get install libsndfile1-dev ffmpeg
+ python3 -m venv env
+ . env/bin/activate
+ python -m pip install --upgrade pip
+ pip install -e '.[dev]'
+ - name: System Dependencies
+ shell: bash
+ run: |
+ sudo apt-get update
+ sudo apt-get install libsndfile1-dev ffmpeg
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f401b614804fdfc6bb427a50b92feab1c758c159
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,63 @@
+# Byte-compiled / optimized / DLL files
+__pycache__
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# macOS dir files
+.DS_Store
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+.ipynb_checkpoints
+
+# Tests and linter
+.pytest_cache/
+.mypy_cache/
+.coverage
+
+# docs
+/api_docs
+
+# dotenv
+.env
+.envrc
+
+# virtualenv
+.venv
+venv/
+ENV/
+
+# egs with manifest files
+egs/*
+!egs/example
+# local datasets
+dataset/*
+!dataset/example
+
+# personal notebooks & scripts
+*/local_scripts
+*/notes
+.vscode/
+/notebooks
+/local_scripts
+/notes
+/cache
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..aabf9130b0a67aca9beaac9f2cb1a40237a4468d
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,28 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+
+## [1.0.0] - 2023-08-02
+
+Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
+Added pretrained model for AudioGen and MultiBandDiffusion.
+
+## [0.0.2] - 2023-08-01
+
+Improved demo, fixed top p (thanks @jnordberg).
+
+Compressor tanh on output to avoid clipping with some style (especially piano).
+Now repeating the conditioning periodically if it is too short.
+
+More options when launching Gradio app locally (thanks @ashleykleynhans).
+
+Testing out PyTorch 2.0 memory efficient attention.
+
+Added extended generation (infinite length) by slowly moving the windows.
+Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
+
+## [0.0.1] - 2023-06-09
+
+Initial release, with model evaluation only.
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..83f431e8feeb7e80d571f39c9f6c1b96857b5f85
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,80 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when there is a
+reasonable belief that an individual's behavior may have a negative impact on
+the project or its community.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at . All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..a3e9507643d4439f509a8fc8b87dc73417ef9822
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,35 @@
+# Contributing to AudioCraft
+
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+
+AudioCraft is the implementation of a research paper.
+Therefore, we do not plan on accepting many pull requests for new features.
+We certainly welcome them for bug fixes.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Meta's open source projects.
+
+Complete your CLA here:
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+By contributing to encodec, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree.
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..efc2431ec0fe674c22fe2fdb9d7045cdf6cd2748
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,26 @@
+FROM nvidia/cuda:11.8.0-base-ubuntu22.04
+
+ENV DEBIAN_FRONTEND=noninteractive \
+ PYTHONUNBUFFERED=1 \
+ PYTHONIOENCODING=UTF-8
+RUN --mount=type=cache,target=/var/cache/apt --mount=type=cache,target=/var/lib/apt apt update &&\
+ apt install -y \
+ wget \
+ git \
+ pkg-config \
+ python3 \
+ python3-pip \
+ python-is-python3 \
+ ffmpeg \
+ libnvrtc11.2 \
+ libtcmalloc-minimal4
+
+RUN useradd -m -u 1000 ac
+RUN --mount=type=cache,target=/root/.cache python -m pip install --upgrade pip wheel
+ENV TORCH_COMMAND="pip install torch==2.0.1+cu118 torchaudio --extra-index-url https://download.pytorch.org/whl/cu118"
+RUN --mount=type=cache,target=/root/.cache python -m $TORCH_COMMAND
+RUN ln -s /usr/lib/x86_64-linux-gnu/libnvrtc.so.11.2 /usr/lib/x86_64-linux-gnu/libnvrtc.so
+USER 1000
+RUN mkdir ~/.cache
+RUN --mount=type=cache,target=/home/ac/.cache --mount=source=.,target=/home/ac/audiocraft python -m pip install -r /home/ac/audiocraft/requirements.txt
+WORKDIR /home/ac/audiocraft
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b93be90515ccd0b9daedaa589e42bf5929693f1f
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) Meta Platforms, Inc. and affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/LICENSE_weights b/LICENSE_weights
new file mode 100644
index 0000000000000000000000000000000000000000..108b5f002fc31efe11d881de2cd05329ebe8cc37
--- /dev/null
+++ b/LICENSE_weights
@@ -0,0 +1,399 @@
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..ac6828f0ab296c7e34e44548b14bce9df4f65a6c
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,9 @@
+include Makefile
+include LICENSE
+include LICENSE_weights
+include *.md
+include *.ini
+include requirements.txt
+include audiocraft/py.typed
+include assets/*.mp3
+recursive-include conf *.yaml
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..be8a8b03aa984ac5ed95c98e05887fe108dce073
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,40 @@
+INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
+ dataset.train.num_samples=10 dataset.valid.num_samples=10 \
+ dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
+ logging.level=DEBUG
+INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 616d7b3c
+INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
+ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
+INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
+ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c
+INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
+ checkpoint.save_last=false # Using compression model from 616d7b3c
+
+default: linter tests
+
+install:
+ pip install -U pip
+ pip install -U -e '.[dev]'
+
+linter:
+ flake8 audiocraft && mypy audiocraft
+ flake8 tests && mypy tests
+
+tests:
+ coverage run -m pytest tests
+ coverage report
+
+tests_integ:
+ $(INTEG_COMPRESSION)
+ $(INTEG_MBD)
+ $(INTEG_MUSICGEN)
+ $(INTEG_AUDIOGEN)
+
+
+api_docs:
+ pdoc3 --html -o api_docs -f audiocraft
+
+dist:
+ python setup.py sdist
+
+.PHONY: linter tests api_docs dist
diff --git a/README.md b/README.md
index 7af2123a959f70532e60ff406fc211458dfa6273..5b9b5569e2eaa6642dfe10aa7da8832df28e98c7 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,102 @@
---
-title: TulipAI Sounscapes
-emoji: 🏆
-colorFrom: red
-colorTo: purple
-sdk: gradio
-sdk_version: 3.50.2
+title: TulipAI_Sounscapes
app_file: app.py
-pinned: false
+sdk: gradio
+sdk_version: 3.40.1
+duplicated_from: TulipAIs/TulipAI_Sounscapes
---
+# AudioCraft Plus
+![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
+![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
+![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)
+
+AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
+for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.
+
+![image](https://github.com/GrandaddyShmax/audiocraft_plus/assets/52707645/c4c5327c-901a-40d8-91be-aa5afcf80b52)
+
+## Features
+AudioCraft Plus is an all-in-one WebUI for the original AudioCraft, adding many quality features on top.
+
+- AudioGen Model
+- Multiband Diffusion
+- Custom Model Support
+- Generation Metadata and Audio Info tab
+- Mono to Stereo
+- Multiprompt/Prompt Segmentation with Structure Prompts
+- Video Output Customization
+- Music Continuation
+
+## Installation
+AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following:
+
+```shell
+# Best to make sure you have torch installed first, in particular before installing xformers.
+# Don't run this if you already have PyTorch installed.
+pip install 'torch>=2.0'
+# Then proceed to one of the following
+pip install -U audiocraft # stable release
+pip install -U git+https://git@github.com/GrandaddyShmax/audiocraft_plus#egg=audiocraft # bleeding edge
+pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
+```
+
+We also recommend having `ffmpeg` installed, either through your system or Anaconda:
+```bash
+sudo apt-get install ffmpeg
+# Or if you are using Anaconda or Miniconda
+conda install 'ffmpeg<5' -c conda-forge
+```
+
+## Models
+
+At the moment, AudioCraft contains the training code and inference code for:
+* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
+* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
+* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
+* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
+
+## Training code
+
+AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
+For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
+the [AudioCraft training documentation](./docs/TRAINING.md).
+
+For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
+that provides pointers to configuration, example grids and model/task-specific information and FAQ.
+
+
+## API documentation
+
+We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.
+
+
+## FAQ
+
+#### Is the training code available?
+
+Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md).
+
+#### Where are the models stored?
+
+Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable.
+
+
+## License
+* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
+* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
+
+
+## Citation
+
+For the general framework of AudioCraft, please cite the following.
+```
+@article{copet2023simple,
+ title={Simple and Controllable Music Generation},
+ author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
+ year={2023},
+ journal={arXiv preprint arXiv:2306.05284},
+}
+```
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+When referring to a specific model, please cite as mentioned in the model specific README, e.g
+[./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e3c3e417a920d6d76946892930a9173755d1d98
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1685 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
+# also released under the MIT license.
+
+import argparse
+from concurrent.futures import ProcessPoolExecutor
+import os
+from pathlib import Path
+import subprocess as sp
+from tempfile import NamedTemporaryFile
+import time
+import warnings
+import glob
+import re
+from PIL import Image
+from pydub import AudioSegment
+from datetime import datetime
+
+import json
+import shutil
+import taglib
+import torch
+import torchaudio
+import gradio as gr
+import numpy as np
+import typing as tp
+
+from audiocraft.data.audio_utils import convert_audio
+from audiocraft.data.audio import audio_write
+from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion
+from audiocraft.utils import ui
+import random, string
+
+version = "2.0.0a"
+
+theme = gr.themes.Base(
+ primary_hue="lime",
+ secondary_hue="lime",
+ neutral_hue="neutral",
+).set(
+ button_primary_background_fill_hover='*primary_500',
+ button_primary_background_fill_hover_dark='*primary_500',
+ button_secondary_background_fill_hover='*primary_500',
+ button_secondary_background_fill_hover_dark='*primary_500'
+)
+
+MODEL = None # Last used model
+MODELS = None
+UNLOAD_MODEL = False
+MOVE_TO_CPU = False
+IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
+print(IS_BATCHED)
+MAX_BATCH_SIZE = 12
+BATCHED_DURATION = 15
+INTERRUPTING = False
+MBD = None
+# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
+_old_call = sp.call
+
+
+def generate_random_string(length):
+ characters = string.ascii_letters + string.digits
+ return ''.join(random.choice(characters) for _ in range(length))
+
+
+def resize_video(input_path, output_path, target_width, target_height):
+ ffmpeg_cmd = [
+ 'ffmpeg',
+ '-y',
+ '-i', input_path,
+ '-vf', f'scale={target_width}:{target_height}',
+ '-c:a', 'copy',
+ output_path
+ ]
+ sp.run(ffmpeg_cmd)
+
+
+def _call_nostderr(*args, **kwargs):
+ # Avoid ffmpeg vomiting on the logs.
+ kwargs['stderr'] = sp.DEVNULL
+ kwargs['stdout'] = sp.DEVNULL
+ _old_call(*args, **kwargs)
+
+
+sp.call = _call_nostderr
+# Preallocating the pool of processes.
+pool = ProcessPoolExecutor(4)
+pool.__enter__()
+
+
+def interrupt():
+ global INTERRUPTING
+ INTERRUPTING = True
+
+
+class FileCleaner:
+ def __init__(self, file_lifetime: float = 3600):
+ self.file_lifetime = file_lifetime
+ self.files = []
+
+ def add(self, path: tp.Union[str, Path]):
+ self._cleanup()
+ self.files.append((time.time(), Path(path)))
+
+ def _cleanup(self):
+ now = time.time()
+ for time_added, path in list(self.files):
+ if now - time_added > self.file_lifetime:
+ if path.exists():
+ path.unlink()
+ self.files.pop(0)
+ else:
+ break
+
+
+file_cleaner = FileCleaner()
+
+
+def make_waveform(*args, **kwargs):
+ # Further remove some warnings.
+ be = time.time()
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ height = kwargs.pop('height')
+ width = kwargs.pop('width')
+ if height < 256:
+ height = 256
+ if width < 256:
+ width = 256
+ waveform_video = gr.make_waveform(*args, **kwargs)
+ out = f"{generate_random_string(12)}.mp4"
+ image = kwargs.get('bg_image', None)
+ if image is None:
+ resize_video(waveform_video, out, 900, 300)
+ else:
+ resize_video(waveform_video, out, width, height)
+ print("Make a video took", time.time() - be)
+ return out
+
+
+def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=None, base_model='GrandaddyShmax/musicgen-medium', gen_type="music"):
+ global MODEL, MODELS
+ print("Loading model", version)
+ if MODELS is None:
+ if version == 'GrandaddyShmax/musicgen-custom':
+ MODEL = MusicGen.get_pretrained(base_model)
+ file_path = os.path.abspath("models/" + str(custom_model) + ".pt")
+ MODEL.lm.load_state_dict(torch.load(file_path))
+ else:
+ if gen_type == "music":
+ MODEL = MusicGen.get_pretrained(version)
+ elif gen_type == "audio":
+ MODEL = AudioGen.get_pretrained(version)
+
+ return
+
+ else:
+ t1 = time.monotonic()
+ if MODEL is not None:
+ MODEL.to('cpu') # move to cache
+ print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
+ t1 = time.monotonic()
+ if version != 'GrandaddyShmax/musicgen-custom' and MODELS.get(version) is None:
+ print("Loading model %s from disk" % version)
+ if gen_type == "music":
+ result = MusicGen.get_pretrained(version)
+ elif gen_type == "audio":
+ result = AudioGen.get_pretrained(version)
+ MODELS[version] = result
+ print("Model loaded in %.2fs" % (time.monotonic() - t1))
+ MODEL = result
+ return
+ result = MODELS[version].to('cuda')
+ print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
+ MODEL = result
+
+def get_audio_info(audio_path):
+ if audio_path is not None:
+ if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
+ if not audio_path.name.endswith(".json"):
+ with taglib.File(audio_path.name, save_on_exit=False) as song:
+ if 'COMMENT' not in song.tags:
+ return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
+ json_string = song.tags['COMMENT'][0]
+ data = json.loads(json_string)
+
+ global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
+ bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
+ key = str("\nKey: " + data['key']) if 'key' in data else ""
+ scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
+ prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
+ duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
+ overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
+ seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
+ audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
+ input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
+ channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
+ sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
+ gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
+ model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
+ custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
+ base_model = str("\nBase Model: " + data['base_model']) if 'base_model' in data else ""
+ decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
+ topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
+ topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
+ temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
+ cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
+ version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
+ info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + base_model + decoder + topk + topp + temperature + cfg_coef)
+ if info == "":
+ return "No tags found. Either the file is not generated by V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
+ return info
+ else:
+ with open(audio_path.name) as json_file:
+ data = json.load(json_file)
+ #if 'global_prompt' not in data:
+ #return "No tags found. Either the file is not generated by V1.2.8a and higher or the tags are corrupted."
+ global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
+ bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
+ key = str("\nKey: " + data['key']) if 'key' in data else ""
+ scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
+ prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
+ duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
+ overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
+ seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
+ audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
+ input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
+ channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
+ sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
+ gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
+ model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
+ custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
+ base_model = str("\nBase Model: " + data['base_model']) if 'base_model' in data else ""
+ decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
+ topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
+ topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
+ temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
+ cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
+ version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
+ info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + base_model + decoder + topk + topp + temperature + cfg_coef)
+ if info == "":
+ return "No tags found. Either the file is not generated by V1.2.7 and higher or the tags are corrupted."
+ return info
+ else:
+ return "Only .wav ,.mp4 and .json files are supported"
+ else:
+ return None
+
+
+def info_to_params(audio_path):
+ if audio_path is not None:
+ if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
+ if not audio_path.name.endswith(".json"):
+ with taglib.File(audio_path.name, save_on_exit=False) as song:
+ if 'COMMENT' not in song.tags:
+ return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
+ json_string = song.tags['COMMENT'][0]
+ data = json.loads(json_string)
+ struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
+ bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
+ key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
+ scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
+ model = data['model'] if 'model' in data else "large"
+ custom_model = (data['custom_model'] if data['custom_model'] in get_available_models() else None) if 'custom_model' in data else None
+ base_model = data['base_model'] if 'base_model' in data else "medium"
+ decoder = data['decoder'] if 'decoder' in data else "Default"
+ if 'texts' not in data:
+ unique_prompts = 1
+ text = ["", "", "", "", "", "", "", "", "", ""]
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ else:
+ s = data['texts']
+ s = re.findall(r"'(.*?)'", s)
+ text = []
+ repeat = []
+ i = 0
+ for elem in s:
+ if elem.strip():
+ if i == 0 or elem != s[i-1]:
+ text.append(elem)
+ repeat.append(1)
+ else:
+ repeat[-1] += 1
+ i += 1
+ text.extend([""] * (10 - len(text)))
+ repeat.extend([1] * (10 - len(repeat)))
+ unique_prompts = len([t for t in text if t])
+ audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
+ duration = int(data['duration']) if 'duration' in data else 10
+ topk = float(data['topk']) if 'topk' in data else 250
+ topp = float(data['topp']) if 'topp' in data else 0
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
+ seed = int(data['seed']) if 'seed' in data else -1
+ overlap = int(data['overlap']) if 'overlap' in data else 12
+ channel = data['channel'] if 'channel' in data else "stereo"
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
+ return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, base_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
+ else:
+ with open(audio_path.name) as json_file:
+ data = json.load(json_file)
+ struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
+ bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
+ key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
+ scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
+ model = data['model'] if 'model' in data else "large"
+ custom_model = (data['custom_model'] if data['custom_model'] in get_available_models() else None) if 'custom_model' in data else None
+ base_model = data['base_model'] if 'base_model' in data else "medium"
+ decoder = data['decoder'] if 'decoder' in data else "Default"
+ if 'texts' not in data:
+ unique_prompts = 1
+ text = ["", "", "", "", "", "", "", "", "", ""]
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ else:
+ s = data['texts']
+ s = re.findall(r"'(.*?)'", s)
+ text = []
+ repeat = []
+ i = 0
+ for elem in s:
+ if elem.strip():
+ if i == 0 or elem != s[i-1]:
+ text.append(elem)
+ repeat.append(1)
+ else:
+ repeat[-1] += 1
+ i += 1
+ text.extend([""] * (10 - len(text)))
+ repeat.extend([1] * (10 - len(repeat)))
+ unique_prompts = len([t for t in text if t])
+ audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
+ duration = int(data['duration']) if 'duration' in data else 10
+ topk = float(data['topk']) if 'topk' in data else 250
+ topp = float(data['topp']) if 'topp' in data else 0
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
+ seed = int(data['seed']) if 'seed' in data else -1
+ overlap = int(data['overlap']) if 'overlap' in data else 12
+ channel = data['channel'] if 'channel' in data else "stereo"
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
+ return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, base_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
+ else:
+ return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
+ else:
+ return "Default", False, "", 120, "C", "Major", "large", None, "medium", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
+
+
+def info_to_params_a(audio_path):
+ if audio_path is not None:
+ if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
+ if not audio_path.name.endswith(".json"):
+ with taglib.File(audio_path.name, save_on_exit=False) as song:
+ if 'COMMENT' not in song.tags:
+ return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
+ json_string = song.tags['COMMENT'][0]
+ data = json.loads(json_string)
+ struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
+ decoder = data['decoder'] if 'decoder' in data else "Default"
+ if 'texts' not in data:
+ unique_prompts = 1
+ text = ["", "", "", "", "", "", "", "", "", ""]
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ else:
+ s = data['texts']
+ s = re.findall(r"'(.*?)'", s)
+ text = []
+ repeat = []
+ i = 0
+ for elem in s:
+ if elem.strip():
+ if i == 0 or elem != s[i-1]:
+ text.append(elem)
+ repeat.append(1)
+ else:
+ repeat[-1] += 1
+ i += 1
+ text.extend([""] * (10 - len(text)))
+ repeat.extend([1] * (10 - len(repeat)))
+ unique_prompts = len([t for t in text if t])
+ duration = int(data['duration']) if 'duration' in data else 10
+ topk = float(data['topk']) if 'topk' in data else 250
+ topp = float(data['topp']) if 'topp' in data else 0
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
+ seed = int(data['seed']) if 'seed' in data else -1
+ overlap = int(data['overlap']) if 'overlap' in data else 12
+ channel = data['channel'] if 'channel' in data else "stereo"
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
+ return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
+ else:
+ with open(audio_path.name) as json_file:
+ data = json.load(json_file)
+ struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
+ global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
+ decoder = data['decoder'] if 'decoder' in data else "Default"
+ if 'texts' not in data:
+ unique_prompts = 1
+ text = ["", "", "", "", "", "", "", "", "", ""]
+ repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ else:
+ s = data['texts']
+ s = re.findall(r"'(.*?)'", s)
+ text = []
+ repeat = []
+ i = 0
+ for elem in s:
+ if elem.strip():
+ if i == 0 or elem != s[i-1]:
+ text.append(elem)
+ repeat.append(1)
+ else:
+ repeat[-1] += 1
+ i += 1
+ text.extend([""] * (10 - len(text)))
+ repeat.extend([1] * (10 - len(repeat)))
+ unique_prompts = len([t for t in text if t])
+ duration = int(data['duration']) if 'duration' in data else 10
+ topk = float(data['topk']) if 'topk' in data else 250
+ topp = float(data['topp']) if 'topp' in data else 0
+ temperature = float(data['temperature']) if 'temperature' in data else 1.0
+ cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
+ seed = int(data['seed']) if 'seed' in data else -1
+ overlap = int(data['overlap']) if 'overlap' in data else 12
+ channel = data['channel'] if 'channel' in data else "stereo"
+ sr_select = data['sr_select'] if 'sr_select' in data else "48000"
+ return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
+
+ else:
+ return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
+ else:
+ return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
+
+
+def make_pseudo_stereo (filename, sr_select, pan, delay):
+ if pan:
+ temp = AudioSegment.from_wav(filename)
+ if sr_select != "32000":
+ temp = temp.set_frame_rate(int(sr_select))
+ left = temp.pan(-0.5) - 5
+ right = temp.pan(0.6) - 5
+ temp = left.overlay(right, position=5)
+ temp.export(filename, format="wav")
+ if delay:
+ waveform, sample_rate = torchaudio.load(filename) # load mono WAV file
+ delay_seconds = 0.01 # set delay 10ms
+ delay_samples = int(delay_seconds * sample_rate) # Calculating delay value in number of samples
+ stereo_waveform = torch.stack([waveform[0], torch.cat((torch.zeros(delay_samples), waveform[0][:-delay_samples]))]) # Generate a stereo file with original mono audio and delayed version
+ torchaudio.save(filename, stereo_waveform, sample_rate)
+ return
+
+
+def normalize_audio(audio_data):
+ audio_data = audio_data.astype(np.float32)
+ max_value = np.max(np.abs(audio_data))
+ audio_data /= max_value
+ return audio_data
+
+
+def load_diffusion():
+ global MBD
+ if MBD is None:
+ print("loading MBD")
+ MBD = MultiBandDiffusion.get_mbd_
+ ()
+
+
+def unload_diffusion():
+ global MBD
+ if MBD is not None:
+ print("unloading MBD")
+ MBD = None
+
+
+def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=False, **gen_kwargs):
+ if gen_type == "music":
+ maximum_size = 29.5
+ elif gen_type == "audio":
+ maximum_size = 9.5
+ cut_size = 0
+ input_length = 0
+ sampleP = None
+ if sample is not None:
+ globalSR, sampleM = sample[0], sample[1]
+ sampleM = normalize_audio(sampleM)
+ sampleM = torch.from_numpy(sampleM).t()
+ if sampleM.dim() == 1:
+ sampleM = sampleM.unsqueeze(0)
+ sample_length = sampleM.shape[sampleM.dim() - 1] / globalSR
+ if trim_start >= sample_length:
+ trim_start = sample_length - 0.5
+ if trim_end >= sample_length:
+ trim_end = sample_length - 0.5
+ if trim_start + trim_end >= sample_length:
+ tmp = sample_length - 0.5
+ trim_start = tmp / 2
+ trim_end = tmp / 2
+ sampleM = sampleM[..., int(globalSR * trim_start):int(globalSR * (sample_length - trim_end))]
+ sample_length = sample_length - (trim_start + trim_end)
+ if sample_length > maximum_size:
+ cut_size = sample_length - maximum_size
+ sampleP = sampleM[..., :int(globalSR * cut_size)]
+ sampleM = sampleM[..., int(globalSR * cut_size):]
+ if sample_length >= duration:
+ duration = sample_length + 0.5
+ input_length = sample_length
+ global MODEL
+ MODEL.set_generation_params(duration=(duration - cut_size), **gen_kwargs)
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies], [None if sample is None else (sample[0], sample[1].shape)])
+ be = time.time()
+ processed_melodies = []
+ if gen_type == "music":
+ target_sr = 32000
+ elif gen_type == "audio":
+ target_sr = 16000
+ target_ac = 1
+
+ for melody in melodies:
+ if melody is None:
+ processed_melodies.append(None)
+ else:
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
+ if melody.dim() == 1:
+ melody = melody[None]
+ melody = melody[..., :int(sr * duration)]
+ melody = convert_audio(melody, sr, target_sr, target_ac)
+ processed_melodies.append(melody)
+
+ if sample is not None:
+ if sampleP is None:
+ if gen_type == "music":
+ outputs = MODEL.generate_continuation(
+ prompt=sampleM,
+ prompt_sample_rate=globalSR,
+ descriptions=texts,
+ progress=progress,
+ return_tokens=USE_DIFFUSION
+ )
+ elif gen_type == "audio":
+ outputs = MODEL.generate_continuation(
+ prompt=sampleM,
+ prompt_sample_rate=globalSR,
+ descriptions=texts,
+ progress=progress
+ )
+ else:
+ if sampleP.dim() > 1:
+ sampleP = convert_audio(sampleP, globalSR, target_sr, target_ac)
+ sampleP = sampleP.to(MODEL.device).float().unsqueeze(0)
+ if gen_type == "music":
+ outputs = MODEL.generate_continuation(
+ prompt=sampleM,
+ prompt_sample_rate=globalSR,
+ descriptions=texts,
+ progress=progress,
+ return_tokens=USE_DIFFUSION
+ )
+ elif gen_type == "audio":
+ outputs = MODEL.generate_continuation(
+ prompt=sampleM,
+ prompt_sample_rate=globalSR,
+ descriptions=texts,
+ progress=progress
+ )
+ outputs = torch.cat([sampleP, outputs], 2)
+
+ elif any(m is not None for m in processed_melodies):
+ if gen_type == "music":
+ outputs = MODEL.generate_with_chroma(
+ descriptions=texts,
+ melody_wavs=processed_melodies,
+ melody_sample_rate=target_sr,
+ progress=progress,
+ return_tokens=USE_DIFFUSION
+ )
+ elif gen_type == "audio":
+ outputs = MODEL.generate_with_chroma(
+ descriptions=texts,
+ melody_wavs=processed_melodies,
+ melody_sample_rate=target_sr,
+ progress=progress
+ )
+ else:
+ if gen_type == "music":
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
+ elif gen_type == "audio":
+ outputs = MODEL.generate(texts, progress=progress)
+
+ if USE_DIFFUSION:
+ print("outputs: " + str(outputs))
+ outputs_diffusion = MBD.tokens_to_wav(outputs[1])
+ outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
+ outputs = outputs.detach().cpu().float()
+ backups = outputs
+ if channel == "stereo":
+ outputs = convert_audio(outputs, target_sr, int(sr_select), 2)
+ elif channel == "mono" and sr_select != "32000":
+ outputs = convert_audio(outputs, target_sr, int(sr_select), 1)
+ out_files = []
+ out_audios = []
+ out_backup = []
+ for output in outputs:
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
+ audio_write(
+ file.name, output, (MODEL.sample_rate if channel == "stereo effect" else int(sr_select)), strategy="loudness",
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
+
+ if channel == "stereo effect":
+ make_pseudo_stereo(file.name, sr_select, pan=True, delay=True);
+
+ out_files.append(pool.submit(make_waveform, file.name, bg_image=image, bg_color=background, bars_color=(bar1, bar2), fg_alpha=1.0, bar_count=75, height=height, width=width))
+ out_audios.append(file.name)
+ file_cleaner.add(file.name)
+ print(f'wav: {file.name}')
+ for backup in backups:
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
+ audio_write(
+ file.name, backup, MODEL.sample_rate, strategy="loudness",
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
+ out_backup.append(file.name)
+ file_cleaner.add(file.name)
+ res = [out_file.result() for out_file in out_files]
+ res_audio = out_audios
+ res_backup = out_backup
+ for file in res:
+ file_cleaner.add(file)
+ print(f'video: {file}')
+ print("batch finished", len(texts), time.time() - be)
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
+ if MOVE_TO_CPU:
+ MODEL.to('cpu')
+ if UNLOAD_MODEL:
+ MODEL = None
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ return res, res_audio, res_backup, input_length
+
+
+def predict_batched(texts, melodies):
+ max_text_length = 512
+ texts = [text[:max_text_length] for text in texts]
+ load_model('melody')
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
+ return res
+
+
+def add_tags(filename, tags):
+ json_string = None
+
+ data = {
+ "global_prompt": tags[0],
+ "bpm": tags[1],
+ "key": tags[2],
+ "scale": tags[3],
+ "texts": tags[4],
+ "duration": tags[5],
+ "overlap": tags[6],
+ "seed": tags[7],
+ "audio_mode": tags[8],
+ "input_length": tags[9],
+ "channel": tags[10],
+ "sr_select": tags[11],
+ "model": tags[12],
+ "custom_model": tags[13],
+ "base_model": tags[14],
+ "decoder": tags[15],
+ "topk": tags[16],
+ "topp": tags[17],
+ "temperature": tags[18],
+ "cfg_coef": tags[19],
+ "generator": tags[20],
+ "version": version
+ }
+
+ json_string = json.dumps(data)
+
+ if os.path.exists(filename):
+ with taglib.File(filename, save_on_exit=True) as song:
+ song.tags = {'COMMENT': json_string }
+
+ json_file = open(tags[7] + '.json', 'w')
+ json_file.write(json_string)
+ json_file.close()
+
+ return json_file.name;
+
+
+def save_outputs(mp4, wav_tmp, tags, gen_type):
+ # mp4: .mp4 file name in root running folder of app.py
+ # wav_tmp: temporary wav file located in %TEMP% folder
+ # seed - used seed
+ # exanple BgnJtr4Pn1AJ.mp4, C:\Users\Alex\AppData\Local\Temp\tmp4ermrebs.wav, 195123182343465
+ # procedure read generated .mp4 and wav files, rename it by using seed as name,
+ # and will store it to ./output/today_date/wav and ./output/today_date/mp4 folders.
+ # if file with same seed number already exist its make postfix in name like seed(n)
+ # where is n - consiqunce number 1-2-3-4 and so on
+ # then we store generated mp4 and wav into destination folders.
+
+ current_date = datetime.now().strftime("%Y%m%d")
+ wav_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'wav')
+ mp4_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'mp4')
+ json_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'json')
+ os.makedirs(wav_directory, exist_ok=True)
+ os.makedirs(mp4_directory, exist_ok=True)
+ os.makedirs(json_directory, exist_ok=True)
+
+ filename = str(tags[7]) + '.wav'
+ target = os.path.join(wav_directory, filename)
+ counter = 1
+ while os.path.exists(target):
+ filename = str(tags[7]) + f'({counter})' + '.wav'
+ target = os.path.join(wav_directory, filename)
+ counter += 1
+
+ shutil.copyfile(wav_tmp, target); # make copy of original file
+ json_file = add_tags(target, tags);
+
+ wav_target=target;
+ target=target.replace('wav', 'mp4');
+ mp4_target=target;
+
+ mp4=r'./' +mp4;
+ shutil.copyfile(mp4, target); # make copy of original file
+ _ = add_tags(target, tags);
+
+ target=target.replace('mp4', 'json'); # change the extension to json
+ json_target=target; # store the json target
+
+ with open(target, 'w') as f: # open a writable file object
+ shutil.copyfile(json_file, target); # make copy of original file
+
+ os.remove(json_file)
+
+ return wav_target, mp4_target, json_target;
+
+
+def clear_cash():
+ # delete all temporary files genegated my system
+ current_date = datetime.now().date()
+ current_directory = os.getcwd()
+ files = glob.glob(os.path.join(current_directory, '*.mp4'))
+ for file in files:
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
+ if creation_date == current_date:
+ os.remove(file)
+
+ temp_directory = os.environ.get('TEMP')
+ files = glob.glob(os.path.join(temp_directory, 'tmp*.mp4'))
+ for file in files:
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
+ if creation_date == current_date:
+ os.remove(file)
+
+ files = glob.glob(os.path.join(temp_directory, 'tmp*.wav'))
+ for file in files:
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
+ if creation_date == current_date:
+ os.remove(file)
+
+ files = glob.glob(os.path.join(temp_directory, 'tmp*.png'))
+ for file in files:
+ creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
+ if creation_date == current_date:
+ os.remove(file)
+ return
+
+
+def s2t(seconds, seconds2):
+ # convert seconds to time format
+ # seconds - time in seconds
+ # return time in format 00:00
+ m, s = divmod(seconds, 60)
+ m2, s2 = divmod(seconds2, 60)
+ if seconds != 0 and seconds < seconds2:
+ s = s + 1
+ return ("%02d:%02d - %02d:%02d" % (m, s, m2, s2))
+
+
+def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9):
+ # calculate the time of generation
+ # overlap - overlap in seconds
+ # d0-d9 - drag
+ # return time in seconds
+ d_amount = [int(d0), int(d1), int(d2), int(d3), int(d4), int(d5), int(d6), int(d7), int(d8), int(d9)]
+ calc = []
+ tracks = []
+ time = 0
+ s = s - 1
+ max_time = duration
+ max_limit = 0
+ if gen_type == "music":
+ max_limit = 30
+ elif gen_type == "audio":
+ max_limit = 10
+ track_add = max_limit - overlap
+ tracks.append(max_limit + ((d_amount[0] - 1) * track_add))
+ for i in range(1, 10):
+ tracks.append(d_amount[i] * track_add)
+
+ if tracks[0] >= max_time or s == 0:
+ calc.append(s2t(time, max_time))
+ time = max_time
+ else:
+ calc.append(s2t(time, tracks[0]))
+ time = tracks[0]
+
+ for i in range(1, 10):
+ if time + tracks[i] >= max_time or i == s:
+ calc.append(s2t(time, max_time))
+ time = max_time
+ else:
+ calc.append(s2t(time, time + tracks[i]))
+ time = time + tracks[i]
+
+ return calc[0], calc[1], calc[2], calc[3], calc[4], calc[5], calc[6], calc[7], calc[8], calc[9]
+
+
+def predict_full(gen_type, model, decoder, custom_model, base_model, prompt_amount, struc_prompt, bpm, key, scale, global_prompt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select, progress=gr.Progress()):
+ global INTERRUPTING
+ global USE_DIFFUSION
+ INTERRUPTING = False
+
+ if gen_type == "audio":
+ custom_model = None
+ base_model = "medium"
+
+ if temperature < 0:
+ raise gr.Error("Temperature must be >= 0.")
+ if topk < 0:
+ raise gr.Error("Topk must be non-negative.")
+ if topp < 0:
+ raise gr.Error("Topp must be non-negative.")
+
+ if trim_start < 0:
+ trim_start = 0
+ if trim_end < 0:
+ trim_end = 0
+
+ topk = int(topk)
+
+ if decoder == "MultiBand_Diffusion":
+ USE_DIFFUSION = True
+ load_diffusion()
+ else:
+ USE_DIFFUSION = False
+ unload_diffusion()
+
+ if gen_type == "music":
+ model_shrt = model
+ model = "GrandaddyShmax/-" + model
+ elif gen_type == "audio":
+ model_shrt = model
+ model = "GrandaddyShmax/audiogen-" + model
+ base_model_shrt = base_model
+ base_model = "GrandaddyShmax/-" + base_model
+
+ if MODEL is None or MODEL.name != (model):
+ load_model(model, custom_model, base_model, gen_type)
+ else:
+ if MOVE_TO_CPU:
+ MODEL.to('cuda')
+
+ if seed < 0:
+ seed = random.randint(0, 0xffff_ffff_ffff)
+ torch.manual_seed(seed)
+
+ def _progress(generated, to_generate):
+ progress((min(generated, to_generate), to_generate))
+ if INTERRUPTING:
+ raise gr.Error("Interrupted.")
+ MODEL.set_custom_progress_callback(_progress)
+
+ audio_mode = "none"
+ melody = None
+ sample = None
+ if audio:
+ audio_mode = mode
+ if mode == "sample":
+ sample = audio
+ elif mode == "melody":
+ melody = audio
+
+ base_model = "none" if model != "custom" else base_model
+ custom_model = "none" if model != "custom" else custom_model
+
+ text_cat = [p0, p1, p2, p3, p4, p5, p6, p7, p8, p9]
+ drag_cat = [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9]
+ texts = []
+ raw_texts = []
+ ind = 0
+ ind2 = 0
+ while ind < prompt_amount:
+ for ind2 in range(int(drag_cat[ind])):
+ if not struc_prompt:
+ texts.append(text_cat[ind])
+ global_prompt = "none"
+ bpm = "none"
+ key = "none"
+ scale = "none"
+ raw_texts.append(text_cat[ind])
+ else:
+ if gen_type == "music":
+ bpm_str = str(bpm) + " bpm"
+ key_str = ", " + str(key) + " " + str(scale)
+ global_str = (", " + str(global_prompt)) if str(global_prompt) != "" else ""
+ elif gen_type == "audio":
+ bpm_str = ""
+ key_str = ""
+ global_str = (str(global_prompt)) if str(global_prompt) != "" else ""
+ texts_str = (", " + str(text_cat[ind])) if str(text_cat[ind]) != "" else ""
+ texts.append(bpm_str + key_str + global_str + texts_str)
+ raw_texts.append(text_cat[ind])
+ ind2 = 0
+ ind = ind + 1
+
+ outs, outs_audio, outs_backup, input_length = _do_predictions(
+ gen_type, [texts], [melody], sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=True,
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, extend_stride=MODEL.max_duration-overlap)
+ tags = [str(global_prompt), str(bpm), str(key), str(scale), str(raw_texts), str(duration), str(overlap), str(seed), str(audio_mode), str(input_length), str(channel), str(sr_select), str(model_shrt), str(custom_model), str(base_model_shrt), str(decoder), str(topk), str(topp), str(temperature), str(cfg_coef), str(gen_type)]
+ wav_target, mp4_target, json_target = save_outputs(outs[0], outs_audio[0], tags, gen_type);
+ # Removes the temporary files.
+ for out in outs:
+ os.remove(out)
+ for out in outs_audio:
+ os.remove(out)
+
+ return mp4_target, wav_target, outs_backup[0], [mp4_target, wav_target, json_target], seed
+
+
+max_textboxes = 10
+
+
+def get_available_models():
+ return sorted([re.sub('.pt$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('.pt')])
+
+
+def toggle_audio_src(choice):
+ if choice == "mic":
+ return gr.update(source="microphone", value=None, label="Microphone")
+ else:
+ return gr.update(source="upload", value=None, label="File")
+
+
+def ui_full(launch_kwargs):
+ with gr.Blocks(title='TulipAI Soundscapes', theme=theme) as interface:
+ gr.Markdown(
+ """
+ # TulipAI Soundscapes
+
+ ### TulipAI's Audio Storytelling Toolkit
+
+ Welcome to Soundscapes - TulipAI’s flagship Audio Storytelling Toolkit. Designed with modern content creators in mind, our AI-driven platform generates audio sound effects in just minutes tailored to your unique needs.
+ """
+ )
+ with gr.Tab("AudioGen"):
+ gr.Markdown(
+ """
+ ### AudioGen
+ Check the "Wiki" to learn how to take the most out of TulipAI Soundscapes Sound Effects Generation Tool.
+ """
+ )
+ with gr.Tab("Generate Sound Effects"):
+ with gr.Row():
+ #with gr.Column():
+ with gr.Tab("Generation"):
+ with gr.Column():
+ textboxes_a = []
+ prompts_a = []
+ repeats_a = []
+ calcs_a = []
+ with gr.Row():
+ text0_a = gr.Text(label="Global Prompt", interactive=True, scale=4)
+ prompts_a.append(text0_a)
+ drag0_a = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
+ repeats_a.append(drag0_a)
+ calc0_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
+ calcs_a.append(calc0_a)
+
+ with gr.Accordion("Structured Prompt (Optional)", open=False):
+ with gr.Row():
+ struc_prompts_a = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
+ #global_prompt_a = gr.Text(label="Global Prompt", interactive=True, scale=3)
+ global_prompt_a = text0_a
+ with gr.Row():
+ s_a = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
+ for i in range(max_textboxes):
+ with gr.Row(visible=False) as t_a:
+ text_a = gr.Text(label="Input Text", interactive=True, scale=3)
+ repeat_a = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
+ calc_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
+ textboxes_a.append(t_a)
+ prompts_a.append(text_a)
+ repeats_a.append(repeat_a)
+ calcs_a.append(calc_a)
+
+ overlap_a = gr.Slider(minimum=1, maximum=9, value=2, step=1, label="Overlap", interactive=True)
+ to_calc_a = gr.Button("Calculate Timings", variant="secondary")
+
+ with gr.Row():
+ duration_a = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
+ with gr.Row():
+ seed_a = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
+ gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed_a], queue=False)
+ reuse_seed_a = gr.Button('\u267b\ufe0f', scale=1)
+
+ with gr.Tab("Audio"):
+ with gr.Row():
+ with gr.Column():
+ input_type_a = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
+ mode_a = gr.Radio(["sample"], label="Input Audio Mode (optional)", value="sample", interactive=False, visible=False)
+ with gr.Row():
+ trim_start_a = gr.Number(label="Trim Start", value=0, interactive=True)
+ trim_end_a = gr.Number(label="Trim End", value=0, interactive=True)
+ audio_a = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
+
+ with gr.Tab("Customization"):
+ with gr.Row():
+ with gr.Column():
+ background_a = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
+ bar1_a = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
+ bar2_a = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
+ with gr.Column():
+ image_a = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
+ with gr.Row():
+ height_a = gr.Number(label="Height", value=512, interactive=True)
+ width_a = gr.Number(label="Width", value=768, interactive=True)
+
+ with gr.Tab("Settings"):
+ with gr.Row():
+ channel_a = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
+ sr_select_a = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
+ with gr.Column():
+ dropdown = gr.Dropdown(choices=get_available_models(), value=("No models found" if len(get_available_models()) < 1 else get_available_models()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True)
+ ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
+ basemodel = gr.Radio(["small", "medium", "melody", "large"], label="Base Model", value="medium", interactive=True, scale=1)
+ with gr.Row():
+ struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
+ bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0)
+ key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True)
+ scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True)
+ with gr.Row():
+ model_a = gr.Radio(["medium"], label="Model", value="medium", interactive=False, visible=False)
+ decoder_a = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False, visible=False)
+ with gr.Row():
+ topk_a = gr.Number(label="Top-k", value=250, interactive=True)
+ topp_a = gr.Number(label="Top-p", value=0, interactive=True)
+ temperature_a = gr.Number(label="Temperature", value=1.0, interactive=True)
+ cfg_coef_a = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
+ with gr.Row():
+ submit_a = gr.Button("Generate", variant="primary")
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
+ with gr.Row():
+ with gr.Tab("Output"):
+ output_a = gr.Video(label="Generated Audio", scale=0)
+ with gr.Row():
+ audio_only_a = gr.Audio(type="numpy", label="Audio Only", interactive=False)
+ backup_only_a = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
+ send_audio_a = gr.Button("Send to Input Audio")
+ seed_used_a = gr.Number(label='Seed used', value=-1, interactive=False)
+ download_a = gr.File(label="Generated Files", interactive=False)
+ with gr.Tab("Wiki"):
+ gr.Markdown(
+ """
+ - **[Generate (button)]:**
+ Generates the audio with the given settings and prompts.
+
+ - **[Interrupt (button)]:**
+ Stops the audio generation as soon as it can, providing an incomplete output.
+
+ ---
+
+ ### Generation Tab:
+
+ #### Structure Prompts:
+
+ This feature helps reduce repetetive prompts by allowing you to set global prompts
+ that will be used for all prompt segments.
+
+ - **[Structure Prompts (checkbox)]:**
+ Enable/Disable the structure prompts feature.
+
+ - **[Global Prompt (text)]:**
+ Here write the prompt that you wish to be used for all prompt segments.
+
+ #### Multi-Prompt:
+
+ This feature allows you to control the audio, adding variation to different time segments.
+ You have up to 10 prompt segments. the first prompt will always be 10s long
+ the other prompts will be [10s - overlap].
+ for example if the overlap is 2s, each prompt segment will be 8s.
+
+ - **[Prompt Segments (number)]:**
+ Amount of unique prompt to generate throughout the audio generation.
+
+ - **[Prompt/Input Text (prompt)]:**
+ Here describe the audio you wish the model to generate.
+
+ - **[Repeat (number)]:**
+ Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
+
+ - **[Time (text)]:**
+ The time of the prompt segment.
+
+ - **[Calculate Timings (button)]:**
+ Calculates the timings of the prompt segments.
+
+ - **[Duration (number)]:**
+ How long you want the generated audio to be (in seconds).
+
+ - **[Overlap (number)]:**
+ How much each new segment will reference the previous segment (in seconds).
+ For example, if you choose 2s: Each new segment after the first one will reference the previous segment 2s
+ and will generate only 8s of new audio. The model can only process 10s of music.
+
+ - **[Seed (number)]:**
+ Your generated audio id. If you wish to generate the exact same audio,
+ place the exact seed with the exact prompts
+ (This way you can also extend specific song that was generated short).
+
+ - **[Random Seed (button)]:**
+ Gives "-1" as a seed, which counts as a random seed.
+
+ - **[Copy Previous Seed (button)]:**
+ Copies the seed from the output seed (if you don't feel like doing it manualy).
+
+ ---
+
+ ### Audio Tab:
+
+ - **[Input Type (selection)]:**
+ `File` mode allows you to upload an audio file to use as input
+ `Mic` mode allows you to use your microphone as input
+
+ - **[Trim Start and Trim End (numbers)]:**
+ `Trim Start` set how much you'd like to trim the input audio from the start
+ `Trim End` same as the above but from the end
+
+ - **[Input Audio (audio file)]:**
+ Input here the audio you wish to use.
+
+ ---
+
+ ### Customization Tab:
+
+ - **[Background Color (color)]:**
+ Works only if you don't upload image. Color of the background of the waveform.
+
+ - **[Bar Color Start (color)]:**
+ First color of the waveform bars.
+
+ - **[Bar Color End (color)]:**
+ Second color of the waveform bars.
+
+ - **[Background Image (image)]:**
+ Background image that you wish to be attached to the generated video along with the waveform.
+
+ - **[Height and Width (numbers)]:**
+ Output video resolution, only works with image.
+ (minimum height and width is 256).
+
+ ---
+
+ ### Settings Tab:
+
+ - **[Output Audio Channels (selection)]:**
+ With this you can select the amount of channels that you wish for your output audio.
+ `mono` is a straightforward single channel audio
+ `stereo` is a dual channel audio but it will sound more or less like mono
+ `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
+
+ - **[Output Audio Sample Rate (dropdown)]:**
+ The output audio sample rate, the model default is 32000.
+
+ - **[Top-k (number)]:**
+ is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
+
+ - **[Top-p (number)]:**
+ also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
+
+ - **[Temperature (number)]:**
+ is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
+
+ - **[Classifier Free Guidance (number)]:**
+ refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
+ """
+ )
+ '''with gr.Tab("MusicGen"):
+ gr.Markdown(
+ """
+ ### MusicGen
+ Check the "Wiki" to learn how to take the most out of TulipAI Soundscapes Music Generation Tool.
+ """
+ )
+ with gr.Tab("Generate Music"):
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("Generation"):
+ with gr.Accordion("Structure Prompts", open=False):
+ with gr.Column():
+ with gr.Row():
+ struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
+ bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0)
+ key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True)
+ scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True)
+ with gr.Row():
+ global_prompt = gr.Text(label="Global Prompt", interactive=True, scale=3)
+ with gr.Row():
+ s = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
+ #s_mode = gr.Radio(["segmentation", "batch"], value="segmentation", interactive=True, scale=1, label="Generation Mode")
+ with gr.Column():
+ textboxes = []
+ prompts = []
+ repeats = []
+ calcs = []
+ with gr.Row():
+ text0 = gr.Text(label="Input Text", interactive=True, scale=4)
+ prompts.append(text0)
+ drag0 = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
+ repeats.append(drag0)
+ calc0 = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
+ calcs.append(calc0)
+ for i in range(max_textboxes):
+ with gr.Row(visible=False) as t:
+ text = gr.Text(label="Input Text", interactive=True, scale=3)
+ repeat = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
+ calc = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
+ textboxes.append(t)
+ prompts.append(text)
+ repeats.append(repeat)
+ calcs.append(calc)
+ to_calc = gr.Button("Calculate Timings", variant="secondary")
+ with gr.Row():
+ duration = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
+ with gr.Row():
+ overlap = gr.Slider(minimum=1, maximum=29, value=12, step=1, label="Overlap", interactive=True)
+ with gr.Row():
+ seed = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
+ gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed], queue=False)
+ reuse_seed = gr.Button('\u267b\ufe0f', scale=1)
+
+ with gr.Tab("Audio"):
+ with gr.Row():
+ with gr.Column():
+ input_type = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
+ mode = gr.Radio(["melody", "sample"], label="Input Audio Mode (optional)", value="sample", interactive=True)
+ with gr.Row():
+ trim_start = gr.Number(label="Trim Start", value=0, interactive=True)
+ trim_end = gr.Number(label="Trim End", value=0, interactive=True)
+ audio = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
+
+ with gr.Tab("Customization"):
+ with gr.Row():
+ with gr.Column():
+ background = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
+ bar1 = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
+ bar2 = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
+ with gr.Column():
+ image = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
+ with gr.Row():
+ height = gr.Number(label="Height", value=512, interactive=True)
+ width = gr.Number(label="Width", value=768, interactive=True)
+
+ with gr.Tab("Settings"):
+ with gr.Row():
+ channel = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
+ sr_select = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
+ with gr.Row():
+ model = gr.Radio(["melody", "small", "medium", "large", "custom"], label="Model", value="large", interactive=True, scale=1)
+ with gr.Column():
+ dropdown = gr.Dropdown(choices=get_available_models(), value=("No models found" if len(get_available_models()) < 1 else get_available_models()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True)
+ ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
+ basemodel = gr.Radio(["small", "medium", "melody", "large"], label="Base Model", value="medium", interactive=True, scale=1)
+ with gr.Row():
+ decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True)
+ with gr.Row():
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
+ with gr.Row():
+ submit = gr.Button("Generate", variant="primary")
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
+ with gr.Column() as c:
+ with gr.Tab("Output"):
+ output = gr.Video(label="Generated Music", scale=0)
+ with gr.Row():
+ audio_only = gr.Audio(type="numpy", label="Audio Only", interactive=False)
+ backup_only = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
+ send_audio = gr.Button("Send to Input Audio")
+ seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
+ download = gr.File(label="Generated Files", interactive=False)
+ with gr.Tab("Wiki"):
+ gr.Markdown(
+ """
+ - **[Generate (button)]:**
+ Generates the music with the given settings and prompts.
+
+ - **[Interrupt (button)]:**
+ Stops the music generation as soon as it can, providing an incomplete output.
+
+ ---
+
+ ### Generation Tab:
+
+ #### Structure Prompts:
+
+ This feature helps reduce repetetive prompts by allowing you to set global prompts
+ that will be used for all prompt segments.
+
+ - **[Structure Prompts (checkbox)]:**
+ Enable/Disable the structure prompts feature.
+
+ - **[BPM (number)]:**
+ Beats per minute of the generated music.
+
+ - **[Key (dropdown)]:**
+ The key of the generated music.
+
+ - **[Scale (dropdown)]:**
+ The scale of the generated music.
+
+ - **[Global Prompt (text)]:**
+ Here write the prompt that you wish to be used for all prompt segments.
+
+ #### Multi-Prompt:
+
+ This feature allows you to control the music, adding variation to different time segments.
+ You have up to 10 prompt segments. the first prompt will always be 30s long
+ the other prompts will be [30s - overlap].
+ for example if the overlap is 10s, each prompt segment will be 20s.
+
+ - **[Prompt Segments (number)]:**
+ Amount of unique prompt to generate throughout the music generation.
+
+ - **[Prompt/Input Text (prompt)]:**
+ Here describe the music you wish the model to generate.
+
+ - **[Repeat (number)]:**
+ Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
+
+ - **[Time (text)]:**
+ The time of the prompt segment.
+
+ - **[Calculate Timings (button)]:**
+ Calculates the timings of the prompt segments.
+
+ - **[Duration (number)]:**
+ How long you want the generated music to be (in seconds).
+
+ - **[Overlap (number)]:**
+ How much each new segment will reference the previous segment (in seconds).
+ For example, if you choose 20s: Each new segment after the first one will reference the previous segment 20s
+ and will generate only 10s of new music. The model can only process 30s of music.
+
+ - **[Seed (number)]:**
+ Your generated music id. If you wish to generate the exact same music,
+ place the exact seed with the exact prompts
+ (This way you can also extend specific song that was generated short).
+
+ - **[Random Seed (button)]:**
+ Gives "-1" as a seed, which counts as a random seed.
+
+ - **[Copy Previous Seed (button)]:**
+ Copies the seed from the output seed (if you don't feel like doing it manualy).
+
+ ---
+
+ ### Audio Tab:
+
+ - **[Input Type (selection)]:**
+ `File` mode allows you to upload an audio file to use as input
+ `Mic` mode allows you to use your microphone as input
+
+ - **[Input Audio Mode (selection)]:**
+ `Melody` mode only works with the melody model: it conditions the music generation to reference the melody
+ `Sample` mode works with any model: it gives a music sample to the model to generate its continuation.
+
+ - **[Trim Start and Trim End (numbers)]:**
+ `Trim Start` set how much you'd like to trim the input audio from the start
+ `Trim End` same as the above but from the end
+
+ - **[Input Audio (audio file)]:**
+ Input here the audio you wish to use with "melody" or "sample" mode.
+
+ ---
+
+ ### Customization Tab:
+
+ - **[Background Color (color)]:**
+ Works only if you don't upload image. Color of the background of the waveform.
+
+ - **[Bar Color Start (color)]:**
+ First color of the waveform bars.
+
+ - **[Bar Color End (color)]:**
+ Second color of the waveform bars.
+
+ - **[Background Image (image)]:**
+ Background image that you wish to be attached to the generated video along with the waveform.
+
+ - **[Height and Width (numbers)]:**
+ Output video resolution, only works with image.
+ (minimum height and width is 256).
+
+ ---
+
+ ### Settings Tab:
+
+ - **[Output Audio Channels (selection)]:**
+ With this you can select the amount of channels that you wish for your output audio.
+ `mono` is a straightforward single channel audio
+ `stereo` is a dual channel audio but it will sound more or less like mono
+ `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
+
+ - **[Output Audio Sample Rate (dropdown)]:**
+ The output audio sample rate, the model default is 32000.
+
+ - **[Model (selection)]:**
+ Here you can choose which model you wish to use:
+ `melody` model is based on the medium model with a unique feature that lets you use melody conditioning
+ `small` model is trained on 300M parameters
+ `medium` model is trained on 1.5B parameters
+ `large` model is trained on 3.3B parameters
+ `custom` model runs the custom model that you provided.
+
+ - **[Custom Model (selection)]:**
+ This dropdown will show you models that are placed in the `models` folder
+ you must select `custom` in the model options in order to use it.
+
+ - **[Refresh (button)]:**
+ Refreshes the dropdown list for custom model.
+
+ - **[Base Model (selection)]:**
+ Choose here the model that your custom model is based on.
+
+ - **[Decoder (selection)]:**
+ Choose here the decoder that you wish to use:
+ `Default` is the default decoder
+ `MultiBand_Diffusion` is a decoder that uses diffusion to generate the audio.
+
+ - **[Top-k (number)]:**
+ is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
+
+ - **[Top-p (number)]:**
+ also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
+
+ - **[Temperature (number)]:**
+ is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
+
+ - **[Classifier Free Guidance (number)]:**
+ refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
+ """
+ )
+ with gr.Tab("Audio Info"):
+ gr.Markdown(
+ """
+ ### Audio Info
+ """
+ )
+ with gr.Row():
+ with gr.Column():
+ in_audio = gr.File(type="file", label="Input Any Audio", interactive=True)
+ with gr.Row():
+ send_gen = gr.Button("Send to MusicGen", variant="primary")
+ send_gen_a = gr.Button("Send to AudioGen", variant="primary")
+ with gr.Column():
+ info = gr.Textbox(label="Audio Info", lines=10, interactive=False)
+ with gr.Tab("About"):
+ with gr.Row():
+ with gr.Column():
+ gen_type = gr.Text(value="music", interactive=False, visible=False)
+ gen_type_a = gr.Text(value="audio", interactive=False, visible=False)
+ gr.Markdown(
+ """
+ # Soundscapes by TulipAI
+ Welcome to Soundscapes - TulipAI’s flagship Audio Storytelling Toolkit. Designed with modern content creators in mind, our AI-driven platform generates audio sound effects in just minutes tailored to your unique needs.
+
+ ## PERFECT FOR:
+
+ - Podcasters aiming to immerse their listeners.
+ - Audiobooks sound engineers
+ - Audio engineers seeking that elusive sound.
+ - Producers wanting to enrich their auditory experience.
+ - Sound designers craving innovative tools.
+ - YouTubers desiring to elevate their content.
+ """
+ )
+ with gr.Column():
+ #gr.Image(shape=(5,5))
+ gr.Image(shape=(5,5), value = "https://tulipai.co/assets/images/image01.png")
+
+ send_gen.click(info_to_params, inputs=[in_audio], outputs=[decoder, struc_prompts, global_prompt, bpm, key, scale, model, dropdown, basemodel, s, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select], queue=False)
+ reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
+ send_audio.click(fn=lambda x: x, inputs=[backup_only], outputs=[audio], queue=False)
+ submit.click(predict_full, inputs=[gen_type, model, decoder, dropdown, basemodel, s, struc_prompts, bpm, key, scale, global_prompt, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select], outputs=[output, audio_only, backup_only, download, seed_used])
+ input_type.change(toggle_audio_src, input_type, [audio], queue=False, show_progress=False)
+ to_calc.click(calc_time, inputs=[gen_type, s, duration, overlap, repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9]], outputs=[calcs[0], calcs[1], calcs[2], calcs[3], calcs[4], calcs[5], calcs[6], calcs[7], calcs[8], calcs[9]], queue=False)'''
+
+
+ gen_type = gr.Text(value="music", interactive=False, visible=False)
+ gen_type_a = gr.Text(value="audio", interactive=False, visible=False)
+
+ #send_gen_a.click(info_to_params_a, inputs=[in_audio], outputs=[decoder_a, struc_prompts_a, global_prompt_a, s_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, channel_a, sr_select_a], queue=False)
+ reuse_seed_a.click(fn=lambda x: x, inputs=[seed_used_a], outputs=[seed_a], queue=False)
+ send_audio_a.click(fn=lambda x: x, inputs=[backup_only_a], outputs=[audio_a], queue=False)
+ submit_a.click(predict_full, inputs=[gen_type_a, model_a, decoder_a, dropdown, basemodel, s_a, struc_prompts_a, bpm, key, scale, global_prompt_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], audio_a, mode_a, trim_start_a, trim_end_a, duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, image_a, height_a, width_a, background_a, bar1_a, bar2_a, channel_a, sr_select_a], outputs=[output_a, audio_only_a, backup_only_a, download_a, seed_used_a])
+ input_type_a.change(toggle_audio_src, input_type_a, [audio_a], queue=False, show_progress=True)
+ to_calc_a.click(calc_time, inputs=[gen_type_a, s_a, duration_a, overlap_a, repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9]], outputs=[calcs_a[0], calcs_a[1], calcs_a[2], calcs_a[3], calcs_a[4], calcs_a[5], calcs_a[6], calcs_a[7], calcs_a[8], calcs_a[9]], queue=False)
+
+ #in_audio.change(get_audio_info, in_audio, outputs=[info])
+
+ def variable_outputs(k):
+ k = int(k) - 1
+ return [gr.Textbox.update(visible=True)]*k + [gr.Textbox.update(visible=False)]*(max_textboxes-k)
+ def get_size(image):
+ if image is not None:
+ img = Image.open(image)
+ img_height = img.height
+ img_width = img.width
+ if (img_height%2) != 0:
+ img_height = img_height + 1
+ if (img_width%2) != 0:
+ img_width = img_width + 1
+ return img_height, img_width
+ else:
+ return 512, 768
+
+ #image.change(get_size, image, outputs=[height, width])
+ #image_a.change(get_size, image_a, outputs=[height_a, width_a])
+ #s.change(variable_outputs, s, textboxes)
+ s_a.change(variable_outputs, s_a, textboxes_a)
+ #interface.queue().launch(**launch_kwargs)
+ interface.queue().launch(share=True)
+
+
+def ui_batched(launch_kwargs):
+ with gr.Blocks() as demo:
+ gr.Markdown(
+ """
+ # MusicGen
+
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
+ a simple and controllable model for music generation
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
+
+
+
+ for longer sequences, more control and no queue.
+ """
+ )
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
+ with gr.Column():
+ radio = gr.Radio(["file", "mic"], value="file",
+ label="Condition on a melody (optional) File or Mic")
+ melody = gr.Audio(source="upload", type="numpy", label="File",
+ interactive=True, elem_id="melody-input")
+ with gr.Row():
+ submit = gr.Button("Generate")
+ with gr.Column():
+ output = gr.Video(label="Generated Music")
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
+ submit.click(predict_batched, inputs=[text, melody],
+ outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
+ gr.Examples(
+ fn=predict_batched,
+ examples=[
+ [
+ "An 80s driving pop song with heavy drums and synth pads in the background",
+ "./assets/bach.mp3",
+ ],
+ [
+ "A cheerful country song with acoustic guitars",
+ "./assets/bolero_ravel.mp3",
+ ],
+ [
+ "90s rock song with electric guitar and heavy drums",
+ None,
+ ],
+ [
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
+ "./assets/bach.mp3",
+ ],
+ [
+ "lofi slow bpm electro chill with organic samples",
+ None,
+ ],
+ ],
+ inputs=[text, melody],
+ outputs=[output]
+ )
+ gr.Markdown("""
+ ### More details
+
+ The model will generate 12 seconds of audio based on the description you provided.
+ You can optionally provide a reference audio from which a broad melody will be extracted.
+ The model will then try to follow both the description and melody provided.
+ All samples are generated with the `melody` model.
+
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
+
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
+ for more details.
+ """)
+
+ #demo.queue(max_size=8 * 4).launch(**launch_kwargs)
+ demo.queue().launch(share=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--listen',
+ type=str,
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
+ help='IP to listen on for connections to Gradio',
+ )
+ parser.add_argument(
+ '--username', type=str, default='', help='Username for authentication'
+ )
+ parser.add_argument(
+ '--password', type=str, default='', help='Password for authentication'
+ )
+ parser.add_argument(
+ '--server_port',
+ type=int,
+ default=0,
+ help='Port to run the server listener on',
+ )
+ parser.add_argument(
+ '--inbrowser', action='store_true', help='Open in browser'
+ )
+ parser.add_argument(
+ '--share', action='store_true', help='Share the gradio UI'
+ )
+ parser.add_argument(
+ '--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory'
+ )
+
+ parser.add_argument(
+ '--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)'
+ )
+
+ parser.add_argument(
+ '--cache', action='store_true', help='Cache models in RAM to quickly switch between them'
+ )
+
+ args = parser.parse_args()
+ UNLOAD_MODEL = args.unload_model
+ MOVE_TO_CPU = args.unload_to_cpu
+ if args.cache:
+ MODELS = {}
+
+ launch_kwargs = {}
+ launch_kwargs['server_name'] = args.listen
+
+ if args.username and args.password:
+ launch_kwargs['auth'] = (args.username, args.password)
+ if args.server_port:
+ launch_kwargs['server_port'] = args.server_port
+ if args.inbrowser:
+ launch_kwargs['inbrowser'] = args.inbrowser
+ if args.share:
+ launch_kwargs['share'] = args.share
+
+ # Show the interface
+ if IS_BATCHED:
+ global USE_DIFFUSION
+ USE_DIFFUSION = False
+ ui_batched(launch_kwargs)
+ else:
+ ui_full(launch_kwargs)
\ No newline at end of file
diff --git a/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 b/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..71be35a12d3e97993996806d6a94175568b2761f
Binary files /dev/null and b/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 differ
diff --git a/assets/bach.mp3 b/assets/bach.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..16d0da76cdae45a067c0d3360503509768fa0b34
Binary files /dev/null and b/assets/bach.mp3 differ
diff --git a/assets/bolero_ravel.mp3 b/assets/bolero_ravel.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..cbec949b9bfcec881ffce1b097325f3377f01830
Binary files /dev/null and b/assets/bolero_ravel.mp3 differ
diff --git a/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 b/assets/sirens_and_a_humming_engine_approach_and_pass.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..e74b5b61a624fbf69f5e70febc64c91658bb38ac
Binary files /dev/null and b/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 differ
diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab346075f1b35366e7231054513097b87552c6f
--- /dev/null
+++ b/audiocraft/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+AudioCraft is a general framework for training audio generative models.
+At the moment we provide the training code for:
+
+- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
+ text-to-music and melody+text autoregressive generative model.
+ For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
+ `audiocraft.models.musicgen.MusicGen`.
+- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
+ text-to-general-audio generative model.
+- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
+ neural audio codec which provides an excellent tokenizer for autoregressive language models.
+ See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
+- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
+ improves the perceived quality and reduces the artifacts coming from adversarial decoders.
+"""
+
+# flake8: noqa
+from . import data, modules, models
+
+__version__ = '1.0.0'
diff --git a/audiocraft/__pycache__/__init__.cpython-310.pyc b/audiocraft/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e2979ad3de3d4b198c99910cf462aa4fa95645a
Binary files /dev/null and b/audiocraft/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/__pycache__/environment.cpython-310.pyc b/audiocraft/__pycache__/environment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..821a1853b1078a72b251edbed0a59eaa4fc3f734
Binary files /dev/null and b/audiocraft/__pycache__/environment.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/__init__.py b/audiocraft/adversarial/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..864058706fbfae13d7f7dc850cc411a2f27d1510
--- /dev/null
+++ b/audiocraft/adversarial/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Adversarial losses and discriminator architectures."""
+
+# flake8: noqa
+from .discriminators import (
+ MultiPeriodDiscriminator,
+ MultiScaleDiscriminator,
+ MultiScaleSTFTDiscriminator
+)
+from .losses import (
+ AdversarialLoss,
+ AdvLossType,
+ get_adv_criterion,
+ get_fake_criterion,
+ get_real_criterion,
+ FeatLossType,
+ FeatureMatchingLoss
+)
diff --git a/audiocraft/adversarial/__pycache__/__init__.cpython-310.pyc b/audiocraft/adversarial/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..076bfe1b479296918dc8056bfe8369d69e4334de
Binary files /dev/null and b/audiocraft/adversarial/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/__pycache__/losses.cpython-310.pyc b/audiocraft/adversarial/__pycache__/losses.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c20cce860fb9d6b5a64b8d11f268112cc64487e
Binary files /dev/null and b/audiocraft/adversarial/__pycache__/losses.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/discriminators/__init__.py b/audiocraft/adversarial/discriminators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9e5ff59950ee0b1d1a67c9b3831d67d08048148
--- /dev/null
+++ b/audiocraft/adversarial/discriminators/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .mpd import MultiPeriodDiscriminator
+from .msd import MultiScaleDiscriminator
+from .msstftd import MultiScaleSTFTDiscriminator
diff --git a/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-310.pyc b/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5674f791a892574d891a3a2d42288350fc3aa11f
Binary files /dev/null and b/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/discriminators/__pycache__/base.cpython-310.pyc b/audiocraft/adversarial/discriminators/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31d481a5d8f529b968ec5d8350689f7316b3478b
Binary files /dev/null and b/audiocraft/adversarial/discriminators/__pycache__/base.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-310.pyc b/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..460ab35690ef6662ece4a2d48527c6818cd917e0
Binary files /dev/null and b/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-310.pyc b/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e93ade43fc4625d1754ecd84d1d034a8bafb7c68
Binary files /dev/null and b/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-310.pyc b/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37d6f1bdfd3a2b46b56c078ebf0f7d3e039e3100
Binary files /dev/null and b/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-310.pyc differ
diff --git a/audiocraft/adversarial/discriminators/base.py b/audiocraft/adversarial/discriminators/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9d517e9f5bf0f4e18252c45c8db3a35a7255f69
--- /dev/null
+++ b/audiocraft/adversarial/discriminators/base.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+import typing as tp
+
+import torch
+import torch.nn as nn
+
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+class MultiDiscriminator(ABC, nn.Module):
+ """Base implementation for discriminators composed of sub-discriminators acting at different scales.
+ """
+ def __init__(self):
+ super().__init__()
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+ ...
+
+ @property
+ @abstractmethod
+ def num_discriminators(self) -> int:
+ """Number of discriminators.
+ """
+ ...
diff --git a/audiocraft/adversarial/discriminators/mpd.py b/audiocraft/adversarial/discriminators/mpd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8debd1fa72d77ca03df680facb60bdf79638cade
--- /dev/null
+++ b/audiocraft/adversarial/discriminators/mpd.py
@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...modules import NormConv2d
+from .base import MultiDiscriminator, MultiDiscriminatorOutputType
+
+
+def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class PeriodDiscriminator(nn.Module):
+ """Period sub-discriminator.
+
+ Args:
+ period (int): Period between samples of audio.
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ n_layers (int): Number of convolutional layers.
+ kernel_sizes (list of int): Kernel sizes for convolutions.
+ stride (int): Stride for convolutions.
+ filters (int): Initial number of filters in convolutions.
+ filters_scale (int): Multiplier of number of filters as we increase depth.
+ max_filters (int): Maximum number of filters.
+ norm (str): Normalization method.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ """
+ def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
+ n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
+ filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
+ activation_params: dict = {'negative_slope': 0.2}):
+ super().__init__()
+ self.period = period
+ self.n_layers = n_layers
+ self.activation = getattr(torch.nn, activation)(**activation_params)
+ self.convs = nn.ModuleList()
+ in_chs = in_channels
+ for i in range(self.n_layers):
+ out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
+ eff_stride = 1 if i == self.n_layers - 1 else stride
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
+ padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
+ in_chs = out_chs
+ self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
+ padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
+
+ def forward(self, x: torch.Tensor):
+ fmap = []
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), 'reflect')
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for conv in self.convs:
+ x = conv(x)
+ x = self.activation(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ # x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(MultiDiscriminator):
+ """Multi-Period (MPD) Discriminator.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
+ **kwargs: Additional args for `PeriodDiscriminator`
+ """
+ def __init__(self, in_channels: int = 1, out_channels: int = 1,
+ periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
+ super().__init__()
+ self.discriminators = nn.ModuleList([
+ PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
+ ])
+
+ @property
+ def num_discriminators(self):
+ return len(self.discriminators)
+
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+ logits = []
+ fmaps = []
+ for disc in self.discriminators:
+ logit, fmap = disc(x)
+ logits.append(logit)
+ fmaps.append(fmap)
+ return logits, fmaps
diff --git a/audiocraft/adversarial/discriminators/msd.py b/audiocraft/adversarial/discriminators/msd.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e67e29b46ab22f6ffeec85ffc64d8b99800b1b
--- /dev/null
+++ b/audiocraft/adversarial/discriminators/msd.py
@@ -0,0 +1,126 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...modules import NormConv1d
+from .base import MultiDiscriminator, MultiDiscriminatorOutputType
+
+
+class ScaleDiscriminator(nn.Module):
+ """Waveform sub-discriminator.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
+ filters (int): Number of initial filters for convolutions.
+ max_filters (int): Maximum number of filters.
+ downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
+ inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
+ groups (Sequence[int] or None): Groups for inner convolutions.
+ strides (Sequence[int] or None): Strides for inner convolutions.
+ paddings (Sequence[int] or None): Paddings for inner convolutions.
+ norm (str): Normalization method.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ pad (str): Padding for initial convolution.
+ pad_params (dict): Parameters to provide to the padding module.
+ """
+ def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
+ filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
+ inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
+ strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
+ activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
+ pad_params: dict = {}):
+ super().__init__()
+ assert len(kernel_sizes) == 2
+ assert kernel_sizes[0] % 2 == 1
+ assert kernel_sizes[1] % 2 == 1
+ assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
+ assert (groups is None or len(groups) == len(downsample_scales))
+ assert (strides is None or len(strides) == len(downsample_scales))
+ assert (paddings is None or len(paddings) == len(downsample_scales))
+ self.activation = getattr(torch.nn, activation)(**activation_params)
+ self.convs = nn.ModuleList()
+ self.convs.append(
+ nn.Sequential(
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
+ NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
+ )
+ )
+
+ in_chs = filters
+ for i, downsample_scale in enumerate(downsample_scales):
+ out_chs = min(in_chs * downsample_scale, max_filters)
+ default_kernel_size = downsample_scale * 10 + 1
+ default_stride = downsample_scale
+ default_padding = (default_kernel_size - 1) // 2
+ default_groups = in_chs // 4
+ self.convs.append(
+ NormConv1d(in_chs, out_chs,
+ kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
+ stride=strides[i] if strides else default_stride,
+ groups=groups[i] if groups else default_groups,
+ padding=paddings[i] if paddings else default_padding,
+ norm=norm))
+ in_chs = out_chs
+
+ out_chs = min(in_chs * 2, max_filters)
+ self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
+ padding=(kernel_sizes[0] - 1) // 2, norm=norm))
+ self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
+ padding=(kernel_sizes[1] - 1) // 2, norm=norm)
+
+ def forward(self, x: torch.Tensor):
+ fmap = []
+ for layer in self.convs:
+ x = layer(x)
+ x = self.activation(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ # x = torch.flatten(x, 1, -1)
+ return x, fmap
+
+
+class MultiScaleDiscriminator(MultiDiscriminator):
+ """Multi-Scale (MSD) Discriminator,
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ downsample_factor (int): Downsampling factor between the different scales.
+ scale_norms (Sequence[str]): Normalization for each sub-discriminator.
+ **kwargs: Additional args for ScaleDiscriminator.
+ """
+ def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
+ scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
+ super().__init__()
+ self.discriminators = nn.ModuleList([
+ ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
+ ])
+ self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
+
+ @property
+ def num_discriminators(self):
+ return len(self.discriminators)
+
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+ logits = []
+ fmaps = []
+ for i, disc in enumerate(self.discriminators):
+ if i != 0:
+ self.downsample(x)
+ logit, fmap = disc(x)
+ logits.append(logit)
+ fmaps.append(fmap)
+ return logits, fmaps
diff --git a/audiocraft/adversarial/discriminators/msstftd.py b/audiocraft/adversarial/discriminators/msstftd.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a9100961c7a89a39df2643b24268fb90bfeaa4
--- /dev/null
+++ b/audiocraft/adversarial/discriminators/msstftd.py
@@ -0,0 +1,134 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import torchaudio
+import torch
+from torch import nn
+from einops import rearrange
+
+from ...modules import NormConv2d
+from .base import MultiDiscriminator, MultiDiscriminatorOutputType
+
+
+def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
+
+
+class DiscriminatorSTFT(nn.Module):
+ """STFT sub-discriminator.
+
+ Args:
+ filters (int): Number of filters in convolutions.
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ n_fft (int): Size of FFT for each scale.
+ hop_length (int): Length of hop between STFT windows for each scale.
+ kernel_size (tuple of int): Inner Conv2d kernel sizes.
+ stride (tuple of int): Inner Conv2d strides.
+ dilations (list of int): Inner Conv2d dilation on the time dimension.
+ win_length (int): Window size for each scale.
+ normalized (bool): Whether to normalize by magnitude after stft.
+ norm (str): Normalization method.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ growth (int): Growth factor for the filters.
+ """
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
+ super().__init__()
+ assert len(kernel_size) == 2
+ assert len(stride) == 2
+ self.filters = filters
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.normalized = normalized
+ self.activation = getattr(torch.nn, activation)(**activation_params)
+ self.spec_transform = torchaudio.transforms.Spectrogram(
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
+ spec_channels = 2 * self.in_channels
+ self.convs = nn.ModuleList()
+ self.convs.append(
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
+ )
+ in_chs = min(filters_scale * self.filters, max_filters)
+ for i, dilation in enumerate(dilations):
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
+ norm=norm))
+ in_chs = out_chs
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+ norm=norm))
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
+ kernel_size=(kernel_size[0], kernel_size[0]),
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+ norm=norm)
+
+ def forward(self, x: torch.Tensor):
+ fmap = []
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
+ z = torch.cat([z.real, z.imag], dim=1)
+ z = rearrange(z, 'b c w t -> b c t w')
+ for i, layer in enumerate(self.convs):
+ z = layer(z)
+ z = self.activation(z)
+ fmap.append(z)
+ z = self.conv_post(z)
+ return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(MultiDiscriminator):
+ """Multi-Scale STFT (MS-STFT) discriminator.
+
+ Args:
+ filters (int): Number of filters in convolutions.
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ sep_channels (bool): Separate channels to distinct samples for stereo support.
+ n_ffts (Sequence[int]): Size of FFT for each scale.
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
+ win_lengths (Sequence[int]): Window size for each scale.
+ **kwargs: Additional args for STFTDiscriminator.
+ """
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
+ super().__init__()
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+ self.sep_channels = sep_channels
+ self.discriminators = nn.ModuleList([
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
+ for i in range(len(n_ffts))
+ ])
+
+ @property
+ def num_discriminators(self):
+ return len(self.discriminators)
+
+ def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T = x.shape
+ return x.view(-1, 1, T)
+
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+ logits = []
+ fmaps = []
+ for disc in self.discriminators:
+ logit, fmap = disc(x)
+ logits.append(logit)
+ fmaps.append(fmap)
+ return logits, fmaps
diff --git a/audiocraft/adversarial/losses.py b/audiocraft/adversarial/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..be293e739bdc2d91273f30fb789befe7c8b49a43
--- /dev/null
+++ b/audiocraft/adversarial/losses.py
@@ -0,0 +1,228 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility module to handle adversarial losses without requiring to mess up the main training loop.
+"""
+
+import typing as tp
+
+import flashy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
+
+
+AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
+FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
+
+
+class AdversarialLoss(nn.Module):
+ """Adversary training wrapper.
+
+ Args:
+ adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
+ We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
+ where the first item is a list of logits and the second item is a list of feature maps.
+ optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
+ loss (AdvLossType): Loss function for generator training.
+ loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
+ loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
+ loss_feat (FeatLossType): Feature matching loss function for generator training.
+ normalize (bool): Whether to normalize by number of sub-discriminators.
+
+ Example of usage:
+ adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
+ for real in loader:
+ noise = torch.randn(...)
+ fake = model(noise)
+ adv_loss.train_adv(fake, real)
+ loss, _ = adv_loss(fake, real)
+ loss.backward()
+ """
+ def __init__(self,
+ adversary: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ loss: AdvLossType,
+ loss_real: AdvLossType,
+ loss_fake: AdvLossType,
+ loss_feat: tp.Optional[FeatLossType] = None,
+ normalize: bool = True):
+ super().__init__()
+ self.adversary: nn.Module = adversary
+ flashy.distrib.broadcast_model(self.adversary)
+ self.optimizer = optimizer
+ self.loss = loss
+ self.loss_real = loss_real
+ self.loss_fake = loss_fake
+ self.loss_feat = loss_feat
+ self.normalize = normalize
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ # Add the optimizer state dict inside our own.
+ super()._save_to_state_dict(destination, prefix, keep_vars)
+ destination[prefix + 'optimizer'] = self.optimizer.state_dict()
+ return destination
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ # Load optimizer state.
+ self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ def get_adversary_pred(self, x):
+ """Run adversary model, validating expected output format."""
+ logits, fmaps = self.adversary(x)
+ assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
+ f'Expecting a list of tensors as logits but {type(logits)} found.'
+ assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
+ for fmap in fmaps:
+ assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
+ f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
+ return logits, fmaps
+
+ def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
+ """Train the adversary with the given fake and real example.
+
+ We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
+ The first item being the logits and second item being a list of feature maps for each sub-discriminator.
+
+ This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
+ and call the optimizer.
+ """
+ loss = torch.tensor(0., device=fake.device)
+ all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
+ all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
+ n_sub_adversaries = len(all_logits_fake_is_fake)
+ for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
+ loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
+
+ if self.normalize:
+ loss /= n_sub_adversaries
+
+ self.optimizer.zero_grad()
+ with flashy.distrib.eager_sync_model(self.adversary):
+ loss.backward()
+ self.optimizer.step()
+
+ return loss
+
+ def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Return the loss for the generator, i.e. trying to fool the adversary,
+ and feature matching loss if provided.
+ """
+ adv = torch.tensor(0., device=fake.device)
+ feat = torch.tensor(0., device=fake.device)
+ with flashy.utils.readonly(self.adversary):
+ all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
+ all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
+ n_sub_adversaries = len(all_logits_fake_is_fake)
+ for logit_fake_is_fake in all_logits_fake_is_fake:
+ adv += self.loss(logit_fake_is_fake)
+ if self.loss_feat:
+ for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
+ feat += self.loss_feat(fmap_fake, fmap_real)
+
+ if self.normalize:
+ adv /= n_sub_adversaries
+ feat /= n_sub_adversaries
+
+ return adv, feat
+
+
+def get_adv_criterion(loss_type: str) -> tp.Callable:
+ assert loss_type in ADVERSARIAL_LOSSES
+ if loss_type == 'mse':
+ return mse_loss
+ elif loss_type == 'hinge':
+ return hinge_loss
+ elif loss_type == 'hinge2':
+ return hinge2_loss
+ raise ValueError('Unsupported loss')
+
+
+def get_fake_criterion(loss_type: str) -> tp.Callable:
+ assert loss_type in ADVERSARIAL_LOSSES
+ if loss_type == 'mse':
+ return mse_fake_loss
+ elif loss_type in ['hinge', 'hinge2']:
+ return hinge_fake_loss
+ raise ValueError('Unsupported loss')
+
+
+def get_real_criterion(loss_type: str) -> tp.Callable:
+ assert loss_type in ADVERSARIAL_LOSSES
+ if loss_type == 'mse':
+ return mse_real_loss
+ elif loss_type in ['hinge', 'hinge2']:
+ return hinge_real_loss
+ raise ValueError('Unsupported loss')
+
+
+def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
+
+
+def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
+
+
+def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+def mse_loss(x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0:
+ return torch.tensor([0.0], device=x.device)
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
+
+
+def hinge_loss(x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0:
+ return torch.tensor([0.0], device=x.device)
+ return -x.mean()
+
+
+def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0:
+ return torch.tensor([0.0])
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+class FeatureMatchingLoss(nn.Module):
+ """Feature matching loss for adversarial training.
+
+ Args:
+ loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
+ normalize (bool): Whether to normalize the loss.
+ by number of feature maps.
+ """
+ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
+ super().__init__()
+ self.loss = loss
+ self.normalize = normalize
+
+ def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
+ assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
+ feat_loss = torch.tensor(0., device=fmap_fake[0].device)
+ feat_scale = torch.tensor(0., device=fmap_fake[0].device)
+ n_fmaps = 0
+ for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
+ assert feat_fake.shape == feat_real.shape
+ n_fmaps += 1
+ feat_loss += self.loss(feat_fake, feat_real)
+ feat_scale += torch.mean(torch.abs(feat_real))
+
+ if self.normalize:
+ feat_loss /= n_fmaps
+
+ return feat_loss
diff --git a/audiocraft/data/__init__.py b/audiocraft/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2906ff12bc85a894837579f3137f6f71a0438329
--- /dev/null
+++ b/audiocraft/data/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Audio loading and writing support. Datasets for raw audio
+or also including some metadata."""
+
+# flake8: noqa
+from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
diff --git a/audiocraft/data/__pycache__/__init__.cpython-310.pyc b/audiocraft/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06bee6fe008abcee3a4cd5b08d7921462255fae8
Binary files /dev/null and b/audiocraft/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/audio.cpython-310.pyc b/audiocraft/data/__pycache__/audio.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88f43cfd25a4bb5a32f088a1808115a390f6d49c
Binary files /dev/null and b/audiocraft/data/__pycache__/audio.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/audio_dataset.cpython-310.pyc b/audiocraft/data/__pycache__/audio_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a474c00b076ba98ec794d0f4b41001ea8c32b45
Binary files /dev/null and b/audiocraft/data/__pycache__/audio_dataset.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/audio_utils.cpython-310.pyc b/audiocraft/data/__pycache__/audio_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8da8487a7cbda680dbd0d1feef5929dd401eb8ee
Binary files /dev/null and b/audiocraft/data/__pycache__/audio_utils.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/info_audio_dataset.cpython-310.pyc b/audiocraft/data/__pycache__/info_audio_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fdae11a46e6f69111d518fdf49c3f3d4cb2b0dd
Binary files /dev/null and b/audiocraft/data/__pycache__/info_audio_dataset.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/music_dataset.cpython-310.pyc b/audiocraft/data/__pycache__/music_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1d6739469e3168cc8af78d4c2115252855121dc
Binary files /dev/null and b/audiocraft/data/__pycache__/music_dataset.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/sound_dataset.cpython-310.pyc b/audiocraft/data/__pycache__/sound_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..001dc9c00c2ee0ec4f67827a6e40d9f25c3e7342
Binary files /dev/null and b/audiocraft/data/__pycache__/sound_dataset.cpython-310.pyc differ
diff --git a/audiocraft/data/__pycache__/zip.cpython-310.pyc b/audiocraft/data/__pycache__/zip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13b3b489b97cddad5b00efc0ffa1f519064ff944
Binary files /dev/null and b/audiocraft/data/__pycache__/zip.cpython-310.pyc differ
diff --git a/audiocraft/data/audio.py b/audiocraft/data/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c87047f5033d0016200df77004a9536e06e81a
--- /dev/null
+++ b/audiocraft/data/audio.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Audio IO methods are defined in this module (info, read, write),
+We rely on av library for faster read when possible, otherwise on torchaudio.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+import logging
+import typing as tp
+
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+import torchaudio as ta
+
+import av
+
+from .audio_utils import f32_pcm, i16_pcm, normalize_audio
+
+
+_av_initialized = False
+
+
+def _init_av():
+ global _av_initialized
+ if _av_initialized:
+ return
+ logger = logging.getLogger('libav.mp3')
+ logger.setLevel(logging.ERROR)
+ _av_initialized = True
+
+
+@dataclass(frozen=True)
+class AudioFileInfo:
+ sample_rate: int
+ duration: float
+ channels: int
+
+
+def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+ _init_av()
+ with av.open(str(filepath)) as af:
+ stream = af.streams.audio[0]
+ sample_rate = stream.codec_context.sample_rate
+ duration = float(stream.duration * stream.time_base)
+ channels = stream.channels
+ return AudioFileInfo(sample_rate, duration, channels)
+
+
+def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+ info = soundfile.info(filepath)
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
+
+
+def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
+ filepath = Path(filepath)
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
+ # ffmpeg has some weird issue with flac.
+ return _soundfile_info(filepath)
+ else:
+ return _av_info(filepath)
+
+
+def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
+ """FFMPEG-based audio file reading using PyAV bindings.
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
+
+ Args:
+ filepath (str or Path): Path to audio file to read.
+ seek_time (float): Time at which to start reading in the file.
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
+ Returns:
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
+ """
+ _init_av()
+ with av.open(str(filepath)) as af:
+ stream = af.streams.audio[0]
+ sr = stream.codec_context.sample_rate
+ num_frames = int(sr * duration) if duration >= 0 else -1
+ frame_offset = int(sr * seek_time)
+ # we need a small negative offset otherwise we get some edge artifact
+ # from the mp3 decoder.
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
+ frames = []
+ length = 0
+ for frame in af.decode(streams=stream.index):
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
+ strip = max(0, frame_offset - current_offset)
+ buf = torch.from_numpy(frame.to_ndarray())
+ if buf.shape[0] != stream.channels:
+ buf = buf.view(-1, stream.channels).t()
+ buf = buf[:, strip:]
+ frames.append(buf)
+ length += buf.shape[1]
+ if num_frames > 0 and length >= num_frames:
+ break
+ assert frames
+ # If the above assert fails, it is likely because we seeked past the end of file point,
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
+ # This will need proper debugging, in due time.
+ wav = torch.cat(frames, dim=1)
+ assert wav.shape[0] == stream.channels
+ if num_frames > 0:
+ wav = wav[:, :num_frames]
+ return f32_pcm(wav), sr
+
+
+def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
+ """Read audio by picking the most appropriate backend tool based on the audio format.
+
+ Args:
+ filepath (str or Path): Path to audio file to read.
+ seek_time (float): Time at which to start reading in the file.
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
+ pad (bool): Pad output audio if not reaching expected duration.
+ Returns:
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
+ """
+ fp = Path(filepath)
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
+ # There is some bug with ffmpeg and reading flac
+ info = _soundfile_info(filepath)
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
+ frame_offset = int(seek_time * info.sample_rate)
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
+ wav = torch.from_numpy(wav).t().contiguous()
+ if len(wav.shape) == 1:
+ wav = torch.unsqueeze(wav, 0)
+ elif (
+ fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
+ and duration <= 0 and seek_time == 0
+ ):
+ # Torchaudio is faster if we load an entire file at once.
+ wav, sr = ta.load(fp)
+ else:
+ wav, sr = _av_read(filepath, seek_time, duration)
+ if pad and duration > 0:
+ expected_frames = int(duration * sr)
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
+ return wav, sr
+
+
+def audio_write(stem_name: tp.Union[str, Path],
+ wav: torch.Tensor, sample_rate: int,
+ format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False,
+ log_clipping: bool = True, make_parent_dir: bool = True,
+ add_suffix: bool = True) -> Path:
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
+
+ Args:
+ stem_name (str or Path): Filename without extension which will be added automatically.
+ format (str): Either "wav" or "mp3".
+ mp3_rate (int): kbps when using mp3s.
+ normalize (bool): if `True` (default), normalizes according to the prescribed
+ strategy (see after). If `False`, the strategy is only used in case clipping
+ would happen.
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+ with extra headroom to avoid clipping. 'clip' just clips.
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+ than the `peak_clip` one to avoid further clipping.
+ loudness_headroom_db (float): Target loudness for loudness normalization.
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
+ occurs despite strategy (only for 'rms').
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
+ Returns:
+ Path: Path of the saved audio.
+ """
+ assert wav.dtype.is_floating_point, "wav is not floating point"
+ if wav.dim() == 1:
+ wav = wav[None]
+ elif wav.dim() > 2:
+ raise ValueError("Input wav should be at most 2 dimension.")
+ assert wav.isfinite().all()
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
+ log_clipping=log_clipping, sample_rate=sample_rate,
+ stem_name=str(stem_name))
+ kwargs: dict = {}
+ if format == 'mp3':
+ suffix = '.mp3'
+ kwargs.update({"compression": mp3_rate})
+ elif format == 'wav':
+ wav = i16_pcm(wav)
+ suffix = '.wav'
+ kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
+ else:
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
+ if not add_suffix:
+ suffix = ''
+ path = Path(str(stem_name) + suffix)
+ if make_parent_dir:
+ path.parent.mkdir(exist_ok=True, parents=True)
+ try:
+ ta.save(path, wav, sample_rate, **kwargs)
+ except Exception:
+ if path.exists():
+ # we do not want to leave half written files around.
+ path.unlink()
+ raise
+ return path
diff --git a/audiocraft/data/audio_dataset.py b/audiocraft/data/audio_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7442526186b3712f5d4754f928a40ecd964174
--- /dev/null
+++ b/audiocraft/data/audio_dataset.py
@@ -0,0 +1,587 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""AudioDataset support. In order to handle a larger number of files
+without having to scan again the folders, we precompute some metadata
+(filename, sample rate, duration), and use that to efficiently sample audio segments.
+"""
+import argparse
+import copy
+from concurrent.futures import ThreadPoolExecutor, Future
+from dataclasses import dataclass, fields
+from contextlib import ExitStack
+from functools import lru_cache
+import gzip
+import json
+import logging
+import os
+from pathlib import Path
+import random
+import sys
+import typing as tp
+
+import torch
+import torch.nn.functional as F
+
+from .audio import audio_read, audio_info
+from .audio_utils import convert_audio
+from .zip import PathInZip
+
+try:
+ import dora
+except ImportError:
+ dora = None # type: ignore
+
+
+@dataclass(order=True)
+class BaseInfo:
+
+ @classmethod
+ def _dict2fields(cls, dictionary: dict):
+ return {
+ field.name: dictionary[field.name]
+ for field in fields(cls) if field.name in dictionary
+ }
+
+ @classmethod
+ def from_dict(cls, dictionary: dict):
+ _dictionary = cls._dict2fields(dictionary)
+ return cls(**_dictionary)
+
+ def to_dict(self):
+ return {
+ field.name: self.__getattribute__(field.name)
+ for field in fields(self)
+ }
+
+
+@dataclass(order=True)
+class AudioMeta(BaseInfo):
+ path: str
+ duration: float
+ sample_rate: int
+ amplitude: tp.Optional[float] = None
+ weight: tp.Optional[float] = None
+ # info_path is used to load additional information about the audio file that is stored in zip files.
+ info_path: tp.Optional[PathInZip] = None
+
+ @classmethod
+ def from_dict(cls, dictionary: dict):
+ base = cls._dict2fields(dictionary)
+ if 'info_path' in base and base['info_path'] is not None:
+ base['info_path'] = PathInZip(base['info_path'])
+ return cls(**base)
+
+ def to_dict(self):
+ d = super().to_dict()
+ if d['info_path'] is not None:
+ d['info_path'] = str(d['info_path'])
+ return d
+
+
+@dataclass(order=True)
+class SegmentInfo(BaseInfo):
+ meta: AudioMeta
+ seek_time: float
+ # The following values are given once the audio is processed, e.g.
+ # at the target sample rate and target number of channels.
+ n_frames: int # actual number of frames without padding
+ total_frames: int # total number of frames, padding included
+ sample_rate: int # actual sample rate
+ channels: int # number of audio channels.
+
+
+DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
+
+logger = logging.getLogger(__name__)
+
+
+def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
+ """AudioMeta from a path to an audio file.
+
+ Args:
+ file_path (str): Resolved path of valid audio file.
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+ Returns:
+ AudioMeta: Audio file path and its metadata.
+ """
+ info = audio_info(file_path)
+ amplitude: tp.Optional[float] = None
+ if not minimal:
+ wav, sr = audio_read(file_path)
+ amplitude = wav.abs().max().item()
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
+
+
+def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
+ """If Dora is available as a dependency, try to resolve potential relative paths
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
+
+ Args:
+ m (AudioMeta): Audio meta to resolve.
+ fast (bool): If True, uses a really fast check for determining if a file
+ is already absolute or not. Only valid on Linux/Mac.
+ Returns:
+ AudioMeta: Audio meta with resolved path.
+ """
+ def is_abs(m):
+ if fast:
+ return str(m)[0] == '/'
+ else:
+ os.path.isabs(str(m))
+
+ if not dora:
+ return m
+
+ if not is_abs(m.path):
+ m.path = dora.git_save.to_absolute_path(m.path)
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
+ return m
+
+
+def find_audio_files(path: tp.Union[Path, str],
+ exts: tp.List[str] = DEFAULT_EXTS,
+ resolve: bool = True,
+ minimal: bool = True,
+ progress: bool = False,
+ workers: int = 0) -> tp.List[AudioMeta]:
+ """Build a list of AudioMeta from a given path,
+ collecting relevant audio files and fetching meta info.
+
+ Args:
+ path (str or Path): Path to folder containing audio files.
+ exts (list of str): List of file extensions to consider for audio files.
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+ progress (bool): Whether to log progress on audio files collection.
+ workers (int): number of parallel workers, if 0, use only the current thread.
+ Returns:
+ list of AudioMeta: List of audio file path and its metadata.
+ """
+ audio_files = []
+ futures: tp.List[Future] = []
+ pool: tp.Optional[ThreadPoolExecutor] = None
+ with ExitStack() as stack:
+ if workers > 0:
+ pool = ThreadPoolExecutor(workers)
+ stack.enter_context(pool)
+
+ if progress:
+ print("Finding audio files...")
+ for root, folders, files in os.walk(path, followlinks=True):
+ for file in files:
+ full_path = Path(root) / file
+ if full_path.suffix.lower() in exts:
+ audio_files.append(full_path)
+ if pool is not None:
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
+ if progress:
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
+
+ if progress:
+ print("Getting audio metadata...")
+ meta: tp.List[AudioMeta] = []
+ for idx, file_path in enumerate(audio_files):
+ try:
+ if pool is None:
+ m = _get_audio_meta(str(file_path), minimal)
+ else:
+ m = futures[idx].result()
+ if resolve:
+ m = _resolve_audio_meta(m)
+ except Exception as err:
+ print("Error with", str(file_path), err, file=sys.stderr)
+ continue
+ meta.append(m)
+ if progress:
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
+ meta.sort()
+ return meta
+
+
+def load_audio_meta(path: tp.Union[str, Path],
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
+ """Load list of AudioMeta from an optionally compressed json file.
+
+ Args:
+ path (str or Path): Path to JSON file.
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
+ fast (bool): activates some tricks to make things faster.
+ Returns:
+ list of AudioMeta: List of audio file path and its total duration.
+ """
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+ with open_fn(path, 'rb') as fp: # type: ignore
+ lines = fp.readlines()
+ meta = []
+ for line in lines:
+ d = json.loads(line)
+ m = AudioMeta.from_dict(d)
+ if resolve:
+ m = _resolve_audio_meta(m, fast=fast)
+ meta.append(m)
+ return meta
+
+
+def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
+ """Save the audio metadata to the file pointer as json.
+
+ Args:
+ path (str or Path): Path to JSON file.
+ metadata (list of BaseAudioMeta): List of audio meta to save.
+ """
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+ with open_fn(path, 'wb') as fp: # type: ignore
+ for m in meta:
+ json_str = json.dumps(m.to_dict()) + '\n'
+ json_bytes = json_str.encode('utf-8')
+ fp.write(json_bytes)
+
+
+class AudioDataset:
+ """Base audio dataset.
+
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
+ and potentially additional information, by creating random segments from the list of audio
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
+ mixing of channels, padding, etc.
+
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
+ duration, applying padding if required.
+
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
+ original audio meta.
+
+ Note that you can call `start_epoch(epoch)` in order to get
+ a deterministic "randomization" for `shuffle=True`.
+ For a given epoch and dataset index, this will always return the same extract.
+ You can get back some diversity by setting the `shuffle_seed` param.
+
+ Args:
+ meta (list of AudioMeta): List of audio files metadata.
+ segment_duration (float, optional): Optional segment duration of audio to load.
+ If not specified, the dataset will load the full audio segment from the file.
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
+ sample_rate (int): Target sample rate of the loaded audio samples.
+ channels (int): Target number of channels of the loaded audio samples.
+ sample_on_duration (bool): Set to `True` to sample segments with probability
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
+ is shorter than the desired segment.
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
+ audio shorter than this will be filtered out.
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
+ audio longer than this will be filtered out.
+ shuffle_seed (int): can be used to further randomize
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
+ with the expected segment_duration (which must be provided if load_wav is False).
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
+ are False. Will ensure a permutation on files when going through the dataset.
+ In that case the epoch number must be provided in order for the model
+ to continue the permutation across epochs. In that case, it is assumed
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
+ `total_batch_size` the overall batch size accounting for all gpus.
+ """
+ def __init__(self,
+ meta: tp.List[AudioMeta],
+ segment_duration: tp.Optional[float] = None,
+ shuffle: bool = True,
+ num_samples: int = 10_000,
+ sample_rate: int = 48_000,
+ channels: int = 2,
+ pad: bool = True,
+ sample_on_duration: bool = True,
+ sample_on_weight: bool = True,
+ min_segment_ratio: float = 0.5,
+ max_read_retry: int = 10,
+ return_info: bool = False,
+ min_audio_duration: tp.Optional[float] = None,
+ max_audio_duration: tp.Optional[float] = None,
+ shuffle_seed: int = 0,
+ load_wav: bool = True,
+ permutation_on_files: bool = False,
+ ):
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
+ assert segment_duration is None or segment_duration > 0
+ assert segment_duration is None or min_segment_ratio >= 0
+ self.segment_duration = segment_duration
+ self.min_segment_ratio = min_segment_ratio
+ self.max_audio_duration = max_audio_duration
+ self.min_audio_duration = min_audio_duration
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
+ assert self.min_audio_duration <= self.max_audio_duration
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
+ assert len(self.meta) # Fail fast if all data has been filtered.
+ self.total_duration = sum(d.duration for d in self.meta)
+
+ if segment_duration is None:
+ num_samples = len(self.meta)
+ self.num_samples = num_samples
+ self.shuffle = shuffle
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.pad = pad
+ self.sample_on_weight = sample_on_weight
+ self.sample_on_duration = sample_on_duration
+ self.sampling_probabilities = self._get_sampling_probabilities()
+ self.max_read_retry = max_read_retry
+ self.return_info = return_info
+ self.shuffle_seed = shuffle_seed
+ self.current_epoch: tp.Optional[int] = None
+ self.load_wav = load_wav
+ if not load_wav:
+ assert segment_duration is not None
+ self.permutation_on_files = permutation_on_files
+ if permutation_on_files:
+ assert not self.sample_on_duration
+ assert not self.sample_on_weight
+ assert self.shuffle
+
+ def start_epoch(self, epoch: int):
+ self.current_epoch = epoch
+
+ def __len__(self):
+ return self.num_samples
+
+ def _get_sampling_probabilities(self, normalized: bool = True):
+ """Return the sampling probabilities for each file inside `self.meta`."""
+ scores: tp.List[float] = []
+ for file_meta in self.meta:
+ score = 1.
+ if self.sample_on_weight and file_meta.weight is not None:
+ score *= file_meta.weight
+ if self.sample_on_duration:
+ score *= file_meta.duration
+ scores.append(score)
+ probabilities = torch.tensor(scores)
+ if normalized:
+ probabilities /= probabilities.sum()
+ return probabilities
+
+ @staticmethod
+ @lru_cache(16)
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
+ # Used to keep the most recent files permutation in memory implicitely.
+ # will work unless someone is using a lot of Datasets in parallel.
+ rng = torch.Generator()
+ rng.manual_seed(base_seed + permutation_index)
+ return torch.randperm(num_files, generator=rng)
+
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
+ This is only called if `segment_duration` is not None.
+
+ You must use the provided random number generator `rng` for reproducibility.
+ You can further make use of the index accessed.
+ """
+ if self.permutation_on_files:
+ assert self.current_epoch is not None
+ total_index = self.current_epoch * len(self) + index
+ permutation_index = total_index // len(self.meta)
+ relative_index = total_index % len(self.meta)
+ permutation = AudioDataset._get_file_permutation(
+ len(self.meta), permutation_index, self.shuffle_seed)
+ file_index = permutation[relative_index]
+ return self.meta[file_index]
+
+ if not self.sample_on_weight and not self.sample_on_duration:
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+ else:
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+ return self.meta[file_index]
+
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
+ # Override this method in subclass if needed.
+ if self.load_wav:
+ return audio_read(path, seek_time, duration, pad=False)
+ else:
+ assert self.segment_duration is not None
+ n_frames = int(self.sample_rate * self.segment_duration)
+ return torch.zeros(self.channels, n_frames), self.sample_rate
+
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
+ if self.segment_duration is None:
+ file_meta = self.meta[index]
+ out, sr = audio_read(file_meta.path)
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
+ n_frames = out.shape[-1]
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
+ sample_rate=self.sample_rate, channels=out.shape[0])
+ else:
+ rng = torch.Generator()
+ if self.shuffle:
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
+ # otherwise we make use of the epoch number and optional shuffle_seed.
+ if self.current_epoch is None:
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
+ else:
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
+ else:
+ # We only use index
+ rng.manual_seed(index)
+
+ for retry in range(self.max_read_retry):
+ file_meta = self.sample_file(index, rng)
+ # We add some variance in the file position even if audio file is smaller than segment
+ # without ending up with empty segments
+ max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
+ seek_time = torch.rand(1, generator=rng).item() * max_seek
+ try:
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
+ n_frames = out.shape[-1]
+ target_frames = int(self.segment_duration * self.sample_rate)
+ if self.pad:
+ out = F.pad(out, (0, target_frames - n_frames))
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
+ sample_rate=self.sample_rate, channels=out.shape[0])
+ except Exception as exc:
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
+ if retry == self.max_read_retry - 1:
+ raise
+ else:
+ break
+
+ if self.return_info:
+ # Returns the wav and additional information on the wave segment
+ return out, segment_info
+ else:
+ return out
+
+ def collater(self, samples):
+ """The collater function has to be provided to the dataloader
+ if AudioDataset has return_info=True in order to properly collate
+ the samples of a batch.
+ """
+ if self.segment_duration is None and len(samples) > 1:
+ assert self.pad, "Must allow padding when batching examples of different durations."
+
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
+ to_pad = self.segment_duration is None and self.pad
+ if to_pad:
+ max_len = max([wav.shape[-1] for wav, _ in samples])
+
+ def _pad_wav(wav):
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+ if self.return_info:
+ if len(samples) > 0:
+ assert len(samples[0]) == 2
+ assert isinstance(samples[0][0], torch.Tensor)
+ assert isinstance(samples[0][1], SegmentInfo)
+
+ wavs = [wav for wav, _ in samples]
+ segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+ if to_pad:
+ # Each wav could be of a different duration as they are not segmented.
+ for i in range(len(samples)):
+ # Determines the total length of the signal with padding, so we update here as we pad.
+ segment_infos[i].total_frames = max_len
+ wavs[i] = _pad_wav(wavs[i])
+
+ wav = torch.stack(wavs)
+ return wav, segment_infos
+ else:
+ assert isinstance(samples[0], torch.Tensor)
+ if to_pad:
+ samples = [_pad_wav(s) for s in samples]
+ return torch.stack(samples)
+
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
+ orig_len = len(meta)
+
+ # Filter data that is too short.
+ if self.min_audio_duration is not None:
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
+
+ # Filter data that is too long.
+ if self.max_audio_duration is not None:
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
+
+ filtered_len = len(meta)
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
+ if removed_percentage < 10:
+ logging.debug(msg)
+ else:
+ logging.warning(msg)
+ return meta
+
+ @classmethod
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+ Args:
+ root (str or Path): Path to root folder containing audio files.
+ kwargs: Additional keyword arguments for the AudioDataset.
+ """
+ root = Path(root)
+ if root.is_dir():
+ if (root / 'data.jsonl').exists():
+ root = root / 'data.jsonl'
+ elif (root / 'data.jsonl.gz').exists():
+ root = root / 'data.jsonl.gz'
+ else:
+ raise ValueError("Don't know where to read metadata from in the dir. "
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+ meta = load_audio_meta(root)
+ return cls(meta, **kwargs)
+
+ @classmethod
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+ Args:
+ root (str or Path): Path to root folder containing audio files.
+ minimal_meta (bool): Whether to only load minimal metadata or not.
+ exts (list of str): Extensions for audio files.
+ kwargs: Additional keyword arguments for the AudioDataset.
+ """
+ root = Path(root)
+ if root.is_file():
+ meta = load_audio_meta(root, resolve=True)
+ else:
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+ return cls(meta, **kwargs)
+
+
+def main():
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+ parser = argparse.ArgumentParser(
+ prog='audio_dataset',
+ description='Generate .jsonl files by scanning a folder.')
+ parser.add_argument('root', help='Root folder with all the audio files')
+ parser.add_argument('output_meta_file',
+ help='Output file to store the metadata, ')
+ parser.add_argument('--complete',
+ action='store_false', dest='minimal', default=True,
+ help='Retrieve all metadata, even the one that are expansive '
+ 'to compute (e.g. normalization).')
+ parser.add_argument('--resolve',
+ action='store_true', default=False,
+ help='Resolve the paths to be absolute and with no symlinks.')
+ parser.add_argument('--workers',
+ default=10, type=int,
+ help='Number of workers.')
+ args = parser.parse_args()
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
+ save_audio_meta(args.output_meta_file, meta)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/audiocraft/data/audio_utils.py b/audiocraft/data/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..565b63a4ef78dcd802dda932b42ebe518ffe7397
--- /dev/null
+++ b/audiocraft/data/audio_utils.py
@@ -0,0 +1,177 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Various utilities for audio convertion (pcm format, sample rate and channels),
+and volume normalization."""
+import sys
+import typing as tp
+
+import julius
+import torch
+import torchaudio
+
+
+def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
+ """Convert audio to the given number of channels.
+
+ Args:
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
+ channels (int): Expected number of channels as output.
+ Returns:
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
+ """
+ *shape, src_channels, length = wav.shape
+ if src_channels == channels:
+ pass
+ elif channels == 1:
+ # Case 1:
+ # The caller asked 1-channel audio, and the stream has multiple
+ # channels, downmix all channels.
+ wav = wav.mean(dim=-2, keepdim=True)
+ elif src_channels == 1:
+ # Case 2:
+ # The caller asked for multiple channels, but the input file has
+ # a single channel, replicate the audio over all channels.
+ wav = wav.expand(*shape, channels, length)
+ elif src_channels >= channels:
+ # Case 3:
+ # The caller asked for multiple channels, and the input file has
+ # more channels than requested. In that case return the first channels.
+ wav = wav[..., :channels, :]
+ else:
+ # Case 4: What is a reasonable choice here?
+ raise ValueError('The audio file has less channels than requested but is not mono.')
+ return wav
+
+
+def convert_audio(wav: torch.Tensor, from_rate: float,
+ to_rate: float, to_channels: int) -> torch.Tensor:
+ """Convert audio to new sample rate and number of audio channels."""
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
+ wav = convert_audio_channels(wav, to_channels)
+ return wav
+
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
+ """Normalize an input signal to a user loudness in dB LKFS.
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+ Args:
+ wav (torch.Tensor): Input multichannel audio data.
+ sample_rate (int): Sample rate.
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+ loudness_compressor (bool): Uses tanh for soft clipping.
+ energy_floor (float): anything below that RMS level will not be rescaled.
+ Returns:
+ torch.Tensor: Loudness normalized output data.
+ """
+ energy = wav.pow(2).mean().sqrt().item()
+ if energy < energy_floor:
+ return wav
+ transform = torchaudio.transforms.Loudness(sample_rate)
+ input_loudness_db = transform(wav).item()
+ # calculate the gain needed to scale to the desired loudness level
+ delta_loudness = -loudness_headroom_db - input_loudness_db
+ gain = 10.0 ** (delta_loudness / 20.0)
+ output = gain * wav
+ if loudness_compressor:
+ output = torch.tanh(output)
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+ return output
+
+
+def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
+ """Utility function to clip the audio with logging if specified."""
+ max_scale = wav.abs().max()
+ if log_clipping and max_scale > 1:
+ clamp_prob = (wav.abs() > 1).float().mean().item()
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
+ #wav.clamp_(-1, 1)
+ wav = wav.clone().clamp_(-1, 1)
+
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False, log_clipping: bool = False,
+ sample_rate: tp.Optional[int] = None,
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
+ """Normalize the audio according to the prescribed strategy (see after).
+
+ Args:
+ wav (torch.Tensor): Audio data.
+ normalize (bool): if `True` (default), normalizes according to the prescribed
+ strategy (see after). If `False`, the strategy is only used in case clipping
+ would happen.
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+ with extra headroom to avoid clipping. 'clip' just clips.
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+ than the `peak_clip` one to avoid further clipping.
+ loudness_headroom_db (float): Target loudness for loudness normalization.
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
+ log_clipping (bool): If True, basic logging on stderr when clipping still
+ occurs despite strategy (only for 'rms').
+ sample_rate (int): Sample rate for the audio data (required for loudness).
+ stem_name (str, optional): Stem name for clipping logging.
+ Returns:
+ torch.Tensor: Normalized audio.
+ """
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+ scale_rms = 10 ** (-rms_headroom_db / 20)
+ if strategy == 'peak':
+ rescaling = (scale_peak / wav.abs().max())
+ if normalize or rescaling < 1:
+ wav = wav * rescaling
+ elif strategy == 'clip':
+ wav = wav.clamp(-scale_peak, scale_peak)
+ elif strategy == 'rms':
+ mono = wav.mean(dim=0)
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
+ if normalize or rescaling < 1:
+ wav = wav * rescaling
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+ elif strategy == 'loudness':
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+ else:
+ assert wav.abs().max() < 1
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+ return wav
+
+
+def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+ """Convert audio to float 32 bits PCM format.
+ """
+ if wav.dtype.is_floating_point:
+ return wav
+ elif wav.dtype == torch.int16:
+ return wav.float() / 2**15
+ elif wav.dtype == torch.int32:
+ return wav.float() / 2**31
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
+
+
+def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+ """Convert audio to int 16 bits PCM format.
+
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
+ """
+ if wav.dtype.is_floating_point:
+ assert wav.abs().max() <= 1
+ candidate = (wav * 2 ** 15).round()
+ if candidate.max() >= 2 ** 15: # clipping would occur
+ candidate = (wav * (2 ** 15 - 1)).round()
+ return candidate.short()
+ else:
+ assert wav.dtype == torch.int16
+ return wav
diff --git a/audiocraft/data/info_audio_dataset.py b/audiocraft/data/info_audio_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..47ab4b1594faf1e9f1ce962fb980d80295b1f079
--- /dev/null
+++ b/audiocraft/data/info_audio_dataset.py
@@ -0,0 +1,110 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Base classes for the datasets that also provide non-audio metadata,
+e.g. description, text transcription etc.
+"""
+from dataclasses import dataclass
+import logging
+import math
+import re
+import typing as tp
+
+import torch
+
+from .audio_dataset import AudioDataset, AudioMeta
+from ..environment import AudioCraftEnvironment
+from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
+
+
+logger = logging.getLogger(__name__)
+
+
+def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
+ """Monkey-patch meta to match cluster specificities."""
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
+ if meta.info_path is not None:
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
+ return meta
+
+
+def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+ """Monkey-patch all meta to match cluster specificities."""
+ return [_clusterify_meta(m) for m in meta]
+
+
+@dataclass
+class AudioInfo(SegmentWithAttributes):
+ """Dummy SegmentInfo with empty attributes.
+
+ The InfoAudioDataset is expected to return metadata that inherits
+ from SegmentWithAttributes class and can return conditioning attributes.
+
+ This basically guarantees all datasets will be compatible with current
+ solver that contain conditioners requiring this.
+ """
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
+
+ def to_condition_attributes(self) -> ConditioningAttributes:
+ return ConditioningAttributes()
+
+
+class InfoAudioDataset(AudioDataset):
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
+
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
+ """
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
+ super().__init__(clusterify_all_meta(meta), **kwargs)
+
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
+ if not self.return_info:
+ wav = super().__getitem__(index)
+ assert isinstance(wav, torch.Tensor)
+ return wav
+ wav, meta = super().__getitem__(index)
+ return wav, AudioInfo(**meta.to_dict())
+
+
+def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
+ """Preprocess a single keyword or possible a list of keywords."""
+ if isinstance(value, list):
+ return get_keyword_list(value)
+ else:
+ return get_keyword(value)
+
+
+def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
+ """Preprocess a single keyword."""
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+ return None
+ else:
+ return value.strip()
+
+
+def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
+ """Preprocess a single keyword."""
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+ return None
+ else:
+ return value.strip().lower()
+
+
+def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
+ """Preprocess a list of keywords."""
+ if isinstance(values, str):
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
+ elif isinstance(values, float) and math.isnan(values):
+ values = []
+ if not isinstance(values, list):
+ logger.debug(f"Unexpected keyword list {values}")
+ values = [str(values)]
+
+ kws = [get_keyword(v) for v in values]
+ kw_list = [k for k in kws if k is not None]
+ if len(kw_list) == 0:
+ return None
+ else:
+ return kw_list
diff --git a/audiocraft/data/music_dataset.py b/audiocraft/data/music_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e28796939f9cde2b23a2c4bf43fd7ba5fa26b2d
--- /dev/null
+++ b/audiocraft/data/music_dataset.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Dataset of music tracks with rich metadata.
+"""
+from dataclasses import dataclass, field, fields, replace
+import gzip
+import json
+import logging
+from pathlib import Path
+import random
+import typing as tp
+
+import torch
+
+from .info_audio_dataset import (
+ InfoAudioDataset,
+ AudioInfo,
+ get_keyword_list,
+ get_keyword,
+ get_string
+)
+from ..modules.conditioners import (
+ ConditioningAttributes,
+ JointEmbedCondition,
+ WavCondition,
+)
+from ..utils.utils import warn_once
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class MusicInfo(AudioInfo):
+ """Segment info augmented with music metadata.
+ """
+ # music-specific metadata
+ title: tp.Optional[str] = None
+ artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
+ key: tp.Optional[str] = None
+ bpm: tp.Optional[float] = None
+ genre: tp.Optional[str] = None
+ moods: tp.Optional[list] = None
+ keywords: tp.Optional[list] = None
+ description: tp.Optional[str] = None
+ name: tp.Optional[str] = None
+ instrument: tp.Optional[str] = None
+ # original wav accompanying the metadata
+ self_wav: tp.Optional[WavCondition] = None
+ # dict mapping attributes names to tuple of wav, text and metadata
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
+
+ @property
+ def has_music_meta(self) -> bool:
+ return self.name is not None
+
+ def to_condition_attributes(self) -> ConditioningAttributes:
+ out = ConditioningAttributes()
+ for _field in fields(self):
+ key, value = _field.name, getattr(self, _field.name)
+ if key == 'self_wav':
+ out.wav[key] = value
+ elif key == 'joint_embed':
+ for embed_attribute, embed_cond in value.items():
+ out.joint_embed[embed_attribute] = embed_cond
+ else:
+ if isinstance(value, list):
+ value = ' '.join(value)
+ out.text[key] = value
+ return out
+
+ @staticmethod
+ def attribute_getter(attribute):
+ if attribute == 'bpm':
+ preprocess_func = get_bpm
+ elif attribute == 'key':
+ preprocess_func = get_musical_key
+ elif attribute in ['moods', 'keywords']:
+ preprocess_func = get_keyword_list
+ elif attribute in ['genre', 'name', 'instrument']:
+ preprocess_func = get_keyword
+ elif attribute in ['title', 'artist', 'description']:
+ preprocess_func = get_string
+ else:
+ preprocess_func = None
+ return preprocess_func
+
+ @classmethod
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
+ _dictionary: tp.Dict[str, tp.Any] = {}
+
+ # allow a subset of attributes to not be loaded from the dictionary
+ # these attributes may be populated later
+ post_init_attributes = ['self_wav', 'joint_embed']
+ optional_fields = ['keywords']
+
+ for _field in fields(cls):
+ if _field.name in post_init_attributes:
+ continue
+ elif _field.name not in dictionary:
+ if fields_required and _field.name not in optional_fields:
+ raise KeyError(f"Unexpected missing key: {_field.name}")
+ else:
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+ value = dictionary[_field.name]
+ if preprocess_func:
+ value = preprocess_func(value)
+ _dictionary[_field.name] = value
+ return cls(**_dictionary)
+
+
+def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
+ drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
+ """Augment MusicInfo description with additional metadata fields and potential dropout.
+ Additional textual attributes are added given probability 'merge_text_conditions_p' and
+ the original textual description is dropped from the augmented description given probability drop_desc_p.
+
+ Args:
+ music_info (MusicInfo): The music metadata to augment.
+ merge_text_p (float): Probability of merging additional metadata to the description.
+ If provided value is 0, then no merging is performed.
+ drop_desc_p (float): Probability of dropping the original description on text merge.
+ if provided value is 0, then no drop out is performed.
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
+ Returns:
+ MusicInfo: The MusicInfo with augmented textual description.
+ """
+ def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
+ valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
+ valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
+ keep_field = random.uniform(0, 1) < drop_other_p
+ return valid_field_name and valid_field_value and keep_field
+
+ def process_value(v: tp.Any) -> str:
+ if isinstance(v, (int, float, str)):
+ return str(v)
+ if isinstance(v, list):
+ return ", ".join(v)
+ else:
+ raise ValueError(f"Unknown type for text value! ({type(v), v})")
+
+ description = music_info.description
+
+ metadata_text = ""
+ if random.uniform(0, 1) < merge_text_p:
+ meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
+ for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
+ random.shuffle(meta_pairs)
+ metadata_text = ". ".join(meta_pairs)
+ description = description if not random.uniform(0, 1) < drop_desc_p else None
+ logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
+
+ if description is None:
+ description = metadata_text if len(metadata_text) > 1 else None
+ else:
+ description = ". ".join([description.rstrip('.'), metadata_text])
+ description = description.strip() if description else None
+
+ music_info = replace(music_info)
+ music_info.description = description
+ return music_info
+
+
+class Paraphraser:
+ def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
+ self.paraphrase_p = paraphrase_p
+ open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
+ with open_fn(paraphrase_source, 'rb') as f: # type: ignore
+ self.paraphrase_source = json.loads(f.read())
+ logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
+
+ def sample_paraphrase(self, audio_path: str, description: str):
+ if random.random() >= self.paraphrase_p:
+ return description
+ info_path = Path(audio_path).with_suffix('.json')
+ if info_path not in self.paraphrase_source:
+ warn_once(logger, f"{info_path} not in paraphrase source!")
+ return description
+ new_desc = random.choice(self.paraphrase_source[info_path])
+ logger.debug(f"{description} -> {new_desc}")
+ return new_desc
+
+
+class MusicDataset(InfoAudioDataset):
+ """Music dataset is an AudioDataset with music-related metadata.
+
+ Args:
+ info_fields_required (bool): Whether to enforce having required fields.
+ merge_text_p (float): Probability of merging additional metadata to the description.
+ drop_desc_p (float): Probability of dropping the original description on text merge.
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
+ joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
+ paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
+ paraphrases for the description. The json should be a dict with keys are the
+ original info path (e.g. track_path.json) and each value is a list of possible
+ paraphrased.
+ paraphrase_p (float): probability of taking a paraphrase.
+
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
+ """
+ def __init__(self, *args, info_fields_required: bool = True,
+ merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
+ joint_embed_attributes: tp.List[str] = [],
+ paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
+ **kwargs):
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
+ super().__init__(*args, **kwargs)
+ self.info_fields_required = info_fields_required
+ self.merge_text_p = merge_text_p
+ self.drop_desc_p = drop_desc_p
+ self.drop_other_p = drop_other_p
+ self.joint_embed_attributes = joint_embed_attributes
+ self.paraphraser = None
+ if paraphrase_source is not None:
+ self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
+
+ def __getitem__(self, index):
+ wav, info = super().__getitem__(index)
+ info_data = info.to_dict()
+ music_info_path = Path(info.meta.path).with_suffix('.json')
+
+ if Path(music_info_path).exists():
+ with open(music_info_path, 'r') as json_file:
+ music_data = json.load(json_file)
+ music_data.update(info_data)
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
+ if self.paraphraser is not None:
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
+ if self.merge_text_p:
+ music_info = augment_music_info_description(
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
+ else:
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
+
+ music_info.self_wav = WavCondition(
+ wav=wav[None], length=torch.tensor([info.n_frames]),
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+
+ for att in self.joint_embed_attributes:
+ att_value = getattr(music_info, att)
+ joint_embed_cond = JointEmbedCondition(
+ wav[None], [att_value], torch.tensor([info.n_frames]),
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+ music_info.joint_embed[att] = joint_embed_cond
+
+ return wav, music_info
+
+
+def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
+ """Preprocess key keywords, discarding them if there are multiple key defined."""
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+ return None
+ elif ',' in value:
+ # For now, we discard when multiple keys are defined separated with comas
+ return None
+ else:
+ return value.strip().lower()
+
+
+def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
+ """Preprocess to a float."""
+ if value is None:
+ return None
+ try:
+ return float(value)
+ except ValueError:
+ return None
diff --git a/audiocraft/data/sound_dataset.py b/audiocraft/data/sound_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b88cbe8016b4bd28c2de749177c9af29f7755fc
--- /dev/null
+++ b/audiocraft/data/sound_dataset.py
@@ -0,0 +1,330 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Dataset of audio with a simple description.
+"""
+
+from dataclasses import dataclass, fields, replace
+import json
+from pathlib import Path
+import random
+import typing as tp
+
+import numpy as np
+import torch
+
+from .info_audio_dataset import (
+ InfoAudioDataset,
+ get_keyword_or_keyword_list
+)
+from ..modules.conditioners import (
+ ConditioningAttributes,
+ SegmentWithAttributes,
+ WavCondition,
+)
+
+
+EPS = torch.finfo(torch.float32).eps
+TARGET_LEVEL_LOWER = -35
+TARGET_LEVEL_UPPER = -15
+
+
+@dataclass
+class SoundInfo(SegmentWithAttributes):
+ """Segment info augmented with Sound metadata.
+ """
+ description: tp.Optional[str] = None
+ self_wav: tp.Optional[torch.Tensor] = None
+
+ @property
+ def has_sound_meta(self) -> bool:
+ return self.description is not None
+
+ def to_condition_attributes(self) -> ConditioningAttributes:
+ out = ConditioningAttributes()
+
+ for _field in fields(self):
+ key, value = _field.name, getattr(self, _field.name)
+ if key == 'self_wav':
+ out.wav[key] = value
+ else:
+ out.text[key] = value
+ return out
+
+ @staticmethod
+ def attribute_getter(attribute):
+ if attribute == 'description':
+ preprocess_func = get_keyword_or_keyword_list
+ else:
+ preprocess_func = None
+ return preprocess_func
+
+ @classmethod
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
+ _dictionary: tp.Dict[str, tp.Any] = {}
+
+ # allow a subset of attributes to not be loaded from the dictionary
+ # these attributes may be populated later
+ post_init_attributes = ['self_wav']
+
+ for _field in fields(cls):
+ if _field.name in post_init_attributes:
+ continue
+ elif _field.name not in dictionary:
+ if fields_required:
+ raise KeyError(f"Unexpected missing key: {_field.name}")
+ else:
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+ value = dictionary[_field.name]
+ if preprocess_func:
+ value = preprocess_func(value)
+ _dictionary[_field.name] = value
+ return cls(**_dictionary)
+
+
+class SoundDataset(InfoAudioDataset):
+ """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
+
+ Args:
+ info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
+ external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
+ The metadata files contained in this folder are expected to match the stem of the audio file with
+ a json extension.
+ aug_p (float): Probability of performing audio mixing augmentation on the batch.
+ mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
+ mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
+ mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
+ mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
+ kwargs: Additional arguments for AudioDataset.
+
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
+ """
+ def __init__(
+ self,
+ *args,
+ info_fields_required: bool = True,
+ external_metadata_source: tp.Optional[str] = None,
+ aug_p: float = 0.,
+ mix_p: float = 0.,
+ mix_snr_low: int = -5,
+ mix_snr_high: int = 5,
+ mix_min_overlap: float = 0.5,
+ **kwargs
+ ):
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
+ super().__init__(*args, **kwargs)
+ self.info_fields_required = info_fields_required
+ self.external_metadata_source = external_metadata_source
+ self.aug_p = aug_p
+ self.mix_p = mix_p
+ if self.aug_p > 0:
+ assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
+ assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
+ self.mix_snr_low = mix_snr_low
+ self.mix_snr_high = mix_snr_high
+ self.mix_min_overlap = mix_min_overlap
+
+ def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
+ """Get path of JSON with metadata (description, etc.).
+ If there exists a JSON with the same name as 'path.name', then it will be used.
+ Else, such JSON will be searched for in an external json source folder if it exists.
+ """
+ info_path = Path(path).with_suffix('.json')
+ if Path(info_path).exists():
+ return info_path
+ elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
+ return Path(self.external_metadata_source) / info_path.name
+ else:
+ raise Exception(f"Unable to find a metadata JSON for path: {path}")
+
+ def __getitem__(self, index):
+ wav, info = super().__getitem__(index)
+ info_data = info.to_dict()
+ info_path = self._get_info_path(info.meta.path)
+ if Path(info_path).exists():
+ with open(info_path, 'r') as json_file:
+ sound_data = json.load(json_file)
+ sound_data.update(info_data)
+ sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
+ # if there are multiple descriptions, sample one randomly
+ if isinstance(sound_info.description, list):
+ sound_info.description = random.choice(sound_info.description)
+ else:
+ sound_info = SoundInfo.from_dict(info_data, fields_required=False)
+
+ sound_info.self_wav = WavCondition(
+ wav=wav[None], length=torch.tensor([info.n_frames]),
+ sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+
+ return wav, sound_info
+
+ def collater(self, samples):
+ # when training, audio mixing is performed in the collate function
+ wav, sound_info = super().collater(samples) # SoundDataset always returns infos
+ if self.aug_p > 0:
+ wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
+ snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
+ min_overlap=self.mix_min_overlap)
+ return wav, sound_info
+
+
+def rms_f(x: torch.Tensor) -> torch.Tensor:
+ return (x ** 2).mean(1).pow(0.5)
+
+
+def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
+ """Normalize the signal to the target level."""
+ rms = rms_f(audio)
+ scalar = 10 ** (target_level / 20) / (rms + EPS)
+ audio = audio * scalar.unsqueeze(1)
+ return audio
+
+
+def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
+ return (abs(audio) > clipping_threshold).any(1)
+
+
+def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
+ start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
+ remainder = src.shape[1] - start
+ if dst.shape[1] > remainder:
+ src[:, start:] = src[:, start:] + dst[:, :remainder]
+ else:
+ src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
+ return src
+
+
+def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
+ target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
+ """Function to mix clean speech and noise at various SNR levels.
+
+ Args:
+ clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
+ noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
+ snr (int): SNR level when mixing.
+ min_overlap (float): Minimum overlap between the two mixed sources.
+ target_level (int): Gain level in dB.
+ clipping_threshold (float): Threshold for clipping the audio.
+ Returns:
+ torch.Tensor: The mixed audio, of shape [B, T].
+ """
+ if clean.shape[1] > noise.shape[1]:
+ noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
+ else:
+ noise = noise[:, :clean.shape[1]]
+
+ # normalizing to -25 dB FS
+ clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
+ clean = normalize(clean, target_level)
+ rmsclean = rms_f(clean)
+
+ noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
+ noise = normalize(noise, target_level)
+ rmsnoise = rms_f(noise)
+
+ # set the noise level for a given SNR
+ noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
+ noisenewlevel = noise * noisescalar
+
+ # mix noise and clean speech
+ noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
+
+ # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
+ # there is a chance of clipping that might happen with very less probability, which is not a major issue.
+ noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
+ rmsnoisy = rms_f(noisyspeech)
+ scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
+ noisyspeech = noisyspeech * scalarnoisy
+ clean = clean * scalarnoisy
+ noisenewlevel = noisenewlevel * scalarnoisy
+
+ # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
+ clipped = is_clipped(noisyspeech)
+ if clipped.any():
+ noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
+ noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
+
+ return noisyspeech
+
+
+def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
+ if snr_low == snr_high:
+ snr = snr_low
+ else:
+ snr = np.random.randint(snr_low, snr_high)
+ mix = snr_mixer(src, dst, snr, min_overlap)
+ return mix
+
+
+def mix_text(src_text: str, dst_text: str):
+ """Mix text from different sources by concatenating them."""
+ if src_text == dst_text:
+ return src_text
+ return src_text + " " + dst_text
+
+
+def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
+ snr_low: int, snr_high: int, min_overlap: float):
+ """Mix samples within a batch, summing the waveforms and concatenating the text infos.
+
+ Args:
+ wavs (torch.Tensor): Audio tensors of shape [B, C, T].
+ infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
+ aug_p (float): Augmentation probability.
+ mix_p (float): Proportion of items in the batch to mix (and merge) together.
+ snr_low (int): Lowerbound for sampling SNR.
+ snr_high (int): Upperbound for sampling SNR.
+ min_overlap (float): Minimum overlap between mixed samples.
+ Returns:
+ tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
+ and mixed SoundInfo for the given batch.
+ """
+ # no mixing to perform within the batch
+ if mix_p == 0:
+ return wavs, infos
+
+ if random.uniform(0, 1) < aug_p:
+ # perform all augmentations on waveforms as [B, T]
+ # randomly picking pairs of audio to mix
+ assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
+ wavs = wavs.mean(dim=1, keepdim=False)
+ B, T = wavs.shape
+ k = int(mix_p * B)
+ mixed_sources_idx = torch.randperm(B)[:k]
+ mixed_targets_idx = torch.randperm(B)[:k]
+ aug_wavs = snr_mix(
+ wavs[mixed_sources_idx],
+ wavs[mixed_targets_idx],
+ snr_low,
+ snr_high,
+ min_overlap,
+ )
+ # mixing textual descriptions in metadata
+ descriptions = [info.description for info in infos]
+ aug_infos = []
+ for i, j in zip(mixed_sources_idx, mixed_targets_idx):
+ text = mix_text(descriptions[i], descriptions[j])
+ m = replace(infos[i])
+ m.description = text
+ aug_infos.append(m)
+
+ # back to [B, C, T]
+ aug_wavs = aug_wavs.unsqueeze(1)
+ assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
+ assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
+ assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
+
+ return aug_wavs, aug_infos # [B, C, T]
+ else:
+ # randomly pick samples in the batch to match
+ # the batch size when performing audio mixing
+ B, C, T = wavs.shape
+ k = int(mix_p * B)
+ wav_idx = torch.randperm(B)[:k]
+ wavs = wavs[wav_idx]
+ infos = [infos[i] for i in wav_idx]
+ assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
+
+ return wavs, infos # [B, C, T]
diff --git a/audiocraft/data/zip.py b/audiocraft/data/zip.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0b17849d36991e7def35a14d3d518b9d867ce36
--- /dev/null
+++ b/audiocraft/data/zip.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utility for reading some info from inside a zip file.
+"""
+
+import typing
+import zipfile
+
+from dataclasses import dataclass
+from functools import lru_cache
+from typing_extensions import Literal
+
+
+DEFAULT_SIZE = 32
+MODE = Literal['r', 'w', 'x', 'a']
+
+
+@dataclass(order=True)
+class PathInZip:
+ """Hold a path of file within a zip file.
+
+ Args:
+ path (str): The convention is :.
+ Let's assume there is a zip file /some/location/foo.zip
+ and inside of it is a json file located at /data/file1.json,
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
+ """
+
+ INFO_PATH_SEP = ':'
+ zip_path: str
+ file_path: str
+
+ def __init__(self, path: str) -> None:
+ split_path = path.split(self.INFO_PATH_SEP)
+ assert len(split_path) == 2
+ self.zip_path, self.file_path = split_path
+
+ @classmethod
+ def from_paths(cls, zip_path: str, file_path: str):
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+ def __str__(self) -> str:
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
+
+
+def _open_zip(path: str, mode: MODE = 'r'):
+ return zipfile.ZipFile(path, mode)
+
+
+_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
+
+
+def set_zip_cache_size(max_size: int):
+ """Sets the maximal LRU caching for zip file opening.
+
+ Args:
+ max_size (int): the maximal LRU cache.
+ """
+ global _cached_open_zip
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
+
+
+def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
+ """Opens a file stored inside a zip and returns a file-like object.
+
+ Args:
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
+ mode (str): The mode in which to open the file with.
+ Returns:
+ A file-like object for PathInZip.
+ """
+ zf = _cached_open_zip(path_in_zip.zip_path)
+ return zf.open(path_in_zip.file_path)
diff --git a/audiocraft/environment.py b/audiocraft/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc7819305758bb50a9984928bfa7f13eabef5f5
--- /dev/null
+++ b/audiocraft/environment.py
@@ -0,0 +1,176 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Provides cluster and tools configuration across clusters (slurm, dora, utilities).
+"""
+
+import logging
+import os
+from pathlib import Path
+import re
+import typing as tp
+
+import omegaconf
+
+from .utils.cluster import _guess_cluster_type
+
+
+logger = logging.getLogger(__name__)
+
+
+class AudioCraftEnvironment:
+ """Environment configuration for teams and clusters.
+
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
+
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
+ Use the following environment variables to specify the cluster, team or configuration:
+
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
+ cannot be inferred automatically.
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
+ If not set, configuration is read from config/teams.yaml.
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
+ Cluster configuration are shared across teams to match compute allocation,
+ specify your cluster configuration in the configuration file under a key mapping
+ your team name.
+ """
+ _instance = None
+ DEFAULT_TEAM = "default"
+
+ def __init__(self) -> None:
+ """Loads configuration."""
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
+ cluster_type = _guess_cluster_type()
+ cluster = os.getenv(
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
+ )
+ logger.info("Detecting cluster type %s", cluster_type)
+
+ self.cluster: str = cluster
+
+ config_path = os.getenv(
+ "AUDIOCRAFT_CONFIG",
+ Path(__file__)
+ .parent.parent.joinpath("config/teams", self.team)
+ .with_suffix(".yaml"),
+ )
+ self.config = omegaconf.OmegaConf.load(config_path)
+ self._dataset_mappers = []
+ cluster_config = self._get_cluster_config()
+ if "dataset_mappers" in cluster_config:
+ for pattern, repl in cluster_config["dataset_mappers"].items():
+ regex = re.compile(pattern)
+ self._dataset_mappers.append((regex, repl))
+
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
+ assert isinstance(self.config, omegaconf.DictConfig)
+ return self.config[self.cluster]
+
+ @classmethod
+ def instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
+
+ @classmethod
+ def reset(cls):
+ """Clears the environment and forces a reload on next invocation."""
+ cls._instance = None
+
+ @classmethod
+ def get_team(cls) -> str:
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
+ If not defined, defaults to "labs".
+ """
+ return cls.instance().team
+
+ @classmethod
+ def get_cluster(cls) -> str:
+ """Gets the detected cluster.
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
+ """
+ return cls.instance().cluster
+
+ @classmethod
+ def get_dora_dir(cls) -> Path:
+ """Gets the path to the dora directory for the current team and cluster.
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
+ """
+ cluster_config = cls.instance()._get_cluster_config()
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
+ logger.warning(f"Dora directory: {dora_dir}")
+ return Path(dora_dir)
+
+ @classmethod
+ def get_reference_dir(cls) -> Path:
+ """Gets the path to the reference directory for the current team and cluster.
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
+ """
+ cluster_config = cls.instance()._get_cluster_config()
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
+
+ @classmethod
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
+ """Get the list of nodes to exclude for that cluster."""
+ cluster_config = cls.instance()._get_cluster_config()
+ return cluster_config.get("slurm_exclude")
+
+ @classmethod
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
+
+ Args:
+ partition_types (list[str], optional): partition types to retrieve. Values must be
+ from ['global', 'team']. If not provided, the global partition is returned.
+ """
+ if not partition_types:
+ partition_types = ["global"]
+
+ cluster_config = cls.instance()._get_cluster_config()
+ partitions = [
+ cluster_config["partitions"][partition_type]
+ for partition_type in partition_types
+ ]
+ return ",".join(partitions)
+
+ @classmethod
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
+
+ Args:
+ path (str or Path): Path to resolve.
+ Returns:
+ Path: Resolved path.
+ """
+ path = str(path)
+
+ if path.startswith("//reference"):
+ reference_dir = cls.get_reference_dir()
+ logger.warn(f"Reference directory: {reference_dir}")
+ assert (
+ reference_dir.exists() and reference_dir.is_dir()
+ ), f"Reference directory does not exist: {reference_dir}."
+ path = re.sub("^//reference", str(reference_dir), path)
+
+ return Path(path)
+
+ @classmethod
+ def apply_dataset_mappers(cls, path: str) -> str:
+ """Applies dataset mapping regex rules as defined in the configuration.
+ If no rules are defined, the path is returned as-is.
+ """
+ instance = cls.instance()
+
+ for pattern, repl in instance._dataset_mappers:
+ path = pattern.sub(repl, path)
+
+ return path
diff --git a/audiocraft/grids/__init__.py b/audiocraft/grids/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..70643517cd1a8b4e712eca90e23411ae89937795
--- /dev/null
+++ b/audiocraft/grids/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Dora Grids."""
diff --git a/audiocraft/grids/_base_explorers.py b/audiocraft/grids/_base_explorers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f26666aa596f7bd2e8695c4f00e7963e978ceb
--- /dev/null
+++ b/audiocraft/grids/_base_explorers.py
@@ -0,0 +1,80 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+import time
+import typing as tp
+from dora import Explorer
+import treetable as tt
+
+
+def get_sheep_ping(sheep) -> tp.Optional[str]:
+ """Return the amount of time since the Sheep made some update
+ to its log. Returns a str using the relevant time unit."""
+ ping = None
+ if sheep.log is not None and sheep.log.exists():
+ delta = time.time() - sheep.log.stat().st_mtime
+ if delta > 3600 * 24:
+ ping = f'{delta / (3600 * 24):.1f}d'
+ elif delta > 3600:
+ ping = f'{delta / (3600):.1f}h'
+ elif delta > 60:
+ ping = f'{delta / 60:.1f}m'
+ else:
+ ping = f'{delta:.1f}s'
+ return ping
+
+
+class BaseExplorer(ABC, Explorer):
+ """Base explorer for AudioCraft grids.
+
+ All task specific solvers are expected to implement the `get_grid_metrics`
+ method to specify logic about metrics to display for a given task.
+
+ If additional stages are used, the child explorer must define how to handle
+ these new stages in the `process_history` and `process_sheep` methods.
+ """
+ def stages(self):
+ return ["train", "valid", "evaluate"]
+
+ def get_grid_meta(self):
+ """Returns the list of Meta information to display for each XP/job.
+ """
+ return [
+ tt.leaf("index", align=">"),
+ tt.leaf("name", wrap=140),
+ tt.leaf("state"),
+ tt.leaf("sig", align=">"),
+ tt.leaf("sid", align="<"),
+ ]
+
+ @abstractmethod
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table.
+ """
+ ...
+
+ def process_sheep(self, sheep, history):
+ train = {
+ "epoch": len(history),
+ }
+ parts = {"train": train}
+ for metrics in history:
+ for key, sub in metrics.items():
+ part = parts.get(key, {})
+ if 'duration' in sub:
+ # Convert to minutes for readability.
+ sub['duration'] = sub['duration'] / 60.
+ part.update(sub)
+ parts[key] = part
+ ping = get_sheep_ping(sheep)
+ if ping is not None:
+ for name in self.stages():
+ if name not in parts:
+ parts[name] = {}
+ # Add the ping to each part for convenience.
+ parts[name]['ping'] = ping
+ return parts
diff --git a/audiocraft/grids/audiogen/__init__.py b/audiocraft/grids/audiogen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0a2688450ce120088b79c3314a2f267394dc11
--- /dev/null
+++ b/audiocraft/grids/audiogen/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""AudioGen grids."""
diff --git a/audiocraft/grids/audiogen/audiogen_base_16khz.py b/audiocraft/grids/audiogen/audiogen_base_16khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..190cc1d0a1e316347e8ebbdfc8de7e2942c1b3d7
--- /dev/null
+++ b/audiocraft/grids/audiogen/audiogen_base_16khz.py
@@ -0,0 +1,23 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ..musicgen._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=64, partition=partitions)
+ launcher.bind_(solver='audiogen/audiogen_base_16khz')
+ # replace this by the desired environmental sound dataset
+ launcher.bind_(dset='internal/sounds_16khz')
+
+ fsdp = {'autocast': False, 'fsdp.use': True}
+ medium = {'model/lm/model_scale': 'medium'}
+
+ launcher.bind_(fsdp)
+ launcher(medium)
diff --git a/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py b/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f6d402a3c4a113d4c37be062790fa435b72104
--- /dev/null
+++ b/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Evaluation with objective metrics for the pretrained AudioGen models.
+This grid takes signature from the training grid and runs evaluation-only stage.
+
+When running the grid for the first time, please use:
+REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
+and re-use the REGEN=1 option when the grid is changed to force regenerating it.
+
+Note that you need the proper metrics external libraries setup to use all
+the objective metrics activated in this grid. Refer to the README for more information.
+"""
+
+import os
+
+from ..musicgen._explorers import GenerationEvalExplorer
+from ...environment import AudioCraftEnvironment
+from ... import train
+
+
+def eval(launcher, batch_size: int = 32):
+ opts = {
+ 'dset': 'audio/audiocaps_16khz',
+ 'solver/audiogen/evaluation': 'objective_eval',
+ 'execute_only': 'evaluate',
+ '+dataset.evaluate.batch_size': batch_size,
+ '+metrics.fad.tf.batch_size': 32,
+ }
+ # binary for FAD computation: replace this path with your own path
+ metrics_opts = {
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
+ }
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
+ opt2 = {'transformer_lm.two_step_cfg': True}
+
+ sub = launcher.bind(opts)
+ sub.bind_(metrics_opts)
+
+ # base objective metrics
+ sub(opt1, opt2)
+
+
+@GenerationEvalExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=4, partition=partitions)
+
+ if 'REGEN' not in os.environ:
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
+ with launcher.job_array():
+ for sig in folder.iterdir():
+ if not sig.is_symlink():
+ continue
+ xp = train.main.get_xp_from_sig(sig.name)
+ launcher(xp.argv)
+ return
+
+ audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
+ audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
+
+ audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
+ audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
+ eval(audiogen_base_medium, batch_size=128)
diff --git a/audiocraft/grids/compression/__init__.py b/audiocraft/grids/compression/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b688528f1f3e4efc0c2a1e9d490f33c4158b3f0
--- /dev/null
+++ b/audiocraft/grids/compression/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""EnCodec grids."""
diff --git a/audiocraft/grids/compression/_explorers.py b/audiocraft/grids/compression/_explorers.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed30d5b8a1c14676503148ddf133c79ed2e33bf
--- /dev/null
+++ b/audiocraft/grids/compression/_explorers.py
@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import treetable as tt
+
+from .._base_explorers import BaseExplorer
+
+
+class CompressionExplorer(BaseExplorer):
+ eval_metrics = ["sisnr", "visqol"]
+
+ def stages(self):
+ return ["train", "valid", "evaluate"]
+
+ def get_grid_meta(self):
+ """Returns the list of Meta information to display for each XP/job.
+ """
+ return [
+ tt.leaf("index", align=">"),
+ tt.leaf("name", wrap=140),
+ tt.leaf("state"),
+ tt.leaf("sig", align=">"),
+ ]
+
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table.
+ """
+ return [
+ tt.group(
+ "train",
+ [
+ tt.leaf("epoch"),
+ tt.leaf("bandwidth", ".2f"),
+ tt.leaf("adv", ".4f"),
+ tt.leaf("d_loss", ".4f"),
+ ],
+ align=">",
+ ),
+ tt.group(
+ "valid",
+ [
+ tt.leaf("bandwidth", ".2f"),
+ tt.leaf("adv", ".4f"),
+ tt.leaf("msspec", ".4f"),
+ tt.leaf("sisnr", ".2f"),
+ ],
+ align=">",
+ ),
+ tt.group(
+ "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
+ ),
+ ]
diff --git a/audiocraft/grids/compression/debug.py b/audiocraft/grids/compression/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..5612ff5688d85fede0e605b244919e8081cb1da9
--- /dev/null
+++ b/audiocraft/grids/compression/debug.py
@@ -0,0 +1,31 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid is a minimal example for debugging compression task
+and how to override parameters directly in a grid.
+Learn more about dora grids: https://github.com/facebookresearch/dora
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=2, partition=partitions)
+ launcher.bind_(solver='compression/debug')
+
+ with launcher.job_array():
+ # base debug task using config from solver=compression/debug
+ launcher()
+ # we can override parameters in the grid to launch additional xps
+ launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
diff --git a/audiocraft/grids/compression/encodec_audiogen_16khz.py b/audiocraft/grids/compression/encodec_audiogen_16khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b41f684045594bb264cfb7f4f15d1da439382c
--- /dev/null
+++ b/audiocraft/grids/compression/encodec_audiogen_16khz.py
@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=8, partition=partitions)
+ # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
+ # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
+ launcher.bind_(solver='compression/encodec_audiogen_16khz')
+ # replace this by the desired sound dataset
+ launcher.bind_(dset='internal/sounds_16khz')
+ # launch xp
+ launcher()
diff --git a/audiocraft/grids/compression/encodec_base_24khz.py b/audiocraft/grids/compression/encodec_base_24khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..117b2b1e496ca31b3d614672b472c9213cedb4ad
--- /dev/null
+++ b/audiocraft/grids/compression/encodec_base_24khz.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid shows how to train a base causal EnCodec model at 24 kHz.
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=8, partition=partitions)
+ # base causal EnCodec trained on monophonic audio sampled at 24 kHz
+ launcher.bind_(solver='compression/encodec_base_24khz')
+ # replace this by the desired dataset
+ launcher.bind_(dset='audio/example')
+ # launch xp
+ launcher()
diff --git a/audiocraft/grids/compression/encodec_musicgen_32khz.py b/audiocraft/grids/compression/encodec_musicgen_32khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..9da31daa5f009f46e753601a51a06391594b8f9b
--- /dev/null
+++ b/audiocraft/grids/compression/encodec_musicgen_32khz.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid shows how to train a MusicGen EnCodec model at 32 kHz.
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=8, partition=partitions)
+ # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
+ # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
+ launcher.bind_(solver='compression/encodec_musicgen_32khz')
+ # replace this by the desired music dataset
+ launcher.bind_(dset='internal/music_400k_32khz')
+ # launch xp
+ launcher()
+ launcher({
+ 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
+ 'label': 'visqol',
+ 'evaluate.metrics.visqol': True
+ })
diff --git a/audiocraft/grids/diffusion/4_bands_base_32khz.py b/audiocraft/grids/diffusion/4_bands_base_32khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7e67bcc89dd0c8e50d770e600b55f179fe19588
--- /dev/null
+++ b/audiocraft/grids/diffusion/4_bands_base_32khz.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Training of the 4 diffusion models described in
+"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
+(paper link).
+"""
+
+from ._explorers import DiffusionExplorer
+
+
+@DiffusionExplorer
+def explorer(launcher):
+ launcher.slurm_(gpus=4, partition='learnfair')
+
+ launcher.bind_({'solver': 'diffusion/default',
+ 'dset': 'internal/music_10k_32khz'})
+
+ with launcher.job_array():
+ launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
+ launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
+ launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
+ launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
diff --git a/audiocraft/grids/diffusion/__init__.py b/audiocraft/grids/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5737294ae16c0de52085b8dcf6825c348f617e4
--- /dev/null
+++ b/audiocraft/grids/diffusion/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Diffusion grids."""
diff --git a/audiocraft/grids/diffusion/_explorers.py b/audiocraft/grids/diffusion/_explorers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bf4ca57b63f5f9308bd1178ddbde5d8f06748e5
--- /dev/null
+++ b/audiocraft/grids/diffusion/_explorers.py
@@ -0,0 +1,66 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import treetable as tt
+
+from .._base_explorers import BaseExplorer
+
+
+class DiffusionExplorer(BaseExplorer):
+ eval_metrics = ["sisnr", "visqol"]
+
+ def stages(self):
+ return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
+
+ def get_grid_meta(self):
+ """Returns the list of Meta information to display for each XP/job.
+ """
+ return [
+ tt.leaf("index", align=">"),
+ tt.leaf("name", wrap=140),
+ tt.leaf("state"),
+ tt.leaf("sig", align=">"),
+ ]
+
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table.
+ """
+ return [
+ tt.group(
+ "train",
+ [
+ tt.leaf("epoch"),
+ tt.leaf("loss", ".3%"),
+ ],
+ align=">",
+ ),
+ tt.group(
+ "valid",
+ [
+ tt.leaf("loss", ".3%"),
+ # tt.leaf("loss_0", ".3%"),
+ ],
+ align=">",
+ ),
+ tt.group(
+ "valid_ema",
+ [
+ tt.leaf("loss", ".3%"),
+ # tt.leaf("loss_0", ".3%"),
+ ],
+ align=">",
+ ),
+ tt.group(
+ "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
+ tt.leaf("rvm_3", ".4f"), ], align=">"
+ ),
+ tt.group(
+ "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
+ tt.leaf("rvm_3", ".4f")], align=">"
+ ),
+ ]
diff --git a/audiocraft/grids/musicgen/__init__.py b/audiocraft/grids/musicgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f101f5a29ff85271e44e4f27545168a8f27baa
--- /dev/null
+++ b/audiocraft/grids/musicgen/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""MusicGen grids."""
diff --git a/audiocraft/grids/musicgen/_explorers.py b/audiocraft/grids/musicgen/_explorers.py
new file mode 100644
index 0000000000000000000000000000000000000000..334836b72559a120feb8a15eef3fe96ce88a4edb
--- /dev/null
+++ b/audiocraft/grids/musicgen/_explorers.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import treetable as tt
+
+from .._base_explorers import BaseExplorer
+
+
+class LMExplorer(BaseExplorer):
+ eval_metrics: tp.List[str] = []
+
+ def stages(self) -> tp.List[str]:
+ return ['train', 'valid']
+
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table."""
+ return [
+ tt.group(
+ 'train',
+ [
+ tt.leaf('epoch'),
+ tt.leaf('duration', '.1f'), # duration in minutes
+ tt.leaf('ping'),
+ tt.leaf('ce', '.4f'), # cross entropy
+ tt.leaf("ppl", '.3f'), # perplexity
+ ],
+ align='>',
+ ),
+ tt.group(
+ 'valid',
+ [
+ tt.leaf('ce', '.4f'),
+ tt.leaf('ppl', '.3f'),
+ tt.leaf('best_ppl', '.3f'),
+ ],
+ align='>',
+ ),
+ ]
+
+ def process_sheep(self, sheep, history):
+ parts = super().process_sheep(sheep, history)
+
+ track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
+ best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
+
+ def comparator(mode, a, b):
+ return a < b if mode == 'lower' else a > b
+
+ for metrics in history:
+ for key, sub in metrics.items():
+ for metric in track_by:
+ # for the validation set, keep track of best metrics (ppl in this example)
+ # this is so we can conveniently compare metrics between runs in the grid
+ if key == 'valid' and metric in sub and comparator(
+ track_by[metric], sub[metric], best_metrics[metric]
+ ):
+ best_metrics[metric] = sub[metric]
+
+ if 'valid' in parts:
+ parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
+ return parts
+
+
+class GenerationEvalExplorer(BaseExplorer):
+ eval_metrics: tp.List[str] = []
+
+ def stages(self) -> tp.List[str]:
+ return ['evaluate']
+
+ def get_grid_metrics(self):
+ """Return the metrics that should be displayed in the tracking table."""
+ return [
+ tt.group(
+ 'evaluate',
+ [
+ tt.leaf('epoch', '.3f'),
+ tt.leaf('duration', '.1f'),
+ tt.leaf('ping'),
+ tt.leaf('ce', '.4f'),
+ tt.leaf('ppl', '.3f'),
+ tt.leaf('fad', '.3f'),
+ tt.leaf('kld', '.3f'),
+ tt.leaf('text_consistency', '.3f'),
+ tt.leaf('chroma_cosine', '.3f'),
+ ],
+ align='>',
+ ),
+ ]
diff --git a/audiocraft/grids/musicgen/musicgen_base_32khz.py b/audiocraft/grids/musicgen/musicgen_base_32khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e364614537e426f21c18a2c2a9d94b3babce051
--- /dev/null
+++ b/audiocraft/grids/musicgen/musicgen_base_32khz.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=32, partition=partitions)
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
+ # replace this by the desired music dataset
+ launcher.bind_(dset='internal/music_400k_32khz')
+
+ fsdp = {'autocast': False, 'fsdp.use': True}
+ medium = {'model/lm/model_scale': 'medium'}
+ large = {'model/lm/model_scale': 'large'}
+
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+ launcher.bind_(fsdp)
+
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub()
+
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub(medium, adam)
+
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
diff --git a/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py b/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9a43f37d7369b5de4542fba87c4c8739d58b1e8
--- /dev/null
+++ b/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=32, partition=partitions)
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
+ # replace this by the desired music dataset
+ launcher.bind_(dset='internal/music_400k_32khz')
+
+ fsdp = {'autocast': False, 'fsdp.use': True}
+ medium = {'model/lm/model_scale': 'medium'}
+ large = {'model/lm/model_scale': 'large'}
+
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+ # BEGINNING OF CACHE WRITING JOBS.
+ cache_write = {
+ 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
+ 'cache.write': True,
+ 'generate.every': 500,
+ 'evaluate.every': 500,
+ 'logging.log_updates': 50,
+ }
+
+ cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
+ cache_sub.bind_({'deadlock.use': True})
+ cache_sub.slurm_(gpus=8)
+ with launcher.job_array():
+ num_shards = 10 # total number of jobs running in parallel.
+ for shard in range(0, num_shards):
+ launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
+
+ # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
+ # OR SUFFICIENTLY AHEAD.
+ return
+
+ cache = {
+ 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
+ }
+ launcher.bind_(fsdp, cache)
+
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub()
+
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub(medium, adam)
+
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
diff --git a/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py b/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..64ad3f8c77afe1ab5908e407ad14d4879e1b1ad1
--- /dev/null
+++ b/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=32, partition=partitions)
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
+ # replace this by the desired music dataset
+ launcher.bind_(dset='internal/music_400k_32khz')
+ launcher.bind_(conditioner='clapemb2music')
+
+ fsdp = {'autocast': False, 'fsdp.use': True}
+ cache_path = {'conditioners.description.clap.cache_path':
+ '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
+ text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
+
+ launcher.bind_(fsdp)
+
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
+ with launcher.job_array():
+ launcher()
+ launcher(text_wav_training_opt)
+ launcher(cache_path)
+ launcher(cache_path, text_wav_training_opt)
diff --git a/audiocraft/grids/musicgen/musicgen_melody_32khz.py b/audiocraft/grids/musicgen/musicgen_melody_32khz.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0d6710a23c117406e9724057a62eccab88ce907
--- /dev/null
+++ b/audiocraft/grids/musicgen/musicgen_melody_32khz.py
@@ -0,0 +1,65 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=32, partition=partitions)
+ launcher.bind_(solver='musicgen/musicgen_melody_32khz')
+ # replace this by the desired music dataset
+ launcher.bind_(dset='internal/music_400k_32khz')
+
+ fsdp = {'autocast': False, 'fsdp.use': True}
+ medium = {'model/lm/model_scale': 'medium'}
+ large = {'model/lm/model_scale': 'large'}
+
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+ cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
+ '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
+
+ # CACHE GENERATION JOBS
+ n_cache_gen_jobs = 4
+ gen_sub = launcher.slurm(gpus=1)
+ gen_sub.bind_(
+ cache_path, {
+ # the cache is always computed over the whole file, so duration doesn't matter here.
+ 'dataset.segment_duration': 2.,
+ 'dataset.batch_size': 8,
+ 'dataset.train.permutation_on_files': True, # try to not repeat files.
+ 'optim.epochs': 10,
+ 'model/lm/model_scale': 'xsmall',
+
+ })
+ with gen_sub.job_array():
+ for gen_job in range(n_cache_gen_jobs):
+ gen_sub({'dataset.train.shuffle_seed': gen_job})
+
+ # ACTUAL TRAINING JOBS.
+ launcher.bind_(fsdp)
+
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub()
+ sub(cache_path)
+
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub(medium, adam)
+
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
+ with launcher.job_array():
+ sub = launcher.bind()
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
diff --git a/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py b/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..39ceaf7dab15ec3f0f669cfe57ca9e932a9ab40d
--- /dev/null
+++ b/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
@@ -0,0 +1,99 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Evaluation with objective metrics for the pretrained MusicGen models.
+This grid takes signature from the training grid and runs evaluation-only stage.
+
+When running the grid for the first time, please use:
+REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
+and re-use the REGEN=1 option when the grid is changed to force regenerating it.
+
+Note that you need the proper metrics external libraries setup to use all
+the objective metrics activated in this grid. Refer to the README for more information.
+"""
+
+import os
+
+from ._explorers import GenerationEvalExplorer
+from ...environment import AudioCraftEnvironment
+from ... import train
+
+
+def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
+ opts = {
+ 'dset': 'audio/musiccaps_32khz',
+ 'solver/musicgen/evaluation': 'objective_eval',
+ 'execute_only': 'evaluate',
+ '+dataset.evaluate.batch_size': batch_size,
+ '+metrics.fad.tf.batch_size': 16,
+ }
+ # chroma-specific evaluation
+ chroma_opts = {
+ 'dset': 'internal/music_400k_32khz',
+ 'dataset.evaluate.segment_duration': 30,
+ 'dataset.evaluate.num_samples': 1000,
+ 'evaluate.metrics.chroma_cosine': True,
+ 'evaluate.metrics.fad': False,
+ 'evaluate.metrics.kld': False,
+ 'evaluate.metrics.text_consistency': False,
+ }
+ # binary for FAD computation: replace this path with your own path
+ metrics_opts = {
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
+ }
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
+ opt2 = {'transformer_lm.two_step_cfg': True}
+
+ sub = launcher.bind(opts)
+ sub.bind_(metrics_opts)
+
+ # base objective metrics
+ sub(opt1, opt2)
+
+ if eval_melody:
+ # chroma-specific metrics
+ sub(opt1, opt2, chroma_opts)
+
+
+@GenerationEvalExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=4, partition=partitions)
+
+ if 'REGEN' not in os.environ:
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
+ with launcher.job_array():
+ for sig in folder.iterdir():
+ if not sig.is_symlink():
+ continue
+ xp = train.main.get_xp_from_sig(sig.name)
+ launcher(xp.argv)
+ return
+
+ with launcher.job_array():
+ musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
+ musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
+
+ # base musicgen models
+ musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
+ eval(musicgen_base_small, batch_size=128)
+
+ musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
+ musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
+ eval(musicgen_base_medium, batch_size=128)
+
+ musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
+ musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
+ eval(musicgen_base_large, batch_size=128)
+
+ # melody musicgen model
+ musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
+ musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
+
+ musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
+ musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
+ eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
diff --git a/audiocraft/losses/__init__.py b/audiocraft/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d55107b2c11822cab749ed3683cf19020802898a
--- /dev/null
+++ b/audiocraft/losses/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Loss related classes and functions. In particular the loss balancer from
+EnCodec, and the usual spectral losses."""
+
+# flake8: noqa
+from .balancer import Balancer
+from .sisnr import SISNR
+from .stftloss import (
+ LogSTFTMagnitudeLoss,
+ MRSTFTLoss,
+ SpectralConvergenceLoss,
+ STFTLoss
+)
+from .specloss import (
+ MelSpectrogramL1Loss,
+ MultiScaleMelSpectrogramLoss,
+)
diff --git a/audiocraft/losses/__pycache__/__init__.cpython-310.pyc b/audiocraft/losses/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3f96d9484a8bb16e710c6c09cad8c774c0fee4e
Binary files /dev/null and b/audiocraft/losses/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/losses/__pycache__/balancer.cpython-310.pyc b/audiocraft/losses/__pycache__/balancer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed1268b546ed1eed6dd0ded696e1bdc22a0c3c96
Binary files /dev/null and b/audiocraft/losses/__pycache__/balancer.cpython-310.pyc differ
diff --git a/audiocraft/losses/__pycache__/sisnr.cpython-310.pyc b/audiocraft/losses/__pycache__/sisnr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5a60ac45a81c68651002f7a44b6f1a33284453c
Binary files /dev/null and b/audiocraft/losses/__pycache__/sisnr.cpython-310.pyc differ
diff --git a/audiocraft/losses/__pycache__/specloss.cpython-310.pyc b/audiocraft/losses/__pycache__/specloss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ebf4b3e2ad137a2c83aacaa993799887c5da8ab
Binary files /dev/null and b/audiocraft/losses/__pycache__/specloss.cpython-310.pyc differ
diff --git a/audiocraft/losses/__pycache__/stftloss.cpython-310.pyc b/audiocraft/losses/__pycache__/stftloss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2a0745f846381df970c7d9c09a39b1e34f14b94
Binary files /dev/null and b/audiocraft/losses/__pycache__/stftloss.cpython-310.pyc differ
diff --git a/audiocraft/losses/balancer.py b/audiocraft/losses/balancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ac8adebab8cdee8f82351965195dc02800d18
--- /dev/null
+++ b/audiocraft/losses/balancer.py
@@ -0,0 +1,136 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import flashy
+import torch
+from torch import autograd
+
+
+class Balancer:
+ """Loss balancer.
+
+ The loss balancer combines losses together to compute gradients for the backward.
+ Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
+ not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
+ `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
+ the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
+ going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
+ interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
+
+ Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
+ (with `avg` an exponential moving average over the updates),
+
+ G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
+
+ If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
+ standard sum of the partial gradients with the given weights.
+
+ A call to the backward method of the balancer will compute the the partial gradients,
+ combining all the losses and potentially rescaling the gradients,
+ which can help stabilize the training and reason about multiple losses with varying scales.
+ The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
+
+ Expected usage:
+
+ weights = {'loss_a': 1, 'loss_b': 4}
+ balancer = Balancer(weights, ...)
+ losses: dict = {}
+ losses['loss_a'] = compute_loss_a(x, y)
+ losses['loss_b'] = compute_loss_b(x, y)
+ if model.training():
+ effective_loss = balancer.backward(losses, x)
+
+ Args:
+ weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
+ from the backward method to match the weights keys to assign weight to each of the provided loss.
+ balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
+ overall gradient, rather than a constant multiplier.
+ total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
+ emay_decay (float): EMA decay for averaging the norms.
+ per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
+ when rescaling the gradients.
+ epsilon (float): Epsilon value for numerical stability.
+ monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
+ coming from each loss, when calling `backward()`.
+ """
+ def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
+ ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
+ monitor: bool = False):
+ self.weights = weights
+ self.per_batch_item = per_batch_item
+ self.total_norm = total_norm or 1.
+ self.averager = flashy.averager(ema_decay or 1.)
+ self.epsilon = epsilon
+ self.monitor = monitor
+ self.balance_grads = balance_grads
+ self._metrics: tp.Dict[str, tp.Any] = {}
+
+ @property
+ def metrics(self):
+ return self._metrics
+
+ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
+ """Compute the backward and return the effective train loss, e.g. the loss obtained from
+ computing the effective weights. If `balance_grads` is True, the effective weights
+ are the one that needs to be applied to each gradient to respect the desired relative
+ scale of gradients coming from each loss.
+
+ Args:
+ losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
+ input (torch.Tensor): the input of the losses, typically the output of the model.
+ This should be the single point of dependence between the losses
+ and the model being trained.
+ """
+ norms = {}
+ grads = {}
+ for name, loss in losses.items():
+ # Compute partial derivative of the less with respect to the input.
+ grad, = autograd.grad(loss, [input], retain_graph=True)
+ if self.per_batch_item:
+ # We do not average the gradient over the batch dimension.
+ dims = tuple(range(1, grad.dim()))
+ norm = grad.norm(dim=dims, p=2).mean()
+ else:
+ norm = grad.norm(p=2)
+ norms[name] = norm
+ grads[name] = grad
+
+ count = 1
+ if self.per_batch_item:
+ count = len(grad)
+ # Average norms across workers. Theoretically we should average the
+ # squared norm, then take the sqrt, but it worked fine like that.
+ avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
+ # We approximate the total norm of the gradient as the sums of the norms.
+ # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
+ total = sum(avg_norms.values())
+
+ self._metrics = {}
+ if self.monitor:
+ # Store the ratio of the total gradient represented by each loss.
+ for k, v in avg_norms.items():
+ self._metrics[f'ratio_{k}'] = v / total
+
+ total_weights = sum([self.weights[k] for k in avg_norms])
+ assert total_weights > 0.
+ desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
+
+ out_grad = torch.zeros_like(input)
+ effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
+ for name, avg_norm in avg_norms.items():
+ if self.balance_grads:
+ # g_balanced = g / avg(||g||) * total_norm * desired_ratio
+ scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
+ else:
+ # We just do regular weighted sum of the gradients.
+ scale = self.weights[name]
+ out_grad.add_(grads[name], alpha=scale)
+ effective_loss += scale * losses[name].detach()
+ # Send the computed partial derivative with respect to the output of the model to the model.
+ input.backward(out_grad)
+ return effective_loss
diff --git a/audiocraft/losses/sisnr.py b/audiocraft/losses/sisnr.py
new file mode 100644
index 0000000000000000000000000000000000000000..30f1fa1de9aca22758b6665609a1eacc0bd992ca
--- /dev/null
+++ b/audiocraft/losses/sisnr.py
@@ -0,0 +1,92 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
+ with K the kernel size, by extracting frames with the given stride.
+ This will pad the input so that `F = ceil(T / K)`.
+ see https://github.com/pytorch/pytorch/issues/60466
+ """
+ *shape, length = a.shape
+ n_frames = math.ceil(length / stride)
+ tgt_length = (n_frames - 1) * stride + kernel_size
+ a = F.pad(a, (0, tgt_length - length))
+ strides = list(a.stride())
+ assert strides[-1] == 1, "data should be contiguous"
+ strides = strides[:-1] + [stride, 1]
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
+
+
+def _center(x: torch.Tensor) -> torch.Tensor:
+ return x - x.mean(-1, True)
+
+
+def _norm2(x: torch.Tensor) -> torch.Tensor:
+ return x.pow(2).sum(-1, True)
+
+
+class SISNR(nn.Module):
+ """SISNR loss.
+
+ Input should be [B, C, T], output is scalar.
+
+ Args:
+ sample_rate (int): Sample rate.
+ segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
+ entire audio only.
+ overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
+ epsilon (float): Epsilon value for numerical stability.
+ """
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ segment: tp.Optional[float] = 20,
+ overlap: float = 0.5,
+ epsilon: float = torch.finfo(torch.float32).eps,
+ ):
+ super().__init__()
+ self.sample_rate = sample_rate
+ self.segment = segment
+ self.overlap = overlap
+ self.epsilon = epsilon
+
+ def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
+ B, C, T = ref_sig.shape
+ assert ref_sig.shape == out_sig.shape
+
+ if self.segment is None:
+ frame = T
+ stride = T
+ else:
+ frame = int(self.segment * self.sample_rate)
+ stride = int(frame * (1 - self.overlap))
+
+ epsilon = self.epsilon * frame # make epsilon prop to frame size.
+
+ gt = _unfold(ref_sig, frame, stride)
+ est = _unfold(out_sig, frame, stride)
+ if self.segment is None:
+ assert gt.shape[-1] == 1
+
+ gt = _center(gt)
+ est = _center(est)
+ dot = torch.einsum("bcft,bcft->bcf", gt, est)
+
+ proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
+ noise = est - proj
+
+ sisnr = 10 * (
+ torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
+ )
+ return -1 * sisnr[..., 0].mean()
diff --git a/audiocraft/losses/specloss.py b/audiocraft/losses/specloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..11f2eb3e5c44b542a02f13db64bfb22fa0d3d212
--- /dev/null
+++ b/audiocraft/losses/specloss.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+from torchaudio.transforms import MelSpectrogram
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ..modules import pad_for_conv1d
+
+
+class MelSpectrogramWrapper(nn.Module):
+ """Wrapper around MelSpectrogram torchaudio transform providing proper padding
+ and additional post-processing including log scaling.
+
+ Args:
+ n_mels (int): Number of mel bins.
+ n_fft (int): Number of fft.
+ hop_length (int): Hop size.
+ win_length (int): Window length.
+ n_mels (int): Number of mel bins.
+ sample_rate (int): Sample rate.
+ f_min (float or None): Minimum frequency.
+ f_max (float or None): Maximum frequency.
+ log (bool): Whether to scale with log.
+ normalized (bool): Whether to normalize the melspectrogram.
+ floor_level (float): Floor level based on human perception (default=1e-5).
+ """
+ def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
+ n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
+ super().__init__()
+ self.n_fft = n_fft
+ hop_length = int(hop_length)
+ self.hop_length = hop_length
+ self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
+ win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
+ window_fn=torch.hann_window, center=False)
+ self.floor_level = floor_level
+ self.log = log
+
+ def forward(self, x):
+ p = int((self.n_fft - self.hop_length) // 2)
+ if len(x.shape) == 2:
+ x = x.unsqueeze(1)
+ x = F.pad(x, (p, p), "reflect")
+ # Make sure that all the frames are full.
+ # The combination of `pad_for_conv1d` and the above padding
+ # will make the output of size ceil(T / hop).
+ x = pad_for_conv1d(x, self.n_fft, self.hop_length)
+ self.mel_transform.to(x.device)
+ mel_spec = self.mel_transform(x)
+ B, C, freqs, frame = mel_spec.shape
+ if self.log:
+ mel_spec = torch.log10(self.floor_level + mel_spec)
+ return mel_spec.reshape(B, C * freqs, frame)
+
+
+class MelSpectrogramL1Loss(torch.nn.Module):
+ """L1 Loss on MelSpectrogram.
+
+ Args:
+ sample_rate (int): Sample rate.
+ n_fft (int): Number of fft.
+ hop_length (int): Hop size.
+ win_length (int): Window length.
+ n_mels (int): Number of mel bins.
+ f_min (float or None): Minimum frequency.
+ f_max (float or None): Maximum frequency.
+ log (bool): Whether to scale with log.
+ normalized (bool): Whether to normalize the melspectrogram.
+ floor_level (float): Floor level value based on human perception (default=1e-5).
+ """
+ def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
+ n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
+ super().__init__()
+ self.l1 = torch.nn.L1Loss()
+ self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+ log=log, normalized=normalized, floor_level=floor_level)
+
+ def forward(self, x, y):
+ self.melspec.to(x.device)
+ s_x = self.melspec(x)
+ s_y = self.melspec(y)
+ return self.l1(s_x, s_y)
+
+
+class MultiScaleMelSpectrogramLoss(nn.Module):
+ """Multi-Scale spectrogram loss (msspec).
+
+ Args:
+ sample_rate (int): Sample rate.
+ range_start (int): Power of 2 to use for the first scale.
+ range_stop (int): Power of 2 to use for the last scale.
+ n_mels (int): Number of mel bins.
+ f_min (float): Minimum frequency.
+ f_max (float or None): Maximum frequency.
+ normalized (bool): Whether to normalize the melspectrogram.
+ alphas (bool): Whether to use alphas as coefficients or not.
+ floor_level (float): Floor level value based on human perception (default=1e-5).
+ """
+ def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
+ n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+ normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
+ super().__init__()
+ l1s = list()
+ l2s = list()
+ self.alphas = list()
+ self.total = 0
+ self.normalized = normalized
+ for i in range(range_start, range_end):
+ l1s.append(
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+ log=False, normalized=normalized, floor_level=floor_level))
+ l2s.append(
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+ log=True, normalized=normalized, floor_level=floor_level))
+ if alphas:
+ self.alphas.append(np.sqrt(2 ** i - 1))
+ else:
+ self.alphas.append(1)
+ self.total += self.alphas[-1] + 1
+
+ self.l1s = nn.ModuleList(l1s)
+ self.l2s = nn.ModuleList(l2s)
+
+ def forward(self, x, y):
+ loss = 0.0
+ self.l1s.to(x.device)
+ self.l2s.to(x.device)
+ for i in range(len(self.alphas)):
+ s_x_1 = self.l1s[i](x)
+ s_y_1 = self.l1s[i](y)
+ s_x_2 = self.l2s[i](x)
+ s_y_2 = self.l2s[i](y)
+ loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
+ if self.normalized:
+ loss = loss / self.total
+ return loss
diff --git a/audiocraft/losses/stftloss.py b/audiocraft/losses/stftloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ad4b7d3324ee5b0e6064b6f71cf8caf0fdc3be7
--- /dev/null
+++ b/audiocraft/losses/stftloss.py
@@ -0,0 +1,207 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# Adapted from MIT code under the original license
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+# TODO: Replace with torchaudio.STFT?
+def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
+ window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
+ """Perform STFT and convert to magnitude spectrogram.
+
+ Args:
+ x: Input signal tensor (B, C, T).
+ fft_size (int): FFT size.
+ hop_length (int): Hop size.
+ win_length (int): Window length.
+ window (torch.Tensor or None): Window function type.
+ normalized (bool): Whether to normalize the STFT or not.
+
+ Returns:
+ torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
+ """
+ B, C, T = x.shape
+ x_stft = torch.stft(
+ x.view(-1, T), fft_size, hop_length, win_length, window,
+ normalized=normalized, return_complex=True,
+ )
+ x_stft = x_stft.view(B, C, *x_stft.shape[1:])
+ real = x_stft.real
+ imag = x_stft.imag
+
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
+
+
+class SpectralConvergenceLoss(nn.Module):
+ """Spectral convergence loss.
+ """
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
+ super().__init__()
+ self.epsilon = epsilon
+
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+ """Calculate forward propagation.
+
+ Args:
+ x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ torch.Tensor: Spectral convergence loss value.
+ """
+ return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
+
+
+class LogSTFTMagnitudeLoss(nn.Module):
+ """Log STFT magnitude loss.
+
+ Args:
+ epsilon (float): Epsilon value for numerical stability.
+ """
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
+ super().__init__()
+ self.epsilon = epsilon
+
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+ """Calculate forward propagation.
+
+ Args:
+ x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ torch.Tensor: Log STFT magnitude loss value.
+ """
+ return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
+
+
+class STFTLosses(nn.Module):
+ """STFT losses.
+
+ Args:
+ n_fft (int): Size of FFT.
+ hop_length (int): Hop length.
+ win_length (int): Window length.
+ window (str): Window function type.
+ normalized (bool): Whether to use normalized STFT or not.
+ epsilon (float): Epsilon for numerical stability.
+ """
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
+ window: str = "hann_window", normalized: bool = False,
+ epsilon: float = torch.finfo(torch.float32).eps):
+ super().__init__()
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.normalized = normalized
+ self.register_buffer("window", getattr(torch, window)(win_length))
+ self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (torch.Tensor): Predicted signal (B, T).
+ y (torch.Tensor): Groundtruth signal (B, T).
+ Returns:
+ torch.Tensor: Spectral convergence loss value.
+ torch.Tensor: Log STFT magnitude loss value.
+ """
+ x_mag = _stft(x, self.n_fft, self.hop_length,
+ self.win_length, self.window, self.normalized) # type: ignore
+ y_mag = _stft(y, self.n_fft, self.hop_length,
+ self.win_length, self.window, self.normalized) # type: ignore
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+ return sc_loss, mag_loss
+
+
+class STFTLoss(nn.Module):
+ """Single Resolution STFT loss.
+
+ Args:
+ n_fft (int): Nb of FFT.
+ hop_length (int): Hop length.
+ win_length (int): Window length.
+ window (str): Window function type.
+ normalized (bool): Whether to use normalized STFT or not.
+ epsilon (float): Epsilon for numerical stability.
+ factor_sc (float): Coefficient for the spectral loss.
+ factor_mag (float): Coefficient for the magnitude loss.
+ """
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
+ window: str = "hann_window", normalized: bool = False,
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
+ epsilon: float = torch.finfo(torch.float32).eps):
+ super().__init__()
+ self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
+ self.factor_sc = factor_sc
+ self.factor_mag = factor_mag
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (torch.Tensor): Predicted signal (B, T).
+ y (torch.Tensor): Groundtruth signal (B, T).
+ Returns:
+ torch.Tensor: Single resolution STFT loss.
+ """
+ sc_loss, mag_loss = self.loss(x, y)
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+
+class MRSTFTLoss(nn.Module):
+ """Multi resolution STFT loss.
+
+ Args:
+ n_ffts (Sequence[int]): Sequence of FFT sizes.
+ hop_lengths (Sequence[int]): Sequence of hop sizes.
+ win_lengths (Sequence[int]): Sequence of window lengths.
+ window (str): Window function type.
+ factor_sc (float): Coefficient for the spectral loss.
+ factor_mag (float): Coefficient for the magnitude loss.
+ normalized (bool): Whether to use normalized STFT or not.
+ epsilon (float): Epsilon for numerical stability.
+ """
+ def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
+ win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
+ normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
+ super().__init__()
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
+ self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
+ self.factor_sc = factor_sc
+ self.factor_mag = factor_mag
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (torch.Tensor): Predicted signal (B, T).
+ y (torch.Tensor): Groundtruth signal (B, T).
+ Returns:
+ torch.Tensor: Multi resolution STFT loss.
+ """
+ sc_loss = torch.Tensor([0.0])
+ mag_loss = torch.Tensor([0.0])
+ for f in self.stft_losses:
+ sc_l, mag_l = f(x, y)
+ sc_loss += sc_l
+ mag_loss += mag_l
+ sc_loss /= len(self.stft_losses)
+ mag_loss /= len(self.stft_losses)
+
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
diff --git a/audiocraft/metrics/__init__.py b/audiocraft/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3474bdc4f1c88b21904d2a21ba077c93a8a70c8b
--- /dev/null
+++ b/audiocraft/metrics/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
+"""
+# flake8: noqa
+from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
+from .chroma_cosinesim import ChromaCosineSimilarityMetric
+from .fad import FrechetAudioDistanceMetric
+from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
+from .rvm import RelativeVolumeMel
+from .visqol import ViSQOL
diff --git a/audiocraft/metrics/__pycache__/__init__.cpython-310.pyc b/audiocraft/metrics/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b49fa12b6d6295e40ac59eadfd05566672644be2
Binary files /dev/null and b/audiocraft/metrics/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/metrics/__pycache__/chroma_cosinesim.cpython-310.pyc b/audiocraft/metrics/__pycache__/chroma_cosinesim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8be3faa7bd05d57b9100a7a43bf78bb2d662046c
Binary files /dev/null and b/audiocraft/metrics/__pycache__/chroma_cosinesim.cpython-310.pyc differ
diff --git a/audiocraft/metrics/__pycache__/clap_consistency.cpython-310.pyc b/audiocraft/metrics/__pycache__/clap_consistency.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0306e9297ff04efdba0e1b5194f35460a3b8a01a
Binary files /dev/null and b/audiocraft/metrics/__pycache__/clap_consistency.cpython-310.pyc differ
diff --git a/audiocraft/metrics/__pycache__/fad.cpython-310.pyc b/audiocraft/metrics/__pycache__/fad.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76407e6c4a0d98bf0d6b30fb7f20f076d62a8e38
Binary files /dev/null and b/audiocraft/metrics/__pycache__/fad.cpython-310.pyc differ
diff --git a/audiocraft/metrics/__pycache__/kld.cpython-310.pyc b/audiocraft/metrics/__pycache__/kld.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..832888cea28be1844eb37e978e247717cb7ed223
Binary files /dev/null and b/audiocraft/metrics/__pycache__/kld.cpython-310.pyc differ
diff --git a/audiocraft/metrics/__pycache__/rvm.cpython-310.pyc b/audiocraft/metrics/__pycache__/rvm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f3745c7ae024300f04b10bb2a7f8b01809c1fd2
Binary files /dev/null and b/audiocraft/metrics/__pycache__/rvm.cpython-310.pyc differ
diff --git a/audiocraft/metrics/__pycache__/visqol.cpython-310.pyc b/audiocraft/metrics/__pycache__/visqol.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..815e7825c1356117b867456eb92c5295cfb3468b
Binary files /dev/null and b/audiocraft/metrics/__pycache__/visqol.cpython-310.pyc differ
diff --git a/audiocraft/metrics/chroma_cosinesim.py b/audiocraft/metrics/chroma_cosinesim.py
new file mode 100644
index 0000000000000000000000000000000000000000..40c26081b803c2017fae1b6d7d086f0b0e074cef
--- /dev/null
+++ b/audiocraft/metrics/chroma_cosinesim.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torchmetrics
+
+from ..data.audio_utils import convert_audio
+from ..modules.chroma import ChromaExtractor
+
+
+class ChromaCosineSimilarityMetric(torchmetrics.Metric):
+ """Chroma cosine similarity metric.
+
+ This metric extracts a chromagram for a reference waveform and
+ a generated waveform and compares each frame using the cosine similarity
+ function. The output is the mean cosine similarity.
+
+ Args:
+ sample_rate (int): Sample rate used by the chroma extractor.
+ n_chroma (int): Number of chroma used by the chroma extractor.
+ radix2_exp (int): Exponent for the chroma extractor.
+ argmax (bool): Whether the chroma extractor uses argmax.
+ eps (float): Epsilon for cosine similarity computation.
+ """
+ def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
+ super().__init__()
+ self.chroma_sample_rate = sample_rate
+ self.n_chroma = n_chroma
+ self.eps = eps
+ self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
+ radix2_exp=radix2_exp, argmax=argmax)
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
+
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+ """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
+ if preds.size(0) == 0:
+ return
+
+ assert preds.shape == targets.shape, (
+ f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
+ assert preds.size(0) == sizes.size(0), (
+ f"Number of items in preds ({preds.shape}) mismatch ",
+ f"with sizes ({sizes.shape})")
+ assert preds.size(0) == sample_rates.size(0), (
+ f"Number of items in preds ({preds.shape}) mismatch ",
+ f"with sample_rates ({sample_rates.shape})")
+ assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
+
+ device = self.weight.device
+ preds, targets = preds.to(device), targets.to(device) # type: ignore
+ sample_rate = sample_rates[0].item()
+ preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+ targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+ gt_chroma = self.chroma_extractor(targets)
+ gen_chroma = self.chroma_extractor(preds)
+ chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
+ for i in range(len(gt_chroma)):
+ t = int(chroma_lens[i].item())
+ cosine_sim = torch.nn.functional.cosine_similarity(
+ gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
+ self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
+ self.weight += torch.tensor(t) # type: ignore
+
+ def compute(self) -> float:
+ """Computes the average cosine similarty across all generated/target chromagrams pairs."""
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
+ return (self.cosine_sum / self.weight).item() # type: ignore
diff --git a/audiocraft/metrics/clap_consistency.py b/audiocraft/metrics/clap_consistency.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2a6c61ae177533ca2fb17e25bc77d2acbbe3791
--- /dev/null
+++ b/audiocraft/metrics/clap_consistency.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import typing as tp
+
+import torch
+import torchmetrics
+from transformers import RobertaTokenizer # type: ignore
+
+from ..data.audio_utils import convert_audio
+from ..environment import AudioCraftEnvironment
+from ..utils.utils import load_clap_state_dict
+
+try:
+ import laion_clap # type: ignore
+except ImportError:
+ laion_clap = None
+
+
+class TextConsistencyMetric(torchmetrics.Metric):
+ """Text consistency metric measuring consistency between audio and text pairs."""
+
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+ raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
+
+ def compute(self):
+ raise NotImplementedError("implement how to compute the final metric score.")
+
+
+class CLAPTextConsistencyMetric(TextConsistencyMetric):
+ """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
+
+ This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
+ or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
+
+ As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
+ similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
+ well as the generated audio based on them, and define the MCC metric as the average cosine similarity
+ between these embeddings.
+
+ Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
+ """
+ def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
+ super().__init__()
+ if laion_clap is None:
+ raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
+ self._initialize_model(model_path, model_arch, enable_fusion)
+
+ def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
+ model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
+ self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+ self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
+ self.model_sample_rate = 48_000
+ load_clap_state_dict(self.model, model_path)
+ self.model.eval()
+
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
+ # we use the default params from CLAP module here as well
+ return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
+
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+ """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
+ assert audio.size(0) == len(text), "Number of audio and text samples should match"
+ assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
+ sample_rate = int(sample_rates[0].item())
+ # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
+ audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
+ audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
+ text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+ # cosine similarity between the text and the audio embedding
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
+ self.cosine_sum += cosine_sim.sum(dim=0)
+ self.weight += torch.tensor(cosine_sim.size(0))
+
+ def compute(self):
+ """Computes the average cosine similarty across all audio/text pairs."""
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
+ return (self.cosine_sum / self.weight).item() # type: ignore
diff --git a/audiocraft/metrics/fad.py b/audiocraft/metrics/fad.py
new file mode 100644
index 0000000000000000000000000000000000000000..de66138dbb14fd4246bbfe590bddfd5beaf1ed8c
--- /dev/null
+++ b/audiocraft/metrics/fad.py
@@ -0,0 +1,329 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from pathlib import Path
+import os
+import subprocess
+import tempfile
+import typing as tp
+
+from audiocraft.data.audio import audio_write
+from audiocraft.data.audio_utils import convert_audio
+import flashy
+import torch
+import torchmetrics
+
+from ..environment import AudioCraftEnvironment
+
+
+logger = logging.getLogger(__name__)
+
+VGGISH_SAMPLE_RATE = 16_000
+VGGISH_CHANNELS = 1
+
+
+class FrechetAudioDistanceMetric(torchmetrics.Metric):
+ """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
+
+ From: D.C. Dowson & B.V. Landau The Fréchet distance between
+ multivariate normal distributions
+ https://doi.org/10.1016/0047-259X(82)90077-X
+ The Fréchet distance between two multivariate gaussians,
+ `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
+ d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
+ = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
+ - 2 * Tr(sqrt(sigma_x*sigma_y)))
+
+ To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
+ from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
+ We provide the below instructions as reference but we do not guarantee for further support
+ in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
+
+ We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
+
+ 1. Get the code and models following the repository instructions. We used the steps below:
+ git clone git@github.com:google-research/google-research.git
+ git clone git@github.com:tensorflow/models.git
+ mkdir google-research/tensorflow_models
+ touch google-research/tensorflow_models/__init__.py
+ cp -r models/research/audioset google-research/tensorflow_models/
+ touch google-research/tensorflow_models/audioset/__init__.py
+ echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
+ google-research/tensorflow_models/audioset/__init__.py
+ # we can now remove the tensorflow models repository
+ # rm -r models
+ cd google-research
+ Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
+ assumes it is placed in the AudioCraft reference dir.
+
+ Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
+ - Update xrange for range in:
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
+ - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
+ `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
+ - Update `import vggish_params as params` to `from . import vggish_params as params` in:
+ https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
+ - Add flag to provide a given batch size for running the AudioSet model in:
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
+ ```
+ flags.DEFINE_integer('batch_size', 64,
+ 'Number of samples in the batch for AudioSet model.')
+ ```
+ Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
+ `batch_size=FLAGS.batch_size` to the provided parameters.
+
+ 2. Follow instructions for the library installation and a valid TensorFlow installation
+ ```
+ # e.g. instructions from: https://www.tensorflow.org/install/pip
+ conda install -c conda-forge cudatoolkit=11.8.0
+ python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
+ mkdir -p $CONDA_PREFIX/etc/conda/activate.d
+ echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ # Verify install: on a machine with GPU device
+ python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
+ ```
+
+ Now install frechet_audio_distance required dependencies:
+ ```
+ # We assume we already have TensorFlow installed from the above steps
+ pip install apache-beam numpy scipy tf_slim
+ ```
+
+ Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
+ (you may want to specify --model_ckpt flag pointing to the model's path).
+
+ 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
+ and Tensorflow library path from the above installation steps:
+ export TF_PYTHON_EXE=""
+ export TF_LIBRARY_PATH=""
+
+ e.g. assuming we have installed everything in a dedicated conda env
+ with python 3.10 that is currently active:
+ export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
+ export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
+
+ Finally you may want to export the following variable:
+ export TF_FORCE_GPU_ALLOW_GROWTH=true
+ See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
+
+ You can save those environment variables in your training conda env, when currently active:
+ `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
+ e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
+ and the training conda env is named audiocraft:
+ ```
+ # activate training env
+ conda activate audiocraft
+ # get path to all envs
+ CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
+ # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
+ touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ # optionally:
+ echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+ # you may need to reactivate the audiocraft env for this to take effect
+ ```
+
+ Args:
+ bin (Path or str): Path to installed frechet audio distance code.
+ model_path (Path or str): Path to Tensorflow checkpoint for the model
+ used to compute statistics over the embedding beams.
+ format (str): Audio format used to save files.
+ log_folder (Path or str, optional): Path where to write process logs.
+ """
+ def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
+ format: str = "wav", batch_size: tp.Optional[int] = None,
+ log_folder: tp.Optional[tp.Union[Path, str]] = None):
+ super().__init__()
+ self.model_sample_rate = VGGISH_SAMPLE_RATE
+ self.model_channels = VGGISH_CHANNELS
+ self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
+ assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
+ self.format = format
+ self.batch_size = batch_size
+ self.bin = bin
+ self.tf_env = {"PYTHONPATH": str(self.bin)}
+ self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
+ logger.info("Python exe for TF is %s", self.python_path)
+ if 'TF_LIBRARY_PATH' in os.environ:
+ self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
+ if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
+ self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
+ logger.info("Env for TF is %r", self.tf_env)
+ self.reset(log_folder)
+ self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
+
+ def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
+ """Reset torchmetrics.Metrics state."""
+ log_folder = Path(log_folder or tempfile.mkdtemp())
+ self.tmp_dir = log_folder / 'fad'
+ self.tmp_dir.mkdir(exist_ok=True)
+ self.samples_tests_dir = self.tmp_dir / 'tests'
+ self.samples_tests_dir.mkdir(exist_ok=True)
+ self.samples_background_dir = self.tmp_dir / 'background'
+ self.samples_background_dir.mkdir(exist_ok=True)
+ self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
+ self.manifest_background = self.tmp_dir / 'files_background.cvs'
+ self.stats_tests_dir = self.tmp_dir / 'stats_tests'
+ self.stats_background_dir = self.tmp_dir / 'stats_background'
+ self.counter = 0
+
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
+ sizes: torch.Tensor, sample_rates: torch.Tensor,
+ stems: tp.Optional[tp.List[str]] = None):
+ """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
+ assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
+ num_samples = preds.shape[0]
+ assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
+ assert stems is None or num_samples == len(set(stems))
+ for i in range(num_samples):
+ self.total_files += 1 # type: ignore
+ self.counter += 1
+ wav_len = int(sizes[i].item())
+ sample_rate = int(sample_rates[i].item())
+ pred_wav = preds[i]
+ target_wav = targets[i]
+ pred_wav = pred_wav[..., :wav_len]
+ target_wav = target_wav[..., :wav_len]
+ stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
+ # dump audio files
+ try:
+ pred_wav = convert_audio(
+ pred_wav.unsqueeze(0), from_rate=sample_rate,
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+ audio_write(
+ self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
+ format=self.format, strategy="peak")
+ except Exception as e:
+ logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
+ try:
+ # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
+ # the original audio when writing it
+ target_wav = convert_audio(
+ target_wav.unsqueeze(0), from_rate=sample_rate,
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+ audio_write(
+ self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
+ format=self.format, strategy="peak")
+ except Exception as e:
+ logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
+
+ def _get_samples_name(self, is_background: bool):
+ return 'background' if is_background else 'tests'
+
+ def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
+ if is_background:
+ input_samples_dir = self.samples_background_dir
+ input_filename = self.manifest_background
+ stats_name = self.stats_background_dir
+ else:
+ input_samples_dir = self.samples_tests_dir
+ input_filename = self.manifest_tests
+ stats_name = self.stats_tests_dir
+ beams_name = self._get_samples_name(is_background)
+ log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
+
+ logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
+ with open(input_filename, "w") as fout:
+ for path in Path(input_samples_dir).glob(f"*.{self.format}"):
+ fout.write(f"{str(path)}\n")
+
+ cmd = [
+ self.python_path, "-m",
+ "frechet_audio_distance.create_embeddings_main",
+ "--model_ckpt", f"{self.model_path}",
+ "--input_files", f"{str(input_filename)}",
+ "--stats", f"{str(stats_name)}",
+ ]
+ if self.batch_size is not None:
+ cmd += ["--batch_size", str(self.batch_size)]
+ logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
+ env = os.environ
+ if gpu_index is not None:
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
+ process = subprocess.Popen(
+ cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
+ return process, log_file
+
+ def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
+ cmd = [
+ self.python_path, "-m", "frechet_audio_distance.compute_fad",
+ "--test_stats", f"{str(self.stats_tests_dir)}",
+ "--background_stats", f"{str(self.stats_background_dir)}",
+ ]
+ logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
+ env = os.environ
+ if gpu_index is not None:
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
+ result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
+ if result.returncode:
+ logger.error(
+ "Error with FAD computation from stats: \n %s \n %s",
+ result.stdout.decode(), result.stderr.decode()
+ )
+ raise RuntimeError("Error while executing FAD computation from stats")
+ try:
+ # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
+ fad_score = float(result.stdout[4:])
+ return fad_score
+ except Exception as e:
+ raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
+
+ def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
+ beams_name = self._get_samples_name(is_background)
+ if returncode:
+ with open(log_file, "r") as f:
+ error_log = f.read()
+ logger.error(error_log)
+ os._exit(1)
+ else:
+ logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
+
+ def _parallel_create_embedding_beams(self, num_of_gpus: int):
+ assert num_of_gpus > 0
+ logger.info("Creating embeddings beams in a parallel manner on different GPUs")
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
+ tests_beams_code = tests_beams_process.wait()
+ bg_beams_code = bg_beams_process.wait()
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
+
+ def _sequential_create_embedding_beams(self):
+ logger.info("Creating embeddings beams in a sequential manner")
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
+ tests_beams_code = tests_beams_process.wait()
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
+ bg_beams_code = bg_beams_process.wait()
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
+
+ @flashy.distrib.rank_zero_only
+ def _local_compute_frechet_audio_distance(self):
+ """Compute Frechet Audio Distance score calling TensorFlow API."""
+ num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+ if num_of_gpus > 1:
+ self._parallel_create_embedding_beams(num_of_gpus)
+ else:
+ self._sequential_create_embedding_beams()
+ fad_score = self._compute_fad_score(gpu_index=0)
+ return fad_score
+
+ def compute(self) -> float:
+ """Compute metrics."""
+ assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
+ fad_score = self._local_compute_frechet_audio_distance()
+ logger.warning(f"FAD score = {fad_score}")
+ fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
+ return fad_score
diff --git a/audiocraft/metrics/kld.py b/audiocraft/metrics/kld.py
new file mode 100644
index 0000000000000000000000000000000000000000..18260bf974bf47d8381223ac39be0c47c031bf8a
--- /dev/null
+++ b/audiocraft/metrics/kld.py
@@ -0,0 +1,218 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+from functools import partial
+import logging
+import os
+import typing as tp
+
+import torch
+import torchmetrics
+
+from ..data.audio_utils import convert_audio
+
+
+logger = logging.getLogger(__name__)
+
+
+class _patch_passt_stft:
+ """Decorator to patch torch.stft in PaSST."""
+ def __init__(self):
+ self.old_stft = torch.stft
+
+ def __enter__(self):
+ # return_complex is a mandatory parameter in latest torch versions
+ # torch is throwing RuntimeErrors when not set
+ torch.stft = partial(torch.stft, return_complex=False)
+
+ def __exit__(self, *exc):
+ torch.stft = self.old_stft
+
+
+def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
+ """Computes the elementwise KL-Divergence loss between probability distributions
+ from generated samples and target samples.
+
+ Args:
+ pred_probs (torch.Tensor): Probabilities for each label obtained
+ from a classifier on generated audio. Expected shape is [B, num_classes].
+ target_probs (torch.Tensor): Probabilities for each label obtained
+ from a classifier on target audio. Expected shape is [B, num_classes].
+ epsilon (float): Epsilon value.
+ Returns:
+ kld (torch.Tensor): KLD loss between each generated sample and target pair.
+ """
+ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
+ return kl_div.sum(-1)
+
+
+class KLDivergenceMetric(torchmetrics.Metric):
+ """Base implementation for KL Divergence metric.
+
+ The KL divergence is measured between probability distributions
+ of class predictions returned by a pre-trained audio classification model.
+ When the KL-divergence is low, the generated audio is expected to
+ have similar acoustic characteristics as the reference audio,
+ according to the classifier.
+ """
+ def __init__(self):
+ super().__init__()
+ self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+ self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+ self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+ self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
+ """Get model output given provided input tensor.
+
+ Args:
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+ Returns:
+ probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
+ """
+ raise NotImplementedError("implement method to extract label distributions from the model.")
+
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+ """Calculates running KL-Divergence loss between batches of audio
+ preds (generated) and target (ground-truth)
+ Args:
+ preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
+ targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+ """
+ assert preds.shape == targets.shape
+ assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
+ preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
+ targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
+ if preds_probs is not None and targets_probs is not None:
+ assert preds_probs.shape == targets_probs.shape
+ kld_scores = kl_divergence(preds_probs, targets_probs)
+ assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
+ self.kld_pq_sum += torch.sum(kld_scores)
+ kld_qp_scores = kl_divergence(targets_probs, preds_probs)
+ self.kld_qp_sum += torch.sum(kld_qp_scores)
+ self.weight += torch.tensor(kld_scores.size(0))
+
+ def compute(self) -> dict:
+ """Computes KL-Divergence across all evaluated pred/target pairs."""
+ weight: float = float(self.weight.item()) # type: ignore
+ assert weight > 0, "Unable to compute with total number of comparisons <= 0"
+ logger.info(f"Computing KL divergence on a total of {weight} samples")
+ kld_pq = self.kld_pq_sum.item() / weight # type: ignore
+ kld_qp = self.kld_qp_sum.item() / weight # type: ignore
+ kld_both = kld_pq + kld_qp
+ return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
+
+
+class PasstKLDivergenceMetric(KLDivergenceMetric):
+ """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
+
+ From: PaSST: Efficient Training of Audio Transformers with Patchout
+ Paper: https://arxiv.org/abs/2110.05069
+ Implementation: https://github.com/kkoutini/PaSST
+
+ Follow instructions from the github repo:
+ ```
+ pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
+ ```
+
+ Args:
+ pretrained_length (float, optional): Audio duration used for the pretrained model.
+ """
+ def __init__(self, pretrained_length: tp.Optional[float] = None):
+ super().__init__()
+ self._initialize_model(pretrained_length)
+
+ def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
+ """Initialize underlying PaSST audio classifier."""
+ model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
+ self.min_input_frames = min_frames
+ self.max_input_frames = max_frames
+ self.model_sample_rate = sr
+ self.model = model
+ self.model.eval()
+ self.model.to(self.device)
+
+ def _load_base_model(self, pretrained_length: tp.Optional[float]):
+ """Load pretrained model from PaSST."""
+ try:
+ if pretrained_length == 30:
+ from hear21passt.base30sec import get_basic_model # type: ignore
+ max_duration = 30
+ elif pretrained_length == 20:
+ from hear21passt.base20sec import get_basic_model # type: ignore
+ max_duration = 20
+ else:
+ from hear21passt.base import get_basic_model # type: ignore
+ # Original PASST was trained on AudioSet with 10s-long audio samples
+ max_duration = 10
+ min_duration = 0.15
+ min_duration = 0.15
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install hear21passt to compute KL divergence: ",
+ "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
+ )
+ model_sample_rate = 32_000
+ max_input_frames = int(max_duration * model_sample_rate)
+ min_input_frames = int(min_duration * model_sample_rate)
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
+ model = get_basic_model(mode='logits')
+ return model, model_sample_rate, max_input_frames, min_input_frames
+
+ def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.Optional[torch.Tensor]:
+ wav = wav.unsqueeze(0)
+ wav = wav[..., :wav_len]
+ wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
+ wav = wav.squeeze(0)
+ # create chunks of audio to match the classifier processing length
+ segments = torch.split(wav, self.max_input_frames, dim=-1)
+ valid_segments = []
+ for s in segments:
+ if s.size(-1) > self.min_input_frames:
+ s = torch.nn.functional.pad(s, (0, self.max_input_frames - s.shape[-1]))
+ valid_segments.append(s)
+ if len(valid_segments) > 0:
+ return torch.stack(valid_segments, dim=0)
+ else:
+ return None
+
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
+ """Get model output given provided input tensor.
+
+ Args:
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+ Returns:
+ probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
+ """
+ all_probs: tp.List[torch.Tensor] = []
+ for i, wav in enumerate(x):
+ sample_rate = int(sample_rates[i].item())
+ wav_len = int(sizes[i].item())
+ wav = self._process_audio(wav, sample_rate, wav_len)
+ if wav is not None:
+ assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
+ wav = wav.mean(dim=1)
+ # PaSST is printing a lot of infos that we are not interested in
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
+ with torch.no_grad(), _patch_passt_stft():
+ logits = self.model(wav.to(self.device))
+ probs = torch.softmax(logits, dim=-1)
+ probs = probs.mean(dim=0)
+ all_probs.append(probs)
+ if len(all_probs) > 0:
+ return torch.stack(all_probs, dim=0)
+ else:
+ return None
diff --git a/audiocraft/metrics/rvm.py b/audiocraft/metrics/rvm.py
new file mode 100644
index 0000000000000000000000000000000000000000..028324529531dd7ee97210dfd890fed717447be0
--- /dev/null
+++ b/audiocraft/metrics/rvm.py
@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+import torch
+from torch import nn
+import torchaudio
+
+
+def db_to_scale(volume: tp.Union[float, torch.Tensor]):
+ return 10 ** (volume / 20)
+
+
+def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
+ min_scale = db_to_scale(min_volume)
+ return 20 * torch.log10(scale.clamp(min=min_scale))
+
+
+class RelativeVolumeMel(nn.Module):
+ """Relative volume melspectrogram measure.
+
+ Computes a measure of distance over two mel spectrogram that is interpretable in terms
+ of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
+ first renormalize both by the ground truth of `x_ref`.
+
+ Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
+ relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
+ clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
+ with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
+ Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
+ average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
+ good (for a neural network output, although sound engineers typically aim for much lower attenuations).
+ Similarly, anything above +30 dB would just be completely missing the target, and there is no point
+ in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
+ in line with what neural nets currently can achieve.
+
+ For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
+ the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
+
+ The metric can be aggregated over a given frequency band in order have different insights for
+ different region of the spectrum. `num_aggregated_bands` controls the number of bands.
+
+ ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
+ is numerically stable when computing its gradient. We thus advise against using it as a training loss.
+
+ Args:
+ sample_rate (int): Sample rate of the input audio.
+ n_mels (int): Number of mel bands to use.
+ n_fft (int): Number of frequency bins for the STFT.
+ hop_length (int): Hop length of the STFT and the mel-spectrogram.
+ min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
+ the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
+ max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
+ max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
+ to that amount, to avoid rescaling near silence. Given in dB.
+ min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
+ bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
+ and anything below that will be considered equally.
+ num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
+ For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
+ """
+ def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
+ hop_length: int = 128, min_relative_volume: float = -25,
+ max_relative_volume: float = 25, max_initial_gain: float = 25,
+ min_activity_volume: float = -25,
+ num_aggregated_bands: int = 4) -> None:
+ super().__init__()
+ self.melspec = torchaudio.transforms.MelSpectrogram(
+ n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
+ normalized=True, sample_rate=sample_rate, power=2)
+ self.min_relative_volume = min_relative_volume
+ self.max_relative_volume = max_relative_volume
+ self.max_initial_gain = max_initial_gain
+ self.min_activity_volume = min_activity_volume
+ self.num_aggregated_bands = num_aggregated_bands
+
+ def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
+ """Compute RVM metric between estimate and reference samples.
+
+ Args:
+ estimate (torch.Tensor): Estimate sample.
+ ground_truth (torch.Tensor): Reference sample.
+
+ Returns:
+ dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
+ for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
+ """
+ min_scale = db_to_scale(-self.max_initial_gain)
+ std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
+ z_gt = self.melspec(ground_truth / std).sqrt()
+ z_est = self.melspec(estimate / std).sqrt()
+
+ delta = z_gt - z_est
+ ref_db = scale_to_db(z_gt, self.min_activity_volume)
+ delta_db = scale_to_db(delta.abs(), min_volume=-120)
+ relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
+ dims = list(range(relative_db.dim()))
+ dims.remove(dims[-2])
+ losses_per_band = relative_db.mean(dim=dims)
+ aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
+ metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
+ metrics['rvm'] = losses_per_band.mean()
+ return metrics
diff --git a/audiocraft/metrics/visqol.py b/audiocraft/metrics/visqol.py
new file mode 100644
index 0000000000000000000000000000000000000000..44f4b0a2c3c6c726857db8386491823dd85dde51
--- /dev/null
+++ b/audiocraft/metrics/visqol.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import csv
+import json
+import logging
+from pathlib import Path
+import tempfile
+import typing as tp
+import subprocess
+import shutil
+
+import torch
+import torchaudio
+
+logger = logging.getLogger(__name__)
+
+
+class ViSQOL:
+ """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
+
+ To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
+ instructions available in the open source repository: https://github.com/google/visqol
+
+ ViSQOL is capable of running in two modes:
+
+ Audio Mode:
+ When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
+ Audio mode uses support vector regression, with the maximum range at ~4.75.
+
+ Speech Mode:
+ When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
+ Input should be resampled to 16kHz.
+ As part of the speech mode processing, a root mean square implementation for voice activity detection
+ is performed on the reference signal to determine what parts of the signal have voice activity and
+ should therefore be included in the comparison. The signal is normalized before performing the voice
+ activity detection.
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
+ Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
+
+ For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
+
+ Args:
+ visqol_bin (str): Path to the ViSQOL binary.
+ mode (str): ViSQOL computation mode, expecting "audio" or "speech".
+ model (str): Name of the model to use for similarity to quality model.
+ debug (bool): Whether to also get debug metrics from ViSQOL or not.
+ """
+ SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
+ ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
+
+ def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
+ model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
+ assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
+ self.visqol_bin = str(bin)
+ self.visqol_mode = mode
+ self.target_sr = self._get_target_sr(self.visqol_mode)
+ self.model = model
+ self.debug = debug
+ assert Path(self.visqol_model).exists(), \
+ f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
+
+ def _get_target_sr(self, mode: str) -> int:
+ # returns target sampling rate for the corresponding ViSQOL mode.
+ if mode not in ViSQOL.SAMPLE_RATES_MODES:
+ raise ValueError(
+ f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
+ )
+ return ViSQOL.SAMPLE_RATES_MODES[mode]
+
+ def _prepare_files(
+ self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
+ ):
+ # prepare files for ViSQOL evaluation.
+ assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
+ assert len(ref_sig) == len(deg_sig), (
+ "Expects same number of ref and degraded inputs",
+ f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
+ )
+ # resample audio if needed
+ if sr != target_sr:
+ transform = torchaudio.transforms.Resample(sr, target_sr)
+ pad = int(0.5 * target_sr)
+ rs_ref = []
+ rs_deg = []
+ for i in range(len(ref_sig)):
+ rs_ref_i = transform(ref_sig[i])
+ rs_deg_i = transform(deg_sig[i])
+ if pad_with_silence:
+ rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
+ rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
+ rs_ref.append(rs_ref_i)
+ rs_deg.append(rs_deg_i)
+ ref_sig = torch.stack(rs_ref)
+ deg_sig = torch.stack(rs_deg)
+ # save audio chunks to tmp dir and create csv
+ tmp_dir = Path(tempfile.mkdtemp())
+ try:
+ tmp_input_csv_path = tmp_dir / "input.csv"
+ tmp_results_csv_path = tmp_dir / "results.csv"
+ tmp_debug_json_path = tmp_dir / "debug.json"
+ with open(tmp_input_csv_path, "w") as csv_file:
+ csv_writer = csv.writer(csv_file)
+ csv_writer.writerow(["reference", "degraded"])
+ for i in range(len(ref_sig)):
+ tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
+ tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
+ torchaudio.save(
+ tmp_ref_filename,
+ torch.clamp(ref_sig[i], min=-0.99, max=0.99),
+ sample_rate=target_sr,
+ bits_per_sample=16,
+ encoding="PCM_S"
+ )
+ torchaudio.save(
+ tmp_deg_filename,
+ torch.clamp(deg_sig[i], min=-0.99, max=0.99),
+ sample_rate=target_sr,
+ bits_per_sample=16,
+ encoding="PCM_S"
+ )
+ csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
+ return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
+ except Exception as e:
+ logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
+ return tmp_dir, None, None, None
+
+ def _flush_files(self, tmp_dir: tp.Union[Path, str]):
+ # flush tmp files used to compute ViSQOL.
+ shutil.rmtree(str(tmp_dir))
+
+ def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
+ # collect results for each evaluated pair and return averaged moslqo score.
+ with open(results_csv_path, "r") as csv_file:
+ reader = csv.DictReader(csv_file)
+ moslqo_scores = [float(row["moslqo"]) for row in reader]
+ if len(moslqo_scores) > 0:
+ return sum(moslqo_scores) / len(moslqo_scores)
+ else:
+ return 0.0
+
+ def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
+ # collect debug data for the visqol inference.
+ with open(debug_json_path, "r") as f:
+ data = json.load(f)
+ return data
+
+ @property
+ def visqol_model(self):
+ return f'{self.visqol_bin}/model/{self.model}'
+
+ def _run_visqol(
+ self,
+ input_csv_path: tp.Union[Path, str],
+ results_csv_path: tp.Union[Path, str],
+ debug_csv_path: tp.Optional[tp.Union[Path, str]],
+ ):
+ input_csv_path = str(input_csv_path)
+ results_csv_path = str(results_csv_path)
+ debug_csv_path = str(debug_csv_path)
+ cmd = [
+ f'{self.visqol_bin}/bazel-bin/visqol',
+ '--batch_input_csv', f'{input_csv_path}',
+ '--results_csv', f'{results_csv_path}'
+ ]
+ if debug_csv_path is not None:
+ cmd += ['--output_debug', f'{debug_csv_path}']
+ if self.visqol_mode == "speech":
+ cmd += ['--use_speech_mode']
+ cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
+ result = subprocess.run(cmd, capture_output=True)
+ if result.returncode:
+ logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
+ raise RuntimeError("Error while executing visqol")
+ result.check_returncode()
+
+ def __call__(
+ self,
+ ref_sig: torch.Tensor,
+ deg_sig: torch.Tensor,
+ sr: int,
+ pad_with_silence: bool = False,
+ ):
+ """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
+ Args:
+ ref_sig (torch.Tensor): Reference signals as [B, C, T].
+ deg_sig (torch.Tensor): Degraded signals as [B, C, T].
+ sr (int): Sample rate of the two audio signals.
+ pad_with_silence (bool): Whether to pad the file with silences as recommended
+ in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
+ Returns:
+ float: The ViSQOL score or mean score for the batch.
+ """
+ logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
+ tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
+ ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
+ )
+ try:
+ if input_csv and results_csv:
+ self._run_visqol(
+ input_csv,
+ results_csv,
+ debug_json if self.debug else None,
+ )
+ mosqol = self._collect_moslqo_score(results_csv)
+ return mosqol
+ else:
+ raise RuntimeError("Something unexpected happened when running VISQOL!")
+ except Exception as e:
+ logger.error("Exception occurred when running ViSQOL: %s", e)
+ finally:
+ self._flush_files(tmp_dir)
diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be6bfe4b787a132aeaabaed1c3437c9ecd5c656c
--- /dev/null
+++ b/audiocraft/models/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
+"""
+# flake8: noqa
+from . import builders, loaders
+from .encodec import (
+ CompressionModel, EncodecModel, DAC,
+ HFEncodecModel, HFEncodecCompressionModel)
+from .audiogen import AudioGen
+from .lm import LMModel
+from .multibanddiffusion import MultiBandDiffusion
+from .musicgen import MusicGen
+from .unet import DiffusionUnet
diff --git a/audiocraft/models/__pycache__/__init__.cpython-310.pyc b/audiocraft/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8026d3a100e97b2b5e96a7a829e857ac6ee21c82
Binary files /dev/null and b/audiocraft/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/audiogen.cpython-310.pyc b/audiocraft/models/__pycache__/audiogen.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c7c3096a9371c82be956bf7d82d72d0723cc894
Binary files /dev/null and b/audiocraft/models/__pycache__/audiogen.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/builders.cpython-310.pyc b/audiocraft/models/__pycache__/builders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ddca8d54c4bb0518376614e29843125afdd2562
Binary files /dev/null and b/audiocraft/models/__pycache__/builders.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/encodec.cpython-310.pyc b/audiocraft/models/__pycache__/encodec.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4eae86f026b93c8f0675b93f899e23bee43ecbb1
Binary files /dev/null and b/audiocraft/models/__pycache__/encodec.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/lm.cpython-310.pyc b/audiocraft/models/__pycache__/lm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e17a05404a6e67140f74027dd9c39fc00c8b9ca
Binary files /dev/null and b/audiocraft/models/__pycache__/lm.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/loaders.cpython-310.pyc b/audiocraft/models/__pycache__/loaders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7c48c9466426d9b309674cd9c306c745cd9199e
Binary files /dev/null and b/audiocraft/models/__pycache__/loaders.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/multibanddiffusion.cpython-310.pyc b/audiocraft/models/__pycache__/multibanddiffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3ba742e63bd447958b3265f17731aae38ffb26a
Binary files /dev/null and b/audiocraft/models/__pycache__/multibanddiffusion.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/musicgen.cpython-310.pyc b/audiocraft/models/__pycache__/musicgen.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ec6de5b70138c4357a33068aeec1c395fd19585
Binary files /dev/null and b/audiocraft/models/__pycache__/musicgen.cpython-310.pyc differ
diff --git a/audiocraft/models/__pycache__/unet.cpython-310.pyc b/audiocraft/models/__pycache__/unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8fc769e01fd845bfb783c3bb5daf2155630d4963
Binary files /dev/null and b/audiocraft/models/__pycache__/unet.cpython-310.pyc differ
diff --git a/audiocraft/models/audiogen.py b/audiocraft/models/audiogen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6adefb97401c10422c9711d222c0857f5593dceb
--- /dev/null
+++ b/audiocraft/models/audiogen.py
@@ -0,0 +1,276 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Main model for using AudioGen. This will combine all the required components
+and provide easy access to the generation API.
+"""
+
+import typing as tp
+
+import torch
+
+from .encodec import CompressionModel
+from .lm import LMModel
+from .builders import get_debug_compression_model, get_debug_lm_model
+from .loaders import load_compression_model, load_lm_model
+from ..data.audio_utils import convert_audio
+from ..modules.conditioners import ConditioningAttributes
+from ..utils.autocast import TorchAutocast
+
+
+class AudioGen:
+ """AudioGen main model with convenient generation API.
+
+ Args:
+ name (str): name of the model.
+ compression_model (CompressionModel): Compression model
+ used to map audio to invertible discrete representations.
+ lm (LMModel): Language model over discrete representations.
+ max_duration (float, optional): maximum duration the model can produce,
+ otherwise, inferred from the training params.
+ """
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+ max_duration: tp.Optional[float] = None):
+ self.name = name
+ self.compression_model = compression_model
+ self.lm = lm
+ if max_duration is None:
+ if hasattr(lm, 'cfg'):
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
+ else:
+ raise ValueError("You must provide max_duration when building directly AudioGen")
+ assert max_duration is not None
+ self.max_duration: float = max_duration
+ self.device = next(iter(lm.parameters())).device
+ self.generation_params: dict = {}
+ self.set_generation_params(duration=5) # 5 seconds by default
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+ if self.device.type == 'cpu':
+ self.autocast = TorchAutocast(enabled=False)
+ else:
+ self.autocast = TorchAutocast(
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+ @property
+ def frame_rate(self) -> float:
+ """Roughly the number of AR steps per seconds."""
+ return self.compression_model.frame_rate
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate of the generated audio."""
+ return self.compression_model.sample_rate
+
+ @property
+ def audio_channels(self) -> int:
+ """Audio channels of the generated audio."""
+ return self.compression_model.channels
+
+ @staticmethod
+ def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
+ """Return pretrained model, we provide a single model for now:
+ - facebook/audiogen-medium (1.5B), text to sound,
+ # see: https://huggingface.co/facebook/audiogen-medium
+ """
+ if device is None:
+ if torch.cuda.device_count():
+ device = 'cuda'
+ else:
+ device = 'cpu'
+
+ if name == 'debug':
+ # used only for unit tests
+ compression_model = get_debug_compression_model(device, sample_rate=16000)
+ lm = get_debug_lm_model(device)
+ return AudioGen(name, compression_model, lm, max_duration=10)
+
+ compression_model = load_compression_model(name, device=device)
+ lm = load_lm_model(name, device=device)
+ assert 'self_wav' not in lm.condition_provider.conditioners, \
+ "AudioGen do not support waveform conditioning for now"
+ return AudioGen(name, compression_model, lm)
+
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+ top_p: float = 0.0, temperature: float = 1.0,
+ duration: float = 10.0, cfg_coef: float = 3.0,
+ two_step_cfg: bool = False, extend_stride: float = 2):
+ """Set the generation parameters for AudioGen.
+
+ Args:
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+ duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+ instead of batching together the two. This has some impact on how things
+ are padded but seems to have little impact in practice.
+ extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
+ should we extend the audio each time. Larger values will mean less context is
+ preserved, and shorter value will require extra computations.
+ """
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+ self.extend_stride = extend_stride
+ self.duration = duration
+ self.generation_params = {
+ 'use_sampling': use_sampling,
+ 'temp': temperature,
+ 'top_k': top_k,
+ 'top_p': top_p,
+ 'cfg_coef': cfg_coef,
+ 'two_step_cfg': two_step_cfg,
+ }
+
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+ """Override the default progress callback."""
+ self._progress_callback = progress_callback
+
+ def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+ """Generate samples conditioned on text.
+
+ Args:
+ descriptions (list of str): A list of strings used as text conditioning.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+ assert prompt_tokens is None
+ return self._generate_tokens(attributes, prompt_tokens, progress)
+
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+ progress: bool = False) -> torch.Tensor:
+ """Generate samples conditioned on audio prompts.
+
+ Args:
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ if prompt.dim() == 2:
+ prompt = prompt[None]
+ if prompt.dim() != 3:
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+ if descriptions is None:
+ descriptions = [None] * len(prompt)
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+ assert prompt_tokens is not None
+ return self._generate_tokens(attributes, prompt_tokens, progress)
+
+ @torch.no_grad()
+ def _prepare_tokens_and_attributes(
+ self,
+ descriptions: tp.Sequence[tp.Optional[str]],
+ prompt: tp.Optional[torch.Tensor],
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+ """Prepare model inputs.
+
+ Args:
+ descriptions (list of str): A list of strings used as text conditioning.
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
+ """
+ attributes = [
+ ConditioningAttributes(text={'description': description})
+ for description in descriptions]
+
+ if prompt is not None:
+ if descriptions is not None:
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+ prompt = prompt.to(self.device)
+ prompt_tokens, scale = self.compression_model.encode(prompt)
+ assert scale is None
+ else:
+ prompt_tokens = None
+ return attributes, prompt_tokens
+
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+ """Generate discrete audio tokens given audio prompt and/or conditions.
+
+ Args:
+ attributes (list of ConditioningAttributes): Conditions used for generation (here text).
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ Returns:
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+ """
+ i = 0
+ prompt_list = attributes[0].text['description']
+ total_gen_len = int(self.duration * self.frame_rate)
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+ current_gen_offset: int = 0
+
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+ generated_tokens += current_gen_offset
+ if self._progress_callback is not None:
+ # Note that total_gen_len might be quite wrong depending on the
+ # codebook pattern used, but with delay it is almost accurate.
+ self._progress_callback(generated_tokens, total_gen_len)
+ else:
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+ if prompt_tokens is not None:
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
+ "Prompt is longer than audio to generate"
+
+ callback = None
+ if progress:
+ callback = _progress_callback
+
+ if self.duration <= self.max_duration:
+ # generate by sampling from LM, simple case.
+ with self.autocast:
+ attributes[0].text['description'] = prompt_list[0]
+ gen_tokens = self.lm.generate(
+ prompt_tokens, attributes,
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+ else:
+ all_tokens = []
+ if prompt_tokens is None:
+ prompt_length = 0
+ else:
+ all_tokens.append(prompt_tokens)
+ prompt_length = prompt_tokens.shape[-1]
+
+ stride_tokens = int(self.frame_rate * self.extend_stride)
+
+ while current_gen_offset + prompt_length < total_gen_len:
+ time_offset = current_gen_offset / self.frame_rate
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
+ max_gen_len = int(chunk_duration * self.frame_rate)
+ with self.autocast:
+ if i >= len(prompt_list):
+ i = len(prompt_list) - 1
+ attributes[0].text['description'] = prompt_list[i]
+ gen_tokens = self.lm.generate(
+ prompt_tokens, attributes,
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+ i = i + 1
+ if prompt_tokens is None:
+ all_tokens.append(gen_tokens)
+ else:
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
+ prompt_length = prompt_tokens.shape[-1]
+ current_gen_offset += stride_tokens
+
+ gen_tokens = torch.cat(all_tokens, dim=-1)
+
+ # generate audio
+ assert gen_tokens.dim() == 3
+ with torch.no_grad():
+ gen_audio = self.compression_model.decode(gen_tokens, None)
+ return gen_audio
+
+ def to(self, device: str):
+ self.compression_model.to(device)
+ self.lm.to(device)
+ return self
\ No newline at end of file
diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py
new file mode 100644
index 0000000000000000000000000000000000000000..038bf99c3d0fbbb86005683d5a2a1b4edcac4298
--- /dev/null
+++ b/audiocraft/models/builders.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+All the functions to build the relevant models and modules
+from the Hydra config.
+"""
+
+import typing as tp
+
+import audiocraft
+import omegaconf
+import torch
+
+from .encodec import CompressionModel, EncodecModel
+from .lm import LMModel
+from ..modules.codebooks_patterns import (
+ CodebooksPatternProvider,
+ DelayedPatternProvider,
+ MusicLMPattern,
+ ParallelPatternProvider,
+ UnrolledPatternProvider,
+ VALLEPattern,
+)
+from ..modules.conditioners import (
+ BaseConditioner,
+ ChromaStemConditioner,
+ CLAPEmbeddingConditioner,
+ ConditionFuser,
+ ConditioningProvider,
+ LUTConditioner,
+ T5Conditioner,
+)
+from .unet import DiffusionUnet
+from .. import quantization as qt
+from ..utils.utils import dict_from_config
+from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
+
+
+def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+ klass = {
+ 'no_quant': qt.DummyQuantizer,
+ 'rvq': qt.ResidualVectorQuantizer
+ }[quantizer]
+ kwargs = dict_from_config(getattr(cfg, quantizer))
+ if quantizer != 'no_quant':
+ kwargs['dimension'] = dimension
+ return klass(**kwargs)
+
+
+def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
+ if encoder_name == 'seanet':
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
+ encoder_override_kwargs = kwargs.pop('encoder')
+ decoder_override_kwargs = kwargs.pop('decoder')
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+ encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+ decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+ return encoder, decoder
+ else:
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
+
+
+def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
+ """Instantiate a compression model."""
+ if cfg.compression_model == 'encodec':
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
+ encoder_name = kwargs.pop('autoencoder')
+ quantizer_name = kwargs.pop('quantizer')
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
+ renormalize = kwargs.pop('renormalize', False)
+ # deprecated params
+ kwargs.pop('renorm', None)
+ return EncodecModel(encoder, decoder, quantizer,
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+ else:
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
+
+
+def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
+ """Instantiate a transformer LM."""
+ if cfg.lm_model == 'transformer_lm':
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
+ n_q = kwargs['n_q']
+ q_modeling = kwargs.pop('q_modeling', None)
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
+ fuser = get_condition_fuser(cfg)
+ condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
+ kwargs['cross_attention'] = True
+ if codebooks_pattern_cfg.modeling is None:
+ assert q_modeling is not None, \
+ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
+ {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
+ )
+ pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+ return LMModel(
+ pattern_provider=pattern_provider,
+ condition_provider=condition_provider,
+ fuser=fuser,
+ cfg_dropout=cfg_prob,
+ cfg_coef=cfg_coef,
+ attribute_dropout=attribute_dropout,
+ dtype=getattr(torch, cfg.dtype),
+ device=cfg.device,
+ **kwargs
+ ).to(cfg.device)
+ else:
+ raise KeyError(f"Unexpected LM model {cfg.lm_model}")
+
+
+def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+ """Instantiate a conditioning model."""
+ device = cfg.device
+ duration = cfg.dataset.segment_duration
+ cfg = getattr(cfg, 'conditioners')
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
+ conditioners: tp.Dict[str, BaseConditioner] = {}
+ condition_provider_args = dict_cfg.pop('args', {})
+ condition_provider_args.pop('merge_text_conditions_p', None)
+ condition_provider_args.pop('drop_desc_p', None)
+
+ for cond, cond_cfg in dict_cfg.items():
+ model_type = cond_cfg['model']
+ model_args = cond_cfg[model_type]
+ if model_type == 't5':
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+ elif model_type == 'lut':
+ conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+ elif model_type == 'chroma_stem':
+ conditioners[str(cond)] = ChromaStemConditioner(
+ output_dim=output_dim,
+ duration=duration,
+ device=device,
+ **model_args
+ )
+ elif model_type == 'clap':
+ conditioners[str(cond)] = CLAPEmbeddingConditioner(
+ output_dim=output_dim,
+ device=device,
+ **model_args
+ )
+ else:
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
+ conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+ return conditioner
+
+
+def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
+ """Instantiate a condition fuser object."""
+ fuser_cfg = getattr(cfg, 'fuser')
+ fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
+ return fuser
+
+
+def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+ """Instantiate a codebooks pattern provider object."""
+ pattern_providers = {
+ 'parallel': ParallelPatternProvider,
+ 'delay': DelayedPatternProvider,
+ 'unroll': UnrolledPatternProvider,
+ 'valle': VALLEPattern,
+ 'musiclm': MusicLMPattern,
+ }
+ name = cfg.modeling
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
+ klass = pattern_providers[name]
+ return klass(n_q, **kwargs)
+
+
+def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
+ """Instantiate a debug compression model to be used for unit tests."""
+ assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
+ model_ratios = {
+ 16000: [10, 8, 8], # 25 Hz at 16kHz
+ 32000: [10, 8, 16] # 25 Hz at 32kHz
+ }
+ ratios: tp.List[int] = model_ratios[sample_rate]
+ frame_rate = 25
+ seanet_kwargs: dict = {
+ 'n_filters': 4,
+ 'n_residual_layers': 1,
+ 'dimension': 32,
+ 'ratios': ratios,
+ }
+ print(seanet_kwargs)
+ encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
+ decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
+ quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
+ init_x = torch.randn(8, 32, 128)
+ quantizer(init_x, 1) # initialize kmeans etc.
+ compression_model = EncodecModel(
+ encoder, decoder, quantizer,
+ frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
+ return compression_model.eval()
+
+
+def get_diffusion_model(cfg: omegaconf.DictConfig):
+ # TODO Find a way to infer the channels from dset
+ channels = cfg.channels
+ num_steps = cfg.schedule.num_steps
+ return DiffusionUnet(
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+
+
+def get_processor(cfg, sample_rate: int = 24000):
+ sample_processor = SampleProcessor()
+ if cfg.use:
+ kw = dict(cfg)
+ kw.pop('use')
+ kw.pop('name')
+ if cfg.name == "multi_band_processor":
+ sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
+ return sample_processor
+
+
+def get_debug_lm_model(device='cpu'):
+ """Instantiate a debug LM to be used for unit tests."""
+ pattern = DelayedPatternProvider(n_q=4)
+ dim = 16
+ providers = {
+ 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+ }
+ condition_provider = ConditioningProvider(providers)
+ fuser = ConditionFuser(
+ {'cross': ['description'], 'prepend': [],
+ 'sum': [], 'input_interpolate': []})
+ lm = LMModel(
+ pattern, condition_provider, fuser,
+ n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
+ cross_attention=True, causal=True)
+ return lm.to(device).eval()
+
+
+def get_wrapped_compression_model(
+ compression_model: CompressionModel,
+ cfg: omegaconf.DictConfig) -> CompressionModel:
+ # more to come.
+ return compression_model
diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cf6b54b582975a01bdb7a06280c766d3d2cc72c
--- /dev/null
+++ b/audiocraft/models/encodec.py
@@ -0,0 +1,392 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Compression models or wrapper around existing models.
+Also defines the main interface that a model must follow to be usable as an audio tokenizer.
+"""
+
+from abc import ABC, abstractmethod
+import logging
+import math
+from pathlib import Path
+import typing as tp
+
+import numpy as np
+import torch
+from torch import nn
+from transformers import EncodecModel as HFEncodecModel
+
+from .. import quantization as qt
+
+
+logger = logging.getLogger()
+
+
+class CompressionModel(ABC, nn.Module):
+ """Base API for all compression model that aim at being used as audio tokenizers
+ with a language model.
+ """
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ ...
+
+ @abstractmethod
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ """See `EncodecModel.encode`."""
+ ...
+
+ @abstractmethod
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ """See `EncodecModel.decode`."""
+ ...
+
+ @abstractmethod
+ def decode_latent(self, codes: torch.Tensor):
+ """Decode from the discrete codes to continuous latent space."""
+ ...
+
+ @property
+ @abstractmethod
+ def channels(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def frame_rate(self) -> float:
+ ...
+
+ @property
+ @abstractmethod
+ def sample_rate(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def cardinality(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def num_codebooks(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def total_codebooks(self) -> int:
+ ...
+
+ @abstractmethod
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer."""
+ ...
+
+ @staticmethod
+ def get_pretrained(
+ name: str, device: tp.Union[torch.device, str] = 'cpu'
+ ) -> 'CompressionModel':
+ """Instantiate a CompressionModel from a given pretrained model.
+
+ Args:
+ name (Path or str): name of the pretrained model. See after.
+ device (torch.device or str): Device on which the model is loaded.
+
+ Pretrained models:
+ - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
+ - dac_24khz (same)
+ - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
+ - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
+ - your own model on HugginFace. Export instructions to come...
+ """
+
+ from . import builders, loaders
+ model: CompressionModel
+ if name in ['dac_44khz', 'dac_24khz']:
+ model_type = name.split('_')[1]
+ logger.info("Getting pretrained compression model from DAC %s", model_type)
+ model = DAC(model_type)
+ elif name in ['debug_compression_model']:
+ logger.info("Getting pretrained compression model for debug")
+ model = builders.get_debug_compression_model()
+ elif Path(name).exists():
+ # We assume here if the paths exist that it is in fact an AC checkpoint
+ # that was exported using `audiocraft.utils.export` functions.
+ model = loaders.load_compression_model(name, device=device)
+ else:
+ logger.info("Getting pretrained compression model from HF %s", name)
+ hf_model = HFEncodecModel.from_pretrained(name)
+ model = HFEncodecCompressionModel(hf_model).to(device)
+ return model.to(device).eval()
+
+
+class EncodecModel(CompressionModel):
+ """Encodec model operating on the raw waveform.
+
+ Args:
+ encoder (nn.Module): Encoder network.
+ decoder (nn.Module): Decoder network.
+ quantizer (qt.BaseQuantizer): Quantizer network.
+ frame_rate (int): Frame rate for the latent representation.
+ sample_rate (int): Audio sample rate.
+ channels (int): Number of audio channels.
+ causal (bool): Whether to use a causal version of the model.
+ renormalize (bool): Whether to renormalize the audio before running the model.
+ """
+ # we need assignment to override the property in the abstract class,
+ # I couldn't find a better way...
+ frame_rate: float = 0
+ sample_rate: int = 0
+ channels: int = 0
+
+ def __init__(self,
+ encoder: nn.Module,
+ decoder: nn.Module,
+ quantizer: qt.BaseQuantizer,
+ frame_rate: int,
+ sample_rate: int,
+ channels: int,
+ causal: bool = False,
+ renormalize: bool = False):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.quantizer = quantizer
+ self.frame_rate = frame_rate
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.renormalize = renormalize
+ self.causal = causal
+ if self.causal:
+ # we force disabling here to avoid handling linear overlap of segments
+ # as supported in original EnCodec codebase.
+ assert not self.renormalize, 'Causal model does not support renormalize'
+
+ @property
+ def total_codebooks(self):
+ """Total number of quantizer codebooks available."""
+ return self.quantizer.total_codebooks
+
+ @property
+ def num_codebooks(self):
+ """Active number of codebooks used by the quantizer."""
+ return self.quantizer.num_codebooks
+
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer."""
+ self.quantizer.set_num_codebooks(n)
+
+ @property
+ def cardinality(self):
+ """Cardinality of each codebook."""
+ return self.quantizer.bins
+
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ scale: tp.Optional[torch.Tensor]
+ if self.renormalize:
+ mono = x.mean(dim=1, keepdim=True)
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+ scale = 1e-8 + volume
+ x = x / scale
+ scale = scale.view(-1, 1)
+ else:
+ scale = None
+ return x, scale
+
+ def postprocess(self,
+ x: torch.Tensor,
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+ if scale is not None:
+ assert self.renormalize
+ x = x * scale.view(-1, 1, 1)
+ return x
+
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ assert x.dim() == 3
+ length = x.shape[-1]
+ x, scale = self.preprocess(x)
+
+ emb = self.encoder(x)
+ q_res = self.quantizer(emb, self.frame_rate)
+ out = self.decoder(q_res.x)
+
+ # remove extra padding added by the encoder and decoder
+ assert out.shape[-1] >= length, (out.shape[-1], length)
+ out = out[..., :length]
+
+ q_res.x = self.postprocess(out, scale)
+
+ return q_res
+
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ """Encode the given input tensor to quantized representation along with scale parameter.
+
+ Args:
+ x (torch.Tensor): Float tensor of shape [B, C, T]
+
+ Returns:
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+ scale a float tensor containing the scale for audio renormalizealization.
+ """
+ assert x.dim() == 3
+ x, scale = self.preprocess(x)
+ emb = self.encoder(x)
+ codes = self.quantizer.encode(emb)
+ return codes, scale
+
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ """Decode the given codes to a reconstructed representation, using the scale to perform
+ audio denormalization if needed.
+
+ Args:
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
+
+ Returns:
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+ """
+ emb = self.decode_latent(codes)
+ out = self.decoder(emb)
+ out = self.postprocess(out, scale)
+ # out contains extra padding added by the encoder and decoder
+ return out
+
+ def decode_latent(self, codes: torch.Tensor):
+ """Decode from the discrete codes to continuous latent space."""
+ return self.quantizer.decode(codes)
+
+
+class DAC(CompressionModel):
+ def __init__(self, model_type: str = "44khz"):
+ super().__init__()
+ try:
+ import dac.utils
+ except ImportError:
+ raise RuntimeError("Could not import dac, make sure it is installed, "
+ "please run `pip install descript-audio-codec`")
+ self.model = dac.utils.load_model(model_type=model_type)
+ self.n_quantizers = self.total_codebooks
+
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ # We don't support training with this.
+ raise NotImplementedError("Forward and training with DAC not supported.")
+
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ codes = self.model.encode(x, self.n_quantizers)[1]
+ return codes, None
+
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ assert scale is None
+ z_q = self.decode_latent(codes)
+ return self.model.decode(z_q)
+
+ def decode_latent(self, codes: torch.Tensor):
+ """Decode from the discrete codes to continuous latent space."""
+ return self.model.quantizer.from_codes(codes)[0]
+
+ @property
+ def channels(self) -> int:
+ return 1
+
+ @property
+ def frame_rate(self) -> float:
+ return self.model.sample_rate / self.model.hop_length
+
+ @property
+ def sample_rate(self) -> int:
+ return self.model.sample_rate
+
+ @property
+ def cardinality(self) -> int:
+ return self.model.codebook_size
+
+ @property
+ def num_codebooks(self) -> int:
+ return self.n_quantizers
+
+ @property
+ def total_codebooks(self) -> int:
+ return self.model.n_codebooks
+
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer.
+ """
+ assert n >= 1
+ assert n <= self.total_codebooks
+ self.n_quantizers = n
+
+
+class HFEncodecCompressionModel(CompressionModel):
+ """Wrapper around HuggingFace Encodec.
+ """
+ def __init__(self, model: HFEncodecModel):
+ super().__init__()
+ self.model = model
+ bws = self.model.config.target_bandwidths
+ num_codebooks = [
+ bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
+ for bw in bws
+ ]
+ deltas = [nc - int(nc) for nc in num_codebooks]
+ # Checking we didn't do some bad maths and we indeed have integers!
+ assert all(deltas) <= 1e-3, deltas
+ self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
+ self.set_num_codebooks(max(self.possible_num_codebooks))
+
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ # We don't support training with this.
+ raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
+
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
+ bandwidth = self.model.config.target_bandwidths[bandwidth_index]
+ res = self.model.encode(x, None, bandwidth)
+ assert len(res[0]) == 1
+ assert len(res[1]) == 1
+ return res[0][0], res[1][0]
+
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ if scale is None:
+ scales = [None] # type: ignore
+ else:
+ scales = scale # type: ignore
+ res = self.model.decode(codes[None], scales)
+ return res[0]
+
+ def decode_latent(self, codes: torch.Tensor):
+ """Decode from the discrete codes to continuous latent space."""
+ return self.model.quantizer.decode(codes.transpose(0, 1))
+
+ @property
+ def channels(self) -> int:
+ return self.model.config.audio_channels
+
+ @property
+ def frame_rate(self) -> float:
+ hop_length = int(np.prod(self.model.config.upsampling_ratios))
+ return self.sample_rate / hop_length
+
+ @property
+ def sample_rate(self) -> int:
+ return self.model.config.sampling_rate
+
+ @property
+ def cardinality(self) -> int:
+ return self.model.config.codebook_size
+
+ @property
+ def num_codebooks(self) -> int:
+ return self._num_codebooks
+
+ @property
+ def total_codebooks(self) -> int:
+ return max(self.possible_num_codebooks)
+
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer.
+ """
+ if n not in self.possible_num_codebooks:
+ raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
+ self._num_codebooks = n
diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cefd2c58c3a337378579d6cd6469fd038cbb1ee
--- /dev/null
+++ b/audiocraft/models/lm.py
@@ -0,0 +1,531 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from functools import partial
+import logging
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from ..utils import utils
+from ..modules.streaming import StreamingModule, State
+from ..modules.transformer import StreamingTransformer, create_norm_fn
+from ..modules.conditioners import (
+ ConditionFuser,
+ ClassifierFreeGuidanceDropout,
+ AttributeDropout,
+ ConditioningProvider,
+ ConditioningAttributes,
+ ConditionType,
+)
+from ..modules.codebooks_patterns import CodebooksPatternProvider
+from ..modules.activations import get_activation_fn
+
+
+logger = logging.getLogger(__name__)
+ConditionTensors = tp.Dict[str, ConditionType]
+CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
+
+
+def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
+ """LM layer initialization.
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
+
+ Args:
+ method (str): Method name for init function. Valid options are:
+ 'gaussian', 'uniform'.
+ input_dim (int): Input dimension of the initialized module.
+ init_depth (int, optional): Optional init depth value used to rescale
+ the standard deviation if defined.
+ """
+ # Compute std
+ std = 1 / math.sqrt(input_dim)
+ # Rescale with depth
+ if init_depth is not None:
+ std = std / math.sqrt(2 * init_depth)
+
+ if method == 'gaussian':
+ return partial(
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
+ )
+ elif method == 'uniform':
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
+ else:
+ raise ValueError("Unsupported layer initialization method")
+
+
+def init_layer(m: nn.Module,
+ method: str,
+ init_depth: tp.Optional[int] = None,
+ zero_bias_init: bool = False):
+ """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
+
+ Args:
+ m (nn.Module): Module to initialize.
+ method (str): Method name for the init function.
+ init_depth (int, optional): Optional init depth value used to rescale
+ the standard deviation if defined.
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
+ """
+ if isinstance(m, nn.Linear):
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+ weight = m.weight.float()
+ init_fn(weight)
+ m.weight.data[:] = weight.half()
+ else:
+ init_fn(m.weight)
+ if zero_bias_init and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Embedding):
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+ weight = m.weight.float()
+ init_fn(weight)
+ m.weight.data[:] = weight.half()
+ else:
+ init_fn(m.weight)
+
+
+class ScaledEmbedding(nn.Embedding):
+ """Boost learning rate for embeddings (with `scale`).
+ """
+ def __init__(self, *args, lr=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lr = lr
+
+ def make_optim_group(self):
+ group = {"params": list(self.parameters())}
+ if self.lr is not None:
+ group["lr"] = self.lr
+ return group
+
+
+@dataclass
+class LMOutput:
+ # The logits are already re-aligned with the input codes
+ # hence no extra shift is required, e.g. when computing CE
+ logits: torch.Tensor # [B, K, T, card]
+ mask: torch.Tensor # [B, K, T]
+
+
+class LMModel(StreamingModule):
+ """Transformer-based language model on multiple streams of codes.
+
+ Args:
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
+ n_q (int): Number of parallel streams to model.
+ card (int): Cardinality, vocabulary size.
+ dim (int): Dimension of the transformer encoder.
+ num_heads (int): Number of heads for the transformer encoder.
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
+ norm (str): Normalization method.
+ norm_first (bool): Use pre-norm instead of post-norm.
+ emb_lr (float, optional): Embedding-specific learning rate.
+ bias_proj (bool): Use bias for output projections.
+ weight_init (str, optional): Method for weight initialization.
+ depthwise_init (str, optional): Method for depthwise weight initialization.
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
+ cfg_dropout (float): Classifier-free guidance dropout.
+ cfg_coef (float): Classifier-free guidance coefficient.
+ attribute_dropout (dict): Attribute dropout probabilities.
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
+ **kwargs: Additional parameters for the transformer encoder.
+ """
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
+ **kwargs):
+ super().__init__()
+ self.cfg_coef = cfg_coef
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
+ self.condition_provider = condition_provider
+ self.fuser = fuser
+ self.card = card
+ embed_dim = self.card + 1
+ self.n_q = n_q
+ self.dim = dim
+ self.pattern_provider = pattern_provider
+ self.two_step_cfg = two_step_cfg
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
+ if 'activation' in kwargs:
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
+ self.transformer = StreamingTransformer(
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
+ norm=norm, norm_first=norm_first, **kwargs)
+ self.out_norm: tp.Optional[nn.Module] = None
+ if norm_first:
+ self.out_norm = create_norm_fn(norm, dim)
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
+ self._fsdp: tp.Optional[nn.Module]
+ self.__dict__['_fsdp'] = None
+
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
+ """Initialization of the transformer module weights.
+
+ Args:
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
+ """
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
+ assert depthwise_init is None or weight_init is not None, \
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
+ assert not zero_bias_init or weight_init is not None, \
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
+
+ if weight_init is None:
+ return
+
+ for emb_layer in self.emb:
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
+ depth = None
+ if depthwise_init == 'current':
+ depth = layer_idx + 1
+ elif depthwise_init == 'global':
+ depth = len(self.transformer.layers)
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
+ tr_layer.apply(init_fn)
+
+ for linear in self.linears:
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+ @property
+ def special_token_id(self) -> int:
+ return self.card
+
+ @property
+ def num_codebooks(self) -> int:
+ return self.n_q
+
+ def forward(self, sequence: torch.Tensor,
+ conditions: tp.List[ConditioningAttributes],
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+ """Apply language model on sequence and conditions.
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+ S the sequence steps, return the logits with shape [B, card, K, S].
+
+ Args:
+ indices (torch.Tensor): Indices of the codes to model.
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
+ the given codes. Note that when evaluating multiple time with the same conditioning
+ you should pre-compute those and pass them as `condition_tensors`.
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
+ tensors, see `conditions`.
+ Returns:
+ torch.Tensor: Logits.
+ """
+ B, K, S = sequence.shape
+ assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+ if condition_tensors is None:
+ assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+ # apply dropout modules
+ conditions = self.cfg_dropout(conditions)
+ conditions = self.att_dropout(conditions)
+ tokenized = self.condition_provider.tokenize(conditions)
+ # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+ condition_tensors = self.condition_provider(tokenized)
+ else:
+ assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+ input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+ out = self.transformer(input_, cross_attention_src=cross_attention_input)
+ if self.out_norm:
+ out = self.out_norm(out)
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
+
+ # remove the prefix from the model outputs
+ if len(self.fuser.fuse2cond['prepend']) > 0:
+ logits = logits[:, :, -S:]
+
+ return logits # [B, K, S, card]
+
+ def compute_predictions(
+ self, codes: torch.Tensor,
+ conditions: tp.List[ConditioningAttributes],
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+ forward using the specified codes interleaving pattern.
+
+ Args:
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+ K the number of codebooks and T the number of timesteps.
+ conditions (list of ConditioningAttributes): conditionings to use when modeling
+ the given codes. Note that when evaluating multiple time with the same conditioning
+ you should pre-compute those and pass them as `condition_tensors`.
+ condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
+ tensors, see `conditions`.
+ Returns:
+ LMOutput: Language model outputs
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+ i.e. the first item corresponds to logits to predict the first code, meaning that
+ no additional shifting of codes and logits is required.
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+ Given the specified interleaving strategies, parts of the logits and codes should
+ not be considered as valid predictions because of invalid context.
+ """
+ B, K, T = codes.shape
+ codes = codes.contiguous()
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+ pattern = self.pattern_provider.get_pattern(T)
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+ codes, self.special_token_id, keep_only_valid_steps=True
+ )
+ # apply model on pattern sequence
+ model = self if self._fsdp is None else self._fsdp
+ logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+ # and provide the corresponding mask over invalid positions of tokens
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+ logits, float('nan'), keep_only_valid_steps=True
+ )
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
+ return LMOutput(logits, logits_mask)
+
+ def _sample_next_token(self,
+ sequence: torch.Tensor,
+ cfg_conditions: CFGConditions,
+ unconditional_state: State,
+ use_sampling: bool = False,
+ temp: float = 1.0,
+ top_k: int = 0,
+ top_p: float = 0.0,
+ cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
+
+ Args:
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
+ with K corresponding to the number of codebooks and S the number of sequence steps.
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
+ use_sampling (bool): Whether to use a sampling strategy or not.
+ temp (float): Sampling temperature.
+ top_k (int): K for "top-k" sampling.
+ top_p (float): P for "top-p" sampling.
+ cfg_coef (float, optional): classifier free guidance coefficient
+ Returns:
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
+ """
+ B = sequence.shape[0]
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
+ model = self if self._fsdp is None else self._fsdp
+ if self.two_step_cfg and cfg_conditions != {}:
+ assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
+ condition_tensors, null_condition_tensors = cfg_conditions
+ cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
+ state = self.get_streaming_state()
+ self.set_streaming_state(unconditional_state)
+ uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
+ unconditional_state.update(self.get_streaming_state())
+ self.set_streaming_state(state)
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
+ else:
+ assert isinstance(cfg_conditions, dict)
+ condition_tensors = cfg_conditions
+ if condition_tensors:
+ # Preparing for CFG, predicting both conditional and unconditional logits.
+ sequence = torch.cat([sequence, sequence], dim=0)
+ all_logits = model(
+ sequence,
+ conditions=[], condition_tensors=condition_tensors)
+ if condition_tensors:
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
+ else:
+ logits = all_logits
+
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
+ logits = logits[..., -1] # [B x K x card]
+
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
+ if use_sampling and temp > 0.0:
+ probs = torch.softmax(logits / temp, dim=-1)
+ if top_p > 0.0:
+ next_token = utils.sample_top_p(probs, p=top_p)
+ elif top_k > 0:
+ next_token = utils.sample_top_k(probs, k=top_k)
+ else:
+ next_token = utils.multinomial(probs, num_samples=1)
+ else:
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
+
+ return next_token
+
+ @torch.no_grad()
+ def generate(self,
+ prompt: tp.Optional[torch.Tensor] = None,
+ conditions: tp.List[ConditioningAttributes] = [],
+ num_samples: tp.Optional[int] = None,
+ max_gen_len: int = 256,
+ use_sampling: bool = True,
+ temp: float = 1.0,
+ top_k: int = 250,
+ top_p: float = 0.0,
+ cfg_coef: tp.Optional[float] = None,
+ two_step_cfg: tp.Optional[bool] = None,
+ remove_prompts: bool = False,
+ check: bool = False,
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+ Args:
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
+ max_gen_len (int): Maximum generation length.
+ use_sampling (bool): Whether to use a sampling strategy or not.
+ temp (float): Sampling temperature.
+ top_k (int): K for "top-k" sampling.
+ top_p (float): P for "top-p" sampling.
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
+ two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
+ remove_prompts (bool): Whether to remove prompts from generation or not.
+ check (bool): Whether to apply further checks on generated sequence.
+ callback (Callback, optional): Callback function to report generation progress.
+ Returns:
+ torch.Tensor: Generated tokens.
+ """
+ assert not self.training, "generation shouldn't be used in training mode."
+ first_param = next(iter(self.parameters()))
+ device = first_param.device
+
+ # Checking all input shapes are consistent.
+ possible_num_samples = []
+ if num_samples is not None:
+ possible_num_samples.append(num_samples)
+ elif prompt is not None:
+ possible_num_samples.append(prompt.shape[0])
+ elif conditions:
+ possible_num_samples.append(len(conditions))
+ else:
+ possible_num_samples.append(1)
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
+ num_samples = possible_num_samples[0]
+
+ # below we create set of conditions: one conditional and one unconditional
+ # to do that we merge the regular condition together with the null condition
+ # we then do 1 forward pass instead of 2.
+ # the reason for that is two-fold:
+ # 1. it is about x2 faster than doing 2 forward passes
+ # 2. avoid the streaming API treating the 2 passes as part of different time steps
+ # We also support doing two different passes, in particular to ensure that
+ # the padding structure is exactly the same between train and test.
+ # With a batch size of 1, this can be slower though.
+ cfg_conditions: CFGConditions
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+ if conditions:
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+ if two_step_cfg:
+ cfg_conditions = (
+ self.condition_provider(self.condition_provider.tokenize(conditions)),
+ self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+ )
+ else:
+ conditions = conditions + null_conditions
+ tokenized = self.condition_provider.tokenize(conditions)
+ cfg_conditions = self.condition_provider(tokenized)
+ else:
+ cfg_conditions = {}
+
+ if prompt is None:
+ assert num_samples > 0
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+ B, K, T = prompt.shape
+ start_offset = T
+ assert start_offset < max_gen_len
+
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
+ # this token is used as default value for codes that are not generated yet
+ unknown_token = -1
+
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+ # filling the gen_codes with the prompt if needed
+ gen_codes[..., :start_offset] = prompt
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+ # retrieve the start_offset in the sequence:
+ # it is the first sequence step that contains the `start_offset` timestep
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+ assert start_offset_sequence is not None
+
+ with self.streaming():
+ unconditional_state = self.get_streaming_state()
+ prev_offset = 0
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
+ for offset in range(start_offset_sequence, gen_sequence_len):
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
+ curr_sequence = gen_sequence[..., prev_offset:offset]
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+ if check:
+ # check coherence between mask and sequence
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+ # should never happen as gen_sequence is filled progressively
+ assert not (curr_sequence == unknown_token).any()
+ # sample next token from the model, next token shape is [B, K, 1]
+ next_token = self._sample_next_token(
+ curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+ cfg_coef=cfg_coef)
+ # ensure the tokens that should be masked are properly set to special_token_id
+ # as the model never output special_token_id
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+ next_token[~valid_mask] = self.special_token_id
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+ # (then mask tokens should be left as is as well, which is correct)
+ gen_sequence[..., offset:offset+1] = torch.where(
+ gen_sequence[..., offset:offset+1] == unknown_token,
+ next_token, gen_sequence[..., offset:offset+1]
+ )
+ prev_offset = offset
+ if callback is not None:
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+ unconditional_state.clear()
+
+ # ensure sequence has been entirely filled
+ assert not (gen_sequence == unknown_token).any()
+ # ensure gen_sequence pattern and mask are matching
+ # which means the gen_sequence is valid according to the pattern
+ assert (
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+ ).all()
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+ # sanity checks over the returned codes and corresponding masks
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
+ assert (out_mask[..., :max_gen_len] == 1).all()
+
+ out_start_offset = start_offset if remove_prompts else 0
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+ # ensure the returned codes are all valid
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+ return out_codes
diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c7808a0588bd1a8084157b072bae42aa7efaf84
--- /dev/null
+++ b/audiocraft/models/loaders.py
@@ -0,0 +1,141 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility functions to load from the checkpoints.
+Each checkpoint is a torch.saved dict with the following keys:
+- 'xp.cfg': the hydra config as dumped during training. This should be used
+ to rebuild the object using the audiocraft.models.builders functions,
+- 'model_best_state': a readily loadable best state for the model, including
+ the conditioner. The model obtained from `xp.cfg` should be compatible
+ with this state dict. In the case of a LM, the encodec model would not be
+ bundled along but instead provided separately.
+
+Those functions also support loading from a remote location with the Torch Hub API.
+They also support overriding some parameters, in particular the device and dtype
+of the returned model.
+"""
+
+from pathlib import Path
+from huggingface_hub import hf_hub_download
+import typing as tp
+import os
+
+from omegaconf import OmegaConf, DictConfig
+import torch
+
+from . import builders
+from .encodec import CompressionModel
+
+
+def get_audiocraft_cache_dir() -> tp.Optional[str]:
+ return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
+
+
+def _get_state_dict(
+ file_or_url_or_id: tp.Union[Path, str],
+ filename: tp.Optional[str] = None,
+ device='cpu',
+ cache_dir: tp.Optional[str] = None,
+):
+ if cache_dir is None:
+ cache_dir = get_audiocraft_cache_dir()
+ # Return the state dict either from a file or url
+ file_or_url_or_id = str(file_or_url_or_id)
+ assert isinstance(file_or_url_or_id, str)
+
+ if os.path.isfile(file_or_url_or_id):
+ return torch.load(file_or_url_or_id, map_location=device)
+
+ if os.path.isdir(file_or_url_or_id):
+ file = f"{file_or_url_or_id}/{filename}"
+ return torch.load(file, map_location=device)
+
+ elif file_or_url_or_id.startswith('https://'):
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
+
+ else:
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
+
+ file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+ return torch.load(file, map_location=device)
+
+
+def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+ return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
+
+
+def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+ pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+ if 'pretrained' in pkg:
+ return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
+ cfg = OmegaConf.create(pkg['xp.cfg'])
+ cfg.device = str(device)
+ model = builders.get_compression_model(cfg)
+ model.load_state_dict(pkg['best_state'])
+ model.eval()
+ return model
+
+
+def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+ return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
+
+
+def _delete_param(cfg: DictConfig, full_name: str):
+ parts = full_name.split('.')
+ for part in parts[:-1]:
+ if part in cfg:
+ cfg = cfg[part]
+ else:
+ return
+ OmegaConf.set_struct(cfg, False)
+ if parts[-1] in cfg:
+ del cfg[parts[-1]]
+ OmegaConf.set_struct(cfg, True)
+
+
+def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+ cfg = OmegaConf.create(pkg['xp.cfg'])
+ cfg.device = str(device)
+ if cfg.device == 'cpu':
+ cfg.dtype = 'float32'
+ else:
+ cfg.dtype = 'float16'
+ _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
+ model = builders.get_lm_model(cfg)
+ model.load_state_dict(pkg['best_state'])
+ model.eval()
+ model.cfg = cfg
+ return model
+
+
+def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+ return _get_state_dict(file_or_url_or_id, filename="all_in_one.pt", cache_dir=cache_dir)
+
+
+def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+ pkg = load_mbd_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+ models = []
+ processors = []
+ cfgs = []
+ sample_rate = pkg['sample_rate']
+ for i in range(pkg['n_bands']):
+ cfg = pkg[i]['cfg']
+ model = builders.get_diffusion_model(cfg)
+ model_dict = pkg[i]['model_state']
+ model.load_state_dict(model_dict)
+ model.to(device)
+ processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
+ processor_dict = pkg[i]['processor_state']
+ processor.load_state_dict(processor_dict)
+ processor.to(device)
+ models.append(model)
+ processors.append(processor)
+ cfgs.append(cfg)
+ return models, processors, cfgs
diff --git a/audiocraft/models/multibanddiffusion.py b/audiocraft/models/multibanddiffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a2f169d516ed5aaf5da61fb482d94dd142f55e9
--- /dev/null
+++ b/audiocraft/models/multibanddiffusion.py
@@ -0,0 +1,194 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Multi Band Diffusion models as described in
+"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
+(paper link).
+"""
+
+import typing as tp
+
+import torch
+import julius
+
+from .unet import DiffusionUnet
+from ..modules.diffusion_schedule import NoiseSchedule
+from .encodec import CompressionModel
+from ..solvers.compression import CompressionSolver
+from .loaders import load_compression_model, load_diffusion_models
+
+
+class DiffusionProcess:
+ """Sampling for a diffusion Model.
+
+ Args:
+ model (DiffusionUnet): Diffusion U-Net model.
+ noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
+ """
+ def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
+ """
+ """
+ self.model = model
+ self.schedule = noise_schedule
+
+ def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
+ step_list: tp.Optional[tp.List[int]] = None):
+ """Perform one diffusion process to generate one of the bands.
+
+ Args:
+ condition (tensor): The embeddings form the compression model.
+ initial_noise (tensor): The initial noise to start the process/
+ """
+ return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
+ condition=condition)
+
+
+class MultiBandDiffusion:
+ """Sample from multiple diffusion models.
+
+ Args:
+ DPs (list of DiffusionProcess): Diffusion processes.
+ codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
+ """
+ def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
+ self.DPs = DPs
+ self.codec_model = codec_model
+ self.device = next(self.codec_model.parameters()).device
+
+ @property
+ def sample_rate(self) -> int:
+ return self.codec_model.sample_rate
+
+ @staticmethod
+ def get_mbd_musicgen(device=None):
+ """Load our diffusion models trained for MusicGen."""
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ path = 'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_musicgen_32khz.th'
+ name = 'facebook/musicgen-small'
+ codec_model = load_compression_model(name, device=device)
+ models, processors, cfgs = load_diffusion_models(path, device=device)
+ DPs = []
+ for i in range(len(models)):
+ schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i])
+ DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+ return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+ @staticmethod
+ def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
+ device: tp.Optional[tp.Union[torch.device, str]] = None,
+ n_q: tp.Optional[int] = None):
+ """Get the pretrained Models for MultibandDiffusion.
+
+ Args:
+ bw (float): Bandwidth of the compression model.
+ pretrained (bool): Whether to use / download if necessary the models.
+ device (torch.device or str, optional): Device on which the models are loaded.
+ n_q (int, optional): Number of quantizers to use within the compression model.
+ """
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
+ if n_q is not None:
+ assert n_q in [2, 4, 8]
+ assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
+ f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
+ n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
+ codec_model = CompressionSolver.model_from_checkpoint(
+ '//pretrained/facebook/encodec_24khz', device=device)
+ codec_model.set_num_codebooks(n_q)
+ codec_model = codec_model.to(device)
+ path = f'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_comp_{n_q}.pt'
+ models, processors, cfgs = load_diffusion_models(path, device=device)
+ DPs = []
+ for i in range(len(models)):
+ schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i])
+ DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+ return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+ return MultiBandDiffusion(DPs, codec_model)
+
+ @torch.no_grad()
+ def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+ """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
+ Args:
+ wav (torch.Tensor): The audio that we want to extract the conditioning from
+ sample_rate (int): sample rate of the audio"""
+ if sample_rate != self.sample_rate:
+ wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
+ codes, scale = self.codec_model.encode(wav)
+ assert scale is None, "Scaled compression models not supported."
+ emb = self.get_emb(codes)
+ return emb
+
+ @torch.no_grad()
+ def get_emb(self, codes: torch.Tensor):
+ """Get latent representation from the discrete codes
+ Argrs:
+ codes (torch.Tensor): discrete tokens"""
+ emb = self.codec_model.decode_latent(codes)
+ return emb
+
+ def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
+ step_list: tp.Optional[tp.List[int]] = None):
+ """Generate Wavform audio from the latent embeddings of the compression model
+ Args:
+ emb (torch.Tensor): Conditioning embeddinds
+ size (none torch.Size): size of the output
+ if None this is computed from the typical upsampling of the model
+ step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
+ """
+ if size is None:
+ upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
+ size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
+ assert size[0] == emb.size(0)
+ out = torch.zeros(size).to(self.device)
+ for DP in self.DPs:
+ out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
+ return out
+
+ def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
+ """match the eq to the encodec output by matching the standard deviation of some frequency bands
+ Args:
+ wav (torch.Tensor): audio to equalize
+ ref (torch.Tensor):refenrence audio from which we match the spectrogram.
+ n_bands (int): number of bands of the eq
+ strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
+ """
+ split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
+ bands = split(wav)
+ bands_ref = split(ref)
+ out = torch.zeros_like(ref)
+ for i in range(n_bands):
+ out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
+ return out
+
+ def regenerate(self, wav: torch.Tensor, sample_rate: int):
+ """Regenerate a wavform through compression and diffusion regeneration.
+ Args:
+ wav (torch.Tensor): Original 'ground truth' audio
+ sample_rate (int): sample rate of the input (and output) wav
+ """
+ if sample_rate != self.codec_model.sample_rate:
+ wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
+ emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
+ size = wav.size()
+ out = self.generate(emb, size=size)
+ if sample_rate != self.codec_model.sample_rate:
+ out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
+ return out
+
+ def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
+ """Generate Waveform audio with diffusion from the discrete codes.
+ Args:
+ tokens (torch.Tensor): discrete codes
+ n_bands (int): bands for the eq matching.
+ """
+ wav_encodec = self.codec_model.decode(tokens)
+ condition = self.get_emb(tokens)
+ wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
+ return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4b2292eaec5016e208bbdf61ec5c99b40b67da
--- /dev/null
+++ b/audiocraft/models/musicgen.py
@@ -0,0 +1,409 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Main model for using MusicGen. This will combine all the required components
+and provide easy access to the generation API.
+"""
+
+import typing as tp
+import warnings
+
+import torch
+
+from .encodec import CompressionModel
+from .lm import LMModel
+from .builders import get_debug_compression_model, get_debug_lm_model
+from .loaders import load_compression_model, load_lm_model
+from ..data.audio_utils import convert_audio
+from ..modules.conditioners import ConditioningAttributes, WavCondition
+from ..utils.autocast import TorchAutocast
+
+
+MelodyList = tp.List[tp.Optional[torch.Tensor]]
+MelodyType = tp.Union[torch.Tensor, MelodyList]
+
+
+# backward compatible names mapping
+_HF_MODEL_CHECKPOINTS_MAP = {
+ "small": "GrandaddyShmax/musicgen-small",
+ "medium": "GrandaddyShmax/musicgen-medium",
+ "large": "GrandaddyShmax/musicgen-large",
+ "melody": "GrandaddyShmax/musicgen-melody",
+}
+
+
+class MusicGen:
+ """MusicGen main model with convenient generation API.
+
+ Args:
+ name (str): name of the model.
+ compression_model (CompressionModel): Compression model
+ used to map audio to invertible discrete representations.
+ lm (LMModel): Language model over discrete representations.
+ max_duration (float, optional): maximum duration the model can produce,
+ otherwise, inferred from the training params.
+ """
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+ max_duration: tp.Optional[float] = None):
+ self.name = name
+ self.compression_model = compression_model
+ self.lm = lm
+ if max_duration is None:
+ if hasattr(lm, 'cfg'):
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
+ else:
+ raise ValueError("You must provide max_duration when building directly MusicGen")
+ assert max_duration is not None
+ self.max_duration: float = max_duration
+ self.device = next(iter(lm.parameters())).device
+ self.generation_params: dict = {}
+ self.set_generation_params(duration=15) # 15 seconds by default
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+ if self.device.type == 'cpu':
+ self.autocast = TorchAutocast(enabled=False)
+ else:
+ self.autocast = TorchAutocast(
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+ @property
+ def frame_rate(self) -> float:
+ """Roughly the number of AR steps per seconds."""
+ return self.compression_model.frame_rate
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate of the generated audio."""
+ return self.compression_model.sample_rate
+
+ @property
+ def audio_channels(self) -> int:
+ """Audio channels of the generated audio."""
+ return self.compression_model.channels
+
+ @staticmethod
+ def get_pretrained(name: str = 'GrandaddyShmax/musicgen-melody', device=None):
+ """Return pretrained model, we provide four models:
+ - facebook/musicgen-small (300M), text to music,
+ # see: https://huggingface.co/facebook/musicgen-small
+ - facebook/musicgen-medium (1.5B), text to music,
+ # see: https://huggingface.co/facebook/musicgen-medium
+ - facebook/musicgen-melody (1.5B) text to music and text+melody to music,
+ # see: https://huggingface.co/facebook/musicgen-melody
+ - facebook/musicgen-large (3.3B), text to music,
+ # see: https://huggingface.co/facebook/musicgen-large
+ """
+ if device is None:
+ if torch.cuda.device_count():
+ device = 'cuda'
+ else:
+ device = 'cpu'
+
+ if name == 'debug':
+ # used only for unit tests
+ compression_model = get_debug_compression_model(device)
+ lm = get_debug_lm_model(device)
+ return MusicGen(name, compression_model, lm, max_duration=30)
+
+ lm = load_lm_model(name, device=device)
+ compression_model = load_compression_model(name, device=device)
+ if 'self_wav' in lm.condition_provider.conditioners:
+ lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+ return MusicGen(name, compression_model, lm)
+
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+ top_p: float = 0.0, temperature: float = 1.0,
+ duration: float = 30.0, cfg_coef: float = 3.0,
+ two_step_cfg: bool = False, extend_stride: float = 18):
+ """Set the generation parameters for MusicGen.
+
+ Args:
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+ instead of batching together the two. This has some impact on how things
+ are padded but seems to have little impact in practice.
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+ should we extend the audio each time. Larger values will mean less context is
+ preserved, and shorter value will require extra computations.
+ """
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+ self.extend_stride = extend_stride
+ self.duration = duration
+ self.generation_params = {
+ 'use_sampling': use_sampling,
+ 'temp': temperature,
+ 'top_k': top_k,
+ 'top_p': top_p,
+ 'cfg_coef': cfg_coef,
+ 'two_step_cfg': two_step_cfg,
+ }
+
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+ """Override the default progress callback."""
+ self._progress_callback = progress_callback
+
+ def generate_unconditional(self, num_samples: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+ """Generate samples in an unconditional manner.
+
+ Args:
+ num_samples (int): Number of samples to be generated.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+ if return_tokens:
+ return self.generate_audio(tokens), tokens
+ return self.generate_audio(tokens)
+
+ def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+ """Generate samples conditioned on text.
+
+ Args:
+ descriptions (list of str): A list of strings used as text conditioning.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+ assert prompt_tokens is None
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+ if return_tokens:
+ return self.generate_audio(tokens), tokens
+ return self.generate_audio(tokens)
+
+ def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, melody_sample_rate: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+ """Generate samples conditioned on text and melody.
+
+ Args:
+ descriptions (list of str): A list of strings used as text conditioning.
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
+ a list of [C, T] tensors.
+ melody_sample_rate: (int): Sample rate of the melody waveforms.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ if isinstance(melody_wavs, torch.Tensor):
+ if melody_wavs.dim() == 2:
+ melody_wavs = melody_wavs[None]
+ if melody_wavs.dim() != 3:
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
+ melody_wavs = list(melody_wavs)
+ else:
+ for melody in melody_wavs:
+ if melody is not None:
+ assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+ melody_wavs = [
+ convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+ if wav is not None else None
+ for wav in melody_wavs]
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+ melody_wavs=melody_wavs)
+ assert prompt_tokens is None
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+ if return_tokens:
+ return self.generate_audio(tokens), tokens
+ return self.generate_audio(tokens)
+
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+ progress: bool = False, return_tokens: bool = False) \
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+ """Generate samples conditioned on audio prompts.
+
+ Args:
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ if prompt.dim() == 2:
+ prompt = prompt[None]
+ if prompt.dim() != 3:
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+ if descriptions is None:
+ descriptions = [None] * len(prompt)
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+ assert prompt_tokens is not None
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+ if return_tokens:
+ return self.generate_audio(tokens), tokens
+ return self.generate_audio(tokens)
+
+ @torch.no_grad()
+ def _prepare_tokens_and_attributes(
+ self,
+ descriptions: tp.Sequence[tp.Optional[str]],
+ prompt: tp.Optional[torch.Tensor],
+ melody_wavs: tp.Optional[MelodyList] = None,
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+ """Prepare model inputs.
+
+ Args:
+ descriptions (list of str): A list of strings used as text conditioning.
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
+ used as melody conditioning. Defaults to None.
+ """
+ attributes = [
+ ConditioningAttributes(text={'description': description})
+ for description in descriptions]
+
+ if melody_wavs is None:
+ for attr in attributes:
+ attr.wav['self_wav'] = WavCondition(
+ torch.zeros((1, 1, 1), device=self.device),
+ torch.tensor([0], device=self.device),
+ sample_rate=[self.sample_rate],
+ path=[None])
+ else:
+ if 'self_wav' not in self.lm.condition_provider.conditioners:
+ raise RuntimeError("This model doesn't support melody conditioning. "
+ "Use the `melody` model.")
+ assert len(melody_wavs) == len(descriptions), \
+ f"number of melody wavs must match number of descriptions! " \
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
+ for attr, melody in zip(attributes, melody_wavs):
+ if melody is None:
+ attr.wav['self_wav'] = WavCondition(
+ torch.zeros((1, 1, 1), device=self.device),
+ torch.tensor([0], device=self.device),
+ sample_rate=[self.sample_rate],
+ path=[None])
+ else:
+ attr.wav['self_wav'] = WavCondition(
+ melody[None].to(device=self.device),
+ torch.tensor([melody.shape[-1]], device=self.device),
+ sample_rate=[self.sample_rate],
+ path=[None],
+ )
+
+ if prompt is not None:
+ if descriptions is not None:
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+ prompt = prompt.to(self.device)
+ prompt_tokens, scale = self.compression_model.encode(prompt)
+ assert scale is None
+ else:
+ prompt_tokens = None
+ return attributes, prompt_tokens
+
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+ """Generate discrete audio tokens given audio prompt and/or conditions.
+
+ Args:
+ attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ Returns:
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+ """
+ i = 0
+ prompt_list = attributes[0].text['description']
+ total_gen_len = int(self.duration * self.frame_rate)
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+ current_gen_offset: int = 0
+
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+ generated_tokens += current_gen_offset
+ if current_gen_offset > 0:
+ generated_tokens += (self.max_duration - self.extend_stride) * self.frame_rate
+ if self._progress_callback is not None:
+ # Note that total_gen_len might be quite wrong depending on the
+ # codebook pattern used, but with delay it is almost accurate.
+ self._progress_callback(generated_tokens, total_gen_len)
+ else:
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+ if prompt_tokens is not None:
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
+ "Prompt is longer than audio to generate"
+
+ callback = None
+ if progress:
+ callback = _progress_callback
+
+ if self.duration <= self.max_duration:
+ # generate by sampling from LM, simple case.
+ with self.autocast:
+ attributes[0].text['description'] = prompt_list[0]
+ gen_tokens = self.lm.generate(
+ prompt_tokens, attributes,
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+ else:
+ # now this gets a bit messier, we need to handle prompts,
+ # melody conditioning etc.
+ ref_wavs = [attr.wav['self_wav'] for attr in attributes]
+ all_tokens = []
+ if prompt_tokens is None:
+ prompt_length = 0
+ else:
+ all_tokens.append(prompt_tokens)
+ prompt_length = prompt_tokens.shape[-1]
+
+ stride_tokens = int(self.frame_rate * self.extend_stride)
+
+ while current_gen_offset + prompt_length < total_gen_len:
+ time_offset = current_gen_offset / self.frame_rate
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
+ max_gen_len = int(chunk_duration * self.frame_rate)
+ for attr, ref_wav in zip(attributes, ref_wavs):
+ wav_length = ref_wav.length.item()
+ if wav_length == 0:
+ continue
+ # We will extend the wav periodically if it not long enough.
+ # we have to do it here rather than in conditioners.py as otherwise
+ # we wouldn't have the full wav.
+ initial_position = int(time_offset * self.sample_rate)
+ wav_target_length = int(self.max_duration * self.sample_rate)
+ positions = torch.arange(initial_position,
+ initial_position + wav_target_length, device=self.device)
+ attr.wav['self_wav'] = WavCondition(
+ ref_wav[0][..., positions % wav_length],
+ torch.full_like(ref_wav[1], wav_target_length),
+ [self.sample_rate] * ref_wav[0].size(0),
+ [None], [0.])
+ with self.autocast:
+ if i >= len(prompt_list):
+ i = len(prompt_list) - 1
+ attributes[0].text['description'] = prompt_list[i]
+ gen_tokens = self.lm.generate(
+ prompt_tokens, attributes,
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+ i = i + 1
+ if prompt_tokens is None:
+ all_tokens.append(gen_tokens)
+ else:
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
+ prompt_length = prompt_tokens.shape[-1]
+ current_gen_offset += stride_tokens
+
+ gen_tokens = torch.cat(all_tokens, dim=-1)
+ return gen_tokens
+
+ def generate_audio(self, gen_tokens: torch.Tensor):
+ """Generate Audio from tokens"""
+ assert gen_tokens.dim() == 3
+ with torch.no_grad():
+ gen_audio = self.compression_model.decode(gen_tokens, None)
+ return gen_audio
+
+ def to(self, device: str):
+ self.compression_model.to(device)
+ self.lm.to(device)
+ return self
\ No newline at end of file
diff --git a/audiocraft/models/unet.py b/audiocraft/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4a6df8e309c21fede37abdbe3c862932027641
--- /dev/null
+++ b/audiocraft/models/unet.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Pytorch Unet Module used for diffusion.
+"""
+
+from dataclasses import dataclass
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
+
+
+@dataclass
+class Output:
+ sample: torch.Tensor
+
+
+def get_model(cfg, channels: int, side: int, num_steps: int):
+ if cfg.model == 'unet':
+ return DiffusionUnet(
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+ else:
+ raise RuntimeError('Not Implemented')
+
+
+class ResBlock(nn.Module):
+ def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
+ dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+ dropout: float = 0.):
+ super().__init__()
+ stride = 1
+ padding = dilation * (kernel - stride) // 2
+ Conv = nn.Conv1d
+ Drop = nn.Dropout1d
+ self.norm1 = nn.GroupNorm(norm_groups, channels)
+ self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+ self.activation1 = activation()
+ self.dropout1 = Drop(dropout)
+
+ self.norm2 = nn.GroupNorm(norm_groups, channels)
+ self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+ self.activation2 = activation()
+ self.dropout2 = Drop(dropout)
+
+ def forward(self, x):
+ h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
+ h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
+ return x + h
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+ dropout: float = 0.):
+ super().__init__()
+ padding = (kernel - stride) // 2
+ self.res_blocks = nn.Sequential(
+ *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
+ for idx in range(res_blocks)])
+ self.norm = nn.GroupNorm(norm_groups, chin)
+ ConvTr = nn.ConvTranspose1d
+ self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
+ self.activation = activation()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.res_blocks(x)
+ x = self.norm(x)
+ x = self.activation(x)
+ x = self.convtr(x)
+ return x
+
+
+class EncoderLayer(nn.Module):
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+ dropout: float = 0.):
+ super().__init__()
+ padding = (kernel - stride) // 2
+ Conv = nn.Conv1d
+ self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
+ self.norm = nn.GroupNorm(norm_groups, chout)
+ self.activation = activation()
+ self.res_blocks = nn.Sequential(
+ *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
+ for idx in range(res_blocks)])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, T = x.shape
+ stride, = self.conv.stride
+ pad = (stride - (T % stride)) % stride
+ x = F.pad(x, (0, pad))
+
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+ x = self.res_blocks(x)
+ return x
+
+
+class BLSTM(nn.Module):
+ """BiLSTM with same hidden units as input dim.
+ """
+ def __init__(self, dim, layers=2):
+ super().__init__()
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
+ self.linear = nn.Linear(2 * dim, dim)
+
+ def forward(self, x):
+ x = x.permute(2, 0, 1)
+ x = self.lstm(x)[0]
+ x = self.linear(x)
+ x = x.permute(1, 2, 0)
+ return x
+
+
+class DiffusionUnet(nn.Module):
+ def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
+ max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
+ bilstm: bool = False, transformer: bool = False,
+ codec_dim: tp.Optional[int] = None, **kwargs):
+ super().__init__()
+ self.encoders = nn.ModuleList()
+ self.decoders = nn.ModuleList()
+ self.embeddings: tp.Optional[nn.ModuleList] = None
+ self.embedding = nn.Embedding(num_steps, hidden)
+ if emb_all_layers:
+ self.embeddings = nn.ModuleList()
+ self.condition_embedding: tp.Optional[nn.Module] = None
+ for d in range(depth):
+ encoder = EncoderLayer(chin, hidden, **kwargs)
+ decoder = DecoderLayer(hidden, chin, **kwargs)
+ self.encoders.append(encoder)
+ self.decoders.insert(0, decoder)
+ if emb_all_layers and d > 0:
+ assert self.embeddings is not None
+ self.embeddings.append(nn.Embedding(num_steps, hidden))
+ chin = hidden
+ hidden = min(int(chin * growth), max_channels)
+ self.bilstm: tp.Optional[nn.Module]
+ if bilstm:
+ self.bilstm = BLSTM(chin)
+ else:
+ self.bilstm = None
+ self.use_transformer = transformer
+ self.cross_attention = False
+ if transformer:
+ self.cross_attention = cross_attention
+ self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
+ cross_attention=cross_attention)
+
+ self.use_codec = False
+ if codec_dim is not None:
+ self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
+ self.use_codec = True
+
+ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
+ skips = []
+ bs = x.size(0)
+ z = x
+ view_args = [1]
+ if type(step) is torch.Tensor:
+ step_tensor = step
+ else:
+ step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
+
+ for idx, encoder in enumerate(self.encoders):
+ z = encoder(z)
+ if idx == 0:
+ z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
+ elif self.embeddings is not None:
+ z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
+
+ skips.append(z)
+
+ if self.use_codec: # insert condition in the bottleneck
+ assert condition is not None, "Model defined for conditionnal generation"
+ condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
+ assert condition_emb.size(-1) <= 2 * z.size(-1), \
+ f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
+ if not self.cross_attention:
+
+ condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
+ assert z.size() == condition_emb.size()
+ z += condition_emb
+ cross_attention_src = None
+ else:
+ cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
+ B, T, C = cross_attention_src.shape
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
+ pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
+ cross_attention_src = cross_attention_src + pos_emb
+ if self.use_transformer:
+ z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
+ else:
+ if self.bilstm is None:
+ z = torch.zeros_like(z)
+ else:
+ z = self.bilstm(z)
+
+ for decoder in self.decoders:
+ s = skips.pop(-1)
+ z = z[:, :, :s.shape[2]]
+ z = z + s
+ z = decoder(z)
+
+ z = z[:, :, :x.shape[2]]
+ return Output(z)
diff --git a/audiocraft/modules/__init__.py b/audiocraft/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61418616ef18f0ecca56a007c43af4a731d98b9b
--- /dev/null
+++ b/audiocraft/modules/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Modules used for building the models."""
+
+# flake8: noqa
+from .conv import (
+ NormConv1d,
+ NormConv2d,
+ NormConvTranspose1d,
+ NormConvTranspose2d,
+ StreamableConv1d,
+ StreamableConvTranspose1d,
+ pad_for_conv1d,
+ pad1d,
+ unpad1d,
+)
+from .lstm import StreamableLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
+from .transformer import StreamingTransformer
\ No newline at end of file
diff --git a/audiocraft/modules/__pycache__/__init__.cpython-310.pyc b/audiocraft/modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16128c9090c63b33bed7eca6dba0a062fd74e3be
Binary files /dev/null and b/audiocraft/modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/activations.cpython-310.pyc b/audiocraft/modules/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5559f532116b8b5cbc03544ce73964869552e5f
Binary files /dev/null and b/audiocraft/modules/__pycache__/activations.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/chroma.cpython-310.pyc b/audiocraft/modules/__pycache__/chroma.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82a258aaf0454b69810617623db9c771d0a5b672
Binary files /dev/null and b/audiocraft/modules/__pycache__/chroma.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/codebooks_patterns.cpython-310.pyc b/audiocraft/modules/__pycache__/codebooks_patterns.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a94527edf3980ed2b24ce8b963c96af979aa9d7
Binary files /dev/null and b/audiocraft/modules/__pycache__/codebooks_patterns.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/conditioners.cpython-310.pyc b/audiocraft/modules/__pycache__/conditioners.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48617cf2f490744c73471580f58f7195aa842660
Binary files /dev/null and b/audiocraft/modules/__pycache__/conditioners.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/conv.cpython-310.pyc b/audiocraft/modules/__pycache__/conv.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..086b957dcb48fc69edbf87df5e0169a227207165
Binary files /dev/null and b/audiocraft/modules/__pycache__/conv.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/diffusion_schedule.cpython-310.pyc b/audiocraft/modules/__pycache__/diffusion_schedule.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2bb2b600c025ae27e60e42d0adba9712ea69e0c
Binary files /dev/null and b/audiocraft/modules/__pycache__/diffusion_schedule.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/lstm.cpython-310.pyc b/audiocraft/modules/__pycache__/lstm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..145810605494378dc35c445fef14df7bc9a50eae
Binary files /dev/null and b/audiocraft/modules/__pycache__/lstm.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/rope.cpython-310.pyc b/audiocraft/modules/__pycache__/rope.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e8bdf73f68e8ffea96e69bb8b3acc4520bb5cfc
Binary files /dev/null and b/audiocraft/modules/__pycache__/rope.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/seanet.cpython-310.pyc b/audiocraft/modules/__pycache__/seanet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..278875bcd7bbfc807bc831efe02a647d1863447d
Binary files /dev/null and b/audiocraft/modules/__pycache__/seanet.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/streaming.cpython-310.pyc b/audiocraft/modules/__pycache__/streaming.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3f6f33575ae4ee37d5952ae8081c2194f8c0fa4
Binary files /dev/null and b/audiocraft/modules/__pycache__/streaming.cpython-310.pyc differ
diff --git a/audiocraft/modules/__pycache__/transformer.cpython-310.pyc b/audiocraft/modules/__pycache__/transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64dbfc5130bb90d19d409671baa4ae7d1cee2d6a
Binary files /dev/null and b/audiocraft/modules/__pycache__/transformer.cpython-310.pyc differ
diff --git a/audiocraft/modules/activations.py b/audiocraft/modules/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d83d7c4c2dc84c64b724eadbe06157507d4f20d
--- /dev/null
+++ b/audiocraft/modules/activations.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Union, Callable
+
+
+class CustomGLU(nn.Module):
+ """Custom Gated Linear Unit activation.
+ Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
+ of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
+ function (i.e. sigmoid, swish, etc.).
+
+ Args:
+ activation (nn.Module): The custom activation to apply in the Gated Linear Unit
+ dim (int): the dimension on which to split the input. Default: -1
+
+ Shape:
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+ Examples::
+ >>> m = CustomGLU(nn.Sigmoid())
+ >>> input = torch.randn(4, 2)
+ >>> output = m(input)
+ """
+ def __init__(self, activation: nn.Module, dim: int = -1):
+ super(CustomGLU, self).__init__()
+ self.dim = dim
+ self.activation = activation
+
+ def forward(self, x: Tensor):
+ assert x.shape[self.dim] % 2 == 0 # M = N / 2
+ a, b = torch.chunk(x, 2, dim=self.dim)
+ return a * self.activation(b)
+
+
+class SwiGLU(CustomGLU):
+ """SiLU Gated Linear Unit activation.
+ Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
+ the first half of the input matrices, :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+ """
+ def __init__(self, dim: int = -1):
+ super(SwiGLU, self).__init__(nn.SiLU(), dim)
+
+
+class GeGLU(CustomGLU):
+ """GeLU Gated Linear Unit activation.
+ Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
+ the first half of the input matrices, :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+ """
+ def __init__(self, dim: int = -1):
+ super(GeGLU, self).__init__(nn.GELU(), dim)
+
+
+class ReGLU(CustomGLU):
+ """ReLU Gated Linear Unit activation.
+ Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
+ the first half of the input matrices, :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+ """
+ def __init__(self, dim: int = -1):
+ super(ReGLU, self).__init__(nn.ReLU(), dim)
+
+
+def get_activation_fn(
+ activation: Union[str, Callable[[Tensor], Tensor]]
+) -> Union[str, Callable[[Tensor], Tensor]]:
+ """Helper function to map an activation string to the activation class.
+ If the supplied activation is not a string that is recognized, the activation is passed back.
+
+ Args:
+ activation (str, or Callable[[Tensor], Tensor]): Activation to check
+ """
+ if isinstance(activation, str):
+ if activation == "reglu":
+ return ReGLU()
+ elif activation == "geglu":
+ return GeGLU()
+ elif activation == "swiglu":
+ return SwiGLU()
+ return activation
diff --git a/audiocraft/modules/chroma.py b/audiocraft/modules/chroma.py
new file mode 100644
index 0000000000000000000000000000000000000000..e84fb66b4a4aaefb0b3ccac8a9a44c3b20e48f61
--- /dev/null
+++ b/audiocraft/modules/chroma.py
@@ -0,0 +1,66 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import typing as tp
+
+from einops import rearrange
+from librosa import filters
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+
+
+class ChromaExtractor(nn.Module):
+ """Chroma extraction and quantization.
+
+ Args:
+ sample_rate (int): Sample rate for the chroma extraction.
+ n_chroma (int): Number of chroma bins for the chroma extraction.
+ radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
+ nfft (int, optional): Number of FFT.
+ winlen (int, optional): Window length.
+ winhop (int, optional): Window hop size.
+ argmax (bool, optional): Whether to use argmax. Defaults to False.
+ norm (float, optional): Norm for chroma normalization. Defaults to inf.
+ """
+ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
+ winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
+ norm: float = torch.inf):
+ super().__init__()
+ self.winlen = winlen or 2 ** radix2_exp
+ self.nfft = nfft or self.winlen
+ self.winhop = winhop or (self.winlen // 4)
+ self.sample_rate = sample_rate
+ self.n_chroma = n_chroma
+ self.norm = norm
+ self.argmax = argmax
+ self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+ n_chroma=self.n_chroma)), persistent=False)
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+ hop_length=self.winhop, power=2, center=True,
+ pad=0, normalized=True)
+
+ def forward(self, wav: torch.Tensor) -> torch.Tensor:
+ T = wav.shape[-1]
+ # in case we are getting a wav that was dropped out (nullified)
+ # from the conditioner, make sure wav length is no less that nfft
+ if T < self.nfft:
+ pad = self.nfft - T
+ r = 0 if pad % 2 == 0 else 1
+ wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+ assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
+
+ spec = self.spec(wav).squeeze(1)
+ raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+ norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
+
+ if self.argmax:
+ idx = norm_chroma.argmax(-1, keepdim=True)
+ norm_chroma[:] = 0
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+ return norm_chroma
diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf3bb41774700a679ffe4325236d0324a99c546
--- /dev/null
+++ b/audiocraft/modules/codebooks_patterns.py
@@ -0,0 +1,539 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import lru_cache
+import logging
+import typing as tp
+
+from abc import ABC, abstractmethod
+import torch
+
+LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
+PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Pattern:
+ """Base implementation of a pattern over a sequence with multiple codebooks.
+
+ The codebook pattern consists in a layout, defining for each sequence step
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+ The first item of the pattern is always an empty list in order to properly insert a special token
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+ is returned along with a mask indicating valid tokens.
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+ to fill and specify invalid positions if needed.
+ See the dedicated methods for more details.
+ """
+ # Pattern layout, for each sequence step, we have a list of coordinates
+ # corresponding to the original codebook timestep and position.
+ # The first list is always an empty list in order to properly insert
+ # a special token to start with.
+ layout: PatternLayout
+ timesteps: int
+ n_q: int
+
+ def __post_init__(self):
+ assert len(self.layout) > 0
+ assert self.layout[0] == []
+ self._validate_layout()
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+ def _validate_layout(self):
+ """Runs checks on the layout to ensure a valid pattern is defined.
+ A pattern is considered invalid if:
+ - Multiple timesteps for a same codebook are defined in the same sequence step
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+ (this would mean that we have future timesteps before past timesteps).
+ """
+ q_timesteps = {q: 0 for q in range(self.n_q)}
+ for s, seq_coords in enumerate(self.layout):
+ if len(seq_coords) > 0:
+ qs = set()
+ for coord in seq_coords:
+ qs.add(coord.q)
+ last_q_timestep = q_timesteps[coord.q]
+ assert coord.t >= last_q_timestep, \
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+ q_timesteps[coord.q] = coord.t
+ # each sequence step contains at max 1 coordinate per codebook
+ assert len(qs) == len(seq_coords), \
+ f"Multiple entries for a same codebook are found at step {s}"
+
+ @property
+ def num_sequence_steps(self):
+ return len(self.layout) - 1
+
+ @property
+ def max_delay(self):
+ max_t_in_seq_coords = 0
+ for seq_coords in self.layout[1:]:
+ for coords in seq_coords:
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+ return max_t_in_seq_coords - self.timesteps
+
+ @property
+ def valid_layout(self):
+ valid_step = len(self.layout) - self.max_delay
+ return self.layout[:valid_step]
+
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+ and the actual codebook coordinates.
+ """
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+ if q is not None:
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+ coords = []
+ for s, seq_codes in enumerate(self.layout):
+ for code in seq_codes:
+ if code.t == t and (q is None or code.q == q):
+ coords.append((s, code))
+ return coords
+
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+ device: tp.Union[torch.device, str] = 'cpu'):
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+ Args:
+ timesteps (int): Maximum number of timesteps steps to consider.
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+ device (torch.device or str): Device for created tensors.
+ Returns:
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+ """
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+ # fill indexes with last sequence step value that will correspond to our special token
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+ # which will correspond to the index: n_q * timesteps
+ indexes[:] = n_q * timesteps
+ # iterate over the pattern and fill scattered indexes and mask
+ for s, sequence_coords in enumerate(ref_layout):
+ for coords in sequence_coords:
+ if coords.t < timesteps:
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
+ mask[coords.q, s] = 1
+ indexes = torch.from_numpy(indexes).to(device)
+ mask = torch.from_numpy(mask).to(device)
+ return indexes, mask
+
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+ """Build sequence corresponding to the pattern from the input tensor z.
+ The sequence is built using up to sequence_steps if specified, and non-pattern
+ coordinates are filled with the special token.
+
+ Args:
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
+ Returns:
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+ """
+ B, K, T = z.shape
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+ )
+ z = z.view(B, -1)
+ # we append the special token as the last index of our flattened z tensor
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+ values = z[:, indexes.view(-1)]
+ values = values.view(B, K, indexes.shape[-1])
+ return values, indexes, mask
+
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+ keep_only_valid_steps: bool = False,
+ is_model_output: bool = False,
+ device: tp.Union[torch.device, str] = 'cpu'):
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
+ from interleaving pattern.
+
+ Args:
+ sequence_steps (int): Sequence steps.
+ n_q (int): Number of codebooks.
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+ device (torch.device or str): Device for created tensors.
+ Returns:
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+ """
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+ timesteps = self.timesteps
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+ assert sequence_steps <= len(ref_layout), \
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
+ if is_model_output:
+ ref_layout = ref_layout[1:]
+
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+ # fill indexes with last sequence step value that will correspond to our special token
+ indexes[:] = n_q * sequence_steps
+ for s, sequence_codes in enumerate(ref_layout):
+ if s < sequence_steps:
+ for code in sequence_codes:
+ if code.t < timesteps:
+ indexes[code.q, code.t] = s + code.q * sequence_steps
+ mask[code.q, code.t] = 1
+ indexes = torch.from_numpy(indexes).to(device)
+ mask = torch.from_numpy(mask).to(device)
+ return indexes, mask
+
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+ are filled with the special token.
+
+ Args:
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+ Returns:
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+ """
+ B, K, S = s.shape
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+ )
+ s = s.view(B, -1)
+ # we append the special token as the last index of our flattened z tensor
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+ values = s[:, indexes.view(-1)]
+ values = values.view(B, K, indexes.shape[-1])
+ return values, indexes, mask
+
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+ """Revert model logits obtained on a sequence built from the pattern
+ back to a tensor matching the original sequence.
+
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
+ 1. It is designed to work with the extra cardinality dimension
+ 2. We return the logits for the first sequence item that matches the special_token and
+ which matching target in the original sequence is the first item of the sequence,
+ while we skip the last logits as there is no matching target
+ """
+ B, card, K, S = logits.shape
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+ )
+ logits = logits.reshape(B, card, -1)
+ # we append the special token as the last index of our flattened z tensor
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
+ values = logits[:, :, indexes.view(-1)]
+ values = values.view(B, card, K, indexes.shape[-1])
+ return values, indexes, mask
+
+
+class CodebooksPatternProvider(ABC):
+ """Abstraction around providing pattern for interleaving codebooks.
+
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+ can be used to construct a new sequence from the original codes respecting the specified
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+ being a tuple with the original timestep and codebook to build the new sequence.
+ Note that all patterns must start with an empty list that is then used to insert a first
+ sequence step of special tokens in the newly generated sequence.
+
+ Args:
+ n_q (int): number of codebooks.
+ cached (bool): if True, patterns for a given length are cached. In general
+ that should be true for efficiency reason to avoid synchronization points.
+ """
+ def __init__(self, n_q: int, cached: bool = True):
+ assert n_q > 0
+ self.n_q = n_q
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
+
+ @abstractmethod
+ def get_pattern(self, timesteps: int) -> Pattern:
+ """Builds pattern with specific interleaving between codebooks.
+
+ Args:
+ timesteps (int): Total number of timesteps.
+ """
+ raise NotImplementedError()
+
+
+class DelayedPatternProvider(CodebooksPatternProvider):
+ """Provider for delayed pattern across delayed codebooks.
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
+ from different timesteps.
+
+ Example:
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ The resulting sequence obtained from the returned pattern is:
+ [[S, 1, 2, 3, 4],
+ [S, S, 1, 2, 3],
+ [S, S, S, 1, 2]]
+ (with S being a special token)
+
+ Args:
+ n_q (int): Number of codebooks.
+ delays (list of int, optional): Delay for each of the codebooks.
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
+ flatten_first (int): Flatten the first N timesteps.
+ empty_initial (int): Prepend with N empty list of coordinates.
+ """
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+ flatten_first: int = 0, empty_initial: int = 0):
+ super().__init__(n_q)
+ if delays is None:
+ delays = list(range(n_q))
+ self.delays = delays
+ self.flatten_first = flatten_first
+ self.empty_initial = empty_initial
+ assert len(self.delays) == self.n_q
+ assert sorted(self.delays) == self.delays
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ max_delay = max(self.delays)
+ if self.empty_initial:
+ out += [[] for _ in range(self.empty_initial)]
+ if self.flatten_first:
+ for t in range(min(timesteps, self.flatten_first)):
+ for q in range(self.n_q):
+ out.append([LayoutCoord(t, q)])
+ for t in range(self.flatten_first, timesteps + max_delay):
+ v = []
+ for q, delay in enumerate(self.delays):
+ t_for_q = t - delay
+ if t_for_q >= self.flatten_first:
+ v.append(LayoutCoord(t_for_q, q))
+ out.append(v)
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class ParallelPatternProvider(DelayedPatternProvider):
+ """Provider for parallel pattern across codebooks.
+ This pattern provider is a special case of the delayed pattern with actually no delay,
+ hence delays=repeat(0, n_q).
+
+ Args:
+ n_q (int): Number of codebooks.
+ """
+ def __init__(self, n_q: int):
+ super().__init__(n_q, [0] * n_q)
+
+
+class UnrolledPatternProvider(CodebooksPatternProvider):
+ """Provider for unrolling codebooks pattern.
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
+ while also specifying a given delay between the flattened codebooks representation, allowing to
+ unroll the codebooks in the sequence.
+
+ Example:
+ 1. Flattening of the codebooks.
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+ taking n_q = 3 and timesteps = 4:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+ and delays = [0, 3, 3]:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
+ [S, S, S, 1, S, 2, S, 3, S, 4],
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+ Args:
+ n_q (int): Number of codebooks.
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+ have n_q extra steps for each timestep.
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
+ no delay is added and therefore will default to [0] * ``n_q``.
+ Note that two codebooks that will be flattened to the same inner step
+ should have the same delay, otherwise the pattern is considered as invalid.
+ """
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+ delays: tp.Optional[tp.List[int]] = None):
+ super().__init__(n_q)
+ if flattening is None:
+ flattening = list(range(n_q))
+ if delays is None:
+ delays = [0] * n_q
+ assert len(flattening) == n_q
+ assert len(delays) == n_q
+ assert sorted(flattening) == flattening
+ assert sorted(delays) == delays
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+ self.max_delay = max(delays)
+
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+ """Build a flattened codebooks representation as a dictionary of inner step
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+ """
+ flattened_codebooks: dict = {}
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+ if inner_step not in flattened_codebooks:
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+ else:
+ flat_codebook = flattened_codebooks[inner_step]
+ assert flat_codebook.delay == delay, (
+ "Delay and flattening between codebooks is inconsistent: ",
+ "two codebooks flattened to the same position should have the same delay."
+ )
+ flat_codebook.codebooks.append(q)
+ flattened_codebooks[inner_step] = flat_codebook
+ return flattened_codebooks
+
+ @property
+ def _num_inner_steps(self):
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+ """
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+ def num_virtual_steps(self, timesteps: int) -> int:
+ return timesteps * self._num_inner_steps + 1
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ """Builds pattern for delay across codebooks.
+
+ Args:
+ timesteps (int): Total number of timesteps.
+ """
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
+ indexed_out: list = [(-1, [])]
+ max_timesteps = timesteps + self.max_delay
+ for t in range(max_timesteps):
+ # for each timestep, we unroll the flattened codebooks,
+ # emitting the sequence step with the corresponding delay
+ for step in range(self._num_inner_steps):
+ if step in self._flattened_codebooks:
+ # we have codebooks at this virtual step to emit
+ step_codebooks = self._flattened_codebooks[step]
+ t_for_q = t + step_codebooks.delay
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+ if t_for_q < max_timesteps and t < max_timesteps:
+ indexed_out.append((t_for_q, coords))
+ else:
+ # there is no codebook in this virtual step so we emit an empty list
+ indexed_out.append((t, []))
+ out = [coords for _, coords in sorted(indexed_out)]
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class VALLEPattern(CodebooksPatternProvider):
+ """Almost VALL-E style pattern.
+ We further allow some delays for the codebooks other than the first one.
+
+ Args:
+ n_q (int): Number of codebooks.
+ delays (list of int, optional): Delay for each of the codebooks.
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
+ """
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+ super().__init__(n_q)
+ if delays is None:
+ delays = [0] * (n_q - 1)
+ self.delays = delays
+ assert len(self.delays) == self.n_q - 1
+ assert sorted(self.delays) == self.delays
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ for t in range(timesteps):
+ out.append([LayoutCoord(t, 0)])
+ max_delay = max(self.delays)
+ for t in range(timesteps + max_delay):
+ v = []
+ for q, delay in enumerate(self.delays):
+ t_for_q = t - delay
+ if t_for_q >= 0:
+ v.append(LayoutCoord(t_for_q, q + 1))
+ out.append(v)
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class MusicLMPattern(CodebooksPatternProvider):
+ """Almost MusicLM style pattern. This is equivalent to full flattening
+ but in a different order.
+
+ Args:
+ n_q (int): Number of codebooks.
+ group_by (int): Number of codebooks to group together.
+ """
+ def __init__(self, n_q: int, group_by: int = 2):
+ super().__init__(n_q)
+ self.group_by = group_by
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ for offset in range(0, self.n_q, self.group_by):
+ for t in range(timesteps):
+ for q in range(offset, offset + self.group_by):
+ out.append([LayoutCoord(t, q)])
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10ac8dc96466375379c883cd62f7c04a1bb0a73
--- /dev/null
+++ b/audiocraft/modules/conditioners.py
@@ -0,0 +1,1411 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from copy import deepcopy
+from dataclasses import dataclass, field
+from itertools import chain
+import logging
+import math
+from pathlib import Path
+import random
+import re
+import typing as tp
+import warnings
+
+import einops
+from num2words import num2words
+import spacy
+from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+from .chroma import ChromaExtractor
+from .streaming import StreamingModule
+from .transformer import create_sin_embedding
+from ..data.audio import audio_read
+from ..data.audio_dataset import SegmentInfo
+from ..data.audio_utils import convert_audio
+from ..environment import AudioCraftEnvironment
+from ..quantization import ResidualVectorQuantizer
+from ..utils.autocast import TorchAutocast
+from ..utils.cache import EmbeddingCache
+from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
+
+
+logger = logging.getLogger(__name__)
+TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
+ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
+
+
+class WavCondition(tp.NamedTuple):
+ wav: torch.Tensor
+ length: torch.Tensor
+ sample_rate: tp.List[int]
+ path: tp.List[tp.Optional[str]] = []
+ seek_time: tp.List[tp.Optional[float]] = []
+
+
+class JointEmbedCondition(tp.NamedTuple):
+ wav: torch.Tensor
+ text: tp.List[tp.Optional[str]]
+ length: torch.Tensor
+ sample_rate: tp.List[int]
+ path: tp.List[tp.Optional[str]] = []
+ seek_time: tp.List[tp.Optional[float]] = []
+
+
+@dataclass
+class ConditioningAttributes:
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
+ wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
+
+ def __getitem__(self, item):
+ return getattr(self, item)
+
+ @property
+ def text_attributes(self):
+ return self.text.keys()
+
+ @property
+ def wav_attributes(self):
+ return self.wav.keys()
+
+ @property
+ def joint_embed_attributes(self):
+ return self.joint_embed.keys()
+
+ @property
+ def attributes(self):
+ return {
+ "text": self.text_attributes,
+ "wav": self.wav_attributes,
+ "joint_embed": self.joint_embed_attributes,
+ }
+
+ def to_flat_dict(self):
+ return {
+ **{f"text.{k}": v for k, v in self.text.items()},
+ **{f"wav.{k}": v for k, v in self.wav.items()},
+ **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
+ }
+
+ @classmethod
+ def from_flat_dict(cls, x):
+ out = cls()
+ for k, v in x.items():
+ kind, att = k.split(".")
+ out[kind][att] = v
+ return out
+
+
+class SegmentWithAttributes(SegmentInfo):
+ """Base class for all dataclasses that are used for conditioning.
+ All child classes should implement `to_condition_attributes` that converts
+ the existing attributes to a dataclass of type ConditioningAttributes.
+ """
+ def to_condition_attributes(self) -> ConditioningAttributes:
+ raise NotImplementedError()
+
+
+def nullify_condition(condition: ConditionType, dim: int = 1):
+ """Transform an input condition to a null condition.
+ The way it is done by converting it to a single zero vector similarly
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
+
+ Args:
+ condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
+ dim (int): The dimension that will be truncated (should be the time dimension)
+ WARNING!: dim should not be the batch dimension!
+ Returns:
+ ConditionType: A tuple of null condition and mask
+ """
+ assert dim != 0, "dim cannot be the batch dimension!"
+ assert isinstance(condition, tuple) and \
+ isinstance(condition[0], torch.Tensor) and \
+ isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
+ cond, mask = condition
+ B = cond.shape[0]
+ last_dim = cond.dim() - 1
+ out = cond.transpose(dim, last_dim)
+ out = 0. * out[..., :1]
+ out = out.transpose(dim, last_dim)
+ mask = torch.zeros((B, 1), device=out.device).int()
+ assert cond.dim() == out.dim()
+ return out, mask
+
+
+def nullify_wav(cond: WavCondition) -> WavCondition:
+ """Transform a WavCondition to a nullified WavCondition.
+ It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
+
+ Args:
+ cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
+ Returns:
+ WavCondition: Nullified wav condition.
+ """
+ null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
+ return WavCondition(
+ wav=null_wav,
+ length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
+ sample_rate=cond.sample_rate,
+ path=[None] * cond.wav.shape[0],
+ seek_time=[None] * cond.wav.shape[0],
+ )
+
+
+def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
+ """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
+ and replacing metadata by dummy attributes.
+
+ Args:
+ cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
+ """
+ null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
+ return JointEmbedCondition(
+ wav=null_wav, text=[None] * len(embed.text),
+ length=torch.LongTensor([0]).to(embed.wav.device),
+ sample_rate=embed.sample_rate,
+ path=[None] * embed.wav.shape[0],
+ seek_time=[0] * embed.wav.shape[0],
+ )
+
+
+class Tokenizer:
+ """Base tokenizer implementation
+ (in case we want to introduce more advances tokenizers in the future).
+ """
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError()
+
+
+class WhiteSpaceTokenizer(Tokenizer):
+ """This tokenizer should be used for natural language descriptions.
+ For example:
+ ["he didn't, know he's going home.", 'shorter sentence'] =>
+ [[78, 62, 31, 4, 78, 25, 19, 34],
+ [59, 77, 0, 0, 0, 0, 0, 0]]
+ """
+ PUNCTUATION = "?:!.,;"
+
+ def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
+ lemma: bool = True, stopwords: bool = True) -> None:
+ self.n_bins = n_bins
+ self.pad_idx = pad_idx
+ self.lemma = lemma
+ self.stopwords = stopwords
+ try:
+ self.nlp = spacy.load(language)
+ except IOError:
+ spacy.cli.download(language) # type: ignore
+ self.nlp = spacy.load(language)
+
+ @tp.no_type_check
+ def __call__(self, texts: tp.List[tp.Optional[str]],
+ return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Take a list of strings and convert them to a tensor of indices.
+
+ Args:
+ texts (list[str]): List of strings.
+ return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]:
+ - Indices of words in the LUT.
+ - And a mask indicating where the padding tokens are
+ """
+ output, lengths = [], []
+ texts = deepcopy(texts)
+ for i, text in enumerate(texts):
+ # if current sample doesn't have a certain attribute, replace with pad token
+ if text is None:
+ output.append(torch.Tensor([self.pad_idx]))
+ lengths.append(0)
+ continue
+
+ # convert numbers to words
+ text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
+ # normalize text
+ text = self.nlp(text) # type: ignore
+ # remove stopwords
+ if self.stopwords:
+ text = [w for w in text if not w.is_stop] # type: ignore
+ # remove punctuation
+ text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
+ # lemmatize if needed
+ text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
+
+ texts[i] = " ".join(text)
+ lengths.append(len(text))
+ # convert to tensor
+ tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
+ output.append(tokens)
+
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
+ padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
+ if return_text:
+ return padded_output, mask, texts # type: ignore
+ return padded_output, mask
+
+
+class NoopTokenizer(Tokenizer):
+ """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
+ The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
+ strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
+ split it to ["Jeff", "Buckley"] and return an index per word.
+
+ For example:
+ ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
+ ["Metal", "Rock", "Classical"] => [0, 223, 51]
+ """
+ def __init__(self, n_bins: int, pad_idx: int = 0):
+ self.n_bins = n_bins
+ self.pad_idx = pad_idx
+
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ output, lengths = [], []
+ for text in texts:
+ # if current sample doesn't have a certain attribute, replace with pad token
+ if text is None:
+ output.append(self.pad_idx)
+ lengths.append(0)
+ else:
+ output.append(hash_trick(text, self.n_bins))
+ lengths.append(1)
+
+ tokens = torch.LongTensor(output).unsqueeze(1)
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
+ return tokens, mask
+
+
+class BaseConditioner(nn.Module):
+ """Base model for all conditioner modules.
+ We allow the output dim to be different than the hidden dim for two reasons:
+ 1) keep our LUTs small when the vocab is large;
+ 2) make all condition dims consistent.
+
+ Args:
+ dim (int): Hidden dim of the model.
+ output_dim (int): Output dim of the conditioner.
+ """
+ def __init__(self, dim: int, output_dim: int):
+ super().__init__()
+ self.dim = dim
+ self.output_dim = output_dim
+ self.output_proj = nn.Linear(dim, output_dim)
+
+ def tokenize(self, *args, **kwargs) -> tp.Any:
+ """Should be any part of the processing that will lead to a synchronization
+ point, e.g. BPE tokenization with transfer to the GPU.
+
+ The returned value will be saved and return later when calling forward().
+ """
+ raise NotImplementedError()
+
+ def forward(self, inputs: tp.Any) -> ConditionType:
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+ Returns:
+ ConditionType:
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+ output embedding and D is the dimension of the embedding.
+ - And a mask indicating where the padding tokens.
+ """
+ raise NotImplementedError()
+
+
+class TextConditioner(BaseConditioner):
+ ...
+
+
+class LUTConditioner(TextConditioner):
+ """Lookup table TextConditioner.
+
+ Args:
+ n_bins (int): Number of bins.
+ dim (int): Hidden dim of the model (text-encoder/LUT).
+ output_dim (int): Output dim of the conditioner.
+ tokenizer (str): Name of the tokenizer.
+ pad_idx (int, optional): Index for padding token. Defaults to 0.
+ """
+ def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
+ super().__init__(dim, output_dim)
+ self.embed = nn.Embedding(n_bins, dim)
+ self.tokenizer: Tokenizer
+ if tokenizer == 'whitespace':
+ self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
+ elif tokenizer == 'noop':
+ self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
+ else:
+ raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
+
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ device = self.embed.weight.device
+ tokens, mask = self.tokenizer(x)
+ tokens, mask = tokens.to(device), mask.to(device)
+ return tokens, mask
+
+ def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
+ tokens, mask = inputs
+ embeds = self.embed(tokens)
+ embeds = self.output_proj(embeds)
+ embeds = (embeds * mask.unsqueeze(-1))
+ return embeds, mask
+
+
+class T5Conditioner(TextConditioner):
+ """T5-based TextConditioner.
+
+ Args:
+ name (str): Name of the T5 model.
+ output_dim (int): Output dim of the conditioner.
+ finetune (bool): Whether to fine-tune T5 at train time.
+ device (str): Device for T5 Conditioner.
+ autocast_dtype (tp.Optional[str], optional): Autocast dtype.
+ word_dropout (float, optional): Word dropout probability.
+ normalize_text (bool, optional): Whether to apply text normalization.
+ """
+ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
+ MODELS_DIMS = {
+ "t5-small": 512,
+ "t5-base": 768,
+ "t5-large": 1024,
+ "t5-3b": 1024,
+ "t5-11b": 1024,
+ "google/flan-t5-small": 512,
+ "google/flan-t5-base": 768,
+ "google/flan-t5-large": 1024,
+ "google/flan-t5-3b": 1024,
+ "google/flan-t5-11b": 1024,
+ }
+
+ def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
+ autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
+ normalize_text: bool = False):
+ assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
+ super().__init__(self.MODELS_DIMS[name], output_dim)
+ self.device = device
+ self.name = name
+ self.finetune = finetune
+ self.word_dropout = word_dropout
+ if autocast_dtype is None or self.device == 'cpu':
+ self.autocast = TorchAutocast(enabled=False)
+ if self.device != 'cpu':
+ logger.warning("T5 has no autocast, this might lead to NaN")
+ else:
+ dtype = getattr(torch, autocast_dtype)
+ assert isinstance(dtype, torch.dtype)
+ logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+ # Let's disable logging temporarily because T5 will vomit some errors otherwise.
+ # thanks https://gist.github.com/simon-weber/7853144
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
+ t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
+ finally:
+ logging.disable(previous_level)
+ if finetune:
+ self.t5 = t5
+ else:
+ # this makes sure that the t5 models is not part
+ # of the saved checkpoint
+ self.__dict__['t5'] = t5.to(device)
+
+ self.normalize_text = normalize_text
+ if normalize_text:
+ self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
+
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
+ # if current sample doesn't have a certain attribute, replace with empty string
+ entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
+ if self.normalize_text:
+ _, _, entries = self.text_normalizer(entries, return_text=True)
+ if self.word_dropout > 0. and self.training:
+ new_entries = []
+ for entry in entries:
+ words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
+ new_entries.append(" ".join(words))
+ entries = new_entries
+
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
+
+ inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
+ mask = inputs['attention_mask']
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
+ return inputs
+
+ def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
+ mask = inputs['attention_mask']
+ with torch.set_grad_enabled(self.finetune), self.autocast:
+ embeds = self.t5(**inputs).last_hidden_state
+ embeds = self.output_proj(embeds.to(self.output_proj.weight))
+ embeds = (embeds * mask.unsqueeze(-1))
+ return embeds, mask
+
+
+class WaveformConditioner(BaseConditioner):
+ """Base class for all conditioners that take a waveform as input.
+ Classes that inherit must implement `_get_wav_embedding` that outputs
+ a continuous tensor, and `_downsampling_factor` that returns the down-sampling
+ factor of the embedding model.
+
+ Args:
+ dim (int): The internal representation dimension.
+ output_dim (int): Output dimension.
+ device (tp.Union[torch.device, str]): Device.
+ """
+ def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
+ super().__init__(dim, output_dim)
+ self.device = device
+
+ def tokenize(self, x: WavCondition) -> WavCondition:
+ wav, length, sample_rate, path, seek_time = x
+ assert length is not None
+ return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
+
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
+ """Gets as input a WavCondition and returns a dense embedding."""
+ raise NotImplementedError()
+
+ def _downsampling_factor(self):
+ """Returns the downsampling factor of the embedding model."""
+ raise NotImplementedError()
+
+ def forward(self, x: WavCondition) -> ConditionType:
+ """Extract condition embedding and mask from a waveform and its metadata.
+ Args:
+ x (WavCondition): Waveform condition containing raw waveform and metadata.
+ Returns:
+ ConditionType: a dense vector representing the conditioning along with its mask
+ """
+ wav, lengths, *_ = x
+ with torch.no_grad():
+ embeds = self._get_wav_embedding(x)
+ embeds = embeds.to(self.output_proj.weight)
+ embeds = self.output_proj(embeds)
+
+ if lengths is not None:
+ lengths = lengths / self._downsampling_factor()
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
+ else:
+ mask = torch.ones_like(embeds)
+ embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+ return embeds, mask
+
+
+class ChromaStemConditioner(WaveformConditioner):
+ """Chroma conditioner based on stems.
+ The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
+ the drums and bass often dominate the chroma leading to the chroma features
+ not containing information about the melody.
+
+ Args:
+ output_dim (int): Output dimension for the conditioner.
+ sample_rate (int): Sample rate for the chroma extractor.
+ n_chroma (int): Number of chroma bins for the chroma extractor.
+ radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
+ duration (int): duration used during training. This is later used for correct padding
+ in case we are using chroma as prefix.
+ match_len_on_eval (bool, optional): if True then all chromas are padded to the training
+ duration. Defaults to False.
+ eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
+ conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
+ Defaults to None.
+ n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
+ device (tp.Union[torch.device, str], optional): Device for the conditioner.
+ **kwargs: Additional parameters for the chroma extractor.
+ """
+ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
+ duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
+ n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
+ device: tp.Union[torch.device, str] = 'cpu', **kwargs):
+ from demucs import pretrained
+ super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
+ self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
+ self.sample_rate = sample_rate
+ self.match_len_on_eval = match_len_on_eval
+ self.duration = duration
+ self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
+ stem_sources: list = self.demucs.sources # type: ignore
+ self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
+ self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
+ radix2_exp=radix2_exp, **kwargs).to(device)
+ self.chroma_len = self._get_chroma_len()
+ self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
+ self.cache = None
+ if cache_path is not None:
+ self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
+ compute_embed_fn=self._get_full_chroma_for_cache,
+ extract_embed_fn=self._extract_chroma_chunk)
+
+ def _downsampling_factor(self) -> int:
+ return self.chroma.winhop
+
+ def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
+ """Load pre-defined waveforms from a json.
+ These waveforms will be used for chroma extraction during evaluation.
+ This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
+ """
+ if path is None:
+ return None
+
+ logger.info(f"Loading evaluation wavs from {path}")
+ from audiocraft.data.audio_dataset import AudioDataset
+ dataset: AudioDataset = AudioDataset.from_meta(
+ path, segment_duration=self.duration, min_audio_duration=self.duration,
+ sample_rate=self.sample_rate, channels=1)
+
+ if len(dataset) > 0:
+ eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
+ logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
+ return eval_wavs
+ else:
+ raise ValueError("Could not find evaluation wavs, check lengths of wavs")
+
+ def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
+ self.eval_wavs = eval_wavs
+
+ def has_eval_wavs(self) -> bool:
+ return self.eval_wavs is not None
+
+ def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
+ """Sample wavs from a predefined list."""
+ assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
+ total_eval_wavs = len(self.eval_wavs)
+ out = self.eval_wavs
+ if num_samples > total_eval_wavs:
+ out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
+ return out[torch.randperm(len(out))][:num_samples]
+
+ def _get_chroma_len(self) -> int:
+ """Get length of chroma during training."""
+ dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
+ dummy_chr = self.chroma(dummy_wav)
+ return dummy_chr.shape[1]
+
+ @torch.no_grad()
+ def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+ """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
+ from demucs.apply import apply_model
+ from demucs.audio import convert_audio
+ with self.autocast:
+ wav = convert_audio(
+ wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
+ stems = apply_model(self.demucs, wav, device=self.device)
+ stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
+ mix_wav = stems.sum(1) # merge extracted stems to single waveform
+ mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
+ return mix_wav
+
+ @torch.no_grad()
+ def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
+ """Extract chroma features from the waveform."""
+ with self.autocast:
+ return self.chroma(wav)
+
+ @torch.no_grad()
+ def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+ """Compute wav embedding, applying stem and chroma extraction."""
+ # avoid 0-size tensors when we are working with null conds
+ if wav.shape[-1] == 1:
+ return self._extract_chroma(wav)
+ stems = self._get_stemmed_wav(wav, sample_rate)
+ chroma = self._extract_chroma(stems)
+ return chroma
+
+ @torch.no_grad()
+ def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
+ """Extract chroma from the whole audio waveform at the given path."""
+ wav, sr = audio_read(path)
+ wav = wav[None].to(self.device)
+ wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
+ chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
+ return chroma
+
+ def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
+ """Extract a chunk of chroma from the full chroma derived from the full waveform."""
+ wav_length = x.wav.shape[-1]
+ seek_time = x.seek_time[idx]
+ assert seek_time is not None, (
+ "WavCondition seek_time is required "
+ "when extracting chroma chunks from pre-computed chroma.")
+ full_chroma = full_chroma.float()
+ frame_rate = self.sample_rate / self._downsampling_factor()
+ target_length = int(frame_rate * wav_length / self.sample_rate)
+ index = int(frame_rate * seek_time)
+ out = full_chroma[index: index + target_length]
+ out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
+ return out.to(self.device)
+
+ @torch.no_grad()
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
+ """Get the wav embedding from the WavCondition.
+ The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
+ or will rely on the embedding cache to load the pre-computed embedding if relevant.
+ """
+ sampled_wav: tp.Optional[torch.Tensor] = None
+ if not self.training and self.eval_wavs is not None:
+ warn_once(logger, "Using precomputed evaluation wavs!")
+ sampled_wav = self._sample_eval_wavs(len(x.wav))
+
+ no_undefined_paths = all(p is not None for p in x.path)
+ no_nullified_cond = x.wav.shape[-1] > 1
+ if sampled_wav is not None:
+ chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
+ elif self.cache is not None and no_undefined_paths and no_nullified_cond:
+ paths = [Path(p) for p in x.path if p is not None]
+ chroma = self.cache.get_embed_from_cache(paths, x)
+ else:
+ assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
+ chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
+
+ if self.match_len_on_eval:
+ B, T, C = chroma.shape
+ if T > self.chroma_len:
+ chroma = chroma[:, :self.chroma_len]
+ logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
+ elif T < self.chroma_len:
+ n_repeat = int(math.ceil(self.chroma_len / T))
+ chroma = chroma.repeat(1, n_repeat, 1)
+ chroma = chroma[:, :self.chroma_len]
+ logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
+
+ return chroma
+
+ def tokenize(self, x: WavCondition) -> WavCondition:
+ """Apply WavConditioner tokenization and populate cache if needed."""
+ x = super().tokenize(x)
+ no_undefined_paths = all(p is not None for p in x.path)
+ if self.cache is not None and no_undefined_paths:
+ paths = [Path(p) for p in x.path if p is not None]
+ self.cache.populate_embed_cache(paths, x)
+ return x
+
+
+class JointEmbeddingConditioner(BaseConditioner):
+ """Joint embedding conditioning supporting both audio or text conditioning.
+
+ Args:
+ dim (int): Dimension.
+ output_dim (int): Output dimension.
+ device (str): Device.
+ attribute (str): Attribute used by the conditioner.
+ autocast_dtype (str): Autocast for the conditioner.
+ quantize (bool): Whether to quantize the CLAP embedding.
+ n_q (int): Number of residual quantizers (used if quantize is true).
+ bins (int): Quantizers' codebooks size (used if quantize is true).
+ kwargs: Additional parameters for residual vector quantizer.
+ """
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
+ autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
+ n_q: int = 12, bins: int = 1024, **kwargs):
+ super().__init__(dim=dim, output_dim=output_dim)
+ self.device = device
+ self.attribute = attribute
+ if autocast_dtype is None or device == 'cpu':
+ self.autocast = TorchAutocast(enabled=False)
+ logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
+ else:
+ dtype = getattr(torch, autocast_dtype)
+ assert isinstance(dtype, torch.dtype)
+ logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+ # residual vector quantizer to discretize the conditioned embedding
+ self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
+ if quantize:
+ self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
+
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Get joint embedding in latent space from the inputs.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
+ and corresponding empty indexes.
+ """
+ raise NotImplementedError()
+
+ def forward(self, x: JointEmbedCondition) -> ConditionType:
+ with self.autocast:
+ embed, empty_idx = self._get_embed(x)
+ if self.quantizer is not None:
+ embed = embed.view(-1, self.dim, 1)
+ q_res = self.quantizer(embed, frame_rate=1)
+ out_embed = q_res.x.view(-1, self.dim)
+ else:
+ out_embed = embed
+ out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
+ mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
+ out_embed = (out_embed * mask.unsqueeze(-1))
+ return out_embed, mask
+
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
+ return x
+
+
+class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
+ """Joint Embedding conditioner based on pre-trained CLAP model.
+
+ This CLAP-based conditioner supports a caching mechanism
+ over the computed embeddings for faster training.
+
+ Args:
+ dim (int): Dimension.
+ output_dim (int): Output dimension.
+ device (str): Device.
+ attribute (str): Attribute used by the conditioner.
+ quantize (bool): Whether to quantize the CLAP embedding.
+ n_q (int): Number of residual quantizers (used if quantize is true).
+ bins (int): Quantizers' codebooks size (used if quantize is true).
+ checkpoint (str): Path to CLAP checkpoint.
+ model_arch (str): CLAP model architecture.
+ enable_fusion (bool): Enable fusion for CLAP model.
+ sample_rate (int): Sample rate used by CLAP model.
+ max_audio_length (float): Maximum audio length for CLAP model.
+ audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
+ normalize (bool): Whether to normalize the CLAP embedding.
+ text_p (float): Probability of using text representation instead of audio at train time.
+ batch_size (Optional[int]): Batch size for CLAP embedding computation.
+ autocast_dtype (str): Autocast for the conditioner.
+ cache_path (Optional[str]): Path for pre-computed embeddings caching.
+ kwargs: Additional parameters for residual vector quantizer.
+ """
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
+ quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
+ enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
+ normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
+ autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
+ try:
+ import laion_clap # type: ignore
+ except ImportError:
+ raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
+ checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
+ clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+ clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
+ load_clap_state_dict(clap_model, checkpoint)
+ clap_model.eval()
+ clap_model.to(device)
+ super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
+ autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
+ **kwargs)
+ self.checkpoint = checkpoint
+ self.enable_fusion = enable_fusion
+ self.model_arch = model_arch
+ self.clap: laion_clap.CLAP_Module
+ self.clap_tokenize: RobertaTokenizer
+ self.clap_sample_rate = sample_rate
+ self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
+ self.clap_stride = int(self.clap_sample_rate * audio_stride)
+ self.batch_size = batch_size or 1
+ self.normalize = normalize
+ self.text_p = text_p
+ self.__dict__['clap_tokenize'] = clap_tokenize
+ self.__dict__['clap'] = clap_model
+ self.wav_cache, self.text_cache = None, None
+ if cache_path is not None:
+ self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
+ compute_embed_fn=self._get_wav_embedding_for_cache,
+ extract_embed_fn=self._extract_wav_embedding_chunk)
+ self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
+ compute_embed_fn=self._get_text_embedding_for_cache)
+
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
+ # we use the default params from CLAP module here as well
+ return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
+
+ def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
+ """Compute text embedding from CLAP model on a given a batch of text.
+
+ Args:
+ text (list[str]): List of text for the batch, with B items.
+ Returns:
+ torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
+ """
+ with torch.no_grad():
+ embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+ return embed.view(embed.size(0), 1, embed.size(-1))
+
+ def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
+ """Get text embedding function for the cache."""
+ text = x.text[idx]
+ text = text if text is not None else ""
+ return self._compute_text_embedding([text])[0]
+
+ def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
+ """Preprocess wav to expected format by CLAP model.
+
+ Args:
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
+ sample_rates (list[int]): Sample rates for each sample in the batch
+ Returns:
+ torch.Tensor: Audio wav of shape [B, T].
+ """
+ assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
+ if sample_rates is not None:
+ _wav = []
+ for i, audio in enumerate(wav):
+ sr = sample_rates[i]
+ audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
+ _wav.append(audio)
+ wav = torch.stack(_wav, dim=0)
+ wav = wav.mean(dim=1)
+ return wav
+
+ def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
+ sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
+ """Compute audio wave embedding from CLAP model.
+
+ Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
+ we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
+ average the resulting embeddings.
+
+ Args:
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
+ sample_rates (list[int]): Sample rates for each sample in the batch.
+ reduce_mean (bool): Whether to get the average tensor.
+ Returns:
+ torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
+ """
+ with torch.no_grad():
+ wav = self._preprocess_wav(wav, length, sample_rates)
+ B, T = wav.shape
+ if T >= self.clap_max_frames:
+ wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
+ else:
+ wav = wav.view(-1, 1, T) # [B, F, T] with F=1
+ wav = einops.rearrange(wav, 'b f t -> (b f) t')
+ embed_list = []
+ for i in range(0, wav.size(0), self.batch_size):
+ _wav = wav[i:i+self.batch_size, ...]
+ _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
+ embed_list.append(_embed)
+ embed = torch.cat(embed_list, dim=0)
+ embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
+ if reduce_mean:
+ embed = embed.mean(dim=1, keepdim=True)
+ return embed # [B, F, D] with F=1 if reduce_mean is True
+
+ def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
+ """Compute audio wave embedding for the cache.
+ The embedding is computed on a given audio read from file.
+
+ Args:
+ path (str or Path): Path to the full audio file.
+ Returns:
+ torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
+ """
+ wav, sr = audio_read(path) # [C, T]
+ wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
+ wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
+ embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
+ return embed.squeeze(0) # [F, D]
+
+ def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
+ """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
+
+ Args:
+ full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
+ x (JointEmbedCondition): Joint embedding condition for the full batch.
+ idx (int): Index considered for the given embedding to extract.
+ Returns:
+ torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
+ """
+ sample_rate = x.sample_rate[idx]
+ seek_time = x.seek_time[idx]
+ seek_time = 0. if seek_time is None else seek_time
+ clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
+ end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
+ start_offset = int(seek_time * sample_rate // clap_stride)
+ end_offset = int(end_seek_time * sample_rate // clap_stride)
+ wav_embed = full_embed[start_offset:end_offset, ...]
+ wav_embed = wav_embed.mean(dim=0, keepdim=True)
+ return wav_embed.to(self.device) # [F, D]
+
+ def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
+ """Get CLAP embedding from a batch of text descriptions."""
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
+ if self.text_cache is not None and no_nullified_cond:
+ assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
+ paths = [Path(p) for p in x.path if p is not None]
+ embed = self.text_cache.get_embed_from_cache(paths, x)
+ else:
+ text = [xi if xi is not None else "" for xi in x.text]
+ embed = self._compute_text_embedding(text)
+ if self.normalize:
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
+ return embed
+
+ def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
+ """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
+ no_undefined_paths = all(p is not None for p in x.path)
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
+ if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
+ paths = [Path(p) for p in x.path if p is not None]
+ embed = self.wav_cache.get_embed_from_cache(paths, x)
+ else:
+ embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
+ if self.normalize:
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
+ return embed
+
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
+ # Trying to limit as much as possible sync points when the cache is warm.
+ no_undefined_paths = all(p is not None for p in x.path)
+ if self.wav_cache is not None and no_undefined_paths:
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
+ paths = [Path(p) for p in x.path if p is not None]
+ self.wav_cache.populate_embed_cache(paths, x)
+ if self.text_cache is not None and no_undefined_paths:
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
+ paths = [Path(p) for p in x.path if p is not None]
+ self.text_cache.populate_embed_cache(paths, x)
+ return x
+
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Extract shared latent representation from either the wav or the text using CLAP."""
+ # decide whether to use text embedding at train time or not
+ use_text_embed = random.random() < self.text_p
+ if self.training and not use_text_embed:
+ embed = self._get_wav_embedding(x)
+ empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
+ else:
+ embed = self._get_text_embedding(x)
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
+ return embed, empty_idx
+
+
+def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
+ If the condition is of type "wav", then nullify it using `nullify_condition` function.
+ If the condition is of any other type, set its value to None.
+ Works in-place.
+ """
+ if condition_type not in ['text', 'wav', 'joint_embed']:
+ raise ValueError(
+ "dropout_condition got an unexpected condition type!"
+ f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
+ )
+
+ if condition not in getattr(sample, condition_type):
+ raise ValueError(
+ "dropout_condition received an unexpected condition!"
+ f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
+ f" but got '{condition}' of type '{condition_type}'!"
+ )
+
+ if condition_type == 'wav':
+ wav_cond = sample.wav[condition]
+ sample.wav[condition] = nullify_wav(wav_cond)
+ elif condition_type == 'joint_embed':
+ embed = sample.joint_embed[condition]
+ sample.joint_embed[condition] = nullify_joint_embed(embed)
+ else:
+ sample.text[condition] = None
+
+ return sample
+
+
+class DropoutModule(nn.Module):
+ """Base module for all dropout modules."""
+ def __init__(self, seed: int = 1234):
+ super().__init__()
+ self.rng = torch.Generator()
+ self.rng.manual_seed(seed)
+
+
+class AttributeDropout(DropoutModule):
+ """Dropout with a given probability per attribute.
+ This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
+ to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
+ This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
+ must also be dropped.
+
+ Args:
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
+ ...
+ "genre": 0.1,
+ "artist": 0.5,
+ "wav": 0.25,
+ ...
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
+ seed (int, optional): Random seed.
+ """
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
+ super().__init__(seed=seed)
+ self.active_on_eval = active_on_eval
+ # construct dict that return the values from p otherwise 0
+ self.p = {}
+ for condition_type, probs in p.items():
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
+
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+ """
+ Args:
+ samples (list[ConditioningAttributes]): List of conditions.
+ Returns:
+ list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+ """
+ if not self.training and not self.active_on_eval:
+ return samples
+
+ samples = deepcopy(samples)
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
+ if torch.rand(1, generator=self.rng).item() < p:
+ for sample in samples:
+ dropout_condition(sample, condition_type, condition)
+ return samples
+
+ def __repr__(self):
+ return f"AttributeDropout({dict(self.p)})"
+
+
+class ClassifierFreeGuidanceDropout(DropoutModule):
+ """Classifier Free Guidance dropout.
+ All attributes are dropped with the same probability.
+
+ Args:
+ p (float): Probability to apply condition dropout during training.
+ seed (int): Random seed.
+ """
+ def __init__(self, p: float, seed: int = 1234):
+ super().__init__(seed=seed)
+ self.p = p
+
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+ """
+ Args:
+ samples (list[ConditioningAttributes]): List of conditions.
+ Returns:
+ list[ConditioningAttributes]: List of conditions after all attributes were set to None.
+ """
+ if not self.training:
+ return samples
+
+ # decide on which attributes to drop in a batched fashion
+ drop = torch.rand(1, generator=self.rng).item() < self.p
+ if not drop:
+ return samples
+
+ # nullify conditions of all attributes
+ samples = deepcopy(samples)
+ for condition_type in ["wav", "text"]:
+ for sample in samples:
+ for condition in sample.attributes[condition_type]:
+ dropout_condition(sample, condition_type, condition)
+ return samples
+
+ def __repr__(self):
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
+
+
+class ConditioningProvider(nn.Module):
+ """Prepare and provide conditions given all the supported conditioners.
+
+ Args:
+ conditioners (dict): Dictionary of conditioners.
+ device (torch.device or str, optional): Device for conditioners and output condition types.
+ """
+ def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
+ super().__init__()
+ self.device = device
+ self.conditioners = nn.ModuleDict(conditioners)
+
+ @property
+ def joint_embed_conditions(self):
+ return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
+
+ @property
+ def has_joint_embed_conditions(self):
+ return len(self.joint_embed_conditions) > 0
+
+ @property
+ def text_conditions(self):
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+ @property
+ def wav_conditions(self):
+ return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+ @property
+ def has_wav_condition(self):
+ return len(self.wav_conditions) > 0
+
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+ This should be called before starting any real GPU work to avoid synchronization points.
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+ Args:
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
+ text and wav conditions.
+ """
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
+ f" but types were {set([type(x) for x in inputs])}"
+ )
+
+ output = {}
+ text = self._collate_text(inputs)
+ wavs = self._collate_wavs(inputs)
+ joint_embeds = self._collate_joint_embeds(inputs)
+
+ assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
+ f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
+ )
+
+ for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
+ return output
+
+ def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+ """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
+ The output is for example:
+ {
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+ ...
+ }
+
+ Args:
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+ """
+ output = {}
+ for attribute, inputs in tokenized.items():
+ condition, mask = self.conditioners[attribute](inputs)
+ output[attribute] = (condition, mask)
+ return output
+
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
+ are the attributes and the values are the aggregated input per attribute.
+ For example:
+ Input:
+ [
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
+ ]
+ Output:
+ {
+ "genre": ["Rock", "Hip-hop"],
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
+ }
+
+ Args:
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
+ Returns:
+ dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
+ """
+ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
+ texts = [x.text for x in samples]
+ for text in texts:
+ for condition in self.text_conditions:
+ out[condition].append(text[condition])
+ return out
+
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
+ """Generate a dict where the keys are attributes by which we fetch similar wavs,
+ and the values are Tensors of wavs according to said attributes.
+
+ *Note*: by the time the samples reach this function, each sample should have some waveform
+ inside the "wav" attribute. It should be either:
+ 1. A real waveform
+ 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
+ 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
+
+ Args:
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
+ Returns:
+ dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
+ """
+ wavs = defaultdict(list)
+ lengths = defaultdict(list)
+ sample_rates = defaultdict(list)
+ paths = defaultdict(list)
+ seek_times = defaultdict(list)
+ out: tp.Dict[str, WavCondition] = {}
+
+ for sample in samples:
+ for attribute in self.wav_conditions:
+ wav, length, sample_rate, path, seek_time = sample.wav[attribute]
+ assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
+ assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
+ # mono-channel conditioning
+ wav = wav.mean(1, keepdim=True) # [1, 1, T]
+ wavs[attribute].append(wav.flatten()) # [T]
+ lengths[attribute].append(length)
+ sample_rates[attribute].extend(sample_rate)
+ paths[attribute].extend(path)
+ seek_times[attribute].extend(seek_time)
+
+ # stack all wavs to a single tensor
+ for attribute in self.wav_conditions:
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
+ out[attribute] = WavCondition(
+ stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
+ paths[attribute], seek_times[attribute])
+
+ return out
+
+ def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
+ """Generate a dict where the keys are attributes by which we compute joint embeddings,
+ and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
+
+ Args:
+ samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
+ Returns:
+ A dictionary mapping an attribute name to joint embeddings.
+ """
+ texts = defaultdict(list)
+ wavs = defaultdict(list)
+ lengths = defaultdict(list)
+ sample_rates = defaultdict(list)
+ paths = defaultdict(list)
+ seek_times = defaultdict(list)
+ channels: int = 0
+
+ out = {}
+ for sample in samples:
+ for attribute in self.joint_embed_conditions:
+ wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
+ assert wav.dim() == 3
+ if channels == 0:
+ channels = wav.size(1)
+ else:
+ assert channels == wav.size(1), "not all audio has same number of channels in batch"
+ assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
+ wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
+ wavs[attribute].append(wav)
+ texts[attribute].extend(text)
+ lengths[attribute].append(length)
+ sample_rates[attribute].extend(sample_rate)
+ paths[attribute].extend(path)
+ seek_times[attribute].extend(seek_time)
+
+ for attribute in self.joint_embed_conditions:
+ stacked_texts = texts[attribute]
+ stacked_paths = paths[attribute]
+ stacked_seek_times = seek_times[attribute]
+ stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
+ stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
+ stacked_sample_rates = sample_rates[attribute]
+ stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
+ assert stacked_lengths.size(0) == stacked_wavs.size(0)
+ assert len(stacked_sample_rates) == stacked_wavs.size(0)
+ assert len(stacked_texts) == stacked_wavs.size(0)
+ out[attribute] = JointEmbedCondition(
+ text=stacked_texts, wav=stacked_wavs,
+ length=stacked_lengths, sample_rate=stacked_sample_rates,
+ path=stacked_paths, seek_time=stacked_seek_times)
+
+ return out
+
+
+class ConditionFuser(StreamingModule):
+ """Condition fuser handles the logic to combine the different conditions
+ to the actual model input.
+
+ Args:
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
+ each condition. For example:
+ {
+ "prepend": ["description"],
+ "sum": ["genre", "bpm"],
+ "cross": ["description"],
+ }
+ cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
+ cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
+ """
+ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
+
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
+ cross_attention_pos_emb_scale: float = 1.0):
+ super().__init__()
+ assert all(
+ [k in self.FUSING_METHODS for k in fuse2cond.keys()]
+ ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
+ self.cross_attention_pos_emb = cross_attention_pos_emb
+ self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
+ self.cond2fuse: tp.Dict[str, str] = {}
+ for fuse_method, conditions in fuse2cond.items():
+ for condition in conditions:
+ self.cond2fuse[condition] = fuse_method
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ conditions: tp.Dict[str, ConditionType]
+ ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ """Fuse the conditions to the provided model input.
+
+ Args:
+ input (torch.Tensor): Transformer input.
+ conditions (dict[str, ConditionType]): Dict of conditions.
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
+ after the conditions have been fused. The second output tensor is the tensor
+ used for cross-attention or None if no cross attention inputs exist.
+ """
+ B, T, _ = input.shape
+
+ if 'offsets' in self._streaming_state:
+ first_step = False
+ offsets = self._streaming_state['offsets']
+ else:
+ first_step = True
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+ f"given conditions contain unknown attributes for fuser, " \
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+ cross_attention_output = None
+ for cond_type, (cond, cond_mask) in conditions.items():
+ op = self.cond2fuse[cond_type]
+ if op == 'sum':
+ input += cond
+ elif op == 'input_interpolate':
+ cond = einops.rearrange(cond, "b t d -> b d t")
+ cond = F.interpolate(cond, size=input.shape[1])
+ input += einops.rearrange(cond, "b d t -> b t d")
+ elif op == 'prepend':
+ if first_step:
+ input = torch.cat([cond, input], dim=1)
+ elif op == 'cross':
+ if cross_attention_output is not None:
+ cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+ else:
+ cross_attention_output = cond
+ else:
+ raise ValueError(f"unknown op ({op})")
+
+ if self.cross_attention_pos_emb and cross_attention_output is not None:
+ positions = torch.arange(
+ cross_attention_output.shape[1],
+ device=cross_attention_output.device
+ ).view(1, -1, 1)
+ pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+ cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+ if self._is_streaming:
+ self._streaming_state['offsets'] = offsets + T
+
+ return input, cross_attention_output
diff --git a/audiocraft/modules/conv.py b/audiocraft/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..d115cbf8729b642ed78608bd00a4d0fd5afae6fd
--- /dev/null
+++ b/audiocraft/modules/conv.py
@@ -0,0 +1,243 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+
+CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
+ 'time_group_norm'])
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'weight_norm':
+ return weight_norm(module)
+ elif norm == 'spectral_norm':
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'time_group_norm':
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+ padding_total: int = 0) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left: end]
+
+
+class NormConv1d(nn.Module):
+ """Wrapper around Conv1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose1d(nn.Module):
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose2d(nn.Module):
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class StreamableConv1d(nn.Module):
+ """Conv1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, dilation: int = 1,
+ groups: int = 1, bias: bool = True, causal: bool = False,
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+ pad_mode: str = 'reflect'):
+ super().__init__()
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
+ norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.pad_mode = pad_mode
+
+ def forward(self, x):
+ B, C, T = x.shape
+ kernel_size = self.conv.conv.kernel_size[0]
+ stride = self.conv.conv.stride[0]
+ dilation = self.conv.conv.dilation[0]
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
+ padding_total = kernel_size - stride
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ if self.causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+ return self.conv(x)
+
+
+class StreamableConvTranspose1d(nn.Module):
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, causal: bool = False,
+ norm: str = 'none', trim_right_ratio: float = 1.,
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
+ super().__init__()
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.trim_right_ratio = trim_right_ratio
+ assert self.causal or self.trim_right_ratio == 1., \
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+ def forward(self, x):
+ kernel_size = self.convtr.convtr.kernel_size[0]
+ stride = self.convtr.convtr.stride[0]
+ padding_total = kernel_size - stride
+
+ y = self.convtr(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
diff --git a/audiocraft/modules/diffusion_schedule.py b/audiocraft/modules/diffusion_schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ca6e3f2e7c4ff904d96dade315b0b46856778d
--- /dev/null
+++ b/audiocraft/modules/diffusion_schedule.py
@@ -0,0 +1,272 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
+"""
+
+from collections import namedtuple
+import random
+import typing as tp
+import julius
+import torch
+
+TrainingItem = namedtuple("TrainingItem", "noisy noise step")
+
+
+def betas_from_alpha_bar(alpha_bar):
+ alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
+ return 1 - alphas
+
+
+class SampleProcessor(torch.nn.Module):
+ def project_sample(self, x: torch.Tensor):
+ """Project the original sample to the 'space' where the diffusion will happen."""
+ return x
+
+ def return_sample(self, z: torch.Tensor):
+ """Project back from diffusion space to the actual sample space."""
+ return z
+
+
+class MultiBandProcessor(SampleProcessor):
+ """
+ MultiBand sample processor. The input audio is splitted across
+ frequency bands evenly distributed in mel-scale.
+
+ Each band will be rescaled to match the power distribution
+ of Gaussian noise in that band, using online metrics
+ computed on the first few samples.
+
+ Args:
+ n_bands (int): Number of mel-bands to split the signal over.
+ sample_rate (int): Sample rate of the audio.
+ num_samples (int): Number of samples to use to fit the rescaling
+ for each band. The processor won't be stable
+ until it has seen that many samples.
+ power_std (float or list/tensor): The rescaling factor computed to match the
+ power of Gaussian noise in each band is taken to
+ that power, i.e. `1.` means full correction of the energy
+ in each band, and values less than `1` means only partial
+ correction. Can be used to balance the relative importance
+ of low vs. high freq in typical audio signals.
+ """
+ def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
+ num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
+ super().__init__()
+ self.n_bands = n_bands
+ self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
+ self.num_samples = num_samples
+ self.power_std = power_std
+ if isinstance(power_std, list):
+ assert len(power_std) == n_bands
+ power_std = torch.tensor(power_std)
+ self.register_buffer('counts', torch.zeros(1))
+ self.register_buffer('sum_x', torch.zeros(n_bands))
+ self.register_buffer('sum_x2', torch.zeros(n_bands))
+ self.register_buffer('sum_target_x2', torch.zeros(n_bands))
+ self.counts: torch.Tensor
+ self.sum_x: torch.Tensor
+ self.sum_x2: torch.Tensor
+ self.sum_target_x2: torch.Tensor
+
+ @property
+ def mean(self):
+ mean = self.sum_x / self.counts
+ return mean
+
+ @property
+ def std(self):
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
+ return std
+
+ @property
+ def target_std(self):
+ target_std = self.sum_target_x2 / self.counts
+ return target_std
+
+ def project_sample(self, x: torch.Tensor):
+ assert x.dim() == 3
+ bands = self.split_bands(x)
+ if self.counts.item() < self.num_samples:
+ ref_bands = self.split_bands(torch.randn_like(x))
+ self.counts += len(x)
+ self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
+ self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
+ self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
+ bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
+ return bands.sum(dim=0)
+
+ def return_sample(self, x: torch.Tensor):
+ assert x.dim() == 3
+ bands = self.split_bands(x)
+ rescale = (self.std / self.target_std) ** self.power_std
+ bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
+ return bands.sum(dim=0)
+
+
+class NoiseSchedule:
+ """Noise schedule for diffusion.
+
+ Args:
+ beta_t0 (float): Variance of the first diffusion step.
+ beta_t1 (float): Variance of the last diffusion step.
+ beta_exp (float): Power schedule exponent
+ num_steps (int): Number of diffusion step.
+ variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
+ clip (float): clipping value for the denoising steps
+ rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
+ repartition (str): shape of the schedule only power schedule is supported
+ sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
+ noise_scale (float): Scaling factor for the noise
+ """
+ def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
+ clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
+ repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
+ sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
+
+ self.beta_t0 = beta_t0
+ self.beta_t1 = beta_t1
+ self.variance = variance
+ self.num_steps = num_steps
+ self.clip = clip
+ self.sample_processor = sample_processor
+ self.rescale = rescale
+ self.n_bands = n_bands
+ self.noise_scale = noise_scale
+ assert n_bands is None
+ if repartition == "power":
+ self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
+ device=device, dtype=torch.float) ** beta_exp
+ else:
+ raise RuntimeError('Not implemented')
+ self.rng = random.Random(1234)
+
+ def get_beta(self, step: tp.Union[int, torch.Tensor]):
+ if self.n_bands is None:
+ return self.betas[step]
+ else:
+ return self.betas[:, step] # [n_bands, len(step)]
+
+ def get_initial_noise(self, x: torch.Tensor):
+ if self.n_bands is None:
+ return torch.randn_like(x)
+ return torch.randn((x.size(0), self.n_bands, x.size(2)))
+
+ def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
+ """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
+ if step is None:
+ return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
+ if type(step) is int:
+ return (1 - self.betas[:step + 1]).prod()
+ else:
+ return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
+
+ def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
+ """Create a noisy data item for diffusion model training:
+
+ Args:
+ x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
+ tensor_step (bool): If tensor_step = false, only one step t is sample,
+ the whole batch is diffused to the same step and t is int.
+ If tensor_step = true, t is a tensor of size (x.size(0),)
+ every element of the batch is diffused to a independently sampled.
+ """
+ step: tp.Union[int, torch.Tensor]
+ if tensor_step:
+ bs = x.size(0)
+ step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
+ else:
+ step = self.rng.randrange(self.num_steps)
+ alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
+
+ x = self.sample_processor.project_sample(x)
+ noise = torch.randn_like(x)
+ noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
+ return TrainingItem(noisy, noise, step)
+
+ def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+ """Full ddpm reverse process.
+
+ Args:
+ model (nn.Module): Diffusion model.
+ initial (tensor): Initial Noise.
+ condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
+ return_list (bool): Whether to return the whole process or only the sampled point.
+ """
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+ current = initial
+ iterates = [initial]
+ for step in range(self.num_steps)[::-1]:
+ with torch.no_grad():
+ estimate = model(current, step, condition=condition).sample
+ alpha = 1 - self.betas[step]
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+ previous_alpha_bar = self.get_alpha_bar(step=step - 1)
+ if step == 0:
+ sigma2 = 0
+ elif self.variance == 'beta':
+ sigma2 = 1 - alpha
+ elif self.variance == 'beta_tilde':
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+ elif self.variance == 'none':
+ sigma2 = 0
+ else:
+ raise ValueError(f'Invalid variance type {self.variance}')
+
+ if sigma2 > 0:
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+ if self.clip:
+ previous = previous.clamp(-self.clip, self.clip)
+ current = previous
+ alpha_bar = previous_alpha_bar
+ if step == 0:
+ previous *= self.rescale
+ if return_list:
+ iterates.append(previous.cpu())
+
+ if return_list:
+ return iterates
+ else:
+ return self.sample_processor.return_sample(previous)
+
+ def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+ """Reverse process that only goes through Markov chain states in step_list."""
+ if step_list is None:
+ step_list = list(range(1000))[::-50] + [0]
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+ alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
+ betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
+ current = initial * self.noise_scale
+ iterates = [current]
+ for idx, step in enumerate(step_list[:-1]):
+ with torch.no_grad():
+ estimate = model(current, step, condition=condition).sample * self.noise_scale
+ alpha = 1 - betas_subsampled[-1 - idx]
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+ previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
+ if step == step_list[-2]:
+ sigma2 = 0
+ previous_alpha_bar = torch.tensor(1.0)
+ else:
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+ if sigma2 > 0:
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+ if self.clip:
+ previous = previous.clamp(-self.clip, self.clip)
+ current = previous
+ alpha_bar = previous_alpha_bar
+ if step == 0:
+ previous *= self.rescale
+ if return_list:
+ iterates.append(previous.cpu())
+ if return_list:
+ return iterates
+ else:
+ return self.sample_processor.return_sample(previous)
diff --git a/audiocraft/modules/lstm.py b/audiocraft/modules/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0866175950c1ca4f6cca98649525e6481853bba
--- /dev/null
+++ b/audiocraft/modules/lstm.py
@@ -0,0 +1,25 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import nn
+
+
+class StreamableLSTM(nn.Module):
+ """LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+ super().__init__()
+ self.skip = skip
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+ def forward(self, x):
+ x = x.permute(2, 0, 1)
+ y, _ = self.lstm(x)
+ if self.skip:
+ y = y + x
+ y = y.permute(1, 2, 0)
+ return y
diff --git a/audiocraft/modules/rope.py b/audiocraft/modules/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..503e6748df2bb72b3c864c20b37cba5498ffdd21
--- /dev/null
+++ b/audiocraft/modules/rope.py
@@ -0,0 +1,121 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch import nn
+import torch
+
+
+class XPos(nn.Module):
+ """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
+ This applies an exponential decay to the RoPE rotation matrix.
+
+ Args:
+ dim (int): Embedding dimension.
+ smoothing (float): Smoothing factor applied to the decay rates.
+ base_scale (int): Base decay rate, given in terms of scaling time.
+ device (torch.device, optional): Device on which to initialize the module.
+ dtype (torch.dtype): dtype to use to generate the embedding.
+ """
+ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
+ device=None, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ assert dim % 2 == 0
+ assert dtype in [torch.float64, torch.float32]
+ self.dtype = dtype
+ self.base_scale = base_scale
+
+ half_dim = dim // 2
+ adim = torch.arange(half_dim, device=device, dtype=dtype)
+ decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
+ self.register_buffer("decay_rates", decay_rates)
+ self.decay: tp.Optional[torch.Tensor] = None
+
+ def get_decay(self, start: int, end: int):
+ """Create complex decay tensor, cache values for fast computation."""
+ if self.decay is None or end > self.decay.shape[0]:
+ assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
+ idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+ power = idx / self.base_scale
+ scale = self.decay_rates ** power.unsqueeze(-1)
+ self.decay = torch.polar(scale, torch.zeros_like(scale))
+ return self.decay[start:end] # [T, C/2]
+
+
+class RotaryEmbedding(nn.Module):
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
+
+ Args:
+ dim (int): Embedding dimension (twice the number of frequencies).
+ max_period (float): Maximum period of the rotation frequencies.
+ xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
+ scale (float): Scale of positional embedding, set to 0 to deactivate.
+ device (torch.device, optional): Device on which to initialize the module.
+ dtype (torch.dtype): dtype to use to generate the embedding.
+ """
+ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
+ scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ assert dim % 2 == 0
+ self.scale = scale
+ assert dtype in [torch.float64, torch.float32]
+ self.dtype = dtype
+
+ adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
+ frequencies = 1.0 / (max_period ** (adim / dim))
+ self.register_buffer("frequencies", frequencies)
+ self.rotation: tp.Optional[torch.Tensor] = None
+
+ self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
+
+ def get_rotation(self, start: int, end: int):
+ """Create complex rotation tensor, cache values for fast computation."""
+ if self.rotation is None or end > self.rotation.shape[0]:
+ assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
+ idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+ angles = torch.outer(idx, self.frequencies)
+ self.rotation = torch.polar(torch.ones_like(angles), angles)
+ return self.rotation[start:end]
+
+ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
+ """Apply rope rotation to query or key tensor."""
+ T = x.shape[1]
+ rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
+
+ if self.xpos:
+ decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
+ else:
+ decay = 1.0
+
+ if invert_decay:
+ decay = decay ** -1
+
+ x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+ scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+ x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
+
+ return x_out.type_as(x)
+
+ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
+ """ Apply rope rotation to both query and key tensors.
+ Supports streaming mode, in which query and key are not expected to have the same shape.
+ In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
+ query will be [C] (typically C == 1).
+
+ Args:
+ query (torch.Tensor): Query to rotate.
+ key (torch.Tensor): Key to rotate.
+ start (int): Start index of the sequence for time offset.
+ """
+ query_timesteps = query.shape[1]
+ key_timesteps = key.shape[1]
+ streaming_offset = key_timesteps - query_timesteps
+
+ query_out = self.rotate(query, start + streaming_offset)
+ key_out = self.rotate(key, start, invert_decay=True)
+
+ return query_out, key_out
diff --git a/audiocraft/modules/seanet.py b/audiocraft/modules/seanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e5998e9153afb6e68ea410d565e00ea835db248
--- /dev/null
+++ b/audiocraft/modules/seanet.py
@@ -0,0 +1,258 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+
+from .conv import StreamableConv1d, StreamableConvTranspose1d
+from .lstm import StreamableLSTM
+
+
+class SEANetResnetBlock(nn.Module):
+ """Residual block from SEANet model.
+
+ Args:
+ dim (int): Dimension of the input/output.
+ kernel_sizes (list): List of kernel sizes for the convolutions.
+ dilations (list): List of dilations for the convolutions.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection.
+ """
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+ super().__init__()
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+ act = getattr(nn, activation)
+ hidden = dim // compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [
+ act(**activation_params),
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+ norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode),
+ ]
+ self.block = nn.Sequential(*block)
+ self.shortcut: nn.Module
+ if true_skip:
+ self.shortcut = nn.Identity()
+ else:
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+
+ def forward(self, x):
+ return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+ """SEANet encoder.
+
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+ For the encoder, it corresponds to the N first blocks.
+ """
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+ disable_norm_outer_blocks: int = 0):
+ super().__init__()
+ self.channels = channels
+ self.dimension = dimension
+ self.n_filters = n_filters
+ self.ratios = list(reversed(ratios))
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+ "Number of blocks for which to disable norm is invalid." \
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+ act = getattr(nn, activation)
+ mult = 1
+ model: tp.List[nn.Module] = [
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+ # Downsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base ** j, 1],
+ norm=block_norm, norm_params=norm_params,
+ activation=activation, activation_params=activation_params,
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+ # Add downsampling layers
+ model += [
+ act(**activation_params),
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
+ kernel_size=ratio * 2, stride=ratio,
+ norm=block_norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode),
+ ]
+ mult *= 2
+
+ if lstm:
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+ model += [
+ act(**activation_params),
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+ """SEANet decoder.
+
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ final_activation (str): Final activation function after all convolutions.
+ final_activation_params (dict): Parameters to provide to the activation function.
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple.
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+ For the decoder, it corresponds to the N last blocks.
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+ If equal to 1.0, it means that all the trimming is done at the right.
+ """
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
+ super().__init__()
+ self.dimension = dimension
+ self.channels = channels
+ self.n_filters = n_filters
+ self.ratios = ratios
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+ "Number of blocks for which to disable norm is invalid." \
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+ act = getattr(nn, activation)
+ mult = int(2 ** len(self.ratios))
+ model: tp.List[nn.Module] = [
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+
+ if lstm:
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+ # Upsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
+ # Add upsampling layers
+ model += [
+ act(**activation_params),
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+ kernel_size=ratio * 2, stride=ratio,
+ norm=block_norm, norm_kwargs=norm_params,
+ causal=causal, trim_right_ratio=trim_right_ratio),
+ ]
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base ** j, 1],
+ activation=activation, activation_params=activation_params,
+ norm=block_norm, norm_params=norm_params, causal=causal,
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+ mult //= 2
+
+ # Add final layers
+ model += [
+ act(**activation_params),
+ StreamableConv1d(n_filters, channels, last_kernel_size,
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+ # Add optional final activation to decoder (eg. tanh)
+ if final_activation is not None:
+ final_act = getattr(nn, final_activation)
+ final_activation_params = final_activation_params or {}
+ model += [
+ final_act(**final_activation_params)
+ ]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, z):
+ y = self.model(z)
+ return y
diff --git a/audiocraft/modules/streaming.py b/audiocraft/modules/streaming.py
new file mode 100644
index 0000000000000000000000000000000000000000..fba06936294ca15d72acd2d44f9dbda39a638107
--- /dev/null
+++ b/audiocraft/modules/streaming.py
@@ -0,0 +1,131 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Streaming module API that should be implemented by all Streaming components,
+"""
+
+from contextlib import contextmanager
+import typing as tp
+from torch import nn
+import torch
+
+
+State = tp.Dict[str, torch.Tensor]
+
+
+class StreamingModule(nn.Module):
+ """Common API for streaming components.
+
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
+ By convention, the first dim of each tensor must be the batch size.
+ Don't use dots in the key names, as this would clash with submodules
+ (like in state_dict).
+
+ If `self._is_streaming` is True, the component should use and remember
+ the proper state inside `self._streaming_state`.
+
+ To set a streaming component in streaming state, use
+
+ with module.streaming():
+ ...
+
+ This will automatically reset the streaming state when exiting the context manager.
+ This also automatically propagates to all streaming children module.
+
+ Some module might also implement the `StreamingModule.flush` method, although
+ this one is trickier, as all parents module must be StreamingModule and implement
+ it as well for it to work properly. See `StreamingSequential` after.
+ """
+ def __init__(self) -> None:
+ super().__init__()
+ self._streaming_state: State = {}
+ self._is_streaming = False
+
+ def _apply_named_streaming(self, fn: tp.Any):
+ for name, module in self.named_modules():
+ if isinstance(module, StreamingModule):
+ fn(name, module)
+
+ def _set_streaming(self, streaming: bool):
+ def _set_streaming(name, module):
+ module._is_streaming = streaming
+ self._apply_named_streaming(_set_streaming)
+
+ @contextmanager
+ def streaming(self):
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
+ self._set_streaming(True)
+ try:
+ yield
+ finally:
+ self._set_streaming(False)
+ self.reset_streaming()
+
+ def reset_streaming(self):
+ """Reset the streaming state."""
+ def _reset(name: str, module: StreamingModule):
+ module._streaming_state.clear()
+
+ self._apply_named_streaming(_reset)
+
+ def get_streaming_state(self) -> State:
+ """Return the streaming state, including that of sub-modules."""
+ state: State = {}
+
+ def _add(name: str, module: StreamingModule):
+ if name:
+ name += "."
+ for key, value in module._streaming_state.items():
+ state[name + key] = value
+
+ self._apply_named_streaming(_add)
+ return state
+
+ def set_streaming_state(self, state: State):
+ """Set the streaming state, including that of sub-modules."""
+ state = dict(state)
+
+ def _set(name: str, module: StreamingModule):
+ if name:
+ name += "."
+ module._streaming_state.clear()
+ for key, value in list(state.items()):
+ # complexity is not ideal here, but probably fine.
+ if key.startswith(name):
+ local_key = key[len(name):]
+ if '.' not in local_key:
+ module._streaming_state[local_key] = value
+ del state[key]
+
+ self._apply_named_streaming(_set)
+ assert len(state) == 0, list(state.keys())
+
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
+ """Flush any remaining outputs that were waiting for completion.
+ Typically, for convolutions, this will add the final padding
+ and process the last buffer.
+
+ This should take an optional argument `x`, which will be provided
+ if a module before this one in the streaming pipeline has already
+ spitted out a flushed out buffer.
+ """
+ if x is None:
+ return None
+ else:
+ return self(x)
+
+
+class StreamingSequential(StreamingModule, nn.Sequential):
+ """A streaming compatible alternative of `nn.Sequential`.
+ """
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
+ for module in self:
+ if isinstance(module, StreamingModule):
+ x = module.flush(x)
+ elif x is not None:
+ x = module(x)
+ return x
diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..048c06dfbb0ab4167afce95dffb73dcc343c2344
--- /dev/null
+++ b/audiocraft/modules/transformer.py
@@ -0,0 +1,747 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Transformer model, with streaming support, xformer attention support
+and easy causal attention with a potentially finite receptive field.
+
+See `StreamingTransformer` for more information.
+
+Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
+"""
+
+import typing as tp
+
+from einops import rearrange
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint as torch_checkpoint
+from xformers import ops
+
+from .rope import RotaryEmbedding
+from .streaming import StreamingModule
+
+_efficient_attention_backend: str = 'torch'
+
+
+def set_efficient_attention_backend(backend: str = 'torch'):
+ # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
+ global _efficient_attention_backend
+ assert _efficient_attention_backend in ['xformers', 'torch']
+ _efficient_attention_backend = backend
+
+
+def _get_attention_time_dimension() -> int:
+ if _efficient_attention_backend == 'torch':
+ return 2
+ else:
+ return 1
+
+
+def _is_profiled() -> bool:
+ # Return true if we are currently running with a xformers profiler activated.
+ try:
+ from xformers.profiler import profiler
+ except ImportError:
+ return False
+ return profiler._Profiler._CURRENT_PROFILER is not None
+
+
+def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
+ """Create normalization module for transformer encoder layer.
+
+ Args:
+ norm_type (str): Normalization method.
+ dim (int): Dimension of the normalized layer.
+ **kwargs (dict): Additional parameters for normalization layer.
+ Returns:
+ nn.Module: Normalization module.
+ """
+ if norm_type == 'layer_norm':
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
+ else:
+ raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
+ dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
+
+ Args:
+ positions (torch.Tensor): LongTensor of positions.
+ dim (int): Dimension of the embedding.
+ max_period (float): Maximum period of the cosine/sine functions.
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
+ Returns:
+ torch.Tensor: Sinusoidal positional embedding.
+ """
+ # We aim for BTC format
+ assert dim % 2 == 0
+ half_dim = dim // 2
+ positions = positions.to(dtype)
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
+
+
+def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
+ if n_rep == 1:
+ return x
+ if _efficient_attention_backend == 'torch':
+ bs, n_kv_heads, slen, head_dim = x.shape
+ return (
+ x[:, :, None, :, :]
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
+ )
+ else:
+ bs, slen, n_kv_heads, head_dim = x.shape
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+class LayerScale(nn.Module):
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
+
+ Args:
+ channels (int): Number of channels.
+ init (float): Initial scale.
+ channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
+ device (torch.device or str, optional): Device on which to initialize the module.
+ dtype (torch.dtype, optional): dtype to use to initialize the module.
+ """
+ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
+ device=None, dtype=None):
+ super().__init__()
+ self.channel_last = channel_last
+ self.scale = nn.Parameter(
+ torch.full((channels,), init,
+ requires_grad=True, device=device, dtype=dtype))
+
+ def forward(self, x: torch.Tensor):
+ if self.channel_last:
+ return self.scale * x
+ else:
+ return self.scale[:, None] * x
+
+
+class StreamingMultiheadAttention(StreamingModule):
+ """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
+
+ Args:
+ embed_dim (int): Dimension to project to.
+ num_heads (int): Number of heads.
+ dropout (float): Dropout level.
+ bias (bool): Use bias in projections.
+ causal (bool): Causal mask applied automatically.
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
+ memory_efficient (bool): Use xformers based memory efficient attention.
+ attention_as_float32 (bool): Perform the attention as float32
+ (especially important with memory_efficient as autocast won't do this automatically).
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
+ cross_attention: Should be true when used as a cross attention.
+ All keys and values must be available at once, streaming is only for the queries.
+ Cannot be used with `causal` or `rope` (as it wouldn't make sens to
+ interpret the time steps in the keys relative to those in the queries).
+ safe_streaming (bool): Bug fix, will go away with xformers update.
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+ device (torch.device, optional): Device on which to initialize.
+ dtype (torch.dtype, optional): dtype to use.
+ """
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
+ causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
+ memory_efficient: bool = False, attention_as_float32: bool = False,
+ rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
+ safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
+ device=None, dtype=None):
+ super().__init__()
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ if past_context is not None:
+ assert causal
+
+ self.embed_dim = embed_dim
+ self.causal = causal
+ self.past_context = past_context
+ self.memory_efficient = memory_efficient
+ self.attention_as_float32 = attention_as_float32
+ self.rope = rope
+ self.cross_attention = cross_attention
+ self.safe_streaming = safe_streaming
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.kv_repeat = kv_repeat
+ if cross_attention:
+ assert not causal, "Causal cannot work with cross attention."
+ assert rope is None, "Rope cannot work with cross attention."
+
+ if memory_efficient:
+ _verify_xformers_memory_efficient_compat()
+
+ self.custom = _is_custom(custom, memory_efficient)
+ if self.custom:
+ out_dim = embed_dim
+ assert num_heads % kv_repeat == 0
+ assert not cross_attention or kv_repeat == 1
+ num_kv = num_heads // kv_repeat
+ kv_dim = (embed_dim // num_heads) * num_kv
+ out_dim += 2 * kv_dim
+ in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
+ # We try to follow the default PyTorch MHA convention, to easily compare results.
+ self.in_proj_weight = in_proj.weight
+ self.in_proj_bias = in_proj.bias
+ if bias:
+ self.in_proj_bias.data.zero_() # Following Pytorch convention
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+ if bias:
+ self.out_proj.bias.data.zero_()
+ else:
+ assert not qk_layer_norm
+ assert kv_repeat == 1
+ self.mha = nn.MultiheadAttention(
+ embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
+ **factory_kwargs)
+ self.qk_layer_norm = qk_layer_norm
+ if qk_layer_norm:
+ assert self.custom
+ assert kv_repeat == 1
+ ln_dim = embed_dim
+ self.q_layer_norm = nn.LayerNorm(ln_dim)
+ self.k_layer_norm = nn.LayerNorm(ln_dim)
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ if not self.custom:
+ # Support compat with regular MHA
+ keys = [n for n, _ in self.mha.named_parameters()]
+ for key in keys:
+ if prefix + key in state_dict:
+ state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
+ # Return a causal mask, accounting for potentially stored past keys/values
+ # We actually return a bias for the attention score, as this has the same
+ # convention both in the builtin MHA in Pytorch, and Xformers functions.
+ time_dim = _get_attention_time_dimension()
+ if self.memory_efficient:
+ from xformers.ops import LowerTriangularMask
+ if current_steps == 1:
+ # If we only have one step, then we do not need a mask.
+ return None
+ elif 'past_keys' in self._streaming_state:
+ raise RuntimeError("Not supported at the moment")
+ else:
+ # Then we can safely use a lower triangular mask
+ return LowerTriangularMask()
+ if self._streaming_state:
+ past_keys = self._streaming_state['past_keys']
+ past_steps = past_keys.shape[time_dim]
+ else:
+ past_steps = 0
+
+ queries_pos = torch.arange(
+ past_steps, current_steps + past_steps, device=device).view(-1, 1)
+ keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
+ delta = queries_pos - keys_pos
+ valid = delta >= 0
+ if self.past_context is not None:
+ valid &= (delta <= self.past_context)
+ return torch.where(
+ valid,
+ torch.zeros([], device=device, dtype=dtype),
+ torch.full([], float('-inf'), device=device, dtype=dtype))
+
+ def _complete_kv(self, k, v):
+ time_dim = _get_attention_time_dimension()
+ if self.cross_attention:
+ # With cross attention we assume all keys and values
+ # are already available, and streaming is with respect
+ # to the queries only.
+ return k, v
+ # Complete the key/value pair using the streaming state.
+ if self._streaming_state:
+ pk = self._streaming_state['past_keys']
+ nk = torch.cat([pk, k], dim=time_dim)
+ if v is k:
+ nv = nk
+ else:
+ pv = self._streaming_state['past_values']
+ nv = torch.cat([pv, v], dim=time_dim)
+ else:
+ nk = k
+ nv = v
+
+ assert nk.shape[time_dim] == nv.shape[time_dim]
+ offset = 0
+ if self.past_context is not None:
+ offset = max(0, nk.shape[time_dim] - self.past_context)
+ if self._is_streaming:
+ self._streaming_state['past_keys'] = nk[:, offset:]
+ if v is not k:
+ self._streaming_state['past_values'] = nv[:, offset:]
+ if 'offset' in self._streaming_state:
+ self._streaming_state['offset'] += offset
+ else:
+ self._streaming_state['offset'] = torch.tensor(0)
+ return nk, nv
+
+ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
+ # TODO: fix and verify layout.
+ assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn."
+ # Apply rope embeddings to query and key tensors.
+ assert self.rope is not None
+ if 'past_keys' in self._streaming_state:
+ past_keys_offset = self._streaming_state['past_keys'].shape[1]
+ else:
+ past_keys_offset = 0
+ if 'offset' in self._streaming_state:
+ past_context_offset = int(self._streaming_state['offset'].item())
+ else:
+ past_context_offset = 0
+ streaming_offset = past_context_offset + past_keys_offset
+ return self.rope.rotate_qk(query, key, start=streaming_offset)
+
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
+ key_padding_mask=None, need_weights=False, attn_mask=None,
+ average_attn_weights=True, is_causal=False):
+ assert attn_mask is None
+ assert not is_causal, ("New param added in torch 2.0.1 not supported, "
+ "use the causal args in the constructor.")
+
+ time_dim = _get_attention_time_dimension()
+ if time_dim == 2:
+ layout = "b h t d"
+ else:
+ layout = "b t h d"
+ dtype = query.dtype
+ if self._is_streaming:
+ assert self.causal or self.cross_attention, \
+ "Streaming only available for causal or cross attention"
+
+ if self.causal:
+ # At the moment we specialize only for the self-attention case.
+ assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+ assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+ attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
+
+ if self.custom:
+ # custom implementation
+ assert need_weights is False
+ assert key_padding_mask is None
+ if self.cross_attention:
+ # Different queries, keys, values, we have to spit manually the weights
+ # before applying the linear.
+ dim = self.in_proj_weight.shape[0] // 3
+ if self.in_proj_bias is None:
+ bias_q, bias_k, bias_v = None, None, None
+ else:
+ bias_q = self.in_proj_bias[:dim]
+ bias_k = self.in_proj_bias[dim: 2 * dim]
+ bias_v = self.in_proj_bias[2 * dim:]
+ q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
+ # todo: when streaming, we could actually save k, v and check the shape actually match.
+ k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
+ v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
+ if self.qk_layer_norm is True:
+ q = self.q_layer_norm(q)
+ k = self.k_layer_norm(k)
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
+ else:
+ if not _is_profiled():
+ # profiling breaks that propertysomehow.
+ assert query is key, "specialized implementation"
+ assert value is key, "specialized implementation"
+ projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
+ if self.kv_repeat == 1:
+ if time_dim == 2:
+ bound_layout = "b h p t d"
+ else:
+ bound_layout = "b t p h d"
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
+ q, k, v = ops.unbind(packed, dim=2)
+ else:
+ embed_dim = self.embed_dim
+ per_head_dim = (embed_dim // self.num_heads)
+ kv_heads = self.num_heads // self.kv_repeat
+ q = projected[:, :, :embed_dim]
+ start = embed_dim
+ end = start + per_head_dim * kv_heads
+ k = projected[:, :, start: end]
+ v = projected[:, :, end:]
+ q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
+ k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
+ v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
+
+ if self.qk_layer_norm is True:
+ assert self.kv_repeat == 1
+ q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
+ q = self.q_layer_norm(q)
+ k = self.k_layer_norm(k)
+ q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
+ if self.rope:
+ q, k = self._apply_rope(q, k)
+ k, v = self._complete_kv(k, v)
+ if self.kv_repeat > 1:
+ k = expand_repeated_kv(k, self.kv_repeat)
+ v = expand_repeated_kv(v, self.kv_repeat)
+ if self.attention_as_float32:
+ q, k, v = [x.float() for x in [q, k, v]]
+ if self.memory_efficient:
+ p = self.dropout if self.training else 0
+ if _efficient_attention_backend == 'torch':
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
+ else:
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
+ else:
+ # We include the dot product as float32, for consistency
+ # with the other implementations that include that step
+ # as part of the attention. Note that when using `autocast`,
+ # the einsums would be done as bfloat16, but the softmax
+ # would be done as bfloat16, so `attention_as_float32` will
+ # extend a bit the range of operations done in float32,
+ # although this should make no difference.
+ q = q / q.shape[-1] ** 0.5
+ key_layout = layout.replace('t', 'k')
+ query_layout = layout
+ if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
+ with torch.autocast(device_type=q.device.type, dtype=torch.float32):
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+ else:
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+ if attn_mask is not None:
+ pre_w = pre_w + attn_mask
+ w = torch.softmax(pre_w, dim=-1)
+ w = F.dropout(w, self.dropout, training=self.training).to(v)
+ # Key and value have the same format.
+ x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
+ x = x.to(dtype)
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
+ x = self.out_proj(x)
+ else:
+ key, value = self._complete_kv(key, value)
+ if self.attention_as_float32:
+ query, key, value = [x.float() for x in [query, key, value]]
+ x, _ = self.mha(
+ query, key, value, key_padding_mask,
+ need_weights, attn_mask, average_attn_weights)
+ x = x.to(dtype)
+
+ return x, None
+
+
+class StreamingTransformerLayer(nn.TransformerEncoderLayer):
+ """TransformerLayer with Streaming / Causal support.
+ This also integrates cross_attention, when passing `cross_attention=True`,
+ rather than having two separate classes like in PyTorch.
+
+ Args:
+ d_model (int): Dimension of the data.
+ num_heads (int): Number of heads.
+ dim_feedforward (int): Intermediate dimension of FF module.
+ dropout (float): Dropout both for MHA and FF.
+ bias_ff (bool): Use bias for FF.
+ bias_attn (bool): Use bias for MHA.
+ causal (bool): Causal mask applied automatically.
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
+ memory_efficient (bool): Use xformers based memory efficient attention.
+ attention_as_float32 (bool): Perform the attention as float32
+ (especially important with memory_efficient as autocast won't do this automatically).
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
+ qk_layer_norm_cross (bool): Same for the cross attention.
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
+ Cross attention will use the default MHA, as it typically won't require
+ special treatment.
+ layer_scale (float, optional): If not None, LayerScale will be used with
+ the given value as initial scale.
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
+ attention_dropout (float, optional): If not None, separate the value of the dimension dropout
+ in FFN and of the attention dropout.
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+ device (torch.device, optional): Device on which to initialize.
+ dtype (torch.dtype, optional): dtype to use.
+ **kwargs: See `nn.TransformerEncoderLayer`.
+ """
+ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
+ bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
+ past_context: tp.Optional[int] = None, custom: bool = False,
+ memory_efficient: bool = False, attention_as_float32: bool = False,
+ qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+ rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
+ kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
+ super().__init__(d_model, num_heads, dim_feedforward, dropout,
+ device=device, dtype=dtype, batch_first=True, **kwargs)
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ # Redefine self_attn to our streaming multi-head attention
+ attn_kwargs: tp.Dict[str, tp.Any] = {
+ 'embed_dim': d_model,
+ 'num_heads': num_heads,
+ 'dropout': dropout if attention_dropout is None else attention_dropout,
+ 'bias': bias_attn,
+ 'custom': custom,
+ 'memory_efficient': memory_efficient,
+ 'attention_as_float32': attention_as_float32,
+ }
+ self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
+ causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
+ kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
+ # Redefine feedforward layers to expose bias parameter
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
+
+ self.layer_scale_1: nn.Module
+ self.layer_scale_2: nn.Module
+ if layer_scale is None:
+ self.layer_scale_1 = nn.Identity()
+ self.layer_scale_2 = nn.Identity()
+ else:
+ self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
+ self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
+
+ self.cross_attention: tp.Optional[nn.Module] = None
+ if cross_attention:
+ self.cross_attention = StreamingMultiheadAttention(
+ cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
+ **attn_kwargs, **factory_kwargs)
+ # Norm and dropout
+ self.dropout_cross = nn.Dropout(dropout)
+ # eps value matching that used in PyTorch reference implementation.
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
+ self.layer_scale_cross: nn.Module
+ if layer_scale is None:
+ self.layer_scale_cross = nn.Identity()
+ else:
+ self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
+
+ def _cross_attention_block(self, src: torch.Tensor,
+ cross_attention_src: torch.Tensor) -> torch.Tensor:
+ assert self.cross_attention is not None
+ # queries are from src, keys and values from cross_attention_src.
+ x = self.cross_attention(
+ src, cross_attention_src, cross_attention_src, need_weights=False)[0]
+ return self.dropout_cross(x) # type: ignore
+
+ def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
+ src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+ cross_attention_src: tp.Optional[torch.Tensor] = None):
+ if self.cross_attention is None:
+ assert cross_attention_src is None
+ else:
+ assert cross_attention_src is not None
+ x = src
+ if self.norm_first:
+ x = x + self.layer_scale_1(
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+ if cross_attention_src is not None:
+ x = x + self.layer_scale_cross(
+ self._cross_attention_block(
+ self.norm_cross(x), cross_attention_src))
+ x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+ else:
+ x = self.norm1(x + self.layer_scale_1(
+ self._sa_block(x, src_mask, src_key_padding_mask)))
+ if cross_attention_src is not None:
+ x = self.norm_cross(
+ x + self.layer_scale_cross(
+ self._cross_attention_block(src, cross_attention_src)))
+ x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+ return x
+
+
+class StreamingTransformer(StreamingModule):
+ """Transformer with Streaming / Causal support.
+
+ Args:
+ d_model (int): Dimension of the data.
+ num_heads (int): Number of heads.
+ dim_feedforward (int): Intermediate dimension of FF module.
+ dropout (float): Dropout both for MHA and FF.
+ bias_ff (bool): Use bias for FF.
+ bias_attn (bool): Use bias for MHA.
+ causal (bool): Causal mask applied automatically.
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
+ memory_efficient (bool): Use xformers based memory efficient attention.
+ attention_as_float32 (bool): Perform the attention as float32
+ (especially important with memory_efficient as autocast won't do this automatically).
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
+ layer_scale (float, optional): If not None, LayerScale will be used
+ with the given value as initial scale.
+ positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
+ max_period (float): Maximum period of the time embedding.
+ positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
+ xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
+ lr (float, optional): learning rate override through the `make_optim_group` API.
+ weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
+ layer_class: (subclass of `StreamingTransformerLayer): class to use
+ to initialize the layers, allowing further customization outside of AudioCraft.
+ checkpointing (str): Checkpointing strategy to reduce memory usage.
+ No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
+ if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
+ minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
+ a policy for opting-out some operations of the checkpointing like
+ linear layers and attention, providing a middle ground between speed and memory.
+ device (torch.device, optional): Device on which to initialize.
+ dtype (torch.dtype, optional): dtype to use.
+ **kwargs: See `nn.TransformerEncoderLayer`.
+ """
+ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
+ dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
+ causal: bool = False, past_context: tp.Optional[int] = None,
+ custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+ positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
+ xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
+ layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
+ checkpointing: str = 'none', device=None, dtype=None, **kwargs):
+ super().__init__()
+ assert d_model % num_heads == 0
+
+ self.positional_embedding = positional_embedding
+ self.max_period = max_period
+ self.positional_scale = positional_scale
+ self.weight_decay = weight_decay
+ self.lr = lr
+
+ assert positional_embedding in ['sin', 'rope', 'sin_rope']
+ self.rope: tp.Optional[RotaryEmbedding] = None
+ if self.positional_embedding in ['rope', 'sin_rope']:
+ assert _is_custom(custom, memory_efficient)
+ self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
+ xpos=xpos, scale=positional_scale, device=device)
+
+ self.checkpointing = checkpointing
+
+ assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
+ if self.checkpointing.startswith('xformers'):
+ _verify_xformers_internal_compat()
+
+ self.layers = nn.ModuleList()
+ for idx in range(num_layers):
+ self.layers.append(
+ layer_class(
+ d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
+ dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
+ causal=causal, past_context=past_context, custom=custom,
+ memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
+ cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
+ device=device, dtype=dtype, **kwargs))
+
+ if self.checkpointing != 'none':
+ for layer in self.layers:
+ # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
+ # backward hook inside of FSDP...
+ layer._magma_checkpointed = True # type: ignore
+ assert layer.layer_drop == 0., "Need further checking" # type: ignore
+
+ def _apply_layer(self, layer, *args, **kwargs):
+ method = self.checkpointing
+ if method == 'none':
+ return layer(*args, **kwargs)
+ elif method == 'torch':
+ return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
+ elif method.startswith('xformers'):
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
+ if method == 'xformers_default':
+ # those operations will be saved, and not recomputed.
+ # According to Francisco we can get smarter policies but this is a good start.
+ allow_list = [
+ "xformers.efficient_attention_forward_cutlass.default",
+ "xformers_flash.flash_fwd.default",
+ "aten.addmm.default",
+ "aten.mm.default",
+ ]
+ elif method == 'xformers_mm':
+ # those operations will be saved, and not recomputed.
+ # According to Francisco we can get smarter policies but this is a good start.
+ allow_list = [
+ "aten.addmm.default",
+ "aten.mm.default",
+ ]
+ else:
+ raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
+ policy_fn = _get_default_policy(allow_list)
+ return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
+ else:
+ raise ValueError(f"Checkpointing method {method} is unknown.")
+
+ def forward(self, x: torch.Tensor, *args, **kwargs):
+ B, T, C = x.shape
+
+ if 'offsets' in self._streaming_state:
+ offsets = self._streaming_state['offsets']
+ else:
+ offsets = torch.zeros(B, dtype=torch.long, device=x.device)
+
+ if self.positional_embedding in ['sin', 'sin_rope']:
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
+ positions = positions + offsets.view(-1, 1, 1)
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
+ x = x + self.positional_scale * pos_emb
+
+ for layer in self.layers:
+ x = self._apply_layer(layer, x, *args, **kwargs)
+
+ if self._is_streaming:
+ self._streaming_state['offsets'] = offsets + T
+
+ return x
+
+ def make_optim_group(self):
+ group = {"params": list(self.parameters())}
+ if self.lr is not None:
+ group["lr"] = self.lr
+ if self.weight_decay is not None:
+ group["weight_decay"] = self.weight_decay
+ return group
+
+
+# special attention related function
+
+def _verify_xformers_memory_efficient_compat():
+ try:
+ from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
+ except ImportError:
+ raise ImportError(
+ "xformers is not installed. Please install it and try again.\n"
+ "To install on AWS and Azure, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+ "To install on FAIR Cluster, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _verify_xformers_internal_compat():
+ try:
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
+ except ImportError:
+ raise ImportError(
+ "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
+ "To install on AWS and Azure, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+ "To install on FAIR Cluster, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _is_custom(custom: bool, memory_efficient: bool):
+ return custom or memory_efficient
diff --git a/audiocraft/optim/__init__.py b/audiocraft/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f48c17dfafa9a2be46a91ed1fb64f54c5572a730
--- /dev/null
+++ b/audiocraft/optim/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers
+and Exponential Moving Average.
+"""
+
+# flake8: noqa
+from .cosine_lr_scheduler import CosineLRScheduler
+from .dadam import DAdaptAdam
+from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler
+from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler
+from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler
+from .ema import ModuleDictEMA
diff --git a/audiocraft/optim/__pycache__/__init__.cpython-310.pyc b/audiocraft/optim/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a8fa09fc77f29ff6c89e162cadee7f240b14a13
Binary files /dev/null and b/audiocraft/optim/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/cosine_lr_scheduler.cpython-310.pyc b/audiocraft/optim/__pycache__/cosine_lr_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e51656014e1f8d36f7ba249527b04827d186e9f8
Binary files /dev/null and b/audiocraft/optim/__pycache__/cosine_lr_scheduler.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/dadam.cpython-310.pyc b/audiocraft/optim/__pycache__/dadam.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48bd9b287a9cfb1958f980f2b77f1b497a3dc1dc
Binary files /dev/null and b/audiocraft/optim/__pycache__/dadam.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/ema.cpython-310.pyc b/audiocraft/optim/__pycache__/ema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc3f7601700f72d9f1f7d552b9022c2c9030183c
Binary files /dev/null and b/audiocraft/optim/__pycache__/ema.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/fsdp.cpython-310.pyc b/audiocraft/optim/__pycache__/fsdp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae2738e2f44890001cea399bb1955a48e0409211
Binary files /dev/null and b/audiocraft/optim/__pycache__/fsdp.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/inverse_sqrt_lr_scheduler.cpython-310.pyc b/audiocraft/optim/__pycache__/inverse_sqrt_lr_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..184bfe00927b8e9aecbf587eb11f48287b4c0f5a
Binary files /dev/null and b/audiocraft/optim/__pycache__/inverse_sqrt_lr_scheduler.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/linear_warmup_lr_scheduler.cpython-310.pyc b/audiocraft/optim/__pycache__/linear_warmup_lr_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4854a32da7d1de62b30e4aef33136a0f695e788b
Binary files /dev/null and b/audiocraft/optim/__pycache__/linear_warmup_lr_scheduler.cpython-310.pyc differ
diff --git a/audiocraft/optim/__pycache__/polynomial_decay_lr_scheduler.cpython-310.pyc b/audiocraft/optim/__pycache__/polynomial_decay_lr_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d7e9f2e3020fec85d99725f47c24bcecf210347
Binary files /dev/null and b/audiocraft/optim/__pycache__/polynomial_decay_lr_scheduler.cpython-310.pyc differ
diff --git a/audiocraft/optim/cosine_lr_scheduler.py b/audiocraft/optim/cosine_lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e4f0bbf28f1ad893a301f1bfac1da8e97370337
--- /dev/null
+++ b/audiocraft/optim/cosine_lr_scheduler.py
@@ -0,0 +1,48 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class CosineLRScheduler(_LRScheduler):
+ """Cosine LR scheduler.
+
+ Args:
+ optimizer (Optimizer): Torch optimizer.
+ warmup_steps (int): Number of warmup steps.
+ total_steps (int): Total number of steps.
+ lr_min_ratio (float): Minimum learning rate.
+ cycle_length (float): Cycle length.
+ """
+ def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
+ lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
+ self.warmup_steps = warmup_steps
+ assert self.warmup_steps >= 0
+ self.total_steps = total_steps
+ assert self.total_steps >= 0
+ self.lr_min_ratio = lr_min_ratio
+ self.cycle_length = cycle_length
+ super().__init__(optimizer)
+
+ def _get_sched_lr(self, lr: float, step: int):
+ if step < self.warmup_steps:
+ lr_ratio = step / self.warmup_steps
+ lr = lr_ratio * lr
+ elif step <= self.total_steps:
+ s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
+ lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
+ (1. + math.cos(math.pi * s / self.cycle_length))
+ lr = lr_ratio * lr
+ else:
+ lr_ratio = self.lr_min_ratio
+ lr = lr_ratio * lr
+ return lr
+
+ def get_lr(self):
+ return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
diff --git a/audiocraft/optim/dadam.py b/audiocraft/optim/dadam.py
new file mode 100644
index 0000000000000000000000000000000000000000..a84402f744867610180b9576b2ee3302501fd035
--- /dev/null
+++ b/audiocraft/optim/dadam.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+import torch
+import torch.optim
+import torch.distributed as dist
+
+if TYPE_CHECKING:
+ from torch.optim.optimizer import _params_t
+else:
+ _params_t = Any
+
+
+logger = logging.getLogger(__name__)
+
+
+def to_real(x):
+ if torch.is_complex(x):
+ return x.real
+ else:
+ return x
+
+
+class DAdaptAdam(torch.optim.Optimizer):
+ """Adam with D-Adaptation automatic step-sizes.
+ Leave LR set to 1 unless you encounter instability.
+
+ Args:
+ params (iterable):
+ Iterable of parameters to optimize or dicts defining parameter groups.
+ lr (float):
+ Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
+ betas (tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ momentum (float):
+ Momentum value in the range [0,1) (default: 0.9).
+ eps (float):
+ Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
+ weight_decay (float):
+ Weight decay, i.e. a L2 penalty (default: 0).
+ log_every (int):
+ Log using print every k steps, default 0 (no logging).
+ decouple (boolean):
+ Use AdamW style decoupled weight decay
+ d0 (float):
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
+ growth_rate (float):
+ prevent the D estimate from growing faster than this multiplicative rate.
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
+ rate warmup effect.
+ fsdp_in_use (bool):
+ If you're using sharded parameters, this should be set to True. The optimizer
+ will attempt to auto-detect this, but if you're using an implementation other
+ than PyTorch's builtin version, the auto-detection won't work.
+ """
+ def __init__(self, params, lr=1.0,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ log_every=0,
+ decouple=True,
+ d0=1e-6,
+ growth_rate=float('inf')):
+ if not 0.0 < d0:
+ raise ValueError("Invalid d0 value: {}".format(d0))
+ if not 0.0 < lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 < eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+ if decouple:
+ logger.info("Using decoupled weight decay")
+
+ from .fsdp import is_fsdp_used
+ fsdp_in_use = is_fsdp_used()
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay,
+ d=d0,
+ k=0,
+ gsq_weighted=0.0,
+ log_every=log_every,
+ decouple=decouple,
+ growth_rate=growth_rate,
+ fsdp_in_use=fsdp_in_use)
+
+ super().__init__(params, defaults)
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ return False
+
+ @property
+ def supports_flat_params(self):
+ return True
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ g_sq = 0.0
+ sksq_weighted = 0.0
+ sk_l1 = 0.0
+
+ lr = max(group['lr'] for group in self.param_groups)
+
+ group = self.param_groups[0]
+ gsq_weighted = group['gsq_weighted']
+ d = group['d']
+ dlr = d*lr
+
+ growth_rate = group['growth_rate']
+ decouple = group['decouple']
+ fsdp_in_use = group['fsdp_in_use']
+ log_every = group['log_every']
+
+ beta1, beta2 = group['betas']
+
+ for group in self.param_groups:
+ group_lr = group['lr']
+ decay = group['weight_decay']
+ k = group['k']
+ eps = group['eps']
+
+ if group_lr not in [lr, 0.0]:
+ raise RuntimeError("Setting different lr values in different parameter "
+ "groups is only supported for values of 0")
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ if hasattr(p, "_fsdp_flattened"):
+ fsdp_in_use = True
+ grad = p.grad.data
+
+ # Apply weight decay (coupled variant)
+ if decay != 0 and not decouple:
+ grad.add_(p.data, alpha=decay)
+
+ state = self.state[p]
+
+ # State initialization
+ if 'step' not in state:
+ state['step'] = 0
+ state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(
+ to_real(p.data), memory_format=torch.preserve_format).detach()
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+ grad_grad = to_real(grad * grad.conj())
+
+ # Adam EMA updates
+ if group_lr > 0:
+ exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
+ exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
+
+ denom = exp_avg_sq.sqrt().add_(eps)
+
+ g_sq += grad_grad.div_(denom).sum().item()
+
+ s = state['s']
+ s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
+ sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
+ sk_l1 += s.abs().sum().item()
+
+ ######
+
+ gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
+ d_hat = d
+
+ # if we have not done any progres, return
+ # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
+ if sk_l1 == 0:
+ return loss
+
+ if lr > 0.0:
+ if fsdp_in_use:
+ dist_tensor = torch.zeros(3, device='cuda')
+ dist_tensor[0] = sksq_weighted
+ dist_tensor[1] = gsq_weighted
+ dist_tensor[2] = sk_l1
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
+ global_sksq_weighted = dist_tensor[0]
+ global_gsq_weighted = dist_tensor[1]
+ global_sk_l1 = dist_tensor[2]
+ else:
+ global_sksq_weighted = sksq_weighted
+ global_gsq_weighted = gsq_weighted
+ global_sk_l1 = sk_l1
+
+ d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
+ d = max(d, min(d_hat, d*growth_rate))
+
+ if log_every > 0 and k % log_every == 0:
+ logger.info(
+ f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
+ f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
+ f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
+
+ for group in self.param_groups:
+ group['gsq_weighted'] = gsq_weighted
+ group['d'] = d
+
+ group_lr = group['lr']
+ decay = group['weight_decay']
+ k = group['k']
+ eps = group['eps']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+
+ state = self.state[p]
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+ state['step'] += 1
+
+ denom = exp_avg_sq.sqrt().add_(eps)
+ denom = denom.type(p.type())
+
+ # Apply weight decay (decoupled variant)
+ if decay != 0 and decouple and group_lr > 0:
+ p.data.add_(p.data, alpha=-decay * dlr)
+
+ # Take step
+ p.data.addcdiv_(exp_avg, denom, value=-1)
+
+ group['k'] = k + 1
+
+ return loss
diff --git a/audiocraft/optim/ema.py b/audiocraft/optim/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..4337eaff066a8ca124dca3e3e63ee36e417c055c
--- /dev/null
+++ b/audiocraft/optim/ema.py
@@ -0,0 +1,85 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# ModelEMA implementation is taken from
+# https://github.com/facebookresearch/demucs
+
+from collections import defaultdict
+import typing as tp
+
+import torch
+import torch.nn as nn
+
+
+def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set:
+ names: set = set()
+ for (name, sub_module) in module.named_modules():
+ if name == '':
+ buffer_names = module._non_persistent_buffers_set
+ buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name
+ for buff_name in buffer_names}
+ names.update(buffer_names)
+ else:
+ sub_name = f"{root}.{name}" if len(root) > 0 else name
+ sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name)
+ names.update(sub_buffer_names)
+ return names
+
+
+def _get_named_tensors(module: nn.Module):
+ non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module)
+ named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers()
+ if name not in non_persistent_buffers_set]
+ named_parameters = list(module.named_parameters())
+ return named_parameters + named_buffers
+
+
+class ModuleDictEMA:
+ """Exponential Moving Average over a nn.ModuleDict.
+
+ You can switch to the EMA weights temporarily.
+ """
+ def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
+ unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
+ self.decay = decay
+ self.module_dict = module_dict
+ self.state: dict = defaultdict(dict)
+ self.count = 0
+ self.device = device
+ self.unbias = unbias
+ self._init()
+
+ def _init(self):
+ for module_name, module in self.module_dict.items():
+ for key, val in _get_named_tensors(module):
+ if not val.is_floating_point():
+ continue
+ device = self.device or val.device
+ if key not in self.state[module_name]:
+ self.state[module_name][key] = val.detach().to(device, copy=True)
+
+ def step(self):
+ if self.unbias:
+ self.count = self.count * self.decay + 1
+ w = 1 / self.count
+ else:
+ w = 1 - self.decay
+ for module_name, module in self.module_dict.items():
+ for key, val in _get_named_tensors(module):
+ if not val.is_floating_point():
+ continue
+ device = self.device or val.device
+ self.state[module_name][key].mul_(1 - w)
+ self.state[module_name][key].add_(val.detach().to(device), alpha=w)
+
+ def state_dict(self):
+ return {'state': self.state, 'count': self.count}
+
+ def load_state_dict(self, state):
+ self.count = state['count']
+ for module_name, module in state['state'].items():
+ for key, val in module.items():
+ self.state[module_name][key].copy_(val)
diff --git a/audiocraft/optim/fsdp.py b/audiocraft/optim/fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c1a55b6bf1a33092a021c5cefbbb2ae848918a
--- /dev/null
+++ b/audiocraft/optim/fsdp.py
@@ -0,0 +1,195 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Wrapper around FSDP for more convenient use in the training loops.
+"""
+
+from contextlib import contextmanager
+import typing as tp
+import dora
+import torch
+
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import (
+ MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType)
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+
+
+def is_fsdp_used() -> bool:
+ """Return whether we are using FSDP."""
+ # A bit of a hack but should work from anywhere.
+ if dora.is_xp():
+ cfg = dora.get_xp().cfg
+ if hasattr(cfg, 'fsdp'):
+ return cfg.fsdp.use
+ return False
+
+
+def is_sharded_tensor(x: tp.Any) -> bool:
+ return isinstance(x, ShardedTensor)
+
+
+@contextmanager
+def switch_to_full_state_dict(models: tp.List[FSDP]):
+ # Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
+ # so let's do thing manually.
+ for model in models:
+ FSDP.set_state_dict_type( # type: ignore
+ model, StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
+ try:
+ yield
+ finally:
+ for model in models:
+ FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore
+
+
+def wrap_with_fsdp(cfg, model: torch.nn.Module,
+ block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
+ """Wraps a model with FSDP."""
+ # Some of the typing is disabled until this gets integrated
+ # into the stable version of PyTorch.
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore
+
+ # we import this here to prevent circular import.
+ from ..modules.transformer import StreamingTransformerLayer
+ from ..modules.conditioners import ConditioningProvider
+
+ _fix_post_backward_hook()
+
+ assert cfg.use
+ sharding_strategy_dict = {
+ "no_shard": ShardingStrategy.NO_SHARD,
+ "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
+ "full_shard": ShardingStrategy.FULL_SHARD,
+ }
+
+ dtype_dict = {
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ }
+
+ mixed_precision_config = MixedPrecision(
+ param_dtype=dtype_dict[cfg.param_dtype],
+ reduce_dtype=dtype_dict[cfg.reduce_dtype],
+ buffer_dtype=dtype_dict[cfg.buffer_dtype],
+ )
+
+ sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
+ # The following is going to require being a bit smart
+ # when doing LM, because this would flush the weights for every time step
+ # during generation. One possiblity is to use hybrid sharding:
+ # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
+ assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
+ "Not supported at the moment, requires a bit more work."
+
+ local_rank = dora.distrib.get_distrib_spec().local_rank
+ assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
+
+ auto_wrap_policy = None
+ if block_classes is None:
+ block_classes = {StreamingTransformerLayer, ConditioningProvider}
+ if cfg.per_block:
+ auto_wrap_policy = ModuleWrapPolicy(block_classes)
+ wrapped = _FSDPFixStateDict(
+ model,
+ sharding_strategy=sharding_strategy_config,
+ mixed_precision=mixed_precision_config,
+ device_id=local_rank,
+ sync_module_states=True,
+ use_orig_params=True,
+ auto_wrap_policy=auto_wrap_policy,
+ ) # type: ignore
+ FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore
+
+ # Let the wrapped model know about the wrapping!
+ # We use __dict__ to avoid it going into the state dict.
+ # This is a bit dirty, but needed during generation, as otherwise
+ # the wrapped model would call itself and bypass FSDP.
+ for module in FSDP.fsdp_modules(wrapped):
+ original = module._fsdp_wrapped_module
+ original.__dict__['_fsdp'] = module
+ return wrapped
+
+
+def purge_fsdp(model: FSDP):
+ """Purge the FSDP cached shard inside the model. This should
+ allow setting the best state or switching to the EMA.
+ """
+ from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore
+ for module in FSDP.fsdp_modules(model):
+ handles = module._handles
+ if not handles:
+ continue
+ handle = handles[0]
+ unsharded_flat_param = handle._get_padded_unsharded_flat_param()
+ storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
+ if storage_size == 0:
+ continue
+ true_list = [True for h in handles]
+ _reshard(module, handles, true_list)
+
+
+class _FSDPFixStateDict(FSDP):
+ @staticmethod
+ def _name_without_fsdp_prefix(name: str) -> str:
+ from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore
+ parts = name.split('.')
+ new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
+ return '.'.join(new_parts)
+
+ def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore
+ state = dict(super().state_dict())
+ for key, value in list(state.items()):
+ if is_sharded_tensor(value):
+ del state[key]
+ return state
+
+ def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore
+ if self._state_dict_type is StateDictType.FULL_STATE_DICT:
+ super().load_state_dict(state)
+ purge_fsdp(self)
+ return
+ # Fix FSDP load state dict in all situation.
+ # Use this only with LOCAL_STATE_DICT !!!
+ current_state = dict(super().state_dict())
+ for key, value in state.items():
+ key = _FSDPFixStateDict._name_without_fsdp_prefix(key)
+ if key not in current_state:
+ # Emulate strict loading manually.
+ raise RuntimeError(f"Unknown state key {key}")
+ current_state[key].copy_(value)
+
+ # Purging cached weights from previous forward.
+ purge_fsdp(self)
+
+
+_hook_fixed = False
+
+
+def _fix_post_backward_hook():
+ global _hook_fixed
+ if _hook_fixed:
+ return
+ _hook_fixed = True
+
+ from torch.distributed.fsdp import _runtime_utils
+ from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState
+ old_hook = _runtime_utils._post_backward_hook
+
+ def _post_backward_hook(state, handle, *args, **kwargs):
+ checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False)
+ if checkpointed:
+ # there will be one more forward in the backward with checkpointing and that will
+ # massively confuse FSDP, so we have to make it think everything
+ # is going according to the plan.
+ state.training_state = TrainingState.FORWARD_BACKWARD
+ handle._training_state = HandleTrainingState.BACKWARD_PRE
+ old_hook(state, handle, *args, **kwargs)
+
+ _runtime_utils._post_backward_hook = _post_backward_hook
diff --git a/audiocraft/optim/inverse_sqrt_lr_scheduler.py b/audiocraft/optim/inverse_sqrt_lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..920192e8842c5635bf6f7f76618fa4a6f4b0114a
--- /dev/null
+++ b/audiocraft/optim/inverse_sqrt_lr_scheduler.py
@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class InverseSquareRootLRScheduler(_LRScheduler):
+ """Inverse square root LR scheduler.
+
+ Args:
+ optimizer (Optimizer): Torch optimizer.
+ warmup_steps (int): Number of warmup steps.
+ warmup_init_lr (tp.Optional[float]): Initial learning rate
+ during warmup phase. When not set, use the provided learning rate.
+ """
+ def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
+ self.warmup_steps = warmup_steps
+ self.warmup_init_lr = warmup_init_lr
+ super().__init__(optimizer)
+
+ def _get_sched_lr(self, lr: float, step: int):
+ if step < self.warmup_steps:
+ warmup_init_lr = self.warmup_init_lr or 0
+ lr_step = (lr - warmup_init_lr) / self.warmup_steps
+ lr = warmup_init_lr + step * lr_step
+ else:
+ decay_factor = lr * self.warmup_steps**0.5
+ lr = decay_factor * step**-0.5
+ return lr
+
+ def get_lr(self):
+ return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
diff --git a/audiocraft/optim/linear_warmup_lr_scheduler.py b/audiocraft/optim/linear_warmup_lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..03274a1ae52b6f20473973b77619f34b2bddd6a1
--- /dev/null
+++ b/audiocraft/optim/linear_warmup_lr_scheduler.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class LinearWarmupLRScheduler(_LRScheduler):
+ """Inverse square root LR scheduler.
+
+ Args:
+ optimizer (Optimizer): Torch optimizer.
+ warmup_steps (int): Number of warmup steps.
+ warmup_init_lr (tp.Optional[float]): Initial learning rate
+ during warmup phase. When not set, use the provided learning rate.
+ """
+ def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
+ self.warmup_steps = warmup_steps
+ self.warmup_init_lr = warmup_init_lr
+ super().__init__(optimizer)
+
+ def _get_sched_lr(self, lr: float, step: int):
+ if step < self.warmup_steps:
+ warmup_init_lr = self.warmup_init_lr or 0
+ lr_step = (lr - warmup_init_lr) / self.warmup_steps
+ lr = warmup_init_lr + step * lr_step
+ return lr
+
+ def get_lr(self):
+ return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
diff --git a/audiocraft/optim/polynomial_decay_lr_scheduler.py b/audiocraft/optim/polynomial_decay_lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5ea30b094538269dbb0055ab3163f84d1cf6e90
--- /dev/null
+++ b/audiocraft/optim/polynomial_decay_lr_scheduler.py
@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class PolynomialDecayLRScheduler(_LRScheduler):
+ """Polynomial decay LR scheduler.
+
+ Args:
+ optimizer (Optimizer): Torch optimizer.
+ warmup_steps (int): Number of warmup steps.
+ total_steps (int): Total number of steps.
+ end_lr (float): Final learning rate to achieve over total number of steps.
+ zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
+ power (float): Decay exponent.
+ """
+ def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
+ end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
+ self.warmup_steps = warmup_steps
+ self.total_steps = total_steps
+ self.end_lr = end_lr
+ self.zero_lr_warmup_steps = zero_lr_warmup_steps
+ self.power = power
+ super().__init__(optimizer)
+
+ def _get_sched_lr(self, lr: float, step: int):
+ if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
+ lr = 0
+ elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
+ lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
+ lr = lr_ratio * lr
+ elif step >= self.total_steps:
+ lr = self.end_lr
+ else:
+ total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
+ lr_range = lr - self.end_lr
+ pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
+ lr = lr_range * pct_remaining ** self.power + self.end_lr
+ return lr
+
+ def get_lr(self):
+ return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
diff --git a/audiocraft/py.typed b/audiocraft/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audiocraft/quantization/__init__.py b/audiocraft/quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e0c7e429ab96d67be667e23bf7a0ffa389c036b
--- /dev/null
+++ b/audiocraft/quantization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""RVQ."""
+# flake8: noqa
+from .vq import ResidualVectorQuantizer
+from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
diff --git a/audiocraft/quantization/__pycache__/__init__.cpython-310.pyc b/audiocraft/quantization/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f649051a17b492e4b558d31a65fe1738f28eb37f
Binary files /dev/null and b/audiocraft/quantization/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/quantization/__pycache__/base.cpython-310.pyc b/audiocraft/quantization/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c9f10ee3741cc4f4d02ee026d37ec4e1f8009ab
Binary files /dev/null and b/audiocraft/quantization/__pycache__/base.cpython-310.pyc differ
diff --git a/audiocraft/quantization/__pycache__/core_vq.cpython-310.pyc b/audiocraft/quantization/__pycache__/core_vq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7967963440d6c92dcb4ef38bc5cfaa57ddbbcd8
Binary files /dev/null and b/audiocraft/quantization/__pycache__/core_vq.cpython-310.pyc differ
diff --git a/audiocraft/quantization/__pycache__/vq.cpython-310.pyc b/audiocraft/quantization/__pycache__/vq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68425f7de959173edc618cc44789ee5d29cecbc5
Binary files /dev/null and b/audiocraft/quantization/__pycache__/vq.cpython-310.pyc differ
diff --git a/audiocraft/quantization/base.py b/audiocraft/quantization/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a77fefb98e62a5bbc6385910261ffdde2ffa5a25
--- /dev/null
+++ b/audiocraft/quantization/base.py
@@ -0,0 +1,99 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Base class for all quantizers.
+"""
+
+from dataclasses import dataclass, field
+import typing as tp
+
+import torch
+from torch import nn
+
+
+@dataclass
+class QuantizedResult:
+ x: torch.Tensor
+ codes: torch.Tensor
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
+ penalty: tp.Optional[torch.Tensor] = None
+ metrics: dict = field(default_factory=dict)
+
+
+class BaseQuantizer(nn.Module):
+ """Base class for quantizers.
+ """
+
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+ """
+ Given input tensor x, returns first the quantized (or approximately quantized)
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
+ Finally, this returns a dict of metrics to update logging etc.
+ Frame rate must be passed so that the bandwidth is properly computed.
+ """
+ raise NotImplementedError()
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
+ raise NotImplementedError()
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation."""
+ raise NotImplementedError()
+
+ @property
+ def total_codebooks(self):
+ """Total number of codebooks."""
+ raise NotImplementedError()
+
+ @property
+ def num_codebooks(self):
+ """Number of active codebooks."""
+ raise NotImplementedError()
+
+ def set_num_codebooks(self, n: int):
+ """Set the number of active codebooks."""
+ raise NotImplementedError()
+
+
+class DummyQuantizer(BaseQuantizer):
+ """Fake quantizer that actually does not perform any quantization.
+ """
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: torch.Tensor, frame_rate: int):
+ q = x.unsqueeze(1)
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
+ In the case of the DummyQuantizer, the codes are actually identical
+ to the input and resulting quantized representation as no quantization is done.
+ """
+ return x.unsqueeze(1)
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ In the case of the DummyQuantizer, the codes are actually identical
+ to the input and resulting quantized representation as no quantization is done.
+ """
+ return codes.squeeze(1)
+
+ @property
+ def total_codebooks(self):
+ """Total number of codebooks."""
+ return 1
+
+ @property
+ def num_codebooks(self):
+ """Total number of codebooks."""
+ return self.total_codebooks
+
+ def set_num_codebooks(self, n: int):
+ """Set the number of active codebooks."""
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..da02a6ce3a7de15353f0fba9e826052beb67c436
--- /dev/null
+++ b/audiocraft/quantization/core_vq.py
@@ -0,0 +1,400 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from einops import rearrange, repeat
+import flashy
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+
+def exists(val: tp.Optional[tp.Any]) -> bool:
+ return val is not None
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+ return val if exists(val) else d
+
+
+def l2norm(t):
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay: float):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+ t = torch.empty(shape)
+ nn.init.kaiming_uniform_(t)
+ return t
+
+
+def sample_vectors(samples, num: int):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+ dim, dtype = samples.shape[-1], samples.dtype
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
+ means, "c d -> () c d"
+ )
+ dists = -(diffs ** 2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+def orthogonal_loss_fn(t):
+ # eq (2) from https://arxiv.org/abs/2112.00384
+ n = t.shape[0]
+ normed_codes = l2norm(t)
+ identity = torch.eye(n, device=t.device)
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
+
+
+class EuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance.
+
+ Args:
+ dim (int): Dimension.
+ codebook_size (int): Codebook size.
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+ If set to true, run the k-means algorithm on the first training batch and use
+ the learned centroids as initialization.
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ kmeans_init: int = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.8,
+ epsilon: float = 1e-5,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.decay = decay
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+ embed = init_fn(codebook_size, dim)
+
+ self.codebook_size = codebook_size
+
+ self.kmeans_iters = kmeans_iters
+ self.epsilon = epsilon
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ @torch.jit.ignore
+ def init_embed_(self, data):
+ if self.inited:
+ return
+
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embed.data.copy_(embed)
+ self.embed_avg.data.copy_(embed.clone())
+ self.cluster_size.data.copy_(cluster_size)
+ self.inited.data.copy_(torch.Tensor([True]))
+ # Make sure all buffers across workers are in sync after initialization
+ flashy.distrib.broadcast_tensors(self.buffers())
+
+ def replace_(self, samples, mask):
+ modified_codebook = torch.where(
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ )
+ self.embed.data.copy_(modified_codebook)
+
+ def expire_codes_(self, batch_samples):
+ if self.threshold_ema_dead_code == 0:
+ return
+
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
+ if not torch.any(expired_codes):
+ return
+
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self.replace_(batch_samples, mask=expired_codes)
+ flashy.distrib.broadcast_tensors(self.buffers())
+
+ def preprocess(self, x):
+ x = rearrange(x, "... d -> (...) d")
+ return x
+
+ def quantize(self, x):
+ embed = self.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ embed_ind = dist.max(dim=-1).indices
+ return embed_ind
+
+ def postprocess_emb(self, embed_ind, shape):
+ return embed_ind.view(*shape[:-1])
+
+ def dequantize(self, embed_ind):
+ quantize = F.embedding(embed_ind, self.embed)
+ return quantize
+
+ def encode(self, x):
+ shape = x.shape
+ # pre-process
+ x = self.preprocess(x)
+ # quantize
+ embed_ind = self.quantize(x)
+ # post-process
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ return embed_ind
+
+ def decode(self, embed_ind):
+ quantize = self.dequantize(embed_ind)
+ return quantize
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ x = self.preprocess(x)
+ self.init_embed_(x)
+
+ embed_ind = self.quantize(x)
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ quantize = self.dequantize(embed_ind)
+
+ if self.training:
+ # We do the expiry of code at that point as buffers are in sync
+ # and all the workers will take the same decision.
+ self.expire_codes_(x)
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+ embed_sum = x.t() @ embed_onehot
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+ cluster_size = (
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+ * self.cluster_size.sum()
+ )
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+ self.embed.data.copy_(embed_normalized)
+
+ return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+ """Vector quantization implementation.
+ Currently supports only euclidean distance.
+
+ Args:
+ dim (int): Dimension
+ codebook_size (int): Codebook size
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int):
+ channels_last (bool): Channels are the last dimension in the input tensors.
+ commitment_weight (float): Weight for commitment loss.
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
+ for orthogonal regularization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ codebook_dim: tp.Optional[int] = None,
+ decay: float = 0.8,
+ epsilon: float = 1e-5,
+ kmeans_init: bool = False,
+ kmeans_iters: int = 10,
+ threshold_ema_dead_code: int = 2,
+ channels_last: bool = False,
+ commitment_weight: float = 1.,
+ orthogonal_reg_weight: float = 0.0,
+ orthogonal_reg_active_codes_only: bool = False,
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
+ ):
+ super().__init__()
+ _codebook_dim: int = default(codebook_dim, dim)
+
+ requires_projection = _codebook_dim != dim
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+ self.epsilon = epsilon
+ self.commitment_weight = commitment_weight
+
+ self.orthogonal_reg_weight = orthogonal_reg_weight
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+ decay=decay, epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code)
+ self.codebook_size = codebook_size
+
+ self.channels_last = channels_last
+
+ @property
+ def codebook(self):
+ return self._codebook.embed
+
+ @property
+ def inited(self):
+ return self._codebook.inited
+
+ def _preprocess(self, x):
+ if not self.channels_last:
+ x = rearrange(x, "b d n -> b n d")
+ return x
+
+ def _postprocess(self, quantize):
+ if not self.channels_last:
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize
+
+ def encode(self, x):
+ x = self._preprocess(x)
+ x = self.project_in(x)
+ embed_in = self._codebook.encode(x)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self._codebook.decode(embed_ind)
+ quantize = self.project_out(quantize)
+ quantize = self._postprocess(quantize)
+ return quantize
+
+ def forward(self, x):
+ device = x.device
+ x = self._preprocess(x)
+
+ x = self.project_in(x)
+ quantize, embed_ind = self._codebook(x)
+
+ if self.training:
+ quantize = x + (quantize - x).detach()
+
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+ if self.training:
+ if self.commitment_weight > 0:
+ commit_loss = F.mse_loss(quantize.detach(), x)
+ loss = loss + commit_loss * self.commitment_weight
+
+ if self.orthogonal_reg_weight > 0:
+ codebook = self.codebook
+
+ if self.orthogonal_reg_active_codes_only:
+ # only calculate orthogonal loss for the activated codes for this batch
+ unique_code_ids = torch.unique(embed_ind)
+ codebook = codebook[unique_code_ids]
+
+ num_codes = codebook.shape[0]
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+ codebook = codebook[rand_ids]
+
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+ quantize = self.project_out(quantize)
+ quantize = self._postprocess(quantize)
+
+ return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+ """Residual vector quantization implementation.
+
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+ def __init__(self, *, num_quantizers, **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+ )
+
+ def forward(self, x, n_q: tp.Optional[int] = None):
+ quantized_out = 0.0
+ residual = x
+
+ all_losses = []
+ all_indices = []
+
+ n_q = n_q or len(self.layers)
+
+ for i, layer in enumerate(self.layers[:n_q]):
+ quantized, indices, loss = layer(residual)
+ residual = residual - quantized
+ quantized_out = quantized_out + quantized
+ all_indices.append(indices)
+ all_losses.append(loss)
+
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, out_indices, out_losses
+
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+ residual = x
+ all_indices = []
+ n_q = n_q or len(self.layers)
+ for layer in self.layers[:n_q]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
+ for i, indices in enumerate(q_indices):
+ layer = self.layers[i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
diff --git a/audiocraft/quantization/vq.py b/audiocraft/quantization/vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa57bea59db95ddae35e0657f723ca3a29ee943b
--- /dev/null
+++ b/audiocraft/quantization/vq.py
@@ -0,0 +1,115 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import torch
+
+from .base import BaseQuantizer, QuantizedResult
+from .core_vq import ResidualVectorQuantization
+
+
+class ResidualVectorQuantizer(BaseQuantizer):
+ """Residual Vector Quantizer.
+
+ Args:
+ dimension (int): Dimension of the codebooks.
+ n_q (int): Number of residual vector quantizers used.
+ q_dropout (bool): Random quantizer drop out at train time.
+ bins (int): Codebook size.
+ decay (float): Decay for exponential moving average over the codebooks.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
+ for orthogonal regularization.
+ """
+ def __init__(
+ self,
+ dimension: int = 256,
+ n_q: int = 8,
+ q_dropout: bool = False,
+ bins: int = 1024,
+ decay: float = 0.99,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 10,
+ threshold_ema_dead_code: int = 2,
+ orthogonal_reg_weight: float = 0.0,
+ orthogonal_reg_active_codes_only: bool = False,
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
+ ):
+ super().__init__()
+ self.max_n_q = n_q
+ self.n_q = n_q
+ self.q_dropout = q_dropout
+ self.dimension = dimension
+ self.bins = bins
+ self.decay = decay
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+ self.orthogonal_reg_weight = orthogonal_reg_weight
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+ self.vq = ResidualVectorQuantization(
+ dim=self.dimension,
+ codebook_size=self.bins,
+ num_quantizers=self.n_q,
+ decay=self.decay,
+ kmeans_init=self.kmeans_init,
+ kmeans_iters=self.kmeans_iters,
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
+ channels_last=False
+ )
+
+ def forward(self, x: torch.Tensor, frame_rate: int):
+ n_q = self.n_q
+ if self.training and self.q_dropout:
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
+ bw_per_q = math.log2(self.bins) * frame_rate / 1000
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+ codes = codes.transpose(0, 1)
+ # codes is [B, K, T], with T frames, K nb of codebooks.
+ bw = torch.tensor(n_q * bw_per_q).to(x)
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
+ The RVQ encode method sets the appropriate number of quantizer to use
+ and returns indices for each quantizer.
+ """
+ n_q = self.n_q
+ codes = self.vq.encode(x, n_q=n_q)
+ codes = codes.transpose(0, 1)
+ # codes is [B, K, T], with T frames, K nb of codebooks.
+ return codes
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation."""
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+ codes = codes.transpose(0, 1)
+ quantized = self.vq.decode(codes)
+ return quantized
+
+ @property
+ def total_codebooks(self):
+ return self.max_n_q
+
+ @property
+ def num_codebooks(self):
+ return self.n_q
+
+ def set_num_codebooks(self, n: int):
+ assert n > 0 and n <= self.max_n_q
+ self.n_q = n
diff --git a/audiocraft/solvers/__init__.py b/audiocraft/solvers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae19f3a8c51abf469697d6affa91449d668716ba
--- /dev/null
+++ b/audiocraft/solvers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Solvers. A Solver is a training recipe, combining the dataloaders, models,
+optimizer, losses etc into a single convenient object.
+"""
+
+# flake8: noqa
+from .audiogen import AudioGenSolver
+from .builders import get_solver
+from .base import StandardSolver
+from .compression import CompressionSolver
+from .musicgen import MusicGenSolver
+from .diffusion import DiffusionSolver
diff --git a/audiocraft/solvers/__pycache__/__init__.cpython-310.pyc b/audiocraft/solvers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e203e6270d827c17050af24e5cb98687478c9713
Binary files /dev/null and b/audiocraft/solvers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/solvers/__pycache__/audiogen.cpython-310.pyc b/audiocraft/solvers/__pycache__/audiogen.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37ac9e25b5f35d849e9a5418dfecc6e78834c813
Binary files /dev/null and b/audiocraft/solvers/__pycache__/audiogen.cpython-310.pyc differ
diff --git a/audiocraft/solvers/__pycache__/base.cpython-310.pyc b/audiocraft/solvers/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e029719d3a11a74d71343904ec1ea00252ed79f
Binary files /dev/null and b/audiocraft/solvers/__pycache__/base.cpython-310.pyc differ
diff --git a/audiocraft/solvers/__pycache__/builders.cpython-310.pyc b/audiocraft/solvers/__pycache__/builders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56185350e8455eb9483574b8eec3e51d9dadd88d
Binary files /dev/null and b/audiocraft/solvers/__pycache__/builders.cpython-310.pyc differ
diff --git a/audiocraft/solvers/__pycache__/compression.cpython-310.pyc b/audiocraft/solvers/__pycache__/compression.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d44696a5bddbac5c4690afc0ab795b21fabbd09
Binary files /dev/null and b/audiocraft/solvers/__pycache__/compression.cpython-310.pyc differ
diff --git a/audiocraft/solvers/__pycache__/diffusion.cpython-310.pyc b/audiocraft/solvers/__pycache__/diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91983e732777e62c6bcf90623a85b7eee1499fe6
Binary files /dev/null and b/audiocraft/solvers/__pycache__/diffusion.cpython-310.pyc differ
diff --git a/audiocraft/solvers/__pycache__/musicgen.cpython-310.pyc b/audiocraft/solvers/__pycache__/musicgen.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de0840f985b2d5f5685d7cf03fba950781437ee8
Binary files /dev/null and b/audiocraft/solvers/__pycache__/musicgen.cpython-310.pyc differ
diff --git a/audiocraft/solvers/audiogen.py b/audiocraft/solvers/audiogen.py
new file mode 100644
index 0000000000000000000000000000000000000000..1568f97fe7b84b90c7ef760ef5606fe0a475545a
--- /dev/null
+++ b/audiocraft/solvers/audiogen.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import builders, musicgen
+
+
+class AudioGenSolver(musicgen.MusicGenSolver):
+ """Solver for AudioGen re-implementation training task.
+
+ Note that this implementation does not strictly follows
+ the method proposed in https://arxiv.org/abs/2209.15352
+ but is derived from MusicGen's training pipeline.
+
+ More information can be found in the AudioGen model card.
+ """
+ DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
diff --git a/audiocraft/solvers/base.py b/audiocraft/solvers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..0432e44a36838c5731711f9d54f81822b21f20bd
--- /dev/null
+++ b/audiocraft/solvers/base.py
@@ -0,0 +1,631 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+import typing as tp
+
+import flashy
+import omegaconf
+import torch
+from torch import nn
+
+from .. import optim
+from ..optim import fsdp
+from ..utils import checkpoint
+from ..utils.autocast import TorchAutocast
+from ..utils.best_state import BestStateDictManager
+from ..utils.deadlock import DeadlockDetect
+from ..utils.profiler import Profiler
+from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng
+
+
+class StandardSolver(ABC, flashy.BaseSolver):
+ """Standard solver for AudioCraft.
+
+ The standard solver implements a base training loop with the following stages:
+ train, valid, evaluate and generate that are expected to be all defined for
+ solvers in AudioCraft. It also provides a nice default management of Dora history replay,
+ checkpoint management across epoch, and logging configuration.
+
+ AudioCraft solvers must inherit from the StandardSolver and define the methods
+ associated to each stage as well as the show, build_model and build_dataloaders methods.
+ """
+ def __init__(self, cfg: omegaconf.DictConfig):
+ super().__init__()
+ self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
+ self.logger.info(f"All XP logs are stored in {self.xp.folder}")
+ self.cfg = cfg
+ self.device = cfg.device
+ self.model: nn.Module
+ self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
+ self._fsdp_modules: tp.List[fsdp.FSDP] = []
+ self._ema_sources: nn.ModuleDict = nn.ModuleDict()
+ self.ema: tp.Optional[optim.ModuleDictEMA] = None
+ self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
+ self._log_updates = self.cfg.logging.get('log_updates', 10)
+ if self.cfg.logging.log_tensorboard:
+ self.init_tensorboard(**self.cfg.get('tensorboard'))
+ if self.cfg.logging.log_wandb and self:
+ self.init_wandb(**self.cfg.get('wandb'))
+ # keep a copy of the best performing state for stateful objects
+ # used for evaluation and generation stages
+ dtype_best: tp.Optional[torch.dtype] = None
+ if self.cfg.fsdp.use:
+ dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore
+ assert isinstance(dtype_best, torch.dtype)
+ elif self.cfg.autocast:
+ dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore
+ assert isinstance(dtype_best, torch.dtype)
+ self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
+ # Hacky support for keeping a copy of the full best state in rank0.
+ self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
+ self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict
+ self._new_best_state: bool = False # should save a new checkpoint
+ # instantiate datasets and appropriate number of updates per epoch
+ self.build_dataloaders()
+ if self.cfg.execute_only is None:
+ assert 'train' in self.dataloaders, "The train dataset split must be provided."
+ assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
+ self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
+ if self.cfg.optim.updates_per_epoch:
+ self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
+ self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
+ # instantiate model & exponential moving average on the model
+ self.build_model()
+ self.logger.info("Model hash: %s", model_hash(self.model))
+ assert 'model' in self.stateful.sources, \
+ "Please register the model to stateful with self.register_stateful('model') in build_model."
+ self.profiler = Profiler(self.model, **self.cfg.profiler)
+ self.initialize_ema()
+ self.register_stateful('ema')
+ assert self.ema is None or 'ema' in self.stateful.sources, \
+ "Please register the ema to stateful with self.register_stateful('ema') in build_model."
+ self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
+ # basic statistics on the trained model
+ model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
+ # one copy of grad, one copy of momentum, one copy of denominator and model weights.
+ # and 4 bytes for each float!
+ mem_usage = model_size * 4 * 4 / 1000
+ self.logger.info("Model size: %.2f M params", model_size)
+ self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)
+
+ @property
+ def autocast(self):
+ """Convenient autocast (or not) using the solver configuration."""
+ return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
+
+ def _get_state_source(self, name) -> flashy.state.StateDictSource:
+ # Internal utility to get a state source from the solver
+ return self.stateful.sources[name]
+
+ @property
+ def best_metric_name(self) -> tp.Optional[str]:
+ """Metric name used to identify the best state. This metric should be stored in the metrics
+ used on the stage for best state identification (most likely, `valid`). If None, then
+ no best state is saved.
+ """
+ return None
+
+ def register_best_state(self, *args: str):
+ """Register state sources in `BestStateDictManager` to keep their best states along with their
+ latest states. The best state will be used at evaluation stages instead of the latest states.
+
+ Shortcut around `BestStateDictManager.register` method. You can pass any number of
+ attribute, included nested attributes and those will be included into the checkpoints
+ and automatically restored when `BaseSolver.restore` is called.
+ """
+ for name in args:
+ state_source = self._get_state_source(name)
+ assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
+ self.best_state.register(name, state_source)
+
+ def register_ema(self, *args: str):
+ """Register state sources for exponential moving average.
+
+ The registered sources are used to instantiate a ModuleDictEMA instance.
+ The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
+ and swapped with the original state sources with self.swap_ema_state() method.
+
+ Usage:
+ self.register_ema('model')
+ """
+ assert self.ema is None, "Cannot register state source to already instantiated EMA."
+ for name in args:
+ self._ema_sources[name] = getattr(self, name)
+
+ def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
+ model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
+ if isinstance(model, fsdp.FSDP):
+ self._fsdp_modules.append(model)
+ return model
+
+ def update_best_state_from_stage(self, stage_name: str = 'valid'):
+ """Update latest best state based on pending metrics of a given stage. This method relies
+ on the `BestStateDictManager.update` method to update the best state_dict with latest weights
+ if the registered states happen to match to the best performing setup.
+ """
+ if self.best_metric_name is None:
+ # when no best metric is defined, the last state is always the best
+ self._new_best_state = True
+ self.logger.info("Updating best state with current state.")
+ else:
+ assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
+ assert self.best_metric_name in self._pending_metrics[stage_name], \
+ f"Best metric not found in {stage_name} metrics. Cannot register best state"
+ current_score = self._pending_metrics[stage_name][self.best_metric_name]
+ all_best_metric_scores = [
+ past_metrics[stage_name][self.best_metric_name]
+ for past_metrics in self.history
+ ]
+ all_best_metric_scores.append(current_score)
+ best_score = min(all_best_metric_scores)
+ self._new_best_state = current_score == best_score
+ if self._new_best_state:
+ old_best = min(all_best_metric_scores[:-1] + [float('inf')])
+ self.logger.info(
+ f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
+
+ if self._new_best_state:
+ if self.cfg.fsdp.use:
+ # this will give an empty state dict on all ranks but the rank 0
+ # which will have a copy in memory of the full model.
+ with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+ for name in self.best_state.states.keys():
+ state_source = self._get_state_source(name)
+ self.best_state.update(name, state_source)
+ # we save to a different dict.
+ self.fsdp_best_state.update(self.best_state.state_dict())
+ # We cannot efficiently load fsdp_best_state when using FSDP,
+ # so we have do do a second pass, with the local shards.
+ for name in self.best_state.states.keys():
+ state_source = self._get_state_source(name)
+ self.best_state.update(name, state_source)
+
+ def _load_new_state_dict(self, state_dict: dict) -> dict:
+ old_states = {}
+ for name, new_state in state_dict.items():
+ state_source = self._get_state_source(name)
+ old_states[name] = copy_state(state_source.state_dict())
+ state_source.load_state_dict(new_state)
+ return old_states
+
+ @contextmanager
+ def swap_best_state(self):
+ self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
+ old_states = self._load_new_state_dict(self.best_state.state_dict())
+ try:
+ yield
+ finally:
+ self.logger.debug("Swapping back from best to original state")
+ for name, old_state in old_states.items():
+ state_source = self._get_state_source(name)
+ state_source.load_state_dict(old_state)
+
+ @contextmanager
+ def swap_ema_state(self):
+ if self.ema is None:
+ yield
+ else:
+ ema_state_dict = self.ema.state_dict()['state']
+ self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
+ old_states = self._load_new_state_dict(ema_state_dict)
+ try:
+ yield
+ finally:
+ self.logger.debug("Swapping back from EMA state to original state")
+ for name, old_state in old_states.items():
+ state_source = self._get_state_source(name)
+ state_source.load_state_dict(old_state)
+
+ @property
+ def is_training(self):
+ return self.current_stage == 'train'
+
+ def log_model_summary(self, model: nn.Module):
+ """Log model summary, architecture and size of the model."""
+ self.logger.info(model)
+ mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
+ self.logger.info("Size: %.1f MB", mb)
+
+ @abstractmethod
+ def build_model(self):
+ """Method to implement to initialize model."""
+ ...
+
+ def initialize_ema(self):
+ """Initialize exponential moving average with the registered sources.
+ EMA object is created if the optim.ema.model.decay value is non-null.
+ """
+ from .builders import get_ema
+ self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
+ if self.ema is None:
+ self.logger.info('No EMA on the model.')
+ else:
+ assert self.cfg.optim.ema.updates > 0
+ self.logger.info(
+ f'Initializing EMA on the model with decay = {self.ema.decay}'
+ f' every {self.cfg.optim.ema.updates} updates'
+ )
+
+ @abstractmethod
+ def build_dataloaders(self):
+ """Method to implement to initialize dataloaders."""
+ ...
+
+ @abstractmethod
+ def show(self):
+ """Method to log any information without running the job."""
+ ...
+
+ @property
+ def log_updates(self):
+ # convenient access to log updates
+ return self._log_updates
+
+ def checkpoint_path(self, **kwargs):
+ kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+ return self.folder / checkpoint.checkpoint_name(**kwargs)
+
+ def epoch_checkpoint_path(self, epoch: int, **kwargs):
+ kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+ return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
+
+ def checkpoint_path_with_name(self, name: str, **kwargs):
+ kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+ return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
+
+ def save_checkpoints(self):
+ """Save checkpoint, optionally keeping a copy for a given epoch."""
+ is_sharded = self.cfg.fsdp.use
+ if not flashy.distrib.is_rank_zero() and not is_sharded:
+ return
+ self.logger.info("Model hash: %s", model_hash(self.model))
+ state = self.state_dict()
+ epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here
+
+ # save minimal state_dict as new checkpoint every X epoch
+ if self.cfg.checkpoint.save_every:
+ if epoch % self.cfg.checkpoint.save_every == 0:
+ minimal_state = state
+ if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
+ minimal_state = {
+ name: source for name, source in state.items()
+ if name in self.cfg.checkpoint.keep_every_states
+ }
+ epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
+ checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
+
+ # save checkpoint as latest checkpoint
+ if self.cfg.checkpoint.save_last:
+ last_checkpoint_path = self.checkpoint_path()
+ checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
+
+ # flush any stale checkpoint to reduce disk footprint
+ checkpoint.flush_stale_checkpoints(self.checkpoint_path())
+
+ def load_from_pretrained(self, name: str) -> dict:
+ raise NotImplementedError("Solver does not provide a way to load pretrained models.")
+
+ def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
+ """Load last checkpoint or the one specified in continue_from.
+
+ Args:
+ load_best (bool): Whether to load from best state dict or not.
+ Best state dict is always used when not loading the current xp.
+ ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
+ Returns:
+ state (dict, optional): The loaded state dictionary.
+ """
+ # load checkpoints from xp folder or cfg.continue_from
+ is_sharded = self.cfg.fsdp.use
+ load_from_path: tp.Optional[Path] = None
+ checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
+
+ if load_best:
+ self.logger.info("Trying to load state_dict from best state.")
+
+ state: tp.Optional[dict] = None
+ rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
+ current_checkpoint_path = self.checkpoint_path()
+ _pretrained_prefix = '//pretrained/'
+ continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
+ if rank0_checkpoint_path.exists():
+ self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
+ load_from_path = current_checkpoint_path
+ checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
+ checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
+ elif self.cfg.continue_from and not continue_pretrained:
+ self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
+ # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
+ load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
+ if load_from_path is None:
+ self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
+ raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
+ checkpoint_source = checkpoint.CheckpointSource.OTHER
+
+ if load_from_path is not None:
+ state = checkpoint.load_checkpoint(load_from_path, is_sharded)
+ elif continue_pretrained:
+ self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
+ state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
+ checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
+ load_best = True
+
+ # checkpoints are not from the current xp, we only retrieve the best state
+ if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
+ assert state is not None
+ self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
+ load_best = True
+ state = {key: state[key] for key in self._continue_best_source_keys if key in state}
+ # loaded checkpoints are FSDP checkpoints: we're reading the best state
+ # from FSDP and we drop the regular best_state
+ if 'fsdp_best_state' in state and state['fsdp_best_state']:
+ state.pop('best_state', None)
+ self.logger.info("... Loaded checkpoint has FSDP best state")
+ # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
+ # then we're initializing FSDP best state with the regular best state
+ elif self.cfg.fsdp.use:
+ if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
+ # we swap non-FSDP checkpoints best_state to FSDP-compatible best state
+ state['fsdp_best_state'] = state.pop('best_state')
+ self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
+
+ if state is not None:
+ if load_best:
+ self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
+ for key in set(ignore_state_keys):
+ if key in state:
+ state.pop(key)
+ has_best_state = 'best_state' in state or 'fsdp_best_state' in state
+ assert has_best_state, ("Trying to load best state but neither 'best_state'",
+ " or 'fsdp_best_state' found in checkpoints.")
+ self.load_state_dict(state)
+
+ # for FSDP, let's make extra sure nothing bad happened with out of sync
+ # checkpoints across workers.
+ epoch = float(self.epoch)
+ avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
+ if avg_epoch != epoch:
+ raise RuntimeError(
+ f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
+ f"but average of epochs is {avg_epoch}, at least one gpu must have a "
+ "different epoch number.")
+
+ # on load_best, properly reinitialize state_dict, best states and ema
+ # otherwise we load from the current xp and don't alter anything
+ if load_best:
+ self.logger.info("Loading state_dict from best state.")
+ if not self.cfg.fsdp.use and self.fsdp_best_state:
+ # loading from an FSDP checkpoint but with FSDP deactivated
+ self.logger.info("... Loading from FSDP best state dict.")
+ self.best_state.load_state_dict(self.fsdp_best_state)
+
+ # if load_best, we permanently override the regular state_dict with the best state
+ if self.cfg.fsdp.use:
+ self.logger.info("FSDP is used, loading from FSDP best state.")
+ with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+ # this might be really fragile but okay for now.
+ self.load_state_dict(self.fsdp_best_state)
+ else:
+ # we permanently swap the stateful objects to their best state
+ self._load_new_state_dict(self.best_state.state_dict())
+
+ # the EMA modules should also be instantiated with best state.
+ # the easiest way to do so is to reinitialize a new EMA with best state loaded.
+ if self.ema is not None:
+ self.logger.info("Re-initializing EMA from best state")
+ self.initialize_ema()
+
+ if self.cfg.fsdp.use:
+ self.logger.info("Re-initializing best state after using FSDP best state.")
+ for name in self.best_state.states.keys():
+ state_source = self._get_state_source(name)
+ self.best_state.update(name, state_source)
+
+ return state
+
+ def restore(self, load_best: bool = False, replay_metrics: bool = False,
+ ignore_state_keys: tp.List[str] = []) -> bool:
+ """Restore the status of a solver for a given xp.
+
+ Args:
+ load_best (bool): if `True`, load the best state from the checkpoint.
+ replay_metrics (bool): if `True`, logs all the metrics from past epochs.
+ ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
+ """
+ self.logger.info("Restoring weights and history.")
+ restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
+
+ self.logger.info("Model hash: %s", model_hash(self.model))
+
+ if replay_metrics and len(self.history) > 0:
+ self.logger.info("Replaying past metrics...")
+ for epoch, stages in enumerate(self.history):
+ for stage_name, metrics in stages.items():
+ # We manually log the metrics summary to the result logger
+ # as we don't want to add them to the pending metrics
+ self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
+ formatter=self.get_formatter(stage_name))
+ return restored_checkpoints is not None
+
+ def commit(self, save_checkpoints: bool = True):
+ """Commit metrics to dora and save checkpoints at the end of an epoch."""
+ # we override commit to introduce more complex checkpoint saving behaviors
+ self.history.append(self._pending_metrics) # This will increase self.epoch
+ if save_checkpoints:
+ self.save_checkpoints()
+ self._start_epoch()
+ if flashy.distrib.is_rank_zero():
+ self.xp.link.update_history(self.history)
+
+ def run_epoch(self):
+ """Run a single epoch with all stages.
+
+ Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
+ Children solvers can extend this method with custom behavior, e.g.:
+
+ def run_epoch(self):
+ ... # custom code
+ super().run_epoch()
+ ... # custom code
+ """
+ self.run_stage('train', self.train)
+ with torch.no_grad():
+ with self.swap_ema_state():
+ self.run_stage('valid', self.valid)
+ # the best state is updated with EMA states if available
+ self.update_best_state_from_stage('valid')
+ with self.swap_best_state():
+ if self.should_run_stage('evaluate'):
+ self.run_stage('evaluate', self.evaluate)
+ if self.should_run_stage('generate'):
+ self.run_stage('generate', with_rank_rng()(self.generate))
+
+ def run(self):
+ """Training loop."""
+ assert len(self.state_dict()) > 0
+ self.restore(replay_metrics=True) # load checkpoint and replay history
+ self.log_hyperparams(dict_from_config(self.cfg))
+ for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
+ if self.should_stop_training():
+ return
+ self.run_epoch()
+ # Commit will send the metrics to Dora and save checkpoints by default.
+ self.commit()
+
+ def should_stop_training(self) -> bool:
+ """Check whether we should stop training or not."""
+ return self.epoch > self.cfg.optim.epochs
+
+ def should_run_stage(self, stage_name) -> bool:
+ """Check whether we want to run the specified stages."""
+ stage_every = self.cfg[stage_name].get('every', None)
+ is_last_epoch = self.epoch == self.cfg.optim.epochs
+ is_epoch_every = (stage_every and self.epoch % stage_every == 0)
+ return is_last_epoch or is_epoch_every
+
+ @abstractmethod
+ def run_step(self, idx: int, batch: tp.Any, metrics: dict):
+ """Perform one training or valid step on a given batch."""
+ ...
+
+ def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
+ """Common logic for train and valid stages."""
+ self.model.train(self.is_training)
+
+ loader = self.dataloaders[dataset_split]
+ # get a different order for distributed training, otherwise this will get ignored
+ if flashy.distrib.world_size() > 1 \
+ and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
+ loader.sampler.set_epoch(self.epoch)
+ updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
+ if self.cfg.benchmark_no_load:
+ self.logger.warning("Fake loading for benchmarking: re-using first batch")
+ batch = next(iter(loader))
+ loader = [batch] * updates_per_epoch # type: ignore
+ lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
+ average = flashy.averager() # epoch wise average
+ instant_average = flashy.averager() # average between two logging
+ metrics: dict = {}
+
+ with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates.
+ for idx, batch in enumerate(lp):
+ self.deadlock_detect.update('batch')
+ if idx >= updates_per_epoch:
+ break
+ metrics = {}
+ metrics = self.run_step(idx, batch, metrics)
+ self.deadlock_detect.update('step')
+ # run EMA step
+ if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
+ self.logger.debug("EMA model step")
+ self.ema.step()
+ self.deadlock_detect.update('ema')
+ self.profiler.step()
+ instant_metrics = instant_average(metrics)
+ if lp.update(**instant_metrics):
+ instant_average = flashy.averager() # reset averager between two logging
+ metrics = average(metrics) # epoch wise average
+ self.deadlock_detect.update('end_batch')
+
+ metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
+ return metrics
+
+ def train(self):
+ """Train stage."""
+ return self.common_train_valid('train')
+
+ def valid(self):
+ """Valid stage."""
+ return self.common_train_valid('valid')
+
+ @abstractmethod
+ def evaluate(self):
+ """Evaluate stage."""
+ ...
+
+ @abstractmethod
+ def generate(self):
+ """Generate stage."""
+ ...
+
+ def run_one_stage(self, stage_name: str):
+ """Run only the specified stage.
+ This method is useful to only generate samples from a trained experiment
+ or rerun the validation or evaluation stages.
+ """
+ fn = {
+ 'generate': with_rank_rng()(self.generate),
+ 'evaluate': self.evaluate,
+ 'valid': self.valid,
+ }
+ if stage_name not in fn:
+ raise ValueError(f'Trying to run stage {stage_name} is not supported.')
+ assert len(self.state_dict()) > 0
+ self._start_epoch()
+ with torch.no_grad(), self.swap_best_state():
+ self.run_stage(stage_name, fn[stage_name])
+ if not self.cfg.execute_inplace:
+ self.commit(save_checkpoints=False)
+
+ @staticmethod
+ def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+ device: tp.Optional[str] = None, autocast: bool = True,
+ batch_size: tp.Optional[int] = None,
+ override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+ **kwargs):
+ """Mostly a convenience function around audiocraft.train.get_solver_from_sig,
+ populating all the proper param, deactivating EMA, FSDP, loading the best state,
+ basically all you need to get a solver ready to "play" with in single GPU mode
+ and with minimal memory overhead.
+
+ Args:
+ sig (str): signature to load.
+ dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+ device (str or None): potential device, as a string, i.e. 'cuda'.
+ override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+ """
+ from audiocraft import train
+ our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+ our_override_cfg['autocast'] = autocast
+ if dtype is not None:
+ our_override_cfg['dtype'] = dtype
+ if device is not None:
+ our_override_cfg['device'] = device
+ if batch_size is not None:
+ our_override_cfg['dataset'] = {'batch_size': batch_size}
+ if override_cfg is None:
+ override_cfg = {}
+ override_cfg = omegaconf.OmegaConf.merge(
+ omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
+ solver = train.get_solver_from_sig(
+ sig, override_cfg=override_cfg,
+ load_best=True, disable_fsdp=True,
+ ignore_state_keys=['optimizer', 'ema'], **kwargs)
+ solver.model.eval()
+ return solver
diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py
new file mode 100644
index 0000000000000000000000000000000000000000..304d8f08d33a70e8be9388c855b2ae43bdf2683b
--- /dev/null
+++ b/audiocraft/solvers/builders.py
@@ -0,0 +1,363 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+All the functions to build the relevant solvers and used objects
+from the Hydra config.
+"""
+
+from enum import Enum
+import logging
+import typing as tp
+
+import dora
+import flashy
+import omegaconf
+import torch
+from torch import nn
+from torch.optim import Optimizer
+# LRScheduler was renamed in some torch versions
+try:
+ from torch.optim.lr_scheduler import LRScheduler # type: ignore
+except ImportError:
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+from .base import StandardSolver
+from .. import adversarial, data, losses, metrics, optim
+from ..utils.utils import dict_from_config, get_loader
+
+
+logger = logging.getLogger(__name__)
+
+
+class DatasetType(Enum):
+ AUDIO = "audio"
+ MUSIC = "music"
+ SOUND = "sound"
+
+
+def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
+ """Instantiate solver from config."""
+ from .audiogen import AudioGenSolver
+ from .compression import CompressionSolver
+ from .musicgen import MusicGenSolver
+ from .diffusion import DiffusionSolver
+ klass = {
+ 'compression': CompressionSolver,
+ 'musicgen': MusicGenSolver,
+ 'audiogen': AudioGenSolver,
+ 'lm': MusicGenSolver, # backward compatibility
+ 'diffusion': DiffusionSolver,
+ 'sound_lm': AudioGenSolver, # backward compatibility
+ }[cfg.solver]
+ return klass(cfg) # type: ignore
+
+
+def get_optim_parameter_groups(model: nn.Module):
+ """Create parameter groups for the model using the appropriate method
+ if defined for each modules, to create the different groups.
+
+ Args:
+ model (nn.Module): torch model
+ Returns:
+ List of parameter groups
+ """
+ seen_params: tp.Set[nn.parameter.Parameter] = set()
+ other_params = []
+ groups = []
+ for name, module in model.named_modules():
+ if hasattr(module, 'make_optim_group'):
+ group = module.make_optim_group()
+ params = set(group['params'])
+ assert params.isdisjoint(seen_params)
+ seen_params |= set(params)
+ groups.append(group)
+ for param in model.parameters():
+ if param not in seen_params:
+ other_params.append(param)
+ groups.insert(0, {'params': other_params})
+ parameters = groups
+ return parameters
+
+
+def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer:
+ """Build torch optimizer from config and set of parameters.
+ Supported optimizers: Adam, AdamW
+
+ Args:
+ params (nn.Module or iterable of torch.Tensor): Parameters to optimize.
+ cfg (DictConfig): Optimization-related configuration.
+ Returns:
+ torch.optim.Optimizer.
+ """
+ if 'optimizer' not in cfg:
+ if getattr(cfg, 'optim', None) is not None:
+ raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?")
+ else:
+ raise KeyError("Optimizer not found in config.")
+
+ parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params
+ optimizer: torch.optim.Optimizer
+ if cfg.optimizer == 'adam':
+ optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam)
+ elif cfg.optimizer == 'adamw':
+ optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam)
+ elif cfg.optimizer == 'dadam':
+ optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam)
+ else:
+ raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
+ return optimizer
+
+
+def get_lr_scheduler(optimizer: torch.optim.Optimizer,
+ cfg: omegaconf.DictConfig,
+ total_updates: int) -> tp.Optional[LRScheduler]:
+ """Build torch learning rate scheduler from config and associated optimizer.
+ Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler
+
+ Args:
+ optimizer (torch.optim.Optimizer): Optimizer.
+ cfg (DictConfig): Schedule-related configuration.
+ total_updates (int): Total number of updates.
+ Returns:
+ torch.optim.Optimizer.
+ """
+ if 'lr_scheduler' not in cfg:
+ raise KeyError("LR Scheduler not found in config")
+
+ lr_sched: tp.Optional[LRScheduler] = None
+ if cfg.lr_scheduler == 'step':
+ lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step)
+ elif cfg.lr_scheduler == 'exponential':
+ lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential)
+ elif cfg.lr_scheduler == 'cosine':
+ kwargs = dict_from_config(cfg.cosine)
+ warmup_steps = kwargs.pop('warmup')
+ lr_sched = optim.CosineLRScheduler(
+ optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
+ elif cfg.lr_scheduler == 'polynomial_decay':
+ kwargs = dict_from_config(cfg.polynomial_decay)
+ warmup_steps = kwargs.pop('warmup')
+ lr_sched = optim.PolynomialDecayLRScheduler(
+ optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
+ elif cfg.lr_scheduler == 'inverse_sqrt':
+ kwargs = dict_from_config(cfg.inverse_sqrt)
+ warmup_steps = kwargs.pop('warmup')
+ lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
+ elif cfg.lr_scheduler == 'linear_warmup':
+ kwargs = dict_from_config(cfg.linear_warmup)
+ warmup_steps = kwargs.pop('warmup')
+ lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
+ elif cfg.lr_scheduler is not None:
+ raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
+ return lr_sched
+
+
+def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]:
+ """Initialize Exponential Moving Average.
+
+ Args:
+ module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA.
+ cfg (omegaconf.DictConfig): Optim EMA configuration.
+ Returns:
+ optim.ModuleDictEMA: EMA version of the ModuleDict.
+ """
+ kw: tp.Dict[str, tp.Any] = dict(cfg)
+ use = kw.pop('use', False)
+ decay = kw.pop('decay', None)
+ device = kw.pop('device', None)
+ if not use:
+ return None
+ if len(module_dict) == 0:
+ raise ValueError("Trying to build EMA but an empty module_dict source is provided!")
+ ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device)
+ return ema_module
+
+
+def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
+ """Instantiate loss from configuration."""
+ klass = {
+ 'l1': torch.nn.L1Loss,
+ 'l2': torch.nn.MSELoss,
+ 'mel': losses.MelSpectrogramL1Loss,
+ 'mrstft': losses.MRSTFTLoss,
+ 'msspec': losses.MultiScaleMelSpectrogramLoss,
+ 'sisnr': losses.SISNR,
+ }[loss_name]
+ kwargs = dict(getattr(cfg, loss_name))
+ return klass(**kwargs)
+
+
+def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer:
+ """Instantiate loss balancer from configuration for the provided weights."""
+ kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg)
+ return losses.Balancer(loss_weights, **kwargs)
+
+
+def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
+ """Initialize adversary from config."""
+ klass = {
+ 'msd': adversarial.MultiScaleDiscriminator,
+ 'mpd': adversarial.MultiPeriodDiscriminator,
+ 'msstftd': adversarial.MultiScaleSTFTDiscriminator,
+ }[name]
+ adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name))
+ return klass(**adv_cfg)
+
+
+def get_adversarial_losses(cfg) -> nn.ModuleDict:
+ """Initialize dict of adversarial losses from config."""
+ device = cfg.device
+ adv_cfg = getattr(cfg, 'adversarial')
+ adversaries = adv_cfg.get('adversaries', [])
+ adv_loss_name = adv_cfg['adv_loss']
+ feat_loss_name = adv_cfg.get('feat_loss')
+ normalize = adv_cfg.get('normalize', True)
+ feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None
+ if feat_loss_name:
+ assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found."
+ loss = get_loss(feat_loss_name, cfg)
+ feat_loss = adversarial.FeatureMatchingLoss(loss, normalize)
+ loss = adversarial.get_adv_criterion(adv_loss_name)
+ loss_real = adversarial.get_real_criterion(adv_loss_name)
+ loss_fake = adversarial.get_fake_criterion(adv_loss_name)
+ adv_losses = nn.ModuleDict()
+ for adv_name in adversaries:
+ adversary = get_adversary(adv_name, cfg).to(device)
+ optimizer = get_optimizer(adversary.parameters(), cfg.optim)
+ adv_loss = adversarial.AdversarialLoss(
+ adversary,
+ optimizer,
+ loss=loss,
+ loss_real=loss_real,
+ loss_fake=loss_fake,
+ loss_feat=feat_loss,
+ normalize=normalize
+ )
+ adv_losses[adv_name] = adv_loss
+ return adv_losses
+
+
+def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
+ """Instantiate ViSQOL metric from config."""
+ kwargs = dict_from_config(cfg)
+ return metrics.ViSQOL(**kwargs)
+
+
+def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric:
+ """Instantiate Frechet Audio Distance metric from config."""
+ kwargs = dict_from_config(cfg.tf)
+ xp = dora.get_xp()
+ kwargs['log_folder'] = xp.folder
+ return metrics.FrechetAudioDistanceMetric(**kwargs)
+
+
+def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
+ """Instantiate KL-Divergence metric from config."""
+ kld_metrics = {
+ 'passt': metrics.PasstKLDivergenceMetric,
+ }
+ klass = kld_metrics[cfg.model]
+ kwargs = dict_from_config(cfg.get(cfg.model))
+ return klass(**kwargs)
+
+
+def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric:
+ """Instantiate Text Consistency metric from config."""
+ text_consistency_metrics = {
+ 'clap': metrics.CLAPTextConsistencyMetric
+ }
+ klass = text_consistency_metrics[cfg.model]
+ kwargs = dict_from_config(cfg.get(cfg.model))
+ return klass(**kwargs)
+
+
+def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric:
+ """Instantiate Chroma Cosine Similarity metric from config."""
+ assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric"
+ kwargs = dict_from_config(cfg.get(cfg.model))
+ return metrics.ChromaCosineSimilarityMetric(**kwargs)
+
+
+def get_audio_datasets(cfg: omegaconf.DictConfig,
+ dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]:
+ """Build AudioDataset from configuration.
+
+ Args:
+ cfg (omegaconf.DictConfig): Configuration.
+ dataset_type: The type of dataset to create.
+ Returns:
+ dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split.
+ """
+ dataloaders: dict = {}
+
+ sample_rate = cfg.sample_rate
+ channels = cfg.channels
+ seed = cfg.seed
+ max_sample_rate = cfg.datasource.max_sample_rate
+ max_channels = cfg.datasource.max_channels
+
+ assert cfg.dataset is not None, "Could not find dataset definition in config"
+
+ dataset_cfg = dict_from_config(cfg.dataset)
+ splits_cfg: dict = {}
+ splits_cfg['train'] = dataset_cfg.pop('train')
+ splits_cfg['valid'] = dataset_cfg.pop('valid')
+ splits_cfg['evaluate'] = dataset_cfg.pop('evaluate')
+ splits_cfg['generate'] = dataset_cfg.pop('generate')
+ execute_only_stage = cfg.get('execute_only', None)
+
+ for split, path in cfg.datasource.items():
+ if not isinstance(path, str):
+ continue # skipping this as not a path
+ if execute_only_stage is not None and split != execute_only_stage:
+ continue
+ logger.info(f"Loading audio data split {split}: {str(path)}")
+ assert (
+ cfg.sample_rate <= max_sample_rate
+ ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found."
+ assert (
+ cfg.channels <= max_channels
+ ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found."
+
+ split_cfg = splits_cfg[split]
+ split_kwargs = {k: v for k, v in split_cfg.items()}
+ kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg
+ kwargs['sample_rate'] = sample_rate
+ kwargs['channels'] = channels
+
+ if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch:
+ kwargs['num_samples'] = (
+ flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch)
+
+ num_samples = kwargs['num_samples']
+ shuffle = kwargs['shuffle']
+
+ return_info = kwargs.pop('return_info')
+ batch_size = kwargs.pop('batch_size', None)
+ num_workers = kwargs.pop('num_workers')
+
+ if dataset_type == DatasetType.MUSIC:
+ dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs)
+ elif dataset_type == DatasetType.SOUND:
+ dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs)
+ elif dataset_type == DatasetType.AUDIO:
+ dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs)
+ else:
+ raise ValueError(f"Dataset type is unsupported: {dataset_type}")
+
+ loader = get_loader(
+ dataset,
+ num_samples,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ seed=seed,
+ collate_fn=dataset.collater if return_info else None,
+ shuffle=shuffle,
+ )
+ dataloaders[split] = loader
+
+ return dataloaders
diff --git a/audiocraft/solvers/compression.py b/audiocraft/solvers/compression.py
new file mode 100644
index 0000000000000000000000000000000000000000..b757503472a3bfbf90e1636999e64913848a7474
--- /dev/null
+++ b/audiocraft/solvers/compression.py
@@ -0,0 +1,328 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import multiprocessing
+from pathlib import Path
+import typing as tp
+
+import flashy
+import omegaconf
+import torch
+from torch import nn
+
+from . import base, builders
+from .. import models, quantization
+from ..utils import checkpoint
+from ..utils.samples.manager import SampleManager
+from ..utils.utils import get_pool_executor
+
+
+logger = logging.getLogger(__name__)
+
+
+class CompressionSolver(base.StandardSolver):
+ """Solver for compression task.
+
+ The compression task combines a set of perceptual and objective losses
+ to train an EncodecModel (composed of an encoder-decoder and a quantizer)
+ to perform high fidelity audio reconstruction.
+ """
+ def __init__(self, cfg: omegaconf.DictConfig):
+ super().__init__(cfg)
+ self.rng: torch.Generator # set at each epoch
+ self.adv_losses = builders.get_adversarial_losses(self.cfg)
+ self.aux_losses = nn.ModuleDict()
+ self.info_losses = nn.ModuleDict()
+ assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
+ loss_weights = dict()
+ for loss_name, weight in self.cfg.losses.items():
+ if loss_name in ['adv', 'feat']:
+ for adv_name, _ in self.adv_losses.items():
+ loss_weights[f'{loss_name}_{adv_name}'] = weight
+ elif weight > 0:
+ self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
+ loss_weights[loss_name] = weight
+ else:
+ self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
+ self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
+ self.register_stateful('adv_losses')
+
+ @property
+ def best_metric_name(self) -> tp.Optional[str]:
+ # best model is the last for the compression model
+ return None
+
+ def build_model(self):
+ """Instantiate model and optimizer."""
+ # Model and optimizer
+ self.model = models.builders.get_compression_model(self.cfg).to(self.device)
+ self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+ self.register_stateful('model', 'optimizer')
+ self.register_best_state('model')
+ self.register_ema('model')
+
+ def build_dataloaders(self):
+ """Instantiate audio dataloaders for each stage."""
+ self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+ def show(self):
+ """Show the compression model and employed adversarial loss."""
+ self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
+ self.log_model_summary(self.model)
+ self.logger.info("Adversarial loss:")
+ self.log_model_summary(self.adv_losses)
+ self.logger.info("Auxiliary losses:")
+ self.logger.info(self.aux_losses)
+ self.logger.info("Info losses:")
+ self.logger.info(self.info_losses)
+
+ def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
+ """Perform one training or valid step on a given batch."""
+ x = batch.to(self.device)
+ y = x.clone()
+
+ qres = self.model(x)
+ assert isinstance(qres, quantization.QuantizedResult)
+ y_pred = qres.x
+ # Log bandwidth in kb/s
+ metrics['bandwidth'] = qres.bandwidth.mean()
+
+ if self.is_training:
+ d_losses: dict = {}
+ if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
+ for adv_name, adversary in self.adv_losses.items():
+ disc_loss = adversary.train_adv(y_pred, y)
+ d_losses[f'd_{adv_name}'] = disc_loss
+ metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
+ metrics.update(d_losses)
+
+ balanced_losses: dict = {}
+ other_losses: dict = {}
+
+ # penalty from quantization
+ if qres.penalty is not None and qres.penalty.requires_grad:
+ other_losses['penalty'] = qres.penalty # penalty term from the quantizer
+
+ # adversarial losses
+ for adv_name, adversary in self.adv_losses.items():
+ adv_loss, feat_loss = adversary(y_pred, y)
+ balanced_losses[f'adv_{adv_name}'] = adv_loss
+ balanced_losses[f'feat_{adv_name}'] = feat_loss
+
+ # auxiliary losses
+ for loss_name, criterion in self.aux_losses.items():
+ loss = criterion(y_pred, y)
+ balanced_losses[loss_name] = loss
+
+ # weighted losses
+ metrics.update(balanced_losses)
+ metrics.update(other_losses)
+ metrics.update(qres.metrics)
+
+ if self.is_training:
+ # backprop losses that are not handled by balancer
+ other_loss = torch.tensor(0., device=self.device)
+ if 'penalty' in other_losses:
+ other_loss += other_losses['penalty']
+ if other_loss.requires_grad:
+ other_loss.backward(retain_graph=True)
+ ratio1 = sum(p.grad.data.norm(p=2).pow(2)
+ for p in self.model.parameters() if p.grad is not None)
+ assert isinstance(ratio1, torch.Tensor)
+ metrics['ratio1'] = ratio1.sqrt()
+
+ # balancer losses backward, returns effective training loss
+ # with effective weights at the current batch.
+ metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
+ # add metrics corresponding to weight ratios
+ metrics.update(self.balancer.metrics)
+ ratio2 = sum(p.grad.data.norm(p=2).pow(2)
+ for p in self.model.parameters() if p.grad is not None)
+ assert isinstance(ratio2, torch.Tensor)
+ metrics['ratio2'] = ratio2.sqrt()
+
+ # optim
+ flashy.distrib.sync_model(self.model)
+ if self.cfg.optim.max_norm:
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.cfg.optim.max_norm
+ )
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ # informative losses only
+ info_losses: dict = {}
+ with torch.no_grad():
+ for loss_name, criterion in self.info_losses.items():
+ loss = criterion(y_pred, y)
+ info_losses[loss_name] = loss
+
+ metrics.update(info_losses)
+
+ # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
+ adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')]
+ if len(adv_losses) > 0:
+ metrics['adv'] = torch.sum(torch.stack(adv_losses))
+ feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')]
+ if len(feat_losses) > 0:
+ metrics['feat'] = torch.sum(torch.stack(feat_losses))
+
+ return metrics
+
+ def run_epoch(self):
+ # reset random seed at the beginning of the epoch
+ self.rng = torch.Generator()
+ self.rng.manual_seed(1234 + self.epoch)
+ # run epoch
+ super().run_epoch()
+
+ def evaluate(self):
+ """Evaluate stage. Runs audio reconstruction evaluation."""
+ self.model.eval()
+ evaluate_stage_name = str(self.current_stage)
+
+ loader = self.dataloaders['evaluate']
+ updates = len(loader)
+ lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+ average = flashy.averager()
+
+ pendings = []
+ ctx = multiprocessing.get_context('spawn')
+ with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
+ for idx, batch in enumerate(lp):
+ x = batch.to(self.device)
+ with torch.no_grad():
+ qres = self.model(x)
+
+ y_pred = qres.x.cpu()
+ y = batch.cpu() # should already be on CPU but just in case
+ pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
+
+ metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
+ for pending in metrics_lp:
+ metrics = pending.result()
+ metrics = average(metrics)
+
+ metrics = flashy.distrib.average_metrics(metrics, len(loader))
+ return metrics
+
+ def generate(self):
+ """Generate stage."""
+ self.model.eval()
+ sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
+ generate_stage_name = str(self.current_stage)
+
+ loader = self.dataloaders['generate']
+ updates = len(loader)
+ lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+ for batch in lp:
+ reference, _ = batch
+ reference = reference.to(self.device)
+ with torch.no_grad():
+ qres = self.model(reference)
+ assert isinstance(qres, quantization.QuantizedResult)
+
+ reference = reference.cpu()
+ estimate = qres.x.cpu()
+ sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
+
+ flashy.distrib.barrier()
+
+ def load_from_pretrained(self, name: str) -> dict:
+ model = models.CompressionModel.get_pretrained(name)
+ if isinstance(model, models.DAC):
+ raise RuntimeError("Cannot fine tune a DAC model.")
+ elif isinstance(model, models.HFEncodecCompressionModel):
+ self.logger.warning('Trying to automatically convert a HuggingFace model '
+ 'to AudioCraft, this might fail!')
+ state = model.model.state_dict()
+ new_state = {}
+ for k, v in state.items():
+ if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
+ # We need to determine if this a convtr or a regular conv.
+ layer = int(k.split('.')[2])
+ if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
+
+ k = k.replace('.conv.', '.convtr.')
+ k = k.replace('encoder.layers.', 'encoder.model.')
+ k = k.replace('decoder.layers.', 'decoder.model.')
+ k = k.replace('conv.', 'conv.conv.')
+ k = k.replace('convtr.', 'convtr.convtr.')
+ k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
+ k = k.replace('.codebook.', '._codebook.')
+ new_state[k] = v
+ state = new_state
+ elif isinstance(model, models.EncodecModel):
+ state = model.state_dict()
+ else:
+ raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
+ return {
+ 'best_state': {'model': state}
+ }
+
+ @staticmethod
+ def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
+ device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+ """Instantiate a CompressionModel from a given checkpoint path or dora sig.
+ This method is a convenient endpoint to load a CompressionModel to use in other solvers.
+
+ Args:
+ checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+ This also supports pre-trained models by using a path of the form //pretrained/NAME.
+ See `model_from_pretrained` for a list of supported pretrained models.
+ use_ema (bool): Use EMA variant of the model instead of the actual model.
+ device (torch.device or str): Device on which the model is loaded.
+ """
+ checkpoint_path = str(checkpoint_path)
+ if checkpoint_path.startswith('//pretrained/'):
+ name = checkpoint_path.split('/', 3)[-1]
+ return models.CompressionModel.get_pretrained(name, device)
+ logger = logging.getLogger(__name__)
+ logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
+ _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
+ assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
+ state = checkpoint.load_checkpoint(_checkpoint_path)
+ assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
+ cfg = state['xp.cfg']
+ cfg.device = device
+ compression_model = models.builders.get_compression_model(cfg).to(device)
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
+
+ assert 'best_state' in state and state['best_state'] != {}
+ assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
+ compression_model.load_state_dict(state['best_state']['model'])
+ compression_model.eval()
+ logger.info("Compression model loaded!")
+ return compression_model
+
+ @staticmethod
+ def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
+ checkpoint_path: tp.Union[Path, str],
+ device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+ """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
+
+ Args:
+ cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
+ checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+ use_ema (bool): Use EMA variant of the model instead of the actual model.
+ device (torch.device or str): Device on which the model is loaded.
+ """
+ compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
+ compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
+ return compression_model
+
+
+def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict:
+ """Audio reconstruction evaluation method that can be conveniently pickled."""
+ metrics = {}
+ if cfg.evaluate.metrics.visqol:
+ visqol = builders.get_visqol(cfg.metrics.visqol)
+ metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate)
+ sisnr = builders.get_loss('sisnr', cfg)
+ metrics['sisnr'] = sisnr(y_pred, y)
+ return metrics
diff --git a/audiocraft/solvers/diffusion.py b/audiocraft/solvers/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..93dea2520836f458ab1b8514dca952b51d113ec2
--- /dev/null
+++ b/audiocraft/solvers/diffusion.py
@@ -0,0 +1,279 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import flashy
+import julius
+import omegaconf
+import torch
+import torch.nn.functional as F
+
+from . import builders
+from . import base
+from .. import models
+from ..modules.diffusion_schedule import NoiseSchedule
+from ..metrics import RelativeVolumeMel
+from ..models.builders import get_processor
+from ..utils.samples.manager import SampleManager
+from ..solvers.compression import CompressionSolver
+
+
+class PerStageMetrics:
+ """Handle prompting the metrics per stage.
+ It outputs the metrics per range of diffusion states.
+ e.g. avg loss when t in [250, 500]
+ """
+ def __init__(self, num_steps: int, num_stages: int = 4):
+ self.num_steps = num_steps
+ self.num_stages = num_stages
+
+ def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
+ if type(step) is int:
+ stage = int((step / self.num_steps) * self.num_stages)
+ return {f"{name}_{stage}": loss for name, loss in losses.items()}
+ elif type(step) is torch.Tensor:
+ stage_tensor = ((step / self.num_steps) * self.num_stages).long()
+ out: tp.Dict[str, float] = {}
+ for stage_idx in range(self.num_stages):
+ mask = (stage_tensor == stage_idx)
+ N = mask.sum()
+ stage_out = {}
+ if N > 0: # pass if no elements in the stage
+ for name, loss in losses.items():
+ stage_loss = (mask * loss).sum() / N
+ stage_out[f"{name}_{stage_idx}"] = stage_loss
+ out = {**out, **stage_out}
+ return out
+
+
+class DataProcess:
+ """Apply filtering or resampling.
+
+ Args:
+ initial_sr (int): Initial sample rate.
+ target_sr (int): Target sample rate.
+ use_resampling: Whether to use resampling or not.
+ use_filter (bool):
+ n_bands (int): Number of bands to consider.
+ idx_band (int):
+ device (torch.device or str):
+ cutoffs ():
+ boost (bool):
+ """
+ def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
+ use_filter: bool = False, n_bands: int = 4,
+ idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
+ """Apply filtering or resampling
+ Args:
+ initial_sr (int): sample rate of the dataset
+ target_sr (int): sample rate after resampling
+ use_resampling (bool): whether or not performs resampling
+ use_filter (bool): when True filter the data to keep only one frequency band
+ n_bands (int): Number of bands used
+ cuts (none or list): The cutoff frequencies of the band filtering
+ if None then we use mel scale bands.
+ idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
+ boost (bool): make the data scale match our music dataset.
+ """
+ assert idx_band < n_bands
+ self.idx_band = idx_band
+ if use_filter:
+ if cutoffs is not None:
+ self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
+ else:
+ self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
+ self.use_filter = use_filter
+ self.use_resampling = use_resampling
+ self.target_sr = target_sr
+ self.initial_sr = initial_sr
+ self.boost = boost
+
+ def process_data(self, x, metric=False):
+ if x is None:
+ return None
+ if self.boost:
+ x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
+ x * 0.22
+ if self.use_filter and not metric:
+ x = self.filter(x)[self.idx_band]
+ if self.use_resampling:
+ x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
+ return x
+
+ def inverse_process(self, x):
+ """Upsampling only."""
+ if self.use_resampling:
+ x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
+ return x
+
+
+class DiffusionSolver(base.StandardSolver):
+ """Solver for compression task.
+
+ The diffusion task allows for MultiBand diffusion model training.
+
+ Args:
+ cfg (DictConfig): Configuration.
+ """
+ def __init__(self, cfg: omegaconf.DictConfig):
+ super().__init__(cfg)
+ self.cfg = cfg
+ self.device = cfg.device
+ self.sample_rate: int = self.cfg.sample_rate
+ self.codec_model = CompressionSolver.model_from_checkpoint(
+ cfg.compression_model_checkpoint, device=self.device)
+
+ self.codec_model.set_num_codebooks(cfg.n_q)
+ assert self.codec_model.sample_rate == self.cfg.sample_rate, (
+ f"Codec model sample rate is {self.codec_model.sample_rate} but "
+ f"Solver sample rate is {self.cfg.sample_rate}."
+ )
+ assert self.codec_model.sample_rate == self.sample_rate, \
+ f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
+ "don't match."
+
+ self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
+ self.register_stateful('sample_processor')
+ self.sample_processor.to(self.device)
+
+ self.schedule = NoiseSchedule(
+ **cfg.schedule, device=self.device, sample_processor=self.sample_processor)
+
+ self.eval_metric: tp.Optional[torch.nn.Module] = None
+
+ self.rvm = RelativeVolumeMel()
+ self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
+ use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
+ use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
+ idx_band=cfg.filter.idx_band, device=self.device)
+
+ @property
+ def best_metric_name(self) -> tp.Optional[str]:
+ if self._current_stage == "evaluate":
+ return 'rvm'
+ else:
+ return 'loss'
+
+ @torch.no_grad()
+ def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
+ codes, scale = self.codec_model.encode(wav)
+ assert scale is None, "Scaled compression models not supported."
+ emb = self.codec_model.decode_latent(codes)
+ return emb
+
+ def build_model(self):
+ """Build model and optimizer as well as optional Exponential Moving Average of the model.
+ """
+ # Model and optimizer
+ self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
+ self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+ self.register_stateful('model', 'optimizer')
+ self.register_best_state('model')
+ self.register_ema('model')
+
+ def build_dataloaders(self):
+ """Build audio dataloaders for each stage."""
+ self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+ def show(self):
+ # TODO
+ raise NotImplementedError()
+
+ def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
+ """Perform one training or valid step on a given batch."""
+ x = batch.to(self.device)
+ loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
+
+ condition = self.get_condition(x) # [bs, 128, T/hop, n_emb]
+ sample = self.data_processor.process_data(x)
+
+ input_, target, step = self.schedule.get_training_item(sample,
+ tensor_step=self.cfg.schedule.variable_step_batch)
+ out = self.model(input_, step, condition=condition).sample
+
+ base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
+ reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
+ loss = base_loss / reference_loss ** self.cfg.loss.norm_power
+
+ if self.is_training:
+ loss.mean().backward()
+ flashy.distrib.sync_model(self.model)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ metrics = {
+ 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
+ }
+ metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
+ metrics.update({
+ 'std_in': input_.std(), 'std_out': out.std()})
+ return metrics
+
+ def run_epoch(self):
+ # reset random seed at the beginning of the epoch
+ self.rng = torch.Generator()
+ self.rng.manual_seed(1234 + self.epoch)
+ self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
+ # run epoch
+ super().run_epoch()
+
+ def evaluate(self):
+ """Evaluate stage.
+ Runs audio reconstruction evaluation.
+ """
+ self.model.eval()
+ evaluate_stage_name = f'{self.current_stage}'
+ loader = self.dataloaders['evaluate']
+ updates = len(loader)
+ lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
+
+ metrics = {}
+ n = 1
+ for idx, batch in enumerate(lp):
+ x = batch.to(self.device)
+ with torch.no_grad():
+ y_pred = self.regenerate(x)
+
+ y_pred = y_pred.cpu()
+ y = batch.cpu() # should already be on CPU but just in case
+ rvm = self.rvm(y_pred, y)
+ lp.update(**rvm)
+ if len(metrics) == 0:
+ metrics = rvm
+ else:
+ for key in rvm.keys():
+ metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
+ metrics = flashy.distrib.average_metrics(metrics)
+ return metrics
+
+ @torch.no_grad()
+ def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
+ """Regenerate the given waveform."""
+ condition = self.get_condition(wav)
+ initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes.
+ result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
+ step_list=step_list)
+ result = self.data_processor.inverse_process(result)
+ return result
+
+ def generate(self):
+ """Generate stage."""
+ sample_manager = SampleManager(self.xp)
+ self.model.eval()
+ generate_stage_name = f'{self.current_stage}'
+
+ loader = self.dataloaders['generate']
+ updates = len(loader)
+ lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+ for batch in lp:
+ reference, _ = batch
+ reference = reference.to(self.device)
+ estimate = self.regenerate(reference)
+ reference = reference.cpu()
+ estimate = estimate.cpu()
+ sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
+ flashy.distrib.barrier()
diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb615abf448f9dd07490aaabf3fff9b861a1b2cb
--- /dev/null
+++ b/audiocraft/solvers/musicgen.py
@@ -0,0 +1,699 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import time
+import typing as tp
+
+import flashy
+import math
+import omegaconf
+import torch
+from torch.nn import functional as F
+
+from . import base, builders
+from .compression import CompressionSolver
+from .. import metrics as eval_metrics
+from .. import models
+from ..data.audio_dataset import AudioDataset
+from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
+from ..data.audio_utils import normalize_audio
+from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
+from ..utils.cache import CachedBatchWriter, CachedBatchLoader
+from ..utils.samples.manager import SampleManager
+from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
+
+
+class MusicGenSolver(base.StandardSolver):
+ """Solver for MusicGen training task.
+
+ Used in: https://arxiv.org/abs/2306.05284
+ """
+ DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
+
+ def __init__(self, cfg: omegaconf.DictConfig):
+ super().__init__(cfg)
+ # easier access to sampling parameters
+ self.generation_params = {
+ 'use_sampling': self.cfg.generate.lm.use_sampling,
+ 'temp': self.cfg.generate.lm.temp,
+ 'top_k': self.cfg.generate.lm.top_k,
+ 'top_p': self.cfg.generate.lm.top_p,
+ }
+ self._best_metric_name: tp.Optional[str] = 'ce'
+
+ self._cached_batch_writer = None
+ self._cached_batch_loader = None
+ if cfg.cache.path:
+ if cfg.cache.write:
+ self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
+ if self.cfg.cache.write_num_shards:
+ self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
+ self._best_metric_name = None
+ else:
+ self._cached_batch_loader = CachedBatchLoader(
+ Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
+ min_length=self.cfg.optim.updates_per_epoch or 1)
+ self.dataloaders['original_train'] = self.dataloaders['train']
+ self.dataloaders['train'] = self._cached_batch_loader # type: ignore
+
+ @staticmethod
+ def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+ device: tp.Optional[str] = None, autocast: bool = True,
+ batch_size: tp.Optional[int] = None,
+ override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+ **kwargs):
+ """Mostly a convenience function around magma.train.get_solver_from_sig,
+ populating all the proper param, deactivating EMA, FSDP, loading the best state,
+ basically all you need to get a solver ready to "play" with in single GPU mode
+ and with minimal memory overhead.
+
+ Args:
+ sig (str): signature to load.
+ dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+ device (str or None): potential device, as a string, i.e. 'cuda'.
+ override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+ """
+ from audiocraft import train
+ our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+ our_override_cfg['autocast'] = autocast
+ if dtype is not None:
+ our_override_cfg['dtype'] = dtype
+ if device is not None:
+ our_override_cfg['device'] = device
+ if batch_size is not None:
+ our_override_cfg['dataset'] = {'batch_size': batch_size}
+ if override_cfg is None:
+ override_cfg = {}
+ override_cfg = omegaconf.OmegaConf.merge(
+ omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
+ solver = train.get_solver_from_sig(
+ sig, override_cfg=override_cfg,
+ load_best=True, disable_fsdp=True,
+ ignore_state_keys=['optimizer', 'ema'], **kwargs)
+ solver.model.eval()
+ return solver
+
+ def get_formatter(self, stage_name: str) -> flashy.Formatter:
+ return flashy.Formatter({
+ 'lr': '.2E',
+ 'ce': '.3f',
+ 'ppl': '.3f',
+ 'grad_norm': '.3E',
+ }, exclude_keys=['ce_q*', 'ppl_q*'])
+
+ @property
+ def best_metric_name(self) -> tp.Optional[str]:
+ return self._best_metric_name
+
+ def build_model(self) -> None:
+ """Instantiate models and optimizer."""
+ # we can potentially not use all quantizers with which the EnCodec model was trained
+ # (e.g. we trained the model with quantizers dropout)
+ self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
+ self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
+ assert self.compression_model.sample_rate == self.cfg.sample_rate, (
+ f"Compression model sample rate is {self.compression_model.sample_rate} but "
+ f"Solver sample rate is {self.cfg.sample_rate}."
+ )
+ # ensure we have matching configuration between LM and compression model
+ assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
+ "Cardinalities of the LM and compression model don't match: ",
+ f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
+ f"compression model cardinality is {self.compression_model.cardinality}"
+ )
+ assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
+ "Numbers of codebooks of the LM and compression models don't match: ",
+ f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
+ f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
+ )
+ self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
+ self.compression_model.num_codebooks, self.compression_model.cardinality,
+ self.compression_model.frame_rate)
+ # instantiate LM model
+ self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
+ if self.cfg.fsdp.use:
+ assert not self.cfg.autocast, "Cannot use autocast with fsdp"
+ self.model = self.wrap_with_fsdp(self.model)
+ self.register_ema('model')
+ # initialize optimization
+ self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
+ self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
+ self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
+ self.register_best_state('model')
+ self.autocast_dtype = {
+ 'float16': torch.float16, 'bfloat16': torch.bfloat16
+ }[self.cfg.autocast_dtype]
+ self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
+ if self.cfg.fsdp.use:
+ need_scaler = self.cfg.fsdp.param_dtype == 'float16'
+ else:
+ need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
+ if need_scaler:
+ if self.cfg.fsdp.use:
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+ self.scaler = ShardedGradScaler() # type: ignore
+ else:
+ self.scaler = torch.cuda.amp.GradScaler()
+ self.register_stateful('scaler')
+
+ def build_dataloaders(self) -> None:
+ """Instantiate audio dataloaders for each stage."""
+ self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
+
+ def show(self) -> None:
+ """Show the compression model and LM model."""
+ self.logger.info("Compression model:")
+ self.log_model_summary(self.compression_model)
+ self.logger.info("LM model:")
+ self.log_model_summary(self.model)
+
+ def load_state_dict(self, state: dict) -> None:
+ if 'condition_provider' in state:
+ model_state = state['model']
+ condition_provider_state = state.pop('condition_provider')
+ prefix = 'condition_provider.'
+ for key, value in condition_provider_state.items():
+ key = prefix + key
+ assert key not in model_state
+ model_state[key] = value
+ super().load_state_dict(state)
+
+ def load_from_pretrained(self, name: str):
+ # TODO: support native HF versions of MusicGen.
+ lm_pkg = models.loaders.load_lm_model_ckpt(name)
+ state: dict = {
+ 'best_state': {
+ 'model': lm_pkg['best_state'],
+ },
+ }
+ return state
+
+ def _compute_cross_entropy(
+ self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
+ ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
+ """Compute cross entropy between multi-codebook targets and model's logits.
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
+ timesteps are set to 0.
+
+ Args:
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
+ Returns:
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
+ """
+ B, K, T = targets.shape
+ assert logits.shape[:-1] == targets.shape
+ assert mask.shape == targets.shape
+ ce = torch.zeros([], device=targets.device)
+ ce_per_codebook: tp.List[torch.Tensor] = []
+ for k in range(K):
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
+ ce_targets = targets_k[mask_k]
+ ce_logits = logits_k[mask_k]
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
+ ce += q_ce
+ ce_per_codebook.append(q_ce.detach())
+ # average cross entropy across codebooks
+ ce = ce / K
+ return ce, ce_per_codebook
+
+ @torch.no_grad()
+ def _prepare_tokens_and_attributes(
+ self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+ check_synchronization_points: bool = False
+ ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
+ """Prepare input batchs for language model training.
+
+ Args:
+ batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
+ and corresponding metadata as SegmentWithAttributes (with B items).
+ check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
+ Returns:
+ Condition tensors (dict[str, any]): Preprocessed condition attributes.
+ Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
+ with B the batch size, K the number of codebooks, T_s the token timesteps.
+ Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
+ """
+ if self._cached_batch_loader is None or self.current_stage != "train":
+ audio, infos = batch
+ audio = audio.to(self.device)
+ audio_tokens = None
+ assert audio.size(0) == len(infos), (
+ f"Mismatch between number of items in audio batch ({audio.size(0)})",
+ f" and in metadata ({len(infos)})"
+ )
+ else:
+ audio = None
+ # In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
+ infos, = batch # type: ignore
+ assert all([isinstance(info, AudioInfo) for info in infos])
+ assert all([info.audio_tokens is not None for info in infos]) # type: ignore
+ audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore
+ audio_tokens = audio_tokens.long()
+ for info in infos:
+ if isinstance(info, MusicInfo):
+ # Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
+ # then you must be using the chroma cache! otherwise the code will try
+ # to use this segment and fail (by that I mean you will see NaN everywhere).
+ info.self_wav = WavCondition(
+ torch.full([1, info.channels, info.total_frames], float('NaN')),
+ length=torch.tensor([info.n_frames]),
+ sample_rate=[info.sample_rate],
+ path=[info.meta.path],
+ seek_time=[info.seek_time])
+ dataset = get_dataset_from_loader(self.dataloaders['original_train'])
+ assert isinstance(dataset, MusicDataset), type(dataset)
+ if dataset.paraphraser is not None and info.description is not None:
+ # Hackingly reapplying paraphraser when using cache.
+ info.description = dataset.paraphraser.sample_paraphrase(
+ info.meta.path, info.description)
+ # prepare attributes
+ attributes = [info.to_condition_attributes() for info in infos]
+ attributes = self.model.cfg_dropout(attributes)
+ attributes = self.model.att_dropout(attributes)
+ tokenized = self.model.condition_provider.tokenize(attributes)
+
+ # Now we should be synchronization free.
+ if self.device == "cuda" and check_synchronization_points:
+ torch.cuda.set_sync_debug_mode("warn")
+
+ if audio_tokens is None:
+ with torch.no_grad():
+ audio_tokens, scale = self.compression_model.encode(audio)
+ assert scale is None, "Scaled compression model not supported with LM."
+
+ with self.autocast:
+ condition_tensors = self.model.condition_provider(tokenized)
+
+ # create a padding mask to hold valid vs invalid positions
+ padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
+ # replace encodec tokens from padded audio with special_token_id
+ if self.cfg.tokens.padding_with_special_token:
+ audio_tokens = audio_tokens.clone()
+ padding_mask = padding_mask.clone()
+ token_sample_rate = self.compression_model.frame_rate
+ B, K, T_s = audio_tokens.shape
+ for i in range(B):
+ n_samples = infos[i].n_frames
+ audio_sample_rate = infos[i].sample_rate
+ # take the last token generated from actual audio frames (non-padded audio)
+ valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
+ audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
+ padding_mask[i, :, valid_tokens:] = 0
+
+ if self.device == "cuda" and check_synchronization_points:
+ torch.cuda.set_sync_debug_mode("default")
+
+ if self._cached_batch_writer is not None and self.current_stage == 'train':
+ assert self._cached_batch_loader is None
+ assert audio_tokens is not None
+ for info, one_audio_tokens in zip(infos, audio_tokens):
+ assert isinstance(info, AudioInfo)
+ if isinstance(info, MusicInfo):
+ assert not info.joint_embed, "joint_embed and cache not supported yet."
+ info.self_wav = None
+ assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
+ info.audio_tokens = one_audio_tokens.short().cpu()
+ self._cached_batch_writer.save(infos)
+
+ return condition_tensors, audio_tokens, padding_mask
+
+ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
+ """Perform one training or valid step on a given batch."""
+ check_synchronization_points = idx == 1 and self.device == 'cuda'
+
+ condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
+ batch, check_synchronization_points)
+
+ self.deadlock_detect.update('tokens_and_conditions')
+
+ if check_synchronization_points:
+ torch.cuda.set_sync_debug_mode('warn')
+
+ with self.autocast:
+ model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore
+ logits = model_output.logits
+ mask = padding_mask & model_output.mask
+ ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
+ loss = ce
+ self.deadlock_detect.update('loss')
+
+ if check_synchronization_points:
+ torch.cuda.set_sync_debug_mode('default')
+
+ if self.is_training:
+ metrics['lr'] = self.optimizer.param_groups[0]['lr']
+ if self.scaler is not None:
+ loss = self.scaler.scale(loss)
+ self.deadlock_detect.update('scale')
+ if self.cfg.fsdp.use:
+ loss.backward()
+ flashy.distrib.average_tensors(self.model.buffers())
+ elif self.cfg.optim.eager_sync:
+ with flashy.distrib.eager_sync_model(self.model):
+ loss.backward()
+ else:
+ # this should always be slower but can be useful
+ # for weird use cases like multiple backwards.
+ loss.backward()
+ flashy.distrib.sync_model(self.model)
+ self.deadlock_detect.update('backward')
+
+ if self.scaler is not None:
+ self.scaler.unscale_(self.optimizer)
+ if self.cfg.optim.max_norm:
+ if self.cfg.fsdp.use:
+ metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
+ else:
+ metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.cfg.optim.max_norm
+ )
+ if self.scaler is None:
+ self.optimizer.step()
+ else:
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ if self.lr_scheduler:
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
+ self.deadlock_detect.update('optim')
+ if self.scaler is not None:
+ scale = self.scaler.get_scale()
+ metrics['grad_scale'] = scale
+ if not loss.isfinite().all():
+ raise RuntimeError("Model probably diverged.")
+
+ metrics['ce'] = ce
+ metrics['ppl'] = torch.exp(ce)
+ for k, ce_q in enumerate(ce_per_codebook):
+ metrics[f'ce_q{k + 1}'] = ce_q
+ metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
+
+ return metrics
+
+ @torch.no_grad()
+ def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+ gen_duration: float, prompt_duration: tp.Optional[float] = None,
+ remove_prompt: bool = False,
+ **generation_params) -> dict:
+ """Run generate step on a batch of optional audio tensor and corresponding attributes.
+
+ Args:
+ batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
+ use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
+ gen_duration (float): Target audio duration for the generation.
+ prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
+ remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
+ generation_params: Additional generation parameters.
+ Returns:
+ gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
+ and the prompt along with additional information.
+ """
+ bench_start = time.time()
+ audio, meta = batch
+ assert audio.size(0) == len(meta), (
+ f"Mismatch between number of items in audio batch ({audio.size(0)})",
+ f" and in metadata ({len(meta)})"
+ )
+ # prepare attributes
+ attributes = [x.to_condition_attributes() for x in meta]
+ # TODO: Add dropout for chroma?
+
+ # prepare audio prompt
+ if prompt_duration is None:
+ prompt_audio = None
+ else:
+ assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
+ prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
+ prompt_audio = audio[..., :prompt_audio_frames]
+
+ # get audio tokens from compression model
+ if prompt_audio is None or prompt_audio.nelement() == 0:
+ num_samples = len(attributes)
+ prompt_tokens = None
+ else:
+ num_samples = None
+ prompt_audio = prompt_audio.to(self.device)
+ prompt_tokens, scale = self.compression_model.encode(prompt_audio)
+ assert scale is None, "Compression model in MusicGen should not require rescaling."
+
+ # generate by sampling from the LM
+ with self.autocast:
+ total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
+ gen_tokens = self.model.generate(
+ prompt_tokens, attributes, max_gen_len=total_gen_len,
+ num_samples=num_samples, **self.generation_params)
+
+ # generate audio from tokens
+ assert gen_tokens.dim() == 3
+ gen_audio = self.compression_model.decode(gen_tokens, None)
+
+ bench_end = time.time()
+ gen_outputs = {
+ 'rtf': (bench_end - bench_start) / gen_duration,
+ 'ref_audio': audio,
+ 'gen_audio': gen_audio,
+ 'gen_tokens': gen_tokens,
+ 'prompt_audio': prompt_audio,
+ 'prompt_tokens': prompt_tokens,
+ }
+ return gen_outputs
+
+ def generate_audio(self) -> dict:
+ """Audio generation stage."""
+ generate_stage_name = f'{self.current_stage}'
+ sample_manager = SampleManager(self.xp)
+ self.logger.info(f"Generating samples in {sample_manager.base_folder}")
+ loader = self.dataloaders['generate']
+ updates = len(loader)
+ lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+ dataset = get_dataset_from_loader(loader)
+ dataset_duration = dataset.segment_duration
+ assert dataset_duration is not None
+ assert isinstance(dataset, AudioDataset)
+ target_duration = self.cfg.generate.lm.gen_duration
+ prompt_duration = self.cfg.generate.lm.prompt_duration
+ if target_duration is None:
+ target_duration = dataset_duration
+ if prompt_duration is None:
+ prompt_duration = dataset_duration / 4
+ assert prompt_duration < dataset_duration, (
+ f"Specified prompt duration ({prompt_duration}s) is longer",
+ f" than reference audio duration ({dataset_duration}s)"
+ )
+
+ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
+ hydrated_conditions = []
+ for sample in [x.to_condition_attributes() for x in meta]:
+ cond_dict = {}
+ for cond_type in sample.__annotations__.keys():
+ for cond_key, cond_val in getattr(sample, cond_type).items():
+ if cond_key not in self.model.condition_provider.conditioners.keys():
+ continue
+ if is_jsonable(cond_val):
+ cond_dict[cond_key] = cond_val
+ elif isinstance(cond_val, WavCondition):
+ cond_dict[cond_key] = cond_val.path
+ elif isinstance(cond_val, JointEmbedCondition):
+ cond_dict[cond_key] = cond_val.text # only support text at inference for now
+ else:
+ # if we reached this point, it is not clear how to log the condition
+ # so we just log the type.
+ cond_dict[cond_key] = str(type(cond_val))
+ continue
+ hydrated_conditions.append(cond_dict)
+ return hydrated_conditions
+
+ metrics: dict = {}
+ average = flashy.averager()
+ for batch in lp:
+ audio, meta = batch
+ # metadata for sample manager
+ hydrated_conditions = get_hydrated_conditions(meta)
+ sample_generation_params = {
+ **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
+ **self.generation_params
+ }
+ if self.cfg.generate.lm.unprompted_samples:
+ if self.cfg.generate.lm.gen_gt_samples:
+ # get the ground truth instead of generation
+ self.logger.warn(
+ "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
+ gen_unprompted_audio = audio
+ rtf = 1.
+ else:
+ gen_unprompted_outputs = self.run_generate_step(
+ batch, gen_duration=target_duration, prompt_duration=prompt_duration,
+ **self.generation_params)
+ gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
+ rtf = gen_unprompted_outputs['rtf']
+ sample_manager.add_samples(
+ gen_unprompted_audio, self.epoch, hydrated_conditions,
+ ground_truth_wavs=audio, generation_args=sample_generation_params)
+
+ if self.cfg.generate.lm.prompted_samples:
+ gen_outputs = self.run_generate_step(
+ batch, gen_duration=target_duration, prompt_duration=prompt_duration,
+ **self.generation_params)
+ gen_audio = gen_outputs['gen_audio'].cpu()
+ prompt_audio = gen_outputs['prompt_audio'].cpu()
+ sample_manager.add_samples(
+ gen_audio, self.epoch, hydrated_conditions,
+ prompt_wavs=prompt_audio, ground_truth_wavs=audio,
+ generation_args=sample_generation_params)
+
+ metrics['rtf'] = rtf
+ metrics = average(metrics)
+
+ flashy.distrib.barrier()
+ return metrics
+
+ def generate(self) -> dict:
+ """Generate stage."""
+ self.model.eval()
+ with torch.no_grad():
+ return self.generate_audio()
+
+ def run_epoch(self):
+ if self.cfg.cache.write:
+ if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
+ return
+ super().run_epoch()
+
+ def train(self):
+ """Train stage.
+ """
+ if self._cached_batch_writer is not None:
+ self._cached_batch_writer.start_epoch(self.epoch)
+ if self._cached_batch_loader is None:
+ dataset = get_dataset_from_loader(self.dataloaders['train'])
+ assert isinstance(dataset, AudioDataset)
+ dataset.current_epoch = self.epoch
+ else:
+ self._cached_batch_loader.start_epoch(self.epoch)
+ return super().train()
+
+ def evaluate_audio_generation(self) -> dict:
+ """Evaluate audio generation with off-the-shelf metrics."""
+ evaluate_stage_name = f'{self.current_stage}_generation'
+ # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
+ fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
+ kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
+ text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
+ chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
+ should_run_eval = False
+ eval_chroma_wavs: tp.Optional[torch.Tensor] = None
+ if self.cfg.evaluate.metrics.fad:
+ fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
+ should_run_eval = True
+ if self.cfg.evaluate.metrics.kld:
+ kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
+ should_run_eval = True
+ if self.cfg.evaluate.metrics.text_consistency:
+ text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
+ should_run_eval = True
+ if self.cfg.evaluate.metrics.chroma_cosine:
+ chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
+ # if we have predefind wavs for chroma we should purge them for computing the cosine metric
+ has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
+ self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
+ if has_predefined_eval_chromas:
+ warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
+ 'Resetting eval chromas to None for evaluation.')
+ eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore
+ self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore
+ should_run_eval = True
+
+ def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
+ audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
+ compressed_audio = self.compression_model.decode(audio_tokens, scale)
+ return compressed_audio[..., :audio.shape[-1]]
+
+ metrics: dict = {}
+ if should_run_eval:
+ loader = self.dataloaders['evaluate']
+ updates = len(loader)
+ lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+ average = flashy.averager()
+ dataset = get_dataset_from_loader(loader)
+ assert isinstance(dataset, AudioDataset)
+ self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
+
+ for idx, batch in enumerate(lp):
+ audio, meta = batch
+ assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
+
+ target_duration = audio.shape[-1] / self.cfg.sample_rate
+ if self.cfg.evaluate.fixed_generation_duration:
+ target_duration = self.cfg.evaluate.fixed_generation_duration
+
+ gen_outputs = self.run_generate_step(
+ batch, gen_duration=target_duration,
+ **self.generation_params
+ )
+ y_pred = gen_outputs['gen_audio'].detach()
+ y_pred = y_pred[..., :audio.shape[-1]]
+
+ normalize_kwargs = dict(self.cfg.generate.audio)
+ normalize_kwargs.pop('format', None)
+ y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
+ y = audio.cpu() # should already be on CPU but just in case
+ sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding
+ sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples
+ audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
+
+ if fad is not None:
+ if self.cfg.metrics.fad.use_gt:
+ y_pred = get_compressed_audio(y).cpu()
+ fad.update(y_pred, y, sizes, sample_rates, audio_stems)
+ if kldiv is not None:
+ if self.cfg.metrics.kld.use_gt:
+ y_pred = get_compressed_audio(y).cpu()
+ kldiv.update(y_pred, y, sizes, sample_rates)
+ if text_consistency is not None:
+ texts = [m.description for m in meta]
+ if self.cfg.metrics.text_consistency.use_gt:
+ y_pred = y
+ text_consistency.update(y_pred, texts, sizes, sample_rates)
+ if chroma_cosine is not None:
+ if self.cfg.metrics.chroma_cosine.use_gt:
+ y_pred = get_compressed_audio(y).cpu()
+ chroma_cosine.update(y_pred, y, sizes, sample_rates)
+ # restore chroma conditioner's eval chroma wavs
+ if eval_chroma_wavs is not None:
+ self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
+
+ flashy.distrib.barrier()
+ if fad is not None:
+ metrics['fad'] = fad.compute()
+ if kldiv is not None:
+ kld_metrics = kldiv.compute()
+ metrics.update(kld_metrics)
+ if text_consistency is not None:
+ metrics['text_consistency'] = text_consistency.compute()
+ if chroma_cosine is not None:
+ metrics['chroma_cosine'] = chroma_cosine.compute()
+ metrics = average(metrics)
+ metrics = flashy.distrib.average_metrics(metrics, len(loader))
+
+ return metrics
+
+ def evaluate(self) -> dict:
+ """Evaluate stage."""
+ self.model.eval()
+ with torch.no_grad():
+ metrics: dict = {}
+ if self.cfg.evaluate.metrics.base:
+ metrics.update(self.common_train_valid('evaluate'))
+ gen_metrics = self.evaluate_audio_generation()
+ return {**metrics, **gen_metrics}
diff --git a/audiocraft/train.py b/audiocraft/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..22dd117830bb403829d0a60b1b95e120d1e6978b
--- /dev/null
+++ b/audiocraft/train.py
@@ -0,0 +1,157 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Entry point for dora to launch solvers for running training loops.
+See more info on how to use dora: https://github.com/facebookresearch/dora
+"""
+
+import logging
+import multiprocessing
+import os
+import sys
+import typing as tp
+
+from dora import git_save, hydra_main, XP
+import flashy
+import hydra
+import omegaconf
+
+from .environment import AudioCraftEnvironment
+from .utils.cluster import get_slurm_parameters
+
+logger = logging.getLogger(__name__)
+
+
+def resolve_config_dset_paths(cfg):
+ """Enable Dora to load manifest from git clone repository."""
+ # manifest files for the different splits
+ for key, value in cfg.datasource.items():
+ if isinstance(value, str):
+ cfg.datasource[key] = git_save.to_absolute_path(value)
+
+
+def get_solver(cfg):
+ from . import solvers
+ # Convert batch size to batch size for each GPU
+ assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
+ cfg.dataset.batch_size //= flashy.distrib.world_size()
+ for split in ['train', 'valid', 'evaluate', 'generate']:
+ if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
+ assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
+ cfg.dataset[split].batch_size //= flashy.distrib.world_size()
+ resolve_config_dset_paths(cfg)
+ solver = solvers.get_solver(cfg)
+ return solver
+
+
+def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+ restore: bool = True, load_best: bool = True,
+ ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
+ """Given a XP, return the Solver object.
+
+ Args:
+ xp (XP): Dora experiment for which to retrieve the solver.
+ override_cfg (dict or None): If not None, should be a dict used to
+ override some values in the config of `xp`. This will not impact
+ the XP signature or folder. The format is different
+ than the one used in Dora grids, nested keys should actually be nested dicts,
+ not flattened, e.g. `{'optim': {'batch_size': 32}}`.
+ restore (bool): If `True` (the default), restore state from the last checkpoint.
+ load_best (bool): If `True` (the default), load the best state from the checkpoint.
+ ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
+ disable_fsdp (bool): if True, disables FSDP entirely. This will
+ also automatically skip loading the EMA. For solver specific
+ state sources, like the optimizer, you might want to
+ use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
+ """
+ logger.info(f"Loading solver from XP {xp.sig}. "
+ f"Overrides used: {xp.argv}")
+ cfg = xp.cfg
+ if override_cfg is not None:
+ cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
+ if disable_fsdp and cfg.fsdp.use:
+ cfg.fsdp.use = False
+ assert load_best is True
+ # ignoring some keys that were FSDP sharded like model, ema, and best_state.
+ # fsdp_best_state will be used in that case. When using a specific solver,
+ # one is responsible for adding the relevant keys, e.g. 'optimizer'.
+ # We could make something to automatically register those inside the solver, but that
+ # seem overkill at this point.
+ ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']
+
+ try:
+ with xp.enter():
+ solver = get_solver(cfg)
+ if restore:
+ solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
+ return solver
+ finally:
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+
+
+def get_solver_from_sig(sig: str, *args, **kwargs):
+ """Return Solver object from Dora signature, i.e. to play with it from a notebook.
+ See `get_solver_from_xp` for more information.
+ """
+ xp = main.get_xp_from_sig(sig)
+ return get_solver_from_xp(xp, *args, **kwargs)
+
+
+def init_seed_and_system(cfg):
+ import numpy as np
+ import torch
+ import random
+ from audiocraft.modules.transformer import set_efficient_attention_backend
+
+ multiprocessing.set_start_method(cfg.mp_start_method)
+ logger.debug('Setting mp start method to %s', cfg.mp_start_method)
+ random.seed(cfg.seed)
+ np.random.seed(cfg.seed)
+ # torch also initialize cuda seed if available
+ torch.manual_seed(cfg.seed)
+ torch.set_num_threads(cfg.num_threads)
+ os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
+ os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
+ logger.debug('Setting num threads to %d', cfg.num_threads)
+ set_efficient_attention_backend(cfg.efficient_attention_backend)
+ logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
+
+
+@hydra_main(config_path='../config', config_name='config', version_base='1.1')
+def main(cfg):
+ init_seed_and_system(cfg)
+
+ # Setup logging both to XP specific folder, and to stderr.
+ log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
+ flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
+ # Initialize distributed training, no need to specify anything when using Dora.
+ flashy.distrib.init()
+ solver = get_solver(cfg)
+ if cfg.show:
+ solver.show()
+ return
+
+ if cfg.execute_only:
+ assert cfg.execute_inplace or cfg.continue_from is not None, \
+ "Please explicitly specify the checkpoint to continue from with continue_from= " + \
+ "when running with execute_only or set execute_inplace to True."
+ solver.restore(replay_metrics=False) # load checkpoint
+ solver.run_one_stage(cfg.execute_only)
+ return
+
+ return solver.run()
+
+
+main.dora.dir = AudioCraftEnvironment.get_dora_dir()
+main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)
+
+if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
+ print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
+ main.dora.shared = None
+
+if __name__ == '__main__':
+ main()
diff --git a/audiocraft/utils/__init__.py b/audiocraft/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75e25a0212f98e4a18d97c86c6cda225636a3215
--- /dev/null
+++ b/audiocraft/utils/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utilities."""
diff --git a/audiocraft/utils/__pycache__/__init__.cpython-310.pyc b/audiocraft/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2fdf8739bcfa1d4d1e0a442529b6513a3ca165d
Binary files /dev/null and b/audiocraft/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/autocast.cpython-310.pyc b/audiocraft/utils/__pycache__/autocast.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab01a2a0693e3d7dcd1fd8a5af32dd09c9bddfbc
Binary files /dev/null and b/audiocraft/utils/__pycache__/autocast.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/best_state.cpython-310.pyc b/audiocraft/utils/__pycache__/best_state.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82c73f7d83af2158e3894987eb334ac8c02aac60
Binary files /dev/null and b/audiocraft/utils/__pycache__/best_state.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/cache.cpython-310.pyc b/audiocraft/utils/__pycache__/cache.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c76bbdfb0b94c2d8c9db263b3da298174610293
Binary files /dev/null and b/audiocraft/utils/__pycache__/cache.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/checkpoint.cpython-310.pyc b/audiocraft/utils/__pycache__/checkpoint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9897c55611ff26e6ea37f80a822dabb3528fc55
Binary files /dev/null and b/audiocraft/utils/__pycache__/checkpoint.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/cluster.cpython-310.pyc b/audiocraft/utils/__pycache__/cluster.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44dc18deeac308d0d0213e1750a3e4399ff3835a
Binary files /dev/null and b/audiocraft/utils/__pycache__/cluster.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/deadlock.cpython-310.pyc b/audiocraft/utils/__pycache__/deadlock.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ff3d1d194727b54dd73691fdd8df691ea10f740
Binary files /dev/null and b/audiocraft/utils/__pycache__/deadlock.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/profiler.cpython-310.pyc b/audiocraft/utils/__pycache__/profiler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25947b53c6783c5166249a03821990a30f0943fd
Binary files /dev/null and b/audiocraft/utils/__pycache__/profiler.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/ui.cpython-310.pyc b/audiocraft/utils/__pycache__/ui.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afea185c9fca2558ffecfc3ea9f9ca905da6717e
Binary files /dev/null and b/audiocraft/utils/__pycache__/ui.cpython-310.pyc differ
diff --git a/audiocraft/utils/__pycache__/utils.cpython-310.pyc b/audiocraft/utils/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a261ed941ec215bfa209d8809597a1bae80e9b8
Binary files /dev/null and b/audiocraft/utils/__pycache__/utils.cpython-310.pyc differ
diff --git a/audiocraft/utils/autocast.py b/audiocraft/utils/autocast.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed644843bb37cf8a92a20fbd51d6cebaa43b9a08
--- /dev/null
+++ b/audiocraft/utils/autocast.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+class TorchAutocast:
+ """TorchAutocast utility class.
+ Allows you to enable and disable autocast. This is specially useful
+ when dealing with different architectures and clusters with different
+ levels of support.
+
+ Args:
+ enabled (bool): Whether to enable torch.autocast or not.
+ args: Additional args for torch.autocast.
+ kwargs: Additional kwargs for torch.autocast
+ """
+ def __init__(self, enabled: bool, *args, **kwargs):
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
+
+ def __enter__(self):
+ if self.autocast is None:
+ return
+ try:
+ self.autocast.__enter__()
+ except RuntimeError:
+ device = self.autocast.device
+ dtype = self.autocast.fast_dtype
+ raise RuntimeError(
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
+ )
+
+ def __exit__(self, *args, **kwargs):
+ if self.autocast is None:
+ return
+ self.autocast.__exit__(*args, **kwargs)
diff --git a/audiocraft/utils/best_state.py b/audiocraft/utils/best_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ad551432ad5cb0f83278b5d2100f9aa287958b
--- /dev/null
+++ b/audiocraft/utils/best_state.py
@@ -0,0 +1,81 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+import typing as tp
+
+import flashy
+import torch
+
+from ..optim import ModuleDictEMA
+from .utils import copy_state
+
+
+logger = logging.getLogger(__name__)
+
+
+class BestStateDictManager(flashy.state.StateDictSource):
+ """BestStateDictManager maintains a copy of best state_dict() for registered sources.
+
+ BestStateDictManager has two main attributes:
+ states (dict): State dict of the registered StateDictSource.
+ param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
+
+ When registering new sources, the BestStateDictManager will ensure two conflicting sources between
+ ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
+ what to consider for best state.
+
+ Args:
+ device (torch.device or str): Device on which we keep the copy.
+ dtype (torch.dtype): Data type for the state parameters.
+ """
+ def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
+ dtype: tp.Optional[torch.dtype] = None):
+ self.device = device
+ self.states: dict = {}
+ self.param_ids: dict = defaultdict(dict)
+ self.dtype = dtype
+
+ def _get_parameter_ids(self, state_dict):
+ return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
+
+ def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
+ for registered_name, registered_param_ids in self.param_ids.items():
+ if registered_name != name:
+ overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
+ assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
+ f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
+
+ def update(self, name: str, source: flashy.state.StateDictSource):
+ if name not in self.states:
+ raise ValueError(f"{name} missing from registered states.")
+ self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+ def register(self, name: str, source: flashy.state.StateDictSource):
+ if name in self.states:
+ raise ValueError(f"{name} already present in states.")
+ # Registering parameter ids for EMA and non-EMA states allows us to check that
+ # there is no overlap that would create ambiguity about how to handle the best state
+ param_ids = self._get_parameter_ids(source.state_dict())
+ if isinstance(source, ModuleDictEMA):
+ logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
+ self._validate_no_parameter_ids_overlap(name, param_ids)
+ self.param_ids[name] = param_ids
+ else:
+ logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
+ self._validate_no_parameter_ids_overlap('base', param_ids)
+ self.param_ids['base'].update(param_ids)
+ # Register state
+ self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+ def state_dict(self) -> flashy.state.StateDict:
+ return self.states
+
+ def load_state_dict(self, state: flashy.state.StateDict):
+ for name, sub_state in state.items():
+ for k, v in sub_state.items():
+ self.states[name][k].copy_(v)
diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fccc0acda4027b0bd36756a29b2d5cee318294d
--- /dev/null
+++ b/audiocraft/utils/cache.py
@@ -0,0 +1,323 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ThreadPoolExecutor
+from collections import deque
+from functools import partial
+from hashlib import sha1
+import logging
+from pathlib import Path
+import sys
+import typing as tp
+import zipfile
+
+import flashy
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
+ """Utility function for the EmbeddingCache, returning the full embedding without any chunking.
+ This method can be used in case there is no need in extracting a chunk of the full embedding
+ read from the cache.
+
+ Args:
+ full_embed (torch.Tensor): The full embedding.
+ x (any): Batch object from which the full embedding is derived.
+ idx (torch.Tensor): Index of object to consider in the batch object.
+ Returns:
+ full_embed (torch.Tensor): The full embedding
+ """
+ return full_embed.to(device)
+
+
+class EmbeddingCache:
+ """Cache around embeddings computation for faster execution.
+ The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
+ to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
+ using a user-provided function. When the cache is warm (all embeddings are pre-computed),
+ the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
+ Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
+ and synchronization points in the forward calls.
+
+ Args:
+ cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
+ device (str or torch.device): Device on which the embedding is returned.
+ compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
+ the embedding from a given object and path. This user provided function can compute the
+ embedding from the provided object or using the provided path as entry point. The last parameter
+ specify the index corresponding to the current embedding in the object that can represent batch metadata.
+ extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
+ the desired embedding chunk from the full embedding loaded from the cache. The last parameter
+ specify the index corresponding to the current embedding in the object that can represent batch metadata.
+ If not specified, will return the full embedding unmodified.
+ """
+ def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, torch.device],
+ compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
+ extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
+ self.cache_path = Path(cache_path)
+ self.device = device
+ self._compute_embed_fn = compute_embed_fn
+ self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
+ if extract_embed_fn is not None:
+ self._extract_embed_fn = extract_embed_fn
+ else:
+ self._extract_embed_fn = partial(get_full_embed, device=device)
+ if self.cache_path is not None:
+ self.cache_path.mkdir(exist_ok=True, parents=True)
+ logger.info(f"Cache instantiated at: {self.cache_path}")
+ self.pool = ThreadPoolExecutor(8)
+ self.pool.__enter__()
+ self._current_batch_cache: dict = {}
+ self._memory_cache: dict = {}
+
+ def _get_cache_path(self, path: tp.Union[Path, str]):
+ """Get cache path for the given file path."""
+ sig = sha1(str(path).encode()).hexdigest()
+ return self.cache_path / sig
+
+ @staticmethod
+ def _get_full_embed_from_cache(cache: Path):
+ """Loads full pre-computed embedding from the cache."""
+ try:
+ embed = torch.load(cache, 'cpu')
+ except Exception as exc:
+ logger.error("Error loading %s: %r", cache, exc)
+ embed = None
+ return embed
+
+ def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
+ """Get embedding from cache, computing and storing it to cache if not already cached.
+ The EmbeddingCache first tries to load the embedding from the in-memory cache
+ containing the pre-computed chunks populated through `populate_embed_cache`.
+ If not found, the full embedding is computed and stored on disk to be later accessed
+ to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
+
+ Args:
+ paths (list[Path or str]): List of paths from where the embeddings can be loaded.
+ x (any): Object from which the embedding is extracted.
+ """
+ embeds = []
+ for idx, path in enumerate(paths):
+ cache = self._get_cache_path(path)
+ if cache in self._current_batch_cache:
+ embed = self._current_batch_cache[cache]
+ else:
+ full_embed = self._compute_embed_fn(path, x, idx)
+ try:
+ with flashy.utils.write_and_rename(cache, pid=True) as f:
+ torch.save(full_embed.cpu(), f)
+ except Exception as exc:
+ logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
+ else:
+ logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
+ embed = self._extract_embed_fn(full_embed, x, idx)
+ embeds.append(embed)
+ embed = torch.stack(embeds, dim=0)
+ return embed
+
+ def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
+ """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
+ The in-memory caches consist in a cache for the full embedding and another cache for the
+ final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
+ and reduce the IO footprint and synchronization points during forward passes.
+
+ Args:
+ paths (list[Path]): List of paths from where the embeddings can be loaded.
+ x (any): Object from which the embedding is extracted.
+ """
+ self._current_batch_cache.clear()
+ if self.cache_path is not None:
+ futures: list = []
+ for path in paths:
+ assert path is not None, "Path is required for computation from cache"
+ cache = self._get_cache_path(path)
+ if cache in self._memory_cache or not cache.exists():
+ futures.append(None)
+ else:
+ futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
+ for idx, (path, future) in enumerate(zip(paths, futures)):
+ assert path is not None
+ cache = self._get_cache_path(path)
+ full_embed = None
+ if future is None:
+ if cache in self._memory_cache:
+ full_embed = self._memory_cache[cache]
+ else:
+ full_embed = future.result()
+ if full_embed is not None:
+ self._memory_cache[cache] = full_embed
+ full_embed = full_embed.to(self.device)
+ if full_embed is not None:
+ embed = self._extract_embed_fn(full_embed, x, idx)
+ self._current_batch_cache[cache] = embed
+
+
+class CachedBatchWriter:
+ """Write pre computed caches for mini batches. This can
+ make loading a lot more efficient depending on your filesystem.
+
+ Args:
+ cache_folder (Path): folder in which the cached minibatches
+ will be stored.
+
+ Inside cache folder, the structure is the following:
+ `epoch_number / update_number.zip`
+ And the zip file contains one entry per batch item.
+
+ It is possible to use the cache with a batch size smaller than
+ created with but obviously not larger. Make sure to call the
+ `start_epoch(epoch)` method for indicating changes of epochs.
+
+ See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
+ for an example of how to warmup the cache.
+ """
+ def __init__(self, cache_folder: Path):
+ self.cache_folder = cache_folder
+ self._current_epoch: tp.Optional[int] = None
+ self._current_index = 0
+
+ def start_epoch(self, epoch: int):
+ """Call at the beginning of each epoch.
+ """
+ self._current_epoch = epoch
+ self._current_index = 0
+ self._zip_path.parent.mkdir(exist_ok=True, parents=True)
+
+ @staticmethod
+ def _get_zip_path(cache_folder: Path, epoch: int, index: int):
+ return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
+
+ @property
+ def _zip_path(self):
+ assert self._current_epoch is not None
+ return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
+
+ def save(self, *content):
+ """Save one mini batch. This function is distributed-aware
+ and will automatically merge all the items from the different
+ workers.
+ """
+ all_contents = []
+ for rank in range(flashy.distrib.world_size()):
+ their_content = flashy.distrib.broadcast_object(content, src=rank)
+ all_contents.append(their_content)
+
+ if flashy.distrib.is_rank_zero():
+ idx = 0
+ with flashy.utils.write_and_rename(self._zip_path) as tmp:
+ with zipfile.ZipFile(tmp, 'w') as zf:
+ for content in all_contents:
+ for vals in zip(*content):
+ with zf.open(f'{idx}', 'w') as f: # type: ignore
+ torch.save(vals, f)
+ idx += 1
+ flashy.distrib.barrier()
+ self._current_index += 1
+
+
+class CachedBatchLoader:
+ """Loader for cached mini-batches dumped with `CachedBatchWriter`.
+
+ Args:
+ cache_folder (Path): folder in which the cached minibatches are stored.
+ batch_size (int): batch size (per GPU) expected.
+ num_workers (int): number of workers to use for loading.
+ min_length (int): minimum expected length for each epoch. If some
+ mini-batches are missing, and error is raised.
+
+ This is iterable just like a regular DataLoader.
+ """
+
+ def __init__(self, cache_folder: Path, batch_size: int,
+ num_workers: int = 10, min_length: int = 1):
+ self.cache_folder = cache_folder
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.min_length = min_length
+ self._current_epoch: tp.Optional[int] = None
+ self.sampler = None # for compatibility with the regular DataLoader
+
+ def __len__(self):
+ path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
+ return len([p for p in path.iterdir() if p.suffix == ".zip"])
+
+ def start_epoch(self, epoch: int):
+ """Call at the beginning of each epoch.
+ """
+ self._current_epoch = epoch
+
+ def _zip_path(self, index: int):
+ assert self._current_epoch is not None
+ return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
+
+ def _load_one(self, index: int):
+ zip_path = self._zip_path(index)
+ if not zip_path.exists():
+ if index < self.min_length:
+ raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
+
+ return None
+ mode = "rb" if sys.version_info >= (3, 9) else "r"
+ try:
+ with zipfile.ZipFile(zip_path, 'r') as zf:
+ rank = flashy.distrib.rank()
+ world_size = flashy.distrib.world_size()
+ root = zipfile.Path(zf)
+ items = list(root.iterdir())
+ total_batch_size = self.batch_size * world_size
+ if len(items) < total_batch_size:
+ raise RuntimeError(
+ f"The cache can handle a max batch size of {len(items)}, "
+ f"but {total_batch_size} is needed.")
+ start = rank * self.batch_size
+ items = items[start: start + self.batch_size]
+ assert len(items) == self.batch_size
+ entries = []
+ entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
+ transposed = zip(*entries)
+ out = []
+ for part in transposed:
+ assert len(part) > 0
+ if isinstance(part[0], torch.Tensor):
+ out.append(torch.stack(part))
+ else:
+ out.append(part)
+ return out
+ except Exception:
+ logger.error("Error when reading zip path %s", zip_path)
+ raise
+
+ def __iter__(self):
+ """This will yields tuples, exactly as provided to the
+ `CachedBatchWriter.save` method.
+ """
+ pool = ThreadPoolExecutor(self.num_workers)
+ next_index = 0
+ queue = deque()
+
+ def _get_next():
+ nonlocal next_index
+ r = queue.popleft().result()
+ if r is None:
+ return None
+ else:
+ queue.append(pool.submit(self._load_one, next_index))
+ next_index += 1
+ return r
+
+ with pool:
+ # fill the buffer of fetching jobs.
+ for _ in range(2 * self.num_workers):
+ queue.append(pool.submit(self._load_one, next_index))
+ next_index += 1
+ while True:
+ batch = _get_next()
+ if batch is None:
+ return
+ yield batch
diff --git a/audiocraft/utils/checkpoint.py b/audiocraft/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f871837e09c5cc7832b85b0d80b84f59e87ca0
--- /dev/null
+++ b/audiocraft/utils/checkpoint.py
@@ -0,0 +1,161 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import logging
+from pathlib import Path
+import re
+import typing as tp
+
+import flashy
+import torch
+
+from ..environment import AudioCraftEnvironment
+
+
+logger = logging.getLogger(__name__)
+
+
+class CheckpointSource(Enum):
+ CURRENT_XP = "current_xp"
+ PRETRAINED = "pretrained"
+ OTHER = "other"
+
+
+def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
+ """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
+ `checkpoint_.th(.)`. By convention, name is expected to be empty for last checkpoint,
+ 'best' for the best checkpoint or the epoch number.
+
+ Args:
+ name (str, optional): Name suffix for the checkpoint file stem.
+ rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
+ use_fsdp (bool): Whether the calling solver relies on FSDP.
+ Returns:
+ str: The checkpoint name.
+ """
+ suffix = ''
+ if rank is None:
+ rank = flashy.distrib.rank()
+ if rank > 0 and use_fsdp:
+ suffix = '.' + str(rank)
+ name_part = ''
+ if name is not None:
+ name_part = f'_{name}'
+ return f'checkpoint{name_part}.th{suffix}'
+
+
+def is_sharded_checkpoint(path: Path) -> bool:
+ """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
+ return re.search(r'\.th\.\d+$', path.name) is not None
+
+
+def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
+ use_fsdp: bool = False) -> tp.Optional[Path]:
+ """Resolve a given checkpoint path for a provided dora sig or path.
+
+ Args:
+ sig_or_path (Path or str): Checkpoint path or dora signature.
+ name (str, optional): Name suffix for the checkpoint file stem.
+ rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
+ use_fsdp (bool): Whether the calling solver relies on FSDP.
+ Returns:
+ Path, optional: Resolved checkpoint path, if it exists.
+ """
+ from audiocraft import train
+ xps_root = train.main.dora.dir / 'xps'
+ sig_or_path = str(sig_or_path)
+ if sig_or_path.startswith('//sig/'):
+ sig = sig_or_path[len('//sig/'):]
+ path = xps_root / sig
+ else:
+ path = Path(sig_or_path)
+ path = AudioCraftEnvironment.resolve_reference_path(path)
+
+ if path.is_dir():
+ path = path / checkpoint_name(name, use_fsdp=use_fsdp)
+
+ if path.exists():
+ return path
+ else:
+ return None
+
+
+def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
+ """Load state from checkpoints at the specified checkpoint path."""
+ if is_sharded:
+ rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
+ if rank0_checkpoint_path.exists():
+ check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
+ state = torch.load(checkpoint_path, 'cpu')
+ logger.info("Checkpoint loaded from %s", checkpoint_path)
+ return state
+
+
+def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
+ """Save state to disk to the specified checkpoint_path."""
+ _safe_save_checkpoint(state, checkpoint_path, is_sharded)
+ logger.info("Checkpoint saved to %s", checkpoint_path)
+
+
+def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
+ """Flush checkpoints to only keep last N checkpoints."""
+ if keep_last is None or keep_last <= 0:
+ return
+ checkpoint_dir = checkpoint_path.parent
+ suffix = ''
+ if flashy.distrib.rank() > 0:
+ suffix = f'.{flashy.distrib.rank()}'
+ checkpoint_files_with_epoch = []
+ for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
+ epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
+ if epoch_part.isdigit():
+ checkpoint_files_with_epoch.append((path, int(epoch_part)))
+ checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
+ total_to_flush = max(0, len(checkpoint_files) - keep_last)
+ files_to_flush = checkpoint_files[:total_to_flush]
+ for path in files_to_flush:
+ logger.debug("Removing checkpoint: %s", str(path))
+ path.unlink(missing_ok=True)
+
+
+def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
+ """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
+ # Finish the work of a previous run that got interrupted while dumping.
+ old_path = Path(str(checkpoint_path) + '.old')
+ if old_path.exists():
+ raise RuntimeError(
+ f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
+ token = Path(str(rank0_checkpoint_path) + '.tmp.done')
+ tmp_path = Path(str(checkpoint_path) + '.tmp')
+ if token.exists():
+ if tmp_path.exists():
+ tmp_path.rename(checkpoint_path)
+ flashy.distrib.barrier()
+ if flashy.distrib.is_rank_zero() and token.exists():
+ token.unlink()
+
+
+def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
+ """Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
+ def _barrier_if_sharded():
+ if is_sharded:
+ flashy.distrib.barrier()
+
+ if flashy.distrib.is_rank_zero():
+ token = Path(str(checkpoint_path) + '.tmp.done')
+ if token.exists():
+ token.unlink()
+ _barrier_if_sharded()
+ with flashy.utils.write_and_rename(checkpoint_path) as f:
+ torch.save(state, f)
+ _barrier_if_sharded()
+ if flashy.distrib.is_rank_zero():
+ token.touch()
+ _barrier_if_sharded()
+ _barrier_if_sharded()
+ if flashy.distrib.rank() == 0:
+ token.unlink()
diff --git a/audiocraft/utils/cluster.py b/audiocraft/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3380d031739d473fb859c76b9c25350f47fa77e8
--- /dev/null
+++ b/audiocraft/utils/cluster.py
@@ -0,0 +1,75 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility functions for SLURM configuration and cluster settings.
+"""
+
+from enum import Enum
+import os
+import socket
+import typing as tp
+
+import omegaconf
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+ LOCAL_DARWIN = "darwin"
+ DEFAULT = "default" # used for any other cluster.
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ fqdn = socket.getfqdn()
+ if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
+ return ClusterType.AWS
+
+ if fqdn.endswith(".fair"):
+ return ClusterType.FAIR
+
+ if fqdn.endswith(".facebook.com"):
+ return ClusterType.RSC
+
+ if uname.sysname == "Darwin":
+ return ClusterType.LOCAL_DARWIN
+
+ return ClusterType.DEFAULT
+
+
+def get_cluster_type(
+ cluster_type: tp.Optional[ClusterType] = None,
+) -> tp.Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_slurm_parameters(
+ cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
+) -> omegaconf.DictConfig:
+ """Update SLURM parameters in configuration based on cluster type.
+ If the cluster type is not specify, it infers it automatically.
+ """
+ from ..environment import AudioCraftEnvironment
+ cluster_type = get_cluster_type(cluster_type)
+ # apply cluster-specific adjustments
+ if cluster_type == ClusterType.AWS:
+ cfg["mem_per_gpu"] = None
+ cfg["constraint"] = None
+ cfg["setup"] = []
+ elif cluster_type == ClusterType.RSC:
+ cfg["mem_per_gpu"] = None
+ cfg["setup"] = []
+ cfg["constraint"] = None
+ cfg["partition"] = "learn"
+ slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
+ if slurm_exclude is not None:
+ cfg["exclude"] = slurm_exclude
+ return cfg
diff --git a/audiocraft/utils/deadlock.py b/audiocraft/utils/deadlock.py
new file mode 100644
index 0000000000000000000000000000000000000000..8abd1bbeea5909e664cf816c020bd7c37effdb66
--- /dev/null
+++ b/audiocraft/utils/deadlock.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+from queue import Queue, Empty
+import signal
+import sys
+import threading
+import traceback
+
+logger = logging.getLogger(__name__)
+
+
+class DeadlockDetect:
+ def __init__(self, use: bool = False, timeout: float = 120.):
+ self.use = use
+ self.timeout = timeout
+ self._queue: Queue = Queue()
+
+ def update(self, stage: str):
+ if self.use:
+ self._queue.put(stage)
+
+ def __enter__(self):
+ if self.use:
+ self._thread = threading.Thread(target=self._detector_thread)
+ self._thread.start()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.use:
+ self._queue.put(None)
+ self._thread.join()
+
+ def _detector_thread(self):
+ logger.debug("Deadlock detector started")
+ last_stage = "init"
+ while True:
+ try:
+ stage = self._queue.get(timeout=self.timeout)
+ except Empty:
+ break
+ if stage is None:
+ logger.debug("Exiting deadlock detector thread")
+ return
+ else:
+ last_stage = stage
+ logger.error("Deadlock detector timed out, last stage was %s", last_stage)
+ for th in threading.enumerate():
+ print(th, file=sys.stderr)
+ traceback.print_stack(sys._current_frames()[th.ident])
+ print(file=sys.stderr)
+ sys.stdout.flush()
+ sys.stderr.flush()
+ os.kill(os.getpid(), signal.SIGKILL)
diff --git a/audiocraft/utils/export.py b/audiocraft/utils/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b214017d9ac23934b67e8254a96131cefa6501
--- /dev/null
+++ b/audiocraft/utils/export.py
@@ -0,0 +1,79 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility to export a training checkpoint to a lightweight release checkpoint.
+"""
+
+from pathlib import Path
+import typing as tp
+
+from omegaconf import OmegaConf
+import torch
+
+from audiocraft import __version__
+
+
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
+ """Export only the best state from the given EnCodec checkpoint. This
+ should be used if you trained your own EnCodec model.
+ """
+ pkg = torch.load(checkpoint_path, 'cpu')
+ new_pkg = {
+ 'best_state': pkg['best_state']['model'],
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+ 'version': __version__,
+ 'exported': True,
+ }
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+ torch.save(new_pkg, out_file)
+ return out_file
+
+
+def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
+ """Export a compression model (potentially EnCodec) from a pretrained model.
+ This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
+ Do not include the //pretrained/ prefix. For instance if you trained a model
+ with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
+
+ In that case, this will not actually include a copy of the model, simply the reference
+ to the model used.
+ """
+ if Path(pretrained_encodec).exists():
+ pkg = torch.load(pretrained_encodec)
+ assert 'best_state' in pkg
+ assert 'xp.cfg' in pkg
+ assert 'version' in pkg
+ assert 'exported' in pkg
+ else:
+ pkg = {
+ 'pretrained': pretrained_encodec,
+ 'exported': True,
+ 'version': __version__,
+ }
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+ torch.save(pkg, out_file)
+
+
+def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
+ """Export only the best state from the given MusicGen or AudioGen checkpoint.
+ """
+ pkg = torch.load(checkpoint_path, 'cpu')
+ if pkg['fsdp_best_state']:
+ best_state = pkg['fsdp_best_state']['model']
+ else:
+ assert pkg['best_state']
+ best_state = pkg['best_state']['model']
+ new_pkg = {
+ 'best_state': best_state,
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+ 'version': __version__,
+ 'exported': True,
+ }
+
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+ torch.save(new_pkg, out_file)
+ return out_file
diff --git a/audiocraft/utils/export_legacy.py b/audiocraft/utils/export_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f145f3148c3e9fdba436273bc45480fbae6481
--- /dev/null
+++ b/audiocraft/utils/export_legacy.py
@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Legacy functions used at the time of the first release, kept for referencd.
+"""
+
+from pathlib import Path
+import typing as tp
+
+from omegaconf import OmegaConf, DictConfig
+import torch
+
+
+def _clean_lm_cfg(cfg: DictConfig):
+ OmegaConf.set_struct(cfg, False)
+ # This used to be set automatically in the LM solver, need a more robust solution
+ # for the future.
+ cfg['transformer_lm']['card'] = 2048
+ cfg['transformer_lm']['n_q'] = 4
+ # Experimental params no longer supported.
+ bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
+ 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
+ for name in bad_params:
+ del cfg['transformer_lm'][name]
+ OmegaConf.set_struct(cfg, True)
+ return cfg
+
+
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+ sig = Path(checkpoint_path).parent.name
+ assert len(sig) == 8, "Not a valid Dora signature"
+ pkg = torch.load(checkpoint_path, 'cpu')
+ new_pkg = {
+ 'best_state': pkg['ema']['state']['model'],
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+ }
+ out_file = Path(out_folder) / f'{sig}.th'
+ torch.save(new_pkg, out_file)
+ return out_file
+
+
+def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+ sig = Path(checkpoint_path).parent.name
+ assert len(sig) == 8, "Not a valid Dora signature"
+ pkg = torch.load(checkpoint_path, 'cpu')
+ new_pkg = {
+ 'best_state': pkg['fsdp_best_state']['model'],
+ 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+ }
+ out_file = Path(out_folder) / f'{sig}.th'
+ torch.save(new_pkg, out_file)
+ return out_file
diff --git a/audiocraft/utils/notebook.py b/audiocraft/utils/notebook.py
new file mode 100644
index 0000000000000000000000000000000000000000..019b9d19e5bef976bedddf428fd25da42a8a9726
--- /dev/null
+++ b/audiocraft/utils/notebook.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+try:
+ import IPython.display as ipd # type: ignore
+except ImportError:
+ # Note in a notebook...
+ pass
+
+
+import torch
+
+
+def display_audio(samples: torch.Tensor, sample_rate: int):
+ """Renders an audio player for the given audio samples.
+
+ Args:
+ samples (torch.Tensor): a Tensor of decoded audio samples
+ with shapes [B, C, T] or [C, T]
+ sample_rate (int): sample rate audio should be displayed with.
+ """
+ assert samples.dim() == 2 or samples.dim() == 3
+
+ samples = samples.detach().cpu()
+ if samples.dim() == 2:
+ samples = samples[None, ...]
+
+ for audio in samples:
+ ipd.display(ipd.Audio(audio, rate=sample_rate))
diff --git a/audiocraft/utils/profiler.py b/audiocraft/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45b6d15910b50305c7b212c089ffad3c25b324d
--- /dev/null
+++ b/audiocraft/utils/profiler.py
@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import typing as tp
+
+import dora
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+class Profiler:
+ """Context manager wrapper for xformers profiler.
+ """
+ def __init__(self, module: torch.nn.Module, enabled: bool = False):
+ self.profiler: tp.Optional[tp.Any] = None
+ if enabled:
+ from xformers.profiler import profile
+ output_dir = dora.get_xp().folder / 'profiler_data'
+ logger.info("Profiling activated, results with be saved to %s", output_dir)
+ self.profiler = profile(output_dir=output_dir, module=module)
+
+ def step(self):
+ if self.profiler is not None:
+ self.profiler.step() # type: ignore
+
+ def __enter__(self):
+ if self.profiler is not None:
+ return self.profiler.__enter__() # type: ignore
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ if self.profiler is not None:
+ return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
diff --git a/audiocraft/utils/samples/__init__.py b/audiocraft/utils/samples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/audiocraft/utils/samples/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/audiocraft/utils/samples/__pycache__/__init__.cpython-310.pyc b/audiocraft/utils/samples/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b420f4a97da11c1050dab1f6387d5551f4027274
Binary files /dev/null and b/audiocraft/utils/samples/__pycache__/__init__.cpython-310.pyc differ
diff --git a/audiocraft/utils/samples/__pycache__/manager.cpython-310.pyc b/audiocraft/utils/samples/__pycache__/manager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5b35ad585599c5302f6dea9dcd0fe5442314d18
Binary files /dev/null and b/audiocraft/utils/samples/__pycache__/manager.cpython-310.pyc differ
diff --git a/audiocraft/utils/samples/manager.py b/audiocraft/utils/samples/manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf0fb21b2d2867c03f7cce6f27d9524fdb89b51d
--- /dev/null
+++ b/audiocraft/utils/samples/manager.py
@@ -0,0 +1,386 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+API that can manage the storage and retrieval of generated samples produced by experiments.
+
+It offers the following benefits:
+* Samples are stored in a consistent way across epoch
+* Metadata about the samples can be stored and retrieved
+* Can retrieve audio
+* Identifiers are reliable and deterministic for prompted and conditioned samples
+* Can request the samples for multiple XPs, grouped by sample identifier
+* For no-input samples (not prompt and no conditions), samples across XPs are matched
+ by sorting their identifiers
+"""
+
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import asdict, dataclass
+from functools import lru_cache
+import hashlib
+import json
+import logging
+from pathlib import Path
+import re
+import typing as tp
+import unicodedata
+import uuid
+
+import dora
+import torch
+
+from ...data.audio import audio_read, audio_write
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ReferenceSample:
+ id: str
+ path: str
+ duration: float
+
+
+@dataclass
+class Sample:
+ id: str
+ path: str
+ epoch: int
+ duration: float
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]]
+ prompt: tp.Optional[ReferenceSample]
+ reference: tp.Optional[ReferenceSample]
+ generation_args: tp.Optional[tp.Dict[str, tp.Any]]
+
+ def __hash__(self):
+ return hash(self.id)
+
+ def audio(self) -> tp.Tuple[torch.Tensor, int]:
+ return audio_read(self.path)
+
+ def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+ return audio_read(self.prompt.path) if self.prompt is not None else None
+
+ def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+ return audio_read(self.reference.path) if self.reference is not None else None
+
+
+class SampleManager:
+ """Audio samples IO handling within a given dora xp.
+
+ The sample manager handles the dumping and loading logic for generated and
+ references samples across epochs for a given xp, providing a simple API to
+ store, retrieve and compare audio samples.
+
+ Args:
+ xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
+ where all outputs are stored and the configuration of the experiment,
+ which is useful to retrieve audio-related parameters.
+ map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
+ instead of generating a dedicated hash id. This is useful to allow easier comparison
+ with ground truth sample from the files directly without having to read the JSON metadata
+ to do the mapping (at the cost of potentially dumping duplicate prompts/references
+ depending on the task).
+ """
+ def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
+ self.xp = xp
+ self.base_folder: Path = xp.folder / xp.cfg.generate.path
+ self.reference_folder = self.base_folder / 'reference'
+ self.map_reference_to_sample_id = map_reference_to_sample_id
+ self.samples: tp.List[Sample] = []
+ self._load_samples()
+
+ @property
+ def latest_epoch(self):
+ """Latest epoch across all samples."""
+ return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
+
+ def _load_samples(self):
+ """Scan the sample folder and load existing samples."""
+ jsons = self.base_folder.glob('**/*.json')
+ with ThreadPoolExecutor(6) as pool:
+ self.samples = list(pool.map(self._load_sample, jsons))
+
+ @staticmethod
+ @lru_cache(2**26)
+ def _load_sample(json_file: Path) -> Sample:
+ with open(json_file, 'r') as f:
+ data: tp.Dict[str, tp.Any] = json.load(f)
+ # fetch prompt data
+ prompt_data = data.get('prompt')
+ prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
+ duration=prompt_data['duration']) if prompt_data else None
+ # fetch reference data
+ reference_data = data.get('reference')
+ reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
+ duration=reference_data['duration']) if reference_data else None
+ # build sample object
+ return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
+ prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
+ generation_args=data.get('generation_args'))
+
+ def _init_hash(self):
+ return hashlib.sha1()
+
+ def _get_tensor_id(self, tensor: torch.Tensor) -> str:
+ hash_id = self._init_hash()
+ hash_id.update(tensor.numpy().data)
+ return hash_id.hexdigest()
+
+ def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
+ conditions: tp.Optional[tp.Dict[str, str]]) -> str:
+ """Computes an id for a sample given its input data.
+ This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
+ Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
+
+ Args:
+ index (int): Batch index, Helpful to differentiate samples from the same batch.
+ prompt_wav (torch.Tensor): Prompt used during generation.
+ conditions (dict[str, str]): Conditioning used during generation.
+ """
+ # For totally unconditioned generations we will just use a random UUID.
+ # The function get_samples_for_xps will do a simple ordered match with a custom key.
+ if prompt_wav is None and not conditions:
+ return f"noinput_{uuid.uuid4().hex}"
+
+ # Human readable portion
+ hr_label = ""
+ # Create a deterministic id using hashing
+ hash_id = self._init_hash()
+ hash_id.update(f"{index}".encode())
+ if prompt_wav is not None:
+ hash_id.update(prompt_wav.numpy().data)
+ hr_label += "_prompted"
+ else:
+ hr_label += "_unprompted"
+ if conditions:
+ encoded_json = json.dumps(conditions, sort_keys=True).encode()
+ hash_id.update(encoded_json)
+ cond_str = "-".join([f"{key}={slugify(value)}"
+ for key, value in sorted(conditions.items())])
+ cond_str = cond_str[:100] # some raw text might be too long to be a valid filename
+ cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
+ hr_label += f"_{cond_str}"
+ else:
+ hr_label += "_unconditioned"
+
+ return hash_id.hexdigest() + hr_label
+
+ def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
+ """Stores the audio with the given stem path using the XP's configuration.
+
+ Args:
+ wav (torch.Tensor): Audio to store.
+ stem_path (Path): Path in sample output directory with file stem to use.
+ overwrite (bool): When False (default), skips storing an existing audio file.
+ Returns:
+ Path: The path at which the audio is stored.
+ """
+ existing_paths = [
+ path for path in stem_path.parent.glob(stem_path.stem + '.*')
+ if path.suffix != '.json'
+ ]
+ exists = len(existing_paths) > 0
+ if exists and overwrite:
+ logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
+ elif exists:
+ return existing_paths[0]
+
+ audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
+ return audio_path
+
+ def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
+ conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
+ ground_truth_wav: tp.Optional[torch.Tensor] = None,
+ generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
+ """Adds a single sample.
+ The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
+ Each sample is assigned an id which is computed using the input data. In addition to the
+ sample itself, a json file containing associated metadata is stored next to it.
+
+ Args:
+ sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
+ epoch (int): current training epoch.
+ index (int): helpful to differentiate samples from the same batch.
+ conditions (dict[str, str], optional): conditioning used during generation.
+ prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
+ ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
+ Tensor of shape [channels, shape].
+ generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
+ Returns:
+ Sample: The saved sample.
+ """
+ sample_id = self._get_sample_id(index, prompt_wav, conditions)
+ reuse_id = self.map_reference_to_sample_id
+ prompt, ground_truth = None, None
+ if prompt_wav is not None:
+ prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
+ prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
+ prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
+ prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
+ if ground_truth_wav is not None:
+ ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
+ ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
+ ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
+ ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
+ sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
+ duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
+ sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
+ self.samples.append(sample)
+ with open(sample_path.with_suffix('.json'), 'w') as f:
+ json.dump(asdict(sample), f, indent=2)
+ return sample
+
+ def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
+ conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
+ prompt_wavs: tp.Optional[torch.Tensor] = None,
+ ground_truth_wavs: tp.Optional[torch.Tensor] = None,
+ generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
+ """Adds a batch of samples.
+ The samples are stored in the XP's sample output directory, under a corresponding
+ epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
+ In addition to the sample itself, a json file containing associated metadata is stored next to it.
+
+ Args:
+ sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
+ epoch (int): Current training epoch.
+ conditioning (list of dict[str, str], optional): List of conditions used during generation,
+ one per sample in the batch.
+ prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
+ [batch_size, channels, shape].
+ ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
+ Tensor of shape [batch_size, channels, shape].
+ generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
+ Returns:
+ samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
+ """
+ samples = []
+ for idx, wav in enumerate(samples_wavs):
+ prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
+ gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
+ conditions = conditioning[idx] if conditioning is not None else None
+ samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
+ return samples
+
+ def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
+ exclude_unprompted: bool = False, exclude_conditioned: bool = False,
+ exclude_unconditioned: bool = False) -> tp.Set[Sample]:
+ """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
+ Please note that existing samples are loaded during the manager's initialization, and added samples through this
+ manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
+ is the only way detect them.
+
+ Args:
+ epoch (int): If provided, only return samples corresponding to this epoch.
+ max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
+ exclude_prompted (bool): If True, does not include samples that used a prompt.
+ exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+ exclude_conditioned (bool): If True, excludes samples that used conditioning.
+ exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+ Returns:
+ Samples (set of Sample): The retrieved samples matching the provided filters.
+ """
+ if max_epoch >= 0:
+ samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
+ else:
+ samples_epoch = self.latest_epoch if epoch < 0 else epoch
+ samples = {
+ sample
+ for sample in self.samples
+ if (
+ (sample.epoch == samples_epoch) and
+ (not exclude_prompted or sample.prompt is None) and
+ (not exclude_unprompted or sample.prompt is not None) and
+ (not exclude_conditioned or not sample.conditioning) and
+ (not exclude_unconditioned or sample.conditioning)
+ )
+ }
+ return samples
+
+
+def slugify(value: tp.Any, allow_unicode: bool = False):
+ """Process string for safer file naming.
+
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
+
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
+ dashes to single dashes. Remove characters that aren't alphanumerics,
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
+ trailing whitespace, dashes, and underscores.
+ """
+ value = str(value)
+ if allow_unicode:
+ value = unicodedata.normalize("NFKC", value)
+ else:
+ value = (
+ unicodedata.normalize("NFKD", value)
+ .encode("ascii", "ignore")
+ .decode("ascii")
+ )
+ value = re.sub(r"[^\w\s-]", "", value.lower())
+ return re.sub(r"[-\s]+", "-", value).strip("-_")
+
+
+def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
+ # Create a dictionary of stable id -> sample per XP
+ stable_samples_per_xp = [{
+ sample.id: sample for sample in samples
+ if sample.prompt is not None or sample.conditioning
+ } for samples in samples_per_xp]
+ # Set of all stable ids
+ stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
+ # Dictionary of stable id -> list of samples. If an XP does not have it, assign None
+ stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
+ # Filter out ids that contain None values (we only want matched samples after all)
+ # cast is necessary to avoid mypy linter errors.
+ return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
+
+
+def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
+ # For unstable ids, we use a sorted list since we'll match them in order
+ unstable_samples_per_xp = [[
+ sample for sample in sorted(samples, key=lambda x: x.id)
+ if sample.prompt is None and not sample.conditioning
+ ] for samples in samples_per_xp]
+ # Trim samples per xp so all samples can have a match
+ min_len = min([len(samples) for samples in unstable_samples_per_xp])
+ unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
+ # Dictionary of index -> list of matched samples
+ return {
+ f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
+ }
+
+
+def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
+ """Gets a dictionary of matched samples across the given XPs.
+ Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
+ will always match the number of XPs provided and will correspond to each XP in the same order given.
+ In other words, only samples that can be match across all provided XPs will be returned
+ in order to satisfy this rule.
+
+ There are two types of ids that can be returned: stable and unstable.
+ * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
+ (prompts/conditioning). This is why we can match them across XPs.
+ * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
+ that used non-deterministic, random ids. This is the case for samples that did not use prompts or
+ conditioning for their generation. This function will sort these samples by their id and match them
+ by their index.
+
+ Args:
+ xps: a list of XPs to match samples from.
+ start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
+ end_epoch (int): If provided, only return samples corresponding to this epoch or older.
+ exclude_prompted (bool): If True, does not include samples that used a prompt.
+ exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+ exclude_conditioned (bool): If True, excludes samples that used conditioning.
+ exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+ """
+ managers = [SampleManager(xp) for xp in xps]
+ samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
+ stable_samples = _match_stable_samples(samples_per_xp)
+ unstable_samples = _match_unstable_samples(samples_per_xp)
+ return dict(stable_samples, **unstable_samples)
diff --git a/audiocraft/utils/ui.py b/audiocraft/utils/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..68fcbe0af257bdbaad767708843b545064d9b219
--- /dev/null
+++ b/audiocraft/utils/ui.py
@@ -0,0 +1,34 @@
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+refresh_symbol = '\U0001f504' # 🔄
+
+class ToolButton(gr.Button, gr.components.IOComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_classes=elem_class, scale=1, size="sm", container=False)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
\ No newline at end of file
diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3135d70e949a058095ef84dd87b49384546c465c
--- /dev/null
+++ b/audiocraft/utils/utils.py
@@ -0,0 +1,298 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ProcessPoolExecutor
+from contextlib import contextmanager
+from functools import wraps, lru_cache
+import hashlib
+import json
+import logging
+from pathlib import Path
+import typing as tp
+
+import flashy
+import flashy.distrib
+import omegaconf
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+logger = logging.getLogger(__name__)
+
+
+def model_hash(model: torch.nn.Module) -> str:
+ """Return a model hash. This should allow us to track regressions in model init
+ from the logs of past experiments.
+ """
+ hasher = hashlib.sha1()
+ for p in model.parameters():
+ hasher.update(p.data.cpu().numpy().tobytes())
+ return hasher.hexdigest()
+
+
+def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
+ """Convenience function to map an omegaconf configuration to a dictionary.
+
+ Args:
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
+ Returns:
+ dict: Config as dictionary object.
+ """
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
+ assert isinstance(dct, dict)
+ return dct
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
+ if max_samples >= len(dataset):
+ return dataset
+
+ generator = torch.Generator().manual_seed(seed)
+ perm = torch.randperm(len(dataset), generator=generator)
+ return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
+
+
+def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
+ num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
+ """Convenience function to load dataset into a dataloader with optional subset sampling.
+
+ Args:
+ dataset: Dataset to load.
+ num_samples (Optional[int]): Number of samples to limit subset size.
+ batch_size (int): Batch size.
+ num_workers (int): Number of workers for data loading.
+ seed (int): Random seed.
+ """
+ if num_samples is not None:
+ dataset = random_subset(dataset, num_samples, seed)
+
+ dataloader = flashy.distrib.loader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **kwargs
+ )
+ return dataloader
+
+
+def get_dataset_from_loader(dataloader):
+ dataset = dataloader.dataset
+ if isinstance(dataset, torch.utils.data.Subset):
+ return dataset.dataset
+ else:
+ return dataset
+
+
+def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
+
+ Args:
+ input (torch.Tensor): The input tensor containing probabilities.
+ num_samples (int): Number of samples to draw.
+ replacement (bool): Whether to draw with replacement or not.
+ Keywords args:
+ generator (torch.Generator): A pseudorandom number generator for sampling.
+ Returns:
+ torch.Tensor: Last dimension contains num_samples indices
+ sampled from the multinomial probability distribution
+ located in the last dimension of tensor input.
+ """
+ input_ = input.reshape(-1, input.shape[-1])
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
+ output = output_.reshape(*list(input.shape[:-1]), -1)
+ return output
+
+
+def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
+ """Sample next token from top K values along the last dimension of the input probs tensor.
+
+ Args:
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+ k (int): The k in “top-k”.
+ Returns:
+ torch.Tensor: Sampled tokens.
+ """
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
+ min_value_top_k = top_k_value[..., [-1]]
+ probs *= (probs >= min_value_top_k).float()
+ probs.div_(probs.sum(dim=-1, keepdim=True))
+ next_token = multinomial(probs, num_samples=1)
+ return next_token
+
+
+def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
+
+ Args:
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+ p (int): The p in “top-p”.
+ Returns:
+ torch.Tensor: Sampled tokens.
+ """
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
+ mask = probs_sum - probs_sort > p
+ probs_sort *= (~mask).float()
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+ next_token = multinomial(probs_sort, num_samples=1)
+ next_token = torch.gather(probs_idx, -1, next_token)
+ return next_token
+
+
+class DummyPoolExecutor:
+ """Dummy pool executor to use when we actually have only 1 worker.
+ (e.g. instead of ProcessPoolExecutor).
+ """
+ class DummyResult:
+ def __init__(self, func, *args, **kwargs):
+ self.func = func
+ self.args = args
+ self.kwargs = kwargs
+
+ def result(self):
+ return self.func(*self.args, **self.kwargs)
+
+ def __init__(self, workers, mp_context=None):
+ pass
+
+ def submit(self, func, *args, **kwargs):
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ return
+
+
+def get_pool_executor(num_workers: int, mp_context=None):
+ return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
+
+
+def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
+ """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
+ For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
+
+ Args:
+ lengths (torch.Tensor): tensor with lengths
+ max_len (int): can set the max length manually. Defaults to None.
+ Returns:
+ torch.Tensor: mask with 0s where there is pad tokens else 1s
+ """
+ assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
+ final_length = lengths.max().item() if not max_len else max_len
+ final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
+ return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+
+
+def hash_trick(word: str, vocab_size: int) -> int:
+ """Hash trick to pair each word with an index
+
+ Args:
+ word (str): word we wish to convert to an index
+ vocab_size (int): size of the vocabulary
+ Returns:
+ int: index of the word in the embedding LUT
+ """
+ hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
+ return hash % vocab_size
+
+
+def with_rank_rng(base_seed: int = 1234):
+ """Decorator for a function so that the function will use a Random Number Generator
+ whose state depend on the GPU rank. The original RNG state is restored upon returning.
+
+ Args:
+ base_seed (int): Random seed.
+ """
+ def _decorator(fun: tp.Callable):
+ @wraps(fun)
+ def _decorated(*args, **kwargs):
+ state = torch.get_rng_state()
+ seed = base_seed ^ flashy.distrib.rank()
+ torch.manual_seed(seed)
+ logger.debug('Rank dependent seed set to %d', seed)
+ try:
+ return fun(*args, **kwargs)
+ finally:
+ torch.set_rng_state(state)
+ logger.debug('RNG state restored.')
+ return _decorated
+ return _decorator
+
+
+def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Get a list of tensors and collate them to a single tensor. according to the following logic:
+ - `dim` specifies the time dimension which will be stacked and padded.
+ - The output will contain 1 new dimension (dimension index 0) which will be the size of
+ of the original list.
+
+ Args:
+ tensors (tp.List[torch.Tensor]): List of tensors to collate.
+ dim (int): Dimension which will be stacked and padded.
+ Returns:
+ tp.Tuple[torch.Tensor, torch.Tensor]:
+ torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
+ (dimension index 0) which will be the size of the original list.
+ torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+ """
+ tensors = [x.transpose(0, dim) for x in tensors]
+ lens = torch.LongTensor([len(x) for x in tensors])
+ padded_tensors = pad_sequence(tensors)
+ padded_tensors = padded_tensors.transpose(0, 1)
+ padded_tensors = padded_tensors.transpose(1, dim + 1)
+ return padded_tensors, lens
+
+
+# TODO: Move to flashy?
+def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
+ dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
+ if isinstance(state, torch.Tensor):
+ if dtype is None or not state.is_floating_point():
+ dtype = state.dtype
+ return state.detach().to(device=device, dtype=dtype, copy=True)
+ elif isinstance(state, dict):
+ return {k: copy_state(v, device, dtype) for k, v in state.items()}
+ elif isinstance(state, list):
+ return [copy_state(v, device, dtype) for v in state]
+
+
+# TODO: Move to flashy?
+@contextmanager
+def swap_state(model, state, **kwargs):
+ old_state = copy_state(model.state_dict())
+ model.load_state_dict(state, **kwargs)
+ try:
+ yield
+ finally:
+ model.load_state_dict(old_state)
+
+
+@lru_cache(None)
+def warn_once(logger, msg):
+ """Warn about a given message only once."""
+ logger.warning(msg)
+
+
+def is_jsonable(x: tp.Any):
+ """Check if an object can be serialized into a json:"""
+ try:
+ json.dumps(x)
+ return True
+ except (TypeError, OverflowError):
+ return False
+
+
+def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
+ """Wrapper around state dict loading of CLAP model
+ addressing compatibility issues between CLAP and AudioCraft
+ HuggingFace transformer version.
+ See: https://github.com/LAION-AI/CLAP/issues/118
+ """
+ from clap_module.factory import load_state_dict # type: ignore
+ pkg = load_state_dict(path)
+ pkg.pop('text_branch.embeddings.position_ids', None)
+ clap_model.model.load_state_dict(pkg)
diff --git a/config/conditioner/chroma2music.yaml b/config/conditioner/chroma2music.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..91d37e758ef183678cff3f7a880b6bab2e36b03c
--- /dev/null
+++ b/config/conditioner/chroma2music.yaml
@@ -0,0 +1,46 @@
+# @package __global__
+
+classifier_free_guidance:
+ training_dropout: 0.2
+ inference_coef: 3.0
+
+attribute_dropout:
+ args:
+ active_on_eval: false
+ text: {}
+ wav:
+ self_wav: 0.5
+
+fuser:
+ cross_attention_pos_emb: false
+ cross_attention_pos_emb_scale: 1
+ sum: []
+ prepend: [self_wav, description]
+ cross: []
+ input_interpolate: []
+
+conditioners:
+ self_wav:
+ model: chroma_stem
+ chroma_stem:
+ sample_rate: ${sample_rate}
+ n_chroma: 12
+ radix2_exp: 14
+ argmax: true
+ match_len_on_eval: false
+ eval_wavs: null
+ n_eval_wavs: 100
+ cache_path: null
+ description:
+ model: t5
+ t5:
+ name: t5-base
+ finetune: false
+ word_dropout: 0.2
+ normalize_text: false
+
+dataset:
+ train:
+ merge_text_p: 0.25
+ drop_desc_p: 0.5
+ drop_other_p: 0.5
diff --git a/config/conditioner/clapemb2music.yaml b/config/conditioner/clapemb2music.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8500a826e7379b4a8baaf67570e233f7bac7e5da
--- /dev/null
+++ b/config/conditioner/clapemb2music.yaml
@@ -0,0 +1,44 @@
+# @package __global__
+
+classifier_free_guidance:
+ training_dropout: 0.3
+ inference_coef: 3.0
+
+attribute_dropout:
+ text: {}
+ wav: {}
+
+fuser:
+ cross_attention_pos_emb: false
+ cross_attention_pos_emb_scale: 1
+ sum: []
+ prepend: []
+ cross: [description]
+ input_interpolate: []
+
+conditioners:
+ description:
+ model: clap
+ clap:
+ checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
+ model_arch: 'HTSAT-base'
+ enable_fusion: false
+ sample_rate: 44100
+ max_audio_length: 10
+ audio_stride: 1
+ dim: 512
+ attribute: description
+ normalize: true
+ quantize: true # use RVQ quantization
+ n_q: 12
+ bins: 1024
+ kmeans_iters: 50
+ text_p: 0. # probability of using text embed at train time
+ cache_path: null
+
+dataset:
+ joint_embed_attributes: [description]
+ train:
+ merge_text_p: 0.25
+ drop_desc_p: 0.5
+ drop_other_p: 0.5
diff --git a/config/conditioner/none.yaml b/config/conditioner/none.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6055dc910cad46d80609aae57bb46b81f2663d70
--- /dev/null
+++ b/config/conditioner/none.yaml
@@ -0,0 +1,19 @@
+# @package __global__
+
+# No conditioning
+
+classifier_free_guidance:
+ training_dropout: 0
+ inference_coef: 1
+
+attribute_dropout:
+ text: {}
+ wav: {}
+
+fuser:
+ sum: []
+ prepend: []
+ cross: []
+ input_interpolate: []
+
+conditioners: null
diff --git a/config/conditioner/text2music.yaml b/config/conditioner/text2music.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d0fe6cfa3fb33bcdb4f9fd16bd5ab4034c68b7b
--- /dev/null
+++ b/config/conditioner/text2music.yaml
@@ -0,0 +1,30 @@
+# @package __global__
+
+classifier_free_guidance:
+ training_dropout: 0.3
+ inference_coef: 3.0
+
+attribute_dropout: {}
+
+fuser:
+ cross_attention_pos_emb: false
+ cross_attention_pos_emb_scale: 1
+ sum: []
+ prepend: []
+ cross: [description]
+ input_interpolate: []
+
+conditioners:
+ description:
+ model: t5
+ t5:
+ name: t5-base
+ finetune: false
+ word_dropout: 0.3
+ normalize_text: false
+
+dataset:
+ train:
+ merge_text_p: 0.25
+ drop_desc_p: 0.5
+ drop_other_p: 0.5
diff --git a/config/conditioner/text2sound.yaml b/config/conditioner/text2sound.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..555d4b7c3cecf0ec06c8cb25440b2f426c098ad2
--- /dev/null
+++ b/config/conditioner/text2sound.yaml
@@ -0,0 +1,24 @@
+# @package __global__
+
+classifier_free_guidance:
+ training_dropout: 0.1
+ inference_coef: 3.0
+
+attribute_dropout: {}
+
+fuser:
+ cross_attention_pos_emb: false
+ cross_attention_pos_emb_scale: 1
+ sum: []
+ prepend: []
+ cross: [description]
+ input_interpolate: []
+
+conditioners:
+ description:
+ model: t5
+ t5:
+ name: t5-large
+ finetune: false
+ word_dropout: 0.
+ normalize_text: false
diff --git a/config/config.yaml b/config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6b0b7866eafac173fe7b056ad5920be1df57a947
--- /dev/null
+++ b/config/config.yaml
@@ -0,0 +1,75 @@
+# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft
+# Please don't update this file directly. Instead use distinct configuration files
+# to override the below configuration.
+defaults:
+ - _self_
+ - dset: default
+ - solver: default
+
+device: cuda
+dtype: float32
+autocast: false
+autocast_dtype: bfloat16
+seed: 2036
+show: false # just show the model and its size and exit
+continue_from: # continue from a given sig or path
+execute_only: # can be set to generate/evaluate/valid to run that stage
+execute_inplace: false # don't enforce continue_from to be set
+ # to enable inplace execution of the stage. This assume
+ # that you know what you are doing and execute stage
+ # preserving the original xp sig.
+benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them
+
+efficient_attention_backend: torch # can be torch or xformers.
+num_threads: 1 # called with torch.set_num_thread.
+mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server).
+
+
+label: # use this if you want twice the same exp, with a name.
+
+# logging parameters
+logging:
+ level: INFO
+ log_updates: 10
+ log_tensorboard: false
+ log_wandb: false
+tensorboard:
+ with_media_logging: false
+ name: # optional name for the experiment
+ sub_dir: # optional sub directory to store tensorboard data
+wandb:
+ with_media_logging: true
+ project: # project name
+ name: # optional name for the experiment
+ group: # optional group
+
+# SLURM launcher configuration.
+slurm:
+ gpus: 4 # convenience parameter, number of GPUs to use.
+ mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`.
+ time: 3600
+ constraint:
+ partition:
+ comment:
+ setup: []
+ exclude: ''
+
+# dora parameters
+dora:
+ # Output folder for all artifacts of an experiment.
+ dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
+ # The following entries will be ignored by dora when computing the unique XP signature.
+ # Note that slurm.* and dora.* are automatically ignored.
+ exclude: [
+ 'device', 'wandb.*', 'tensorboard.*', 'logging.*',
+ 'dataset.num_workers', 'eval.num_workers', 'special.*',
+ 'metrics.visqol.bin', 'metrics.fad.bin',
+ 'execute_only', 'execute_best', 'generate.every',
+ 'optim.eager_sync', 'profiler.*', 'deadlock.*',
+ 'efficient_attention_backend', 'num_threads', 'mp_start_method',
+ ]
+ use_rendezvous: false
+ # for grids, always run from a clean repo, allowing reliable runs and storing
+ # the exact commit. Your repo must be absolutely pristine clean.
+ # Local `dora run` are not impacted for easier debugging.
+ git_save: true
diff --git a/config/dset/audio/audiocaps_16khz.yaml b/config/dset/audio/audiocaps_16khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..14f5d6a4fcbf4426b7987d4427ca2d98d17d6c5b
--- /dev/null
+++ b/config/dset/audio/audiocaps_16khz.yaml
@@ -0,0 +1,11 @@
+# @package __global__
+
+# AudioCaps dataset
+datasource:
+ max_sample_rate: 16000
+ max_channels: 1
+
+ train: null # only evaluation set
+ valid: null # only evaluation set
+ evaluate: egs/audiocaps/audiocaps_16khz
+ generate: egs/audiocaps/audiocaps_16khz # identical to evaluate
diff --git a/config/dset/audio/default.yaml b/config/dset/audio/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..80be23e999c6366cc89ebcf55af6b958c0e45158
--- /dev/null
+++ b/config/dset/audio/default.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+datasource:
+ max_sample_rate: ???
+ max_channels: ???
+
+ train: ???
+ valid: ???
+ evaluate: ???
+ generate: null
diff --git a/config/dset/audio/example.yaml b/config/dset/audio/example.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d559d6d79a1cc05a82bb09f267c446258ef9ca55
--- /dev/null
+++ b/config/dset/audio/example.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+datasource:
+ max_sample_rate: 44100
+ max_channels: 2
+
+ train: egs/example
+ valid: egs/example
+ evaluate: egs/example
+ generate: egs/example
diff --git a/config/dset/audio/musiccaps_32khz.yaml b/config/dset/audio/musiccaps_32khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9d4eea0f7a521a47b9f673fecab075c5223d2b07
--- /dev/null
+++ b/config/dset/audio/musiccaps_32khz.yaml
@@ -0,0 +1,12 @@
+# @package __global__
+
+# total samples obtained from MusicCaps = 5469
+# (out of 5521 due to AudioSet corrupted samples)
+datasource:
+ max_sample_rate: 32000
+ max_channels: 2
+
+ train: null # only evaluation set
+ valid: null # only evaluation set
+ evaluate: egs/musiccaps/musiccaps_32khz
+ generate: egs/musiccaps/musiccaps_32khz # identical to evaluate
diff --git a/config/dset/default.yaml b/config/dset/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5d730130e090b38a42984a8a87e1eea01cbf031
--- /dev/null
+++ b/config/dset/default.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft
+# Please don't update this file directly. Instead use distinct configuration files
+# to override the below configuration.
+datasource:
+ train: ???
+ valid: ???
+ evaluate: ???
+ generate: ???
diff --git a/config/dset/internal/music_10k_32khz.yaml b/config/dset/internal/music_10k_32khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..036628abfeaa89279790547bbb5b3ee9dd69cea3
--- /dev/null
+++ b/config/dset/internal/music_10k_32khz.yaml
@@ -0,0 +1,11 @@
+# @package __global__
+
+# high quality music dataset with no artist overlap between splits
+datasource:
+ max_sample_rate: 32000
+ max_channels: 1
+
+ train: egs/music/music_10k_32khz/train
+ valid: egs/music/music_10k_32khz/valid
+ evaluate: egs/music/music_10k_32khz/test
+ generate: egs/music/music_10k_32khz/test # identical to evaluate
diff --git a/config/dset/internal/music_400k_32khz.yaml b/config/dset/internal/music_400k_32khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7786880ab9c0464a0423d906c18d62bdf7194463
--- /dev/null
+++ b/config/dset/internal/music_400k_32khz.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+datasource:
+ max_sample_rate: 32000
+ max_channels: 1
+
+ train: egs/music/music_400k_32khz/train
+ valid: egs/music/music_400k_32khz/valid
+ evaluate: egs/music/music_400k_32khz/test
+ generate: egs/music/music_400k_32khz/test # identical to evaluate
diff --git a/config/dset/internal/sounds_16khz.yaml b/config/dset/internal/sounds_16khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f3401a1b44ce300e22f3f64ef9c54d5c013c153
--- /dev/null
+++ b/config/dset/internal/sounds_16khz.yaml
@@ -0,0 +1,12 @@
+# @package __global__
+
+# environmental sounds dataset compiling all datasets
+# with applied filters on tags
+datasource:
+ max_sample_rate: 16000
+ max_channels: 1
+
+ train: egs/sound/sounds_16khz/train
+ valid: egs/sound/sounds_16khz/valid
+ evaluate: egs/sound/sounds_16khz/test
+ generate: egs/sound/sounds_16khz/test # identical to evaluate
diff --git a/config/model/encodec/default.yaml b/config/model/encodec/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec62c6c8ef9a686890bdca8b8f27a2f1c232205d
--- /dev/null
+++ b/config/model/encodec/default.yaml
@@ -0,0 +1,54 @@
+# @package __global__
+
+compression_model: encodec
+
+encodec:
+ autoencoder: seanet
+ quantizer: rvq
+ sample_rate: ${sample_rate}
+ channels: ${channels}
+ causal: false
+ renormalize: false
+
+seanet:
+ dimension: 128
+ channels: ${channels}
+ causal: ${encodec.causal}
+ n_filters: 32
+ n_residual_layers: 1
+ ratios: [8, 5, 4, 2]
+ activation: ELU
+ activation_params: {"alpha": 1.}
+ norm: weight_norm
+ norm_params: {}
+ kernel_size: 7
+ residual_kernel_size: 3
+ last_kernel_size: 7
+ dilation_base: 2
+ pad_mode: constant
+ true_skip: true
+ compress: 2
+ lstm: 2
+ disable_norm_outer_blocks: 0
+ # Specific encoder or decoder params.
+ # You can also override any param for the encoder or decoder only
+ # by using Hydra `+param=` syntax, i.e.`
+ # `+seanet.decoder.n_filters=64`.
+ decoder:
+ trim_right_ratio: 1.0
+ final_activation: null
+ final_activation_params: null
+ encoder: {}
+
+rvq:
+ n_q: 8
+ q_dropout: false
+ bins: 1024
+ decay: 0.99
+ kmeans_init: true
+ kmeans_iters: 50
+ threshold_ema_dead_code: 2
+ orthogonal_reg_weight: 0.0
+ orthogonal_reg_active_codes_only: false
+
+no_quant: {}
diff --git a/config/model/encodec/encodec_base_causal.yaml b/config/model/encodec/encodec_base_causal.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3ca555bcdc69433f172915400bb71c3b63e68681
--- /dev/null
+++ b/config/model/encodec/encodec_base_causal.yaml
@@ -0,0 +1,11 @@
+# @package __global__
+
+defaults:
+ - encodec/default
+
+encodec:
+ causal: true
+
+rvq:
+ n_q: 32
+ q_dropout: true
diff --git a/config/model/encodec/encodec_large_nq4_s320.yaml b/config/model/encodec/encodec_large_nq4_s320.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5f2d77590afd8a81185358c705a6e42853e257c3
--- /dev/null
+++ b/config/model/encodec/encodec_large_nq4_s320.yaml
@@ -0,0 +1,13 @@
+# @package __global__
+
+defaults:
+ - encodec/default
+
+seanet:
+ # default ratios are [8, 5, 4, 2]
+ n_filters: 64
+
+rvq:
+ bins: 2048
+ n_q: 4
+ q_dropout: false
diff --git a/config/model/encodec/encodec_large_nq4_s640.yaml b/config/model/encodec/encodec_large_nq4_s640.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3fcb7e87f4f700554164b0a58e9927b2f96a2c5a
--- /dev/null
+++ b/config/model/encodec/encodec_large_nq4_s640.yaml
@@ -0,0 +1,13 @@
+# @package __global__
+
+defaults:
+ - encodec/default
+
+seanet:
+ ratios: [8, 5, 4, 4]
+ n_filters: 64
+
+rvq:
+ bins: 2048
+ n_q: 4
+ q_dropout: false
diff --git a/config/model/lm/audiogen_lm.yaml b/config/model/lm/audiogen_lm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..696f74620af193c12208ce66fdb93a37f8ea9d80
--- /dev/null
+++ b/config/model/lm/audiogen_lm.yaml
@@ -0,0 +1,36 @@
+# @package __global__
+
+defaults:
+ - lm/default
+ - override /conditioner: text2sound
+ - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly
+
+lm_model: transformer_lm
+
+codebooks_pattern:
+ modeling: delay
+ delay:
+ delays: [0, 1, 2, 3]
+ flatten_first: 0
+ empty_initial: 0
+ unroll:
+ flattening: [0, 1, 2, 3]
+ delays: [0, 0, 0, 0]
+ music_lm:
+ group_by: 2
+ valle:
+ delays: [0, 0, 0]
+
+transformer_lm:
+ n_q: 4
+ card: 2048
+ memory_efficient: true
+ bias_proj: false
+ bias_ff: false
+ bias_attn: false
+ norm_first: true
+ layer_scale: null
+ weight_init: gaussian
+ depthwise_init: current
+ zero_bias_init: true
+ attention_as_float32: false
diff --git a/config/model/lm/default.yaml b/config/model/lm/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d256ad14ef69d25d62c19b73599937c8546e79b
--- /dev/null
+++ b/config/model/lm/default.yaml
@@ -0,0 +1,47 @@
+# @package __global__
+defaults:
+ - _self_
+ - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly
+
+lm_model: transformer_lm
+
+codebooks_pattern:
+ modeling: parallel
+
+transformer_lm:
+ dim: 512
+ num_heads: 8
+ num_layers: 8
+ hidden_scale: 4
+ n_q: 8 # number of streams to model
+ card: 1024
+ dropout: 0.
+ emb_lr: null
+ activation: gelu
+ norm_first: false # use pre-norm instead of post-norm
+ bias_ff: true # use bias for the feedforward
+ bias_attn: true # use bias for the attention
+ bias_proj: true # use bias for the output projections
+ past_context: null
+ causal: true
+ custom: false # use custom MHA implementation
+ memory_efficient: false # use flash attention
+ attention_as_float32: false # use float32 for the attention part,
+ # recommended at the moment when memory_efficient is True.
+ layer_scale: null
+ positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope).
+ xpos: false # apply xpos decay (rope only).
+ checkpointing: none # layer checkpointing method, can be none, torch, xformers_default.
+ # torch is the slowest but uses the least memory,
+ # xformers_default is somewhere in between.
+ weight_init: null # weight initialization (null, gaussian or uniform)
+ depthwise_init: null # perform depthwise initialization (null, current, global)
+ zero_bias_init: false # initialize bias to zero if bias in linears and
+ # if a weight_init method is used.
+ norm: layer_norm # normalization method to use in transformer.
+ cross_attention: false
+ qk_layer_norm: false
+ qk_layer_norm_cross: false
+ attention_dropout: null
+ kv_repeat: 1
+ two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not...
diff --git a/config/model/lm/model_scale/base.yaml b/config/model/lm/model_scale/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3da88d2305e4c380435de1a3eecfe311ecfc82f9
--- /dev/null
+++ b/config/model/lm/model_scale/base.yaml
@@ -0,0 +1,3 @@
+# @package __global__
+
+# overrides nothing because default is already transformer base (~ 60M params)
diff --git a/config/model/lm/model_scale/large.yaml b/config/model/lm/model_scale/large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d355bfb93618003ac8994bc093eb7bc96ac60114
--- /dev/null
+++ b/config/model/lm/model_scale/large.yaml
@@ -0,0 +1,7 @@
+# @package _global_
+
+# gpt2 inspired, even bigger (~3.3B params)
+transformer_lm:
+ dim: 2048
+ num_heads: 32
+ num_layers: 48
diff --git a/config/model/lm/model_scale/medium.yaml b/config/model/lm/model_scale/medium.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c825d1ff6c3b8cc9ae4959a898e14b40409d95e8
--- /dev/null
+++ b/config/model/lm/model_scale/medium.yaml
@@ -0,0 +1,7 @@
+# @package _global_
+
+# gpt2 like (~1.5B params)
+transformer_lm:
+ dim: 1536
+ num_heads: 24
+ num_layers: 48
diff --git a/config/model/lm/model_scale/small.yaml b/config/model/lm/model_scale/small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..88d89cb5ac1b183fb3a9092834cea83aa16c70a8
--- /dev/null
+++ b/config/model/lm/model_scale/small.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+
+# 300M Param.
+
+transformer_lm:
+ dim: 1024
+ num_heads: 16
+ num_layers: 24
diff --git a/config/model/lm/model_scale/xsmall.yaml b/config/model/lm/model_scale/xsmall.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e98d4370d4fe7497f12aeb58f092a88797d1afa1
--- /dev/null
+++ b/config/model/lm/model_scale/xsmall.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+# just used for debugging or when we just want to populate the cache
+# and do not care about training.
+
+transformer_lm:
+ dim: 64
+ num_heads: 2
+ num_layers: 2
diff --git a/config/model/lm/musicgen_lm.yaml b/config/model/lm/musicgen_lm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5bc87a628789a34e381e2aa8ba5ef6ed780669d7
--- /dev/null
+++ b/config/model/lm/musicgen_lm.yaml
@@ -0,0 +1,36 @@
+# @package __global__
+
+defaults:
+ - lm/default
+ - override /conditioner: text2music
+ - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly
+
+lm_model: transformer_lm
+
+codebooks_pattern:
+ modeling: delay
+ delay:
+ delays: [0, 1, 2, 3]
+ flatten_first: 0
+ empty_initial: 0
+ unroll:
+ flattening: [0, 1, 2, 3]
+ delays: [0, 0, 0, 0]
+ music_lm:
+ group_by: 2
+ valle:
+ delays: [0, 0, 0]
+
+transformer_lm:
+ n_q: 4
+ card: 2048
+ memory_efficient: true
+ bias_proj: false
+ bias_ff: false
+ bias_attn: false
+ norm_first: true
+ layer_scale: null
+ weight_init: gaussian
+ depthwise_init: current
+ zero_bias_init: true
+ attention_as_float32: false
diff --git a/config/model/none.yaml b/config/model/none.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d4169f468d462c794ee6ed25017c3d78ae45d06
--- /dev/null
+++ b/config/model/none.yaml
@@ -0,0 +1,4 @@
+# @package __global__
+
+# This file exist so that model is recognized as a config group
+# by Hydra, and Dora. A bit weird we might need a better fix someday.
diff --git a/config/model/score/basic.yaml b/config/model/score/basic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..75fbc3783942602beaddaa38d0aca977aeee2dda
--- /dev/null
+++ b/config/model/score/basic.yaml
@@ -0,0 +1,17 @@
+# @package _global_
+
+diffusion_unet:
+ hidden: 48
+ depth: 4
+ res_blocks: 1
+ norm_groups: 4
+ kernel: 8
+ stride: 4
+ growth: 4
+ max_channels: 10_000
+ dropout: 0.
+ emb_all_layers: true
+ bilstm: false
+ codec_dim: null
+ transformer: false
+ cross_attention: false
\ No newline at end of file
diff --git a/config/solver/audiogen/audiogen_base_16khz.yaml b/config/solver/audiogen/audiogen_base_16khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd6aee785c74db19ce9d6f488e68e6eeb471c026
--- /dev/null
+++ b/config/solver/audiogen/audiogen_base_16khz.yaml
@@ -0,0 +1,70 @@
+# @package __global__
+
+# This is the training loop solver
+# for the base AudioGen model (text-to-sound)
+# on monophonic audio sampled at 16 kHz
+# using a similar EnCodec+LM setup to MusicGen
+defaults:
+ - audiogen/default
+ - /model: lm/audiogen_lm
+ - override /dset: audio/default
+ - _self_
+
+autocast: true
+autocast_dtype: float16
+
+# EnCodec large trained on mono-channel music audio sampled at 16khz
+# with a total stride of 320 leading to 50 frames/s.
+# rvq.n_q=4, rvq.bins=2048, no quantization dropout
+# (transformer_lm card and n_q must be compatible)
+compression_model_checkpoint: //reference/bd44a852/checkpoint.th
+
+channels: 1
+sample_rate: 16000
+
+deadlock:
+ use: true # deadlock detection
+
+dataset:
+ batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128)
+ num_workers: 10
+ segment_duration: 10
+ min_segment_ratio: 1.0
+ sample_on_weight: false # Uniform sampling all the way
+ sample_on_duration: false # Uniform sampling all the way
+ external_metadata_source: null
+ # sample mixing augmentation at train time
+ train:
+ batch_size: 256 # matching AudioGen paper setup
+ aug_p: 0.5 # perform audio mixing 50% of the time
+ mix_p: 0.5 # proportion of batch items mixed together
+ # important: note that this will reduce the
+ # actual batch size used at train time
+ # which will be equal to mix_p * batch_size
+ mix_snr_low: -5
+ mix_snr_high: 5
+ mix_min_overlap: 0.5
+
+generate:
+ lm:
+ use_sampling: true
+ top_k: 250
+ top_p: 0.0
+
+optim:
+ epochs: 100
+ optimizer: adamw
+ lr: 5e-4
+ ema:
+ use: true
+ updates: 10
+ device: cuda
+
+logging:
+ log_tensorboard: true
+
+schedule:
+ lr_scheduler: inverse_sqrt
+ inverse_sqrt:
+ warmup: 3000
+ warmup_init_lr: 0.0
diff --git a/config/solver/audiogen/debug.yaml b/config/solver/audiogen/debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fbda8281c6d552d9445e04fee498641a26549aa5
--- /dev/null
+++ b/config/solver/audiogen/debug.yaml
@@ -0,0 +1,52 @@
+# @package __global__
+
+# This is a minimal debugging configuration
+# for MusicGen training solver
+defaults:
+ - audiogen/default
+ - /model: lm/audiogen_lm
+ - override /model/lm/model_scale: xsmall
+ - override /dset: audio/example
+ - _self_
+
+autocast: false
+compression_model_checkpoint: null
+
+codebooks_pattern:
+ modeling: parallel
+
+channels: 1
+sample_rate: 16000
+
+deadlock:
+ use: false # deadlock detection
+
+dataset:
+ batch_size: 4
+ segment_duration: 5
+ sample_on_weight: false # Uniform sampling all the way
+ sample_on_duration: false # Uniform sampling all the way
+
+generate:
+ audio:
+ strategy: peak
+ lm:
+ use_sampling: false
+ top_k: 0
+ top_p: 0.0
+
+checkpoint:
+ save_every: 0
+ keep_last: 0
+
+optim:
+ epochs: 2
+ updates_per_epoch: 10
+ optimizer: adamw
+ lr: 1e-4
+
+logging:
+ log_tensorboard: true
+
+schedule:
+ lr_scheduler: null
diff --git a/config/solver/audiogen/default.yaml b/config/solver/audiogen/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..afee63c65e0dd7350e3e89d2133bbca221d17631
--- /dev/null
+++ b/config/solver/audiogen/default.yaml
@@ -0,0 +1,40 @@
+# @package __global__
+
+defaults:
+ - /solver/musicgen/default
+ - _self_
+ - /solver/audiogen/evaluation: none
+ - override /dset: audio/default
+
+# See config/solver/musicgen/default.yaml for a list of possible values.
+# We only keep the most important here.
+
+autocast: true
+autocast_dtype: float16
+
+solver: audiogen
+sample_rate: ???
+channels: ???
+compression_model_checkpoint: ???
+
+tokens:
+ padding_with_special_token: false
+
+dataset:
+ batch_size: 128
+ segment_duration: 10
+ min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence.
+
+optim:
+ epochs: 100
+ updates_per_epoch: 2000
+ lr: 1e-4
+ optimizer: adamw
+ max_norm: 1.0
+ adam:
+ betas: [0.9, 0.95]
+ weight_decay: 0.1
+ eps: 1e-8
+
+schedule:
+ lr_scheduler: null
diff --git a/config/solver/audiogen/evaluation/none.yaml b/config/solver/audiogen/evaluation/none.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e739995ed6488700527529862a7a24f1afdcc7a
--- /dev/null
+++ b/config/solver/audiogen/evaluation/none.yaml
@@ -0,0 +1,5 @@
+# @package __global__
+
+dataset:
+ evaluate:
+ num_samples: 10000
diff --git a/config/solver/audiogen/evaluation/objective_eval.yaml b/config/solver/audiogen/evaluation/objective_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6ffeed7fb8a8c0c9c55683200a5d25d9a22aa8f1
--- /dev/null
+++ b/config/solver/audiogen/evaluation/objective_eval.yaml
@@ -0,0 +1,24 @@
+# @package __global__
+
+# Setup for execute only on audiocaps for audio generation
+# evaluation with objective metrics
+# execute_only=evaluate
+
+dataset:
+ max_audio_duration: null
+ # ensure the proper values are broadcasted here for evaluate
+ evaluate:
+ min_audio_duration: 1. # some metrics requires a minimum audio length
+ max_audio_duration: null # all samples from audiocaps should be ~10s
+ num_samples: null
+ segment_duration: null
+ generate:
+ min_audio_duration: 1.
+ max_audio_duration: null
+ num_samples: 500
+
+evaluate:
+ metrics:
+ fad: true
+ kld: true
+ text_consistency: true
diff --git a/config/solver/compression/debug.yaml b/config/solver/compression/debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..54dac175278d4ff509b0e44905d6b6195441f2c6
--- /dev/null
+++ b/config/solver/compression/debug.yaml
@@ -0,0 +1,55 @@
+# @package __global__
+
+defaults:
+ - compression/default
+ - /model: encodec/encodec_base_causal
+ - override /dset: audio/example
+ - _self_
+
+channels: 1
+sample_rate: 16000
+
+# debug config uses just L1
+losses:
+ adv: 0.
+ feat: 0.
+ l1: 1.
+ mel: 0.
+ msspec: 0.
+# no balancer
+balancer:
+ balance_grads: false
+ ema_decay: 1.
+ total_norm: 1.
+ per_batch_item: false
+# no adversaries
+adversarial:
+ adversaries: []
+ adv_loss: hinge
+ feat_loss: l1
+
+# faster model for local dev
+seanet:
+ dimension: 16
+ n_filters: 4
+
+# very small dataset
+dataset:
+ batch_size: 8
+ num_workers: 10
+ num_samples: 100
+ segment_duration: 1
+ evaluate:
+ batch_size: 32
+ generate:
+ batch_size: 1
+ num_samples: 5
+ segment_duration: 10
+
+# limited training
+evaluate:
+ every: 5
+generate:
+ every: 5
+optim:
+ epochs: 50
diff --git a/config/solver/compression/default.yaml b/config/solver/compression/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..41c812ba9ff8afe7ee10302ad5b9f05b745877d9
--- /dev/null
+++ b/config/solver/compression/default.yaml
@@ -0,0 +1,160 @@
+# @package __global__
+
+defaults:
+ - ../default
+ - override /dset: audio/default
+ - _self_
+
+solver: compression
+sample_rate: ???
+channels: ???
+
+# loss balancing
+losses:
+ adv: 4.
+ feat: 4.
+ l1: 0.1
+ mel: 0.
+ msspec: 2.
+ sisnr: 0.
+balancer:
+ balance_grads: true
+ ema_decay: 0.999
+ per_batch_item: true
+ total_norm: 1.
+
+adversarial:
+ every: 1
+ adversaries: [msstftd]
+ adv_loss: hinge
+ feat_loss: l1
+
+# losses hyperparameters
+l1: {}
+l2: {}
+mrstft:
+ factor_sc: .5
+ factor_mag: .5
+ normalized: false
+mel:
+ sample_rate: ${sample_rate}
+ n_fft: 1024
+ hop_length: 256
+ win_length: 1024
+ n_mels: 64
+ f_min: 64
+ f_max: null
+ normalized: false
+ floor_level: 1e-5
+sisnr:
+ sample_rate: ${sample_rate}
+ segment: 5.
+msspec:
+ sample_rate: ${sample_rate}
+ range_start: 6
+ range_end: 11
+ n_mels: 64
+ f_min: 64
+ f_max: null
+ normalized: true
+ alphas: false
+ floor_level: 1e-5
+
+# metrics
+metrics:
+ visqol:
+ mode: audio
+ bin: null # path to visqol install
+ model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3
+
+# adversaries hyperparameters
+msstftd:
+ in_channels: 1
+ out_channels: 1
+ filters: 32
+ norm: weight_norm
+ n_ffts: [1024, 2048, 512, 256, 128]
+ hop_lengths: [256, 512, 128, 64, 32]
+ win_lengths: [1024, 2048, 512, 256, 128]
+ activation: LeakyReLU
+ activation_params: {negative_slope: 0.3}
+msd:
+ in_channels: 1
+ out_channels: 1
+ scale_norms: [spectral_norm, weight_norm, weight_norm]
+ kernel_sizes: [5, 3]
+ filters: 16
+ max_filters: 1024
+ downsample_scales: [4, 4, 4, 4]
+ inner_kernel_sizes: null
+ groups: [4, 4, 4, 4]
+ strides: null
+ paddings: null
+ activation: LeakyReLU
+ activation_params: {negative_slope: 0.3}
+mpd:
+ in_channels: 1
+ out_channels: 1
+ periods: [2, 3, 5, 7, 11]
+ n_layers: 5
+ kernel_size: 5
+ stride: 3
+ filters: 8
+ filter_scales: 4
+ max_filters: 1024
+ activation: LeakyReLU
+ activation_params: {negative_slope: 0.3}
+ norm: weight_norm
+
+# data hyperparameters
+dataset:
+ batch_size: 64
+ num_workers: 10
+ segment_duration: 1
+ train:
+ num_samples: 500000
+ valid:
+ num_samples: 10000
+ evaluate:
+ batch_size: 32
+ num_samples: 10000
+ generate:
+ batch_size: 32
+ num_samples: 50
+ segment_duration: 10
+
+# solver hyperparameters
+evaluate:
+ every: 25
+ num_workers: 5
+ metrics:
+ visqol: false
+ sisnr: true
+generate:
+ every: 25
+ num_workers: 5
+ audio:
+ sample_rate: ${sample_rate}
+
+# checkpointing schedule
+checkpoint:
+ save_last: true
+ save_every: 25
+ keep_last: 10
+ keep_every_states: null
+
+# optimization hyperparameters
+optim:
+ epochs: 200
+ updates_per_epoch: 2000
+ lr: 3e-4
+ max_norm: 0.
+ optimizer: adam
+ adam:
+ betas: [0.5, 0.9]
+ weight_decay: 0.
+ ema:
+ use: true # whether to use EMA or not
+ updates: 1 # update at every step
+ device: ${device} # device for EMA, can be put on GPU if more frequent updates
+ decay: 0.99 # EMA decay value, if null, no EMA is used
diff --git a/config/solver/compression/encodec_audiogen_16khz.yaml b/config/solver/compression/encodec_audiogen_16khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..654deaa01ba9cace3f7144cc91921791c081b32a
--- /dev/null
+++ b/config/solver/compression/encodec_audiogen_16khz.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+defaults:
+ - compression/default
+ - /model: encodec/encodec_large_nq4_s320
+ - override /dset: audio/default
+ - _self_
+
+channels: 1
+sample_rate: 16000
diff --git a/config/solver/compression/encodec_base_24khz.yaml b/config/solver/compression/encodec_base_24khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..018ad1cd61af84b616ad3088f055e8eaa36729eb
--- /dev/null
+++ b/config/solver/compression/encodec_base_24khz.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+defaults:
+ - compression/default
+ - /model: encodec/encodec_base_causal
+ - override /dset: audio/default
+ - _self_
+
+channels: 1
+sample_rate: 24000
diff --git a/config/solver/compression/encodec_musicgen_32khz.yaml b/config/solver/compression/encodec_musicgen_32khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eca4b90fb221372dace164fe59bb15822207a980
--- /dev/null
+++ b/config/solver/compression/encodec_musicgen_32khz.yaml
@@ -0,0 +1,10 @@
+# @package __global__
+
+defaults:
+ - compression/default
+ - /model: encodec/encodec_large_nq4_s640
+ - override /dset: audio/default
+ - _self_
+
+channels: 1
+sample_rate: 32000
diff --git a/config/solver/default.yaml b/config/solver/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d7452ea1e415516dceaaae86d692cbb8c811bd57
--- /dev/null
+++ b/config/solver/default.yaml
@@ -0,0 +1,108 @@
+# @package __global__
+
+# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft
+# Please don't update this file directly. Instead use distinct configuration files
+# to override the below configuration.
+solver: ???
+
+fsdp:
+ use: false # should we use FSDP.
+ param_dtype: float16 # equivalent to autocast_dtype for FSDP.
+ reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability.
+ buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it.
+ sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard.
+ # full_shard will use less memory but slower ??
+ per_block: true # If True, uses nested FSDP.
+
+profiler:
+ enabled: false
+
+deadlock:
+ use: false
+ timeout: 600
+
+dataset:
+ batch_size: ???
+ num_workers: 10
+ segment_duration: null
+ num_samples: null
+ return_info: false
+ shuffle: false
+ sample_on_duration: true
+ sample_on_weight: true
+ min_segment_ratio: 0.5
+ train:
+ num_samples: null
+ shuffle: true
+ shuffle_seed: 0 # if you want to sample the data differently.
+ permutation_on_files: false
+ valid:
+ num_samples: null
+ evaluate:
+ num_samples: null
+ generate:
+ num_samples: null
+ return_info: true
+
+checkpoint:
+ save_last: true
+ save_every: null
+ keep_last: null
+ keep_every_states: null
+
+generate:
+ every: null
+ path: 'samples'
+ audio:
+ format: 'mp3'
+ strategy: 'clip'
+ sample_rate: null
+ lm:
+ use_sampling: false
+ temp: 1.0
+ top_k: 0
+ top_p: 0.0
+evaluate:
+ every: null
+ num_workers: 5
+ truncate_audio: null
+ fixed_generation_duration: null # in secs
+ metrics:
+ base: true # run default evaluation (e.g. like train/valid stage)
+
+optim:
+ epochs: ???
+ updates_per_epoch: null
+ lr: ???
+ optimizer: ???
+ adam:
+ betas: [0.9, 0.999]
+ weight_decay: 0.
+ ema:
+ use: false # whether to use EMA or not
+ updates: ${optim.updates_per_epoch} # frequency of updates of the EMA
+ device: cpu # device for EMA, can be put on GPU if more frequent updates
+ decay: 0.99 # EMA decay value, if null, no EMA is used
+
+schedule:
+ lr_scheduler: null
+ step:
+ step_size: null
+ gamma: null
+ exponential:
+ lr_decay: null
+ cosine:
+ warmup: null
+ lr_min_ratio: 0.0
+ cycle_length: 1.0
+ polynomial_decay:
+ warmup: null
+ zero_lr_warmup_steps: 0
+ end_lr: 0.0
+ power: 1
+ inverse_sqrt:
+ warmup: null
+ warmup_init_lr: 0.0
+ linear_warmup:
+ warmup: null
+ warmup_init_lr: 0.0
diff --git a/config/solver/diffusion/debug.yaml b/config/solver/diffusion/debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bc27c53486f7215a080d167032972402b90f5c77
--- /dev/null
+++ b/config/solver/diffusion/debug.yaml
@@ -0,0 +1,106 @@
+# @package __global__
+
+defaults:
+ - /solver/default
+ - /model: score/basic
+ - override /dset: audio/default
+ - _self_
+
+solver: diffusion
+
+sample_rate: 16000
+channels: 1
+compression_model_checkpoint: //sig/5091833e
+n_q: 2 # number of codebooks to keep
+
+dataset:
+ batch_size: 8
+ num_workers: 10
+ segment_duration: 1
+ train:
+ num_samples: 100
+ valid:
+ num_samples: 100
+ evaluate:
+ batch_size: 8
+ num_samples: 10
+ generate:
+ batch_size: 8
+ num_samples: 10
+ segment_duration: 10
+
+loss:
+ kind: mse
+ norm_power: 0.
+
+valid:
+ every: 1
+
+evaluate:
+ every: 5
+ num_workers: 5
+ metrics:
+ visqol: false
+ sisnr: false
+ rvm: true
+
+generate:
+ every: 5
+ num_workers: 5
+ audio:
+ sample_rate: ${sample_rate}
+
+checkpoint:
+ save_last: true
+ save_every: 25
+ keep_last: 10
+ keep_every_states: null
+
+
+optim:
+ epochs: 50
+ updates_per_epoch: 2000
+ lr: 2e-4
+ max_norm: 0
+ optimizer: adam
+ adam:
+ betas: [0.9, 0.999]
+ weight_decay: 0.
+ ema:
+ use: true # whether to use EMA or not
+ updates: 1 # update at every step
+ device: ${device} # device for EMA, can be put on GPU if more frequent updates
+ decay: 0.99 # EMA decay value, if null, no EMA is used
+
+processor:
+ name: multi_band_processor
+ use: false
+ n_bands: 8
+ num_samples: 10_000
+ power_std: 1.
+
+resampling:
+ use: false
+ target_sr: 16000
+
+filter:
+ use: false
+ n_bands: 4
+ idx_band: 0
+ cutoffs: null
+
+schedule:
+ repartition: "power"
+ variable_step_batch: true
+ beta_t0: 1.0e-5
+ beta_t1: 2.9e-2
+ beta_exp: 7.5
+ num_steps: 1000
+ variance: 'beta'
+ clip: 5.
+ rescale: 1.
+ n_bands: null
+ noise_scale: 1.0
+
+metrics:
+ num_stage: 4
diff --git a/config/solver/diffusion/default.yaml b/config/solver/diffusion/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3793d4d08d912db575c022a6803a8909c2b25273
--- /dev/null
+++ b/config/solver/diffusion/default.yaml
@@ -0,0 +1,107 @@
+# @package __global__
+
+defaults:
+ - /solver/default
+ - /model: score/basic
+ - override /dset: audio/default
+ - _self_
+
+solver: diffusion
+
+sample_rate: ???
+channels: ???
+compression_model_checkpoint: ???
+n_q: ??? # number of codebooks to keep
+
+
+dataset:
+ batch_size: 128
+ num_workers: 10
+ segment_duration: 1
+ train:
+ num_samples: 500000
+ valid:
+ num_samples: 10000
+ evaluate:
+ batch_size: 16
+ num_samples: 10000
+ generate:
+ batch_size: 32
+ num_samples: 50
+ segment_duration: 10
+ audio:
+ sample_rate: ${sample_rate}
+
+loss:
+ kind: mse
+ norm_power: 0.
+
+valid:
+ every: 1
+
+evaluate:
+ every: 20
+ num_workers: 5
+ metrics:
+ visqol: false
+ sisnr: false
+ rvm: true
+
+generate:
+ every: 25
+ num_workers: 5
+
+checkpoint:
+ save_last: true
+ save_every: 25
+ keep_last: 10
+ keep_every_states: null
+
+
+optim:
+ epochs: 20000
+ updates_per_epoch: 2000
+ lr: 2e-4
+ max_norm: 0
+ optimizer: adam
+ adam:
+ betas: [0.9, 0.999]
+ weight_decay: 0.
+ ema:
+ use: true # whether to use EMA or not
+ updates: 1 # update at every step
+ device: ${device} # device for EMA, can be put on GPU if more frequent updates
+ decay: 0.99 # EMA decay value, if null, no EMA is used
+
+processor:
+ name: multi_band_processor
+ use: false
+ n_bands: 8
+ num_samples: 10_000
+ power_std: 1.
+
+resampling:
+ use: false
+ target_sr: 16000
+
+filter:
+ use: false
+ n_bands: 4
+ idx_band: 0
+ cutoffs: null
+
+schedule:
+ repartition: "power"
+ variable_step_batch: true
+ beta_t0: 1.0e-5
+ beta_t1: 2.9e-2
+ beta_exp: 7.5
+ num_steps: 1000
+ variance: 'beta'
+ clip: 5.
+ rescale: 1.
+ n_bands: null
+ noise_scale: 1.0
+
+metrics:
+ num_stage: 4
diff --git a/config/solver/diffusion/encodec_24khz.yaml b/config/solver/diffusion/encodec_24khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..774e88f43d54980daef0c68d11717ddb7a214db1
--- /dev/null
+++ b/config/solver/diffusion/encodec_24khz.yaml
@@ -0,0 +1,11 @@
+# @package __global__
+
+defaults:
+ - diffusion/default
+ - _self_
+
+
+sample_rate: 24000
+channels: 1
+compression_model_checkpoint: //pretrained/facebook/encodec_24khz
+n_q: 4 # num quantizers, 3kbps
diff --git a/config/solver/musicgen/debug.yaml b/config/solver/musicgen/debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec658f9d2fb0262cc8eab19d0cf333963c646a98
--- /dev/null
+++ b/config/solver/musicgen/debug.yaml
@@ -0,0 +1,55 @@
+# @package __global__
+
+# This is a minimal debugging configuration
+# for MusicGen training solver
+defaults:
+ - musicgen/default
+ - /model: lm/musicgen_lm
+ - override /model/lm/model_scale: xsmall
+ - override /dset: audio/example
+ - _self_
+
+autocast: false
+compression_model_checkpoint: //pretrained/debug_compression_model
+transformer_lm:
+ n_q: 4
+ card: 400
+
+codebooks_pattern:
+ modeling: parallel
+
+channels: 1
+sample_rate: 32000
+
+deadlock:
+ use: false # deadlock detection
+
+dataset:
+ batch_size: 4
+ segment_duration: 5
+ sample_on_weight: false # Uniform sampling all the way
+ sample_on_duration: false # Uniform sampling all the way
+
+generate:
+ audio:
+ strategy: peak
+ lm:
+ use_sampling: false
+ top_k: 0
+ top_p: 0.0
+
+checkpoint:
+ save_every: 0
+ keep_last: 0
+
+optim:
+ epochs: 2
+ updates_per_epoch: 10
+ optimizer: adamw
+ lr: 1e-4
+
+logging:
+ log_tensorboard: true
+
+schedule:
+ lr_scheduler: null
diff --git a/config/solver/musicgen/default.yaml b/config/solver/musicgen/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..59e011376fb2b909fe599bc86bf0ef4029ce5d6e
--- /dev/null
+++ b/config/solver/musicgen/default.yaml
@@ -0,0 +1,119 @@
+# @package __global__
+
+defaults:
+ - /solver/default
+ - /conditioner: none
+ - _self_
+ - /solver/musicgen/evaluation: none
+ - override /dset: audio/default
+
+autocast: true
+autocast_dtype: float16
+
+solver: musicgen
+sample_rate: ???
+channels: ???
+compression_model_checkpoint: ???
+
+tokens:
+ padding_with_special_token: false
+
+cache:
+ path:
+ write: false
+ write_shard: 0
+ write_num_shards: 1
+
+
+dataset:
+ batch_size: 128
+ num_workers: 10
+ segment_duration: 30
+ min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence.
+ return_info: true
+ train:
+ num_samples: 1000000 # need a randomly large number here for AudioDataset
+ valid:
+ num_samples: 10000
+ generate:
+ num_samples: 50
+
+metrics:
+ fad:
+ use_gt: false
+ model: tf
+ tf:
+ bin: null # path to local frechet_audio_distance code
+ model_path: //reference/fad/vggish_model.ckpt
+ kld:
+ use_gt: false
+ model: passt
+ passt:
+ pretrained_length: 20
+ text_consistency:
+ use_gt: false
+ model: clap
+ clap:
+ model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
+ model_arch: 'HTSAT-base'
+ enable_fusion: false
+ chroma_cosine:
+ use_gt: false
+ model: chroma_base
+ chroma_base:
+ sample_rate: ${sample_rate}
+ n_chroma: 12
+ radix2_exp: 14
+ argmax: true
+
+generate:
+ every: 25
+ num_workers: 5
+ path: samples
+ audio:
+ format: wav
+ strategy: loudness
+ sample_rate: ${sample_rate}
+ loudness_headroom_db: 14
+ lm:
+ prompted_samples: true
+ unprompted_samples: true
+ gen_gt_samples: false
+ prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4
+ gen_duration: null # if not set, will use dataset.generate.segment_duration
+ remove_prompts: false
+ # generation params
+ use_sampling: false
+ temp: 1.0
+ top_k: 0
+ top_p: 0.0
+evaluate:
+ every: 25
+ num_workers: 5
+ metrics:
+ base: false
+ fad: false
+ kld: false
+ text_consistency: false
+ chroma_cosine: false
+
+checkpoint:
+ save_last: true
+ save_every: 50
+ keep_last: 10
+ keep_every_states: null
+
+optim:
+ epochs: 200
+ updates_per_epoch: 2000
+ lr: 1e-4
+ optimizer: adamw
+ max_norm: 1.0
+ eager_sync: true
+ adam:
+ betas: [0.9, 0.95]
+ weight_decay: 0.1
+ eps: 1e-8
+
+schedule:
+ lr_scheduler: null
diff --git a/config/solver/musicgen/evaluation/none.yaml b/config/solver/musicgen/evaluation/none.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e739995ed6488700527529862a7a24f1afdcc7a
--- /dev/null
+++ b/config/solver/musicgen/evaluation/none.yaml
@@ -0,0 +1,5 @@
+# @package __global__
+
+dataset:
+ evaluate:
+ num_samples: 10000
diff --git a/config/solver/musicgen/evaluation/objective_eval.yaml b/config/solver/musicgen/evaluation/objective_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4881e9d86cddf36b306a75fb498253e1e12ec5be
--- /dev/null
+++ b/config/solver/musicgen/evaluation/objective_eval.yaml
@@ -0,0 +1,24 @@
+# @package __global__
+
+# Setup for execute only on musiccaps for audio generation
+# evaluation with objective metrics
+# execute_only=evaluate
+
+dataset:
+ max_audio_duration: null
+ # ensure the proper values are broadcasted here for evaluate
+ evaluate:
+ min_audio_duration: 1. # some metrics requires a minimum audio length
+ max_audio_duration: null # all samples from musiccaps should be < 20s
+ num_samples: null
+ segment_duration: null
+ generate:
+ min_audio_duration: 1.
+ max_audio_duration: null
+ num_samples: 500
+
+evaluate:
+ metrics:
+ fad: true
+ kld: true
+ text_consistency: true
diff --git a/config/solver/musicgen/musicgen_base_32khz.yaml b/config/solver/musicgen/musicgen_base_32khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b32c9c898a70718f91af862caa79f5553a5107e1
--- /dev/null
+++ b/config/solver/musicgen/musicgen_base_32khz.yaml
@@ -0,0 +1,55 @@
+# @package __global__
+
+# This is the training loop solver
+# for the base MusicGen model (text-to-music)
+# on monophonic audio sampled at 32 kHz
+defaults:
+ - musicgen/default
+ - /model: lm/musicgen_lm
+ - override /dset: audio/default
+ - _self_
+
+autocast: true
+autocast_dtype: float16
+
+# EnCodec large trained on mono-channel music audio sampled at 32khz
+# with a total stride of 640 leading to 50 frames/s.
+# rvq.n_q=4, rvq.bins=2048, no quantization dropout
+# (transformer_lm card and n_q must be compatible)
+compression_model_checkpoint: //pretrained/facebook/encodec_32khz
+
+channels: 1
+sample_rate: 32000
+
+deadlock:
+ use: true # deadlock detection
+
+dataset:
+ batch_size: 192 # 32 GPUs
+ sample_on_weight: false # Uniform sampling all the way
+ sample_on_duration: false # Uniform sampling all the way
+
+generate:
+ lm:
+ use_sampling: true
+ top_k: 250
+ top_p: 0.0
+
+optim:
+ epochs: 500
+ optimizer: dadam
+ lr: 1
+ ema:
+ use: true
+ updates: 10
+ device: cuda
+
+logging:
+ log_tensorboard: true
+
+schedule:
+ lr_scheduler: cosine
+ cosine:
+ warmup: 4000
+ lr_min_ratio: 0.0
+ cycle_length: 1.0
diff --git a/config/solver/musicgen/musicgen_melody_32khz.yaml b/config/solver/musicgen/musicgen_melody_32khz.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ad3e0aeeb9583887d6e8ecd6d32a3dc69e102ed
--- /dev/null
+++ b/config/solver/musicgen/musicgen_melody_32khz.yaml
@@ -0,0 +1,56 @@
+# @package __global__
+
+# This is the training loop solver
+# for the melody MusicGen model (text+chroma to music)
+# on monophonic audio sampled at 32 kHz
+defaults:
+ - musicgen/default
+ - /model: lm/musicgen_lm
+ - override /conditioner: chroma2music
+ - override /dset: audio/default
+ - _self_
+
+autocast: true
+autocast_dtype: float16
+
+# EnCodec large trained on mono-channel music audio sampled at 32khz
+# with a total stride of 640 leading to 50 frames/s.
+# rvq.n_q=4, rvq.bins=2048, no quantization dropout
+# (transformer_lm card and n_q must be compatible)
+compression_model_checkpoint: //pretrained/facebook/encodec_32khz
+
+channels: 1
+sample_rate: 32000
+
+deadlock:
+ use: true # deadlock detection
+
+dataset:
+ batch_size: 192 # 32 GPUs
+ sample_on_weight: false # Uniform sampling all the way
+ sample_on_duration: false # Uniform sampling all the way
+
+generate:
+ lm:
+ use_sampling: true
+ top_k: 250
+ top_p: 0.0
+
+optim:
+ epochs: 500
+ optimizer: dadam
+ lr: 1
+ ema:
+ use: true
+ updates: 10
+ device: cuda
+
+logging:
+ log_tensorboard: true
+
+schedule:
+ lr_scheduler: cosine
+ cosine:
+ warmup: 4000
+ lr_min_ratio: 0.0
+ cycle_length: 1.0
diff --git a/config/teams/default.yaml b/config/teams/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..407066df1e154208af2823a6e46d16df381c5d42
--- /dev/null
+++ b/config/teams/default.yaml
@@ -0,0 +1,12 @@
+default:
+ dora_dir: /tmp/audiocraft_${oc.env:USER}
+ partitions:
+ global: debug
+ team: debug
+ reference_dir: /tmp
+darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc.
+ dora_dir: /tmp/audiocraft_${oc.env:USER}
+ partitions:
+ global: debug
+ team: debug
+ reference_dir: /tmp
diff --git a/config/teams/labs.yaml b/config/teams/labs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..da350a94bc5758531ced5d9e4332624fe86f3d57
--- /dev/null
+++ b/config/teams/labs.yaml
@@ -0,0 +1,28 @@
+aws:
+ dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs
+ partitions:
+ global: learnlab
+ team: learnlab
+ reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference
+ dataset_mappers:
+ "^/checkpoint/[a-z]+": "/fsx-audio-craft-llm"
+fair:
+ dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
+ partitions:
+ global: learnlab
+ team: learnlab
+ reference_dir: /large_experiments/audiocraft/reference
+ dataset_mappers:
+ "^/datasets01/datasets01": "/datasets01"
+darwin:
+ dora_dir: /tmp/audiocraft_${oc.env:USER}
+ partitions:
+ global: debug
+ team: debug
+ reference_dir: /tmp
+rsc:
+ dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs
+ partitions:
+ global: learn
+ team: learn
+ reference_dir: /checkpoint/audiocraft/shared/reference
diff --git a/dataset/example/electro_1.json b/dataset/example/electro_1.json
new file mode 100644
index 0000000000000000000000000000000000000000..eeffc95038a1e031fad5598f822ddf2538d7f4da
--- /dev/null
+++ b/dataset/example/electro_1.json
@@ -0,0 +1 @@
+{"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]}
diff --git a/dataset/example/electro_1.mp3 b/dataset/example/electro_1.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..8fa509266df4ee76519b82bfbea247cb0b18bcda
Binary files /dev/null and b/dataset/example/electro_1.mp3 differ
diff --git a/dataset/example/electro_2.json b/dataset/example/electro_2.json
new file mode 100644
index 0000000000000000000000000000000000000000..3ee91c89c1d4b603f3e4d3fcc029618dc110e730
--- /dev/null
+++ b/dataset/example/electro_2.json
@@ -0,0 +1 @@
+{"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []}
diff --git a/dataset/example/electro_2.mp3 b/dataset/example/electro_2.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..01ab323e4322d08546635861959b868c3d7b416b
Binary files /dev/null and b/dataset/example/electro_2.mp3 differ
diff --git a/demos/audiogen_demo.ipynb b/demos/audiogen_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d3ad73fbcf172ea291ee4d73729b35af75ccedaa
--- /dev/null
+++ b/demos/audiogen_demo.ipynb
@@ -0,0 +1,175 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# AudioGen\n",
+ "Welcome to AudioGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use AudioGen in different settings.\n",
+ "\n",
+ "First, we start by initializing AudioGen. For now, we provide only a medium sized model for AudioGen: `facebook/audiogen-medium` - 1.5B transformer decoder. \n",
+ "\n",
+ "**Important note:** This variant is different from the original AudioGen model presented at [\"AudioGen: Textually-guided audio generation\"](https://arxiv.org/abs/2209.15352) as the model architecture is similar to MusicGen with a smaller frame rate and multiple streams of tokens, allowing to reduce generation time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from audiocraft.models import AudioGen\n",
+ "\n",
+ "model = AudioGen.get_pretrained('facebook/audiogen-medium')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
+ "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
+ "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
+ "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
+ "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
+ "* `duration` (float, optional): duration of the generated waveform. Defaults to 10.0.\n",
+ "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
+ "\n",
+ "When left unchanged, AudioGen will revert to its default parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.set_generation_params(\n",
+ " use_sampling=True,\n",
+ " top_k=250,\n",
+ " duration=5\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we can go ahead and start generating sound using one of the following modes:\n",
+ "* Audio continuation using `model.generate_continuation`\n",
+ "* Text-conditional samples using `model.generate`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Audio Continuation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import math\n",
+ "import torchaudio\n",
+ "import torch\n",
+ "from audiocraft.utils.notebook import display_audio\n",
+ "\n",
+ "def get_bip_bip(bip_duration=0.125, frequency=440,\n",
+ " duration=0.5, sample_rate=16000, device=\"cuda\"):\n",
+ " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
+ " t = torch.arange(\n",
+ " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
+ " wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
+ " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
+ " envelope = (tp >= 0.5).float()\n",
+ " return wav * envelope"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Here we use a synthetic signal to prompt the generated audio.\n",
+ "res = model.generate_continuation(\n",
+ " get_bip_bip(0.125).expand(2, -1, -1), \n",
+ " 16000, ['Whistling with wind blowing', \n",
+ " 'Typing on a typewriter'], \n",
+ " progress=True)\n",
+ "display_audio(res, 16000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
+ "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/sirens_and_a_humming_engine_approach_and_pass.mp3\")\n",
+ "prompt_duration = 2\n",
+ "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
+ "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n",
+ "display_audio(output, sample_rate=16000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Text-conditional Generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from audiocraft.utils.notebook import display_audio\n",
+ "\n",
+ "output = model.generate(\n",
+ " descriptions=[\n",
+ " 'Subway train blowing its horn',\n",
+ " 'A cat meowing',\n",
+ " ],\n",
+ " progress=True\n",
+ ")\n",
+ "display_audio(output, sample_rate=16000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/demos/musicgen_demo.ipynb b/demos/musicgen_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..a57bc79778e3356425a63cfc99193709641a87c8
--- /dev/null
+++ b/demos/musicgen_demo.ipynb
@@ -0,0 +1,232 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MusicGen\n",
+ "Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n",
+ "\n",
+ "First, we start by initializing MusicGen, you can choose a model from the following selection:\n",
+ "1. `facebook/musicgen-small` - 300M transformer decoder.\n",
+ "2. `facebook/musicgen-medium` - 1.5B transformer decoder.\n",
+ "3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.\n",
+ "4. `facebook/musicgen-large` - 3.3B transformer decoder.\n",
+ "\n",
+ "We will use the `facebook/musicgen-small` variant for the purpose of this demonstration."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from audiocraft.models import MusicGen\n",
+ "from audiocraft.models import MultiBandDiffusion\n",
+ "import torch\n",
+ "USE_DIFFUSION_DECODER = False\n",
+ "# Using small model, better results would be obtained with `medium` or `large`.\n",
+ "model = MusicGen.get_pretrained('facebook/musicgen-large')\n",
+ "if USE_DIFFUSION_DECODER:\n",
+ " mbd = MultiBandDiffusion.get_mbd_musicgen()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
+ "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
+ "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
+ "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
+ "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
+ "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n",
+ "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
+ "\n",
+ "When left unchanged, MusicGen will revert to its default parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.set_generation_params(\n",
+ " use_sampling=True,\n",
+ " top_k=250,\n",
+ " duration=30\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we can go ahead and start generating music using one of the following modes:\n",
+ "* Unconditional samples using `model.generate_unconditional`\n",
+ "* Music continuation using `model.generate_continuation`\n",
+ "* Text-conditional samples using `model.generate`\n",
+ "* Melody-conditional samples using `model.generate_with_chroma`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Music Continuation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import math\n",
+ "import torchaudio\n",
+ "import torch\n",
+ "from audiocraft.utils.notebook import display_audio\n",
+ "\n",
+ "def get_bip_bip(bip_duration=0.125, frequency=440,\n",
+ " duration=0.5, sample_rate=32000, device=\"cuda\"):\n",
+ " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
+ " t = torch.arange(\n",
+ " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
+ " wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
+ " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
+ " envelope = (tp >= 0.5).float()\n",
+ " return wav * envelope"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Here we use a synthetic signal to prompt both the tonality and the BPM\n",
+ "# of the generated audio.\n",
+ "res = model.generate_continuation(\n",
+ " get_bip_bip(0.125).expand(2, -1, -1), \n",
+ " 32000, ['Jazz jazz and only jazz', \n",
+ " 'Heartful EDM with beautiful synths and chords'], \n",
+ " progress=True)\n",
+ "display_audio(res, 32000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
+ "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/bach.mp3\")\n",
+ "prompt_duration = 2\n",
+ "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
+ "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=True)\n",
+ "display_audio(output[0], sample_rate=32000)\n",
+ "if USE_DIFFUSION_DECODER:\n",
+ " out_diffusion = mbd.tokens_to_wav(output[1])\n",
+ " display_audio(out_diffusion, sample_rate=32000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Text-conditional Generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from audiocraft.utils.notebook import display_audio\n",
+ "\n",
+ "output = model.generate(\n",
+ " descriptions=[\n",
+ " #'80s pop track with bassy drums and synth',\n",
+ " #'90s rock song with loud guitars and heavy drums',\n",
+ " #'Progressive rock drum and bass solo',\n",
+ " #'Punk Rock song with loud drum and power guitar',\n",
+ " #'Bluesy guitar instrumental with soulful licks and a driving rhythm section',\n",
+ " #'Jazz Funk song with slap bass and powerful saxophone',\n",
+ " 'drum and bass beat with intense percussions'\n",
+ " ],\n",
+ " progress=True, return_tokens=True\n",
+ ")\n",
+ "display_audio(output[0], sample_rate=32000)\n",
+ "if USE_DIFFUSION_DECODER:\n",
+ " out_diffusion = mbd.tokens_to_wav(output[1])\n",
+ " display_audio(out_diffusion, sample_rate=32000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Melody-conditional Generation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torchaudio\n",
+ "from audiocraft.utils.notebook import display_audio\n",
+ "\n",
+ "model = MusicGen.get_pretrained('facebook/musicgen-melody')\n",
+ "model.set_generation_params(duration=8)\n",
+ "\n",
+ "melody_waveform, sr = torchaudio.load(\"../assets/bach.mp3\")\n",
+ "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n",
+ "output = model.generate_with_chroma(\n",
+ " descriptions=[\n",
+ " '80s pop track with bassy drums and synth',\n",
+ " '90s rock song with loud guitars and heavy drums',\n",
+ " ],\n",
+ " melody_wavs=melody_waveform,\n",
+ " melody_sample_rate=sr,\n",
+ " progress=True, return_tokens=True\n",
+ ")\n",
+ "display_audio(output[0], sample_rate=32000)\n",
+ "if USE_DIFFUSION_DECODER:\n",
+ " out_diffusion = mbd.tokens_to_wav(output[1])\n",
+ " display_audio(out_diffusion, sample_rate=32000)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dockerignore b/dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..6e25fa8f106e51b7dd538d35b1291301f821b2ad
--- /dev/null
+++ b/dockerignore
@@ -0,0 +1 @@
+cache/
\ No newline at end of file
diff --git a/docs/AUDIOGEN.md b/docs/AUDIOGEN.md
new file mode 100644
index 0000000000000000000000000000000000000000..a0ff481190fb52fe865aa66aaaa10176f7cf995c
--- /dev/null
+++ b/docs/AUDIOGEN.md
@@ -0,0 +1,158 @@
+# AudioGen: Textually-guided audio generation
+
+AudioCraft provides the code and a model re-implementing AudioGen, a [textually-guided audio generation][audiogen_arxiv]
+model that performs text-to-sound generation.
+
+The provided AudioGen reimplementation follows the LM model architecture introduced in [MusicGen][musicgen_arxiv]
+and is a single stage auto-regressive Transformer model trained over a 16kHz
+EnCodec tokenizer with 4 codebooks sampled at 50 Hz.
+This model variant reaches similar audio quality than the original implementation introduced in the AudioGen publication
+while providing faster generation speed given the smaller frame rate.
+
+**Important note:** The provided models are NOT the original models used to report numbers in the
+[AudioGen publication][audiogen_arxiv]. Refer to the model card to learn more about architectural changes.
+
+Listen to samples from the **original AudioGen implementation** in our [sample page][audiogen_samples].
+
+
+## Model Card
+
+See [the model card](../model_cards/AUDIOGEN_MODEL_CARD.md).
+
+
+## Installation
+
+Please follow the AudioCraft installation instructions from the [README](../README.md).
+
+AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters).
+
+## API and usage
+
+We provide a simple API and 1 pre-trained models for AudioGen:
+
+`facebook/audiogen-medium`: 1.5B model, text to sound - [🤗 Hub](https://huggingface.co/facebook/audiogen-medium)
+
+You can play with AudioGen by running the jupyter notebook at [`demos/audiogen_demo.ipynb`](../demos/audiogen_demo.ipynb) locally (if you have a GPU).
+
+See after a quick example for using the API.
+
+```python
+import torchaudio
+from audiocraft.models import AudioGen
+from audiocraft.data.audio import audio_write
+
+model = AudioGen.get_pretrained('facebook/audiogen-medium')
+model.set_generation_params(duration=5) # generate 5 seconds.
+descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor']
+wav = model.generate(descriptions) # generates 3 samples.
+
+for idx, one_wav in enumerate(wav):
+ # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
+ audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
+```
+
+## Training
+
+The [AudioGenSolver](../audiocraft/solvers/audiogen.py) implements the AudioGen's training pipeline
+used to develop the released model. Note that this may not fully reproduce the results presented in the paper.
+Similarly to MusicGen, it defines an autoregressive language modeling task over multiple streams of
+discrete tokens extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md)
+for more details on how to train such model) with dataset-specific changes for environmental sound
+processing.
+
+Note that **we do NOT provide any of the datasets** used for training AudioGen.
+
+### Example configurations and grids
+
+We provide configurations to reproduce the released models and our research.
+AudioGen solvers configuration are available in [config/solver/audiogen](../config/solver/audiogen).
+The base training configuration used for the released models is the following:
+[`solver=audiogen/audiogen_base_16khz`](../config/solver/audiogen/audiogen_base_16khz.yaml)
+
+Please find some example grids to train AudioGen at
+[audiocraft/grids/audiogen](../audiocraft/grids/audiogen/).
+
+```shell
+# text-to-sound
+dora grid audiogen.audiogen_base_16khz
+```
+
+### Sound dataset and metadata
+
+AudioGen's underlying dataset is an AudioDataset augmented with description metadata.
+The AudioGen dataset implementation expects the metadata to be available as `.json` files
+at the same location as the audio files or through specified external folder.
+Learn more in the [datasets section](./DATASETS.md).
+
+### Evaluation stage
+
+By default, evaluation stage is also computing the cross-entropy and the perplexity over the
+evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run
+or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md)
+for more details on the requirements for each metric.
+
+We provide an off-the-shelf configuration to enable running the objective metrics
+for audio generation in
+[config/solver/audiogen/evaluation/objective_eval](../config/solver/audiogen/evaluation/objective_eval.yaml).
+
+One can then activate evaluation the following way:
+```shell
+# using the configuration
+dora run solver=audiogen/debug solver/audiogen/evaluation=objective_eval
+# specifying each of the fields, e.g. to activate KL computation
+dora run solver=audiogen/debug evaluate.metrics.kld=true
+```
+
+See [an example evaluation grid](../audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py).
+
+### Generation stage
+
+The generation stage allows to generate samples conditionally and/or unconditionally and to perform
+audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling
+from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples
+generated and the batch size used are controlled by the `dataset.generate` configuration
+while the other generation parameters are defined in `generate.lm`.
+
+```shell
+# control sampling parameters
+dora run solver=audiogen/debug generate.lm.gen_duration=5 generate.lm.use_sampling=true generate.lm.top_k=15
+```
+
+## More information
+
+Refer to [MusicGen's instructions](./MUSICGEN.md).
+
+### Learn more
+
+Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
+
+
+## Citation
+
+AudioGen
+```
+@article{kreuk2022audiogen,
+ title={Audiogen: Textually guided audio generation},
+ author={Kreuk, Felix and Synnaeve, Gabriel and Polyak, Adam and Singer, Uriel and D{\'e}fossez, Alexandre and Copet, Jade and Parikh, Devi and Taigman, Yaniv and Adi, Yossi},
+ journal={arXiv preprint arXiv:2209.15352},
+ year={2022}
+}
+```
+
+MusicGen
+```
+@article{copet2023simple,
+ title={Simple and Controllable Music Generation},
+ author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
+ year={2023},
+ journal={arXiv preprint arXiv:2306.05284},
+}
+```
+
+## License
+
+See license information in the [model card](../model_cards/AUDIOGEN_MODEL_CARD.md).
+
+[audiogen_arxiv]: https://arxiv.org/abs/2209.15352
+[musicgen_arxiv]: https://arxiv.org/abs/2306.05284
+[audiogen_samples]: https://felixkreuk.github.io/audiogen/
diff --git a/docs/CONDITIONING.md b/docs/CONDITIONING.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e356cb8e9912d3e18fc84598c1acf77c6e7abc5
--- /dev/null
+++ b/docs/CONDITIONING.md
@@ -0,0 +1,146 @@
+# AudioCraft conditioning modules
+
+AudioCraft provides a
+[modular implementation of conditioning modules](../audiocraft/modules/conditioners.py)
+that can be used with the language model to condition the generation.
+The codebase was developed in order to easily extend the set of modules
+currently supported to easily develop new ways of controlling the generation.
+
+
+## Conditioning methods
+
+For now, we support 3 main types of conditioning within AudioCraft:
+* Text-based conditioning methods
+* Waveform-based conditioning methods
+* Joint embedding conditioning methods for text and audio projected in a shared latent space.
+
+The Language Model relies on 2 core components that handle processing information:
+* The `ConditionProvider` class, that maps metadata to processed conditions leveraging
+all the defined conditioners for the given task.
+* The `ConditionFuser` class, that takes preprocessed conditions and properly fuse the
+conditioning embedding to the language model inputs following a given fusing strategy.
+
+Different conditioners (for text, waveform, joint embeddings...) are provided as torch
+modules in AudioCraft and are used internally in the language model to process the
+conditioning signals and feed them to the language model.
+
+
+## Core concepts
+
+### Conditioners
+
+The `BaseConditioner` torch module is the base implementation for all conditioners in audiocraft.
+
+Each conditioner is expected to implement 2 methods:
+* The `tokenize` method that is used as a preprocessing method that contains all processing
+that can lead to synchronization points (e.g. BPE tokenization with transfer to the GPU).
+The output of the tokenize method will then be used to feed the forward method.
+* The `forward` method that takes the output of the tokenize method and contains the core computation
+to obtain the conditioning embedding along with a mask indicating valid indices (e.g. padding tokens).
+
+### ConditionProvider
+
+The ConditionProvider prepares and provides conditions given a dictionary of conditioners.
+
+Conditioners are specified as a dictionary of attributes and the corresponding conditioner
+providing the processing logic for the given attribute.
+
+Similarly to the conditioners, the condition provider works in two steps to avoid sychronization points:
+* A `tokenize` method that takes a list of conditioning attributes for the batch,
+and run all tokenize steps for the set of conditioners.
+* A `forward` method that takes the output of the tokenize step and run all the forward steps
+for the set of conditioners.
+
+The list of conditioning attributes is passed as a list of `ConditioningAttributes`
+that is presented just below.
+
+### ConditionFuser
+
+Once all conditioning signals have been extracted and processed by the `ConditionProvider`
+as dense embeddings, they remain to be passed to the language model along with the original
+language model inputs.
+
+The `ConditionFuser` handles specifically the logic to combine the different conditions
+to the actual model input, supporting different strategies to combine them.
+
+One can therefore define different strategies to combine or fuse the condition to the input, in particular:
+* Prepending the conditioning signal to the input with the `prepend` strategy,
+* Summing the conditioning signal to the input with the `sum` strategy,
+* Combining the conditioning relying on a cross-attention mechanism with the `cross` strategy,
+* Using input interpolation with the `input_interpolate` strategy.
+
+### SegmentWithAttributes and ConditioningAttributes: From metadata to conditions
+
+The `ConditioningAttributes` dataclass is the base class for metadata
+containing all attributes used for conditioning the language model.
+
+It currently supports the following types of attributes:
+* Text conditioning attributes: Dictionary of textual attributes used for text-conditioning.
+* Wav conditioning attributes: Dictionary of waveform attributes used for waveform-based
+conditioning such as the chroma conditioning.
+* JointEmbed conditioning attributes: Dictionary of text and waveform attributes
+that are expected to be represented in a shared latent space.
+
+These different types of attributes are the attributes that are processed
+by the different conditioners.
+
+`ConditioningAttributes` are extracted from metadata loaded along the audio in the datasets,
+provided that the metadata used by the dataset implements the `SegmentWithAttributes` abstraction.
+
+All metadata-enabled datasets to use for conditioning in AudioCraft inherits
+the [`audiocraft.data.info_dataset.InfoAudioDataset`](../audiocraft/data/info_audio_dataset.py) class
+and the corresponding metadata inherits and implements the `SegmentWithAttributes` abstraction.
+Refer to the [`audiocraft.data.music_dataset.MusicAudioDataset`](../audiocraft/data/music_dataset.py)
+class as an example.
+
+
+## Available conditioners
+
+### Text conditioners
+
+All text conditioners are expected to inherit from the `TextConditioner` class.
+
+AudioCraft currently provides two text conditioners:
+* The `LUTConditioner` that relies on look-up-table of embeddings learned at train time,
+and relying on either no tokenizer or a spacy tokenizer. This conditioner is particularly
+useful for simple experiments and categorical labels.
+* The `T5Conditioner` that relies on a
+[pre-trained T5 model](https://huggingface.co/docs/transformers/model_doc/t5)
+frozen or fine-tuned at train time to extract the text embeddings.
+
+### Waveform conditioners
+
+All waveform conditioners are expected to inherit from the `WaveformConditioner` class and
+consists of conditioning method that takes a waveform as input. The waveform conditioner
+must implement the logic to extract the embedding from the waveform and define the downsampling
+factor from the waveform to the resulting embedding.
+
+The `ChromaStemConditioner` conditioner is a waveform conditioner for the chroma features
+conditioning used by MusicGen. It takes a given waveform, extract relevant stems for melody
+(namely all non drums and bass stems) using a
+[pre-trained Demucs model](https://github.com/facebookresearch/demucs)
+and then extract the chromagram bins from the remaining mix of stems.
+
+### Joint embeddings conditioners
+
+We finally provide support for conditioning based on joint text and audio embeddings through
+the `JointEmbeddingConditioner` class and the `CLAPEmbeddingConditioner` that implements such
+a conditioning method relying on a [pretrained CLAP model](https://github.com/LAION-AI/CLAP).
+
+## Classifier Free Guidance
+
+We provide a Classifier Free Guidance implementation in AudioCraft. With the classifier free
+guidance dropout, all attributes are dropped with the same probability.
+
+## Attribute Dropout
+
+We further provide an attribute dropout strategy. Unlike the classifier free guidance dropout,
+the attribute dropout drops given attributes with a defined probability, allowing the model
+not to expect all conditioning signals to be provided at once.
+
+## Faster computation of conditions
+
+Conditioners that require some heavy computation on the waveform can be cached, in particular
+the `ChromaStemConditioner` or `CLAPEmbeddingConditioner`. You just need to provide the
+`cache_path` parameter to them. We recommend running dummy jobs for filling up the cache quickly.
+An example is provied in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py).
\ No newline at end of file
diff --git a/docs/DATASETS.md b/docs/DATASETS.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0890c03cf732450eb498559638c6b45d50e40c3
--- /dev/null
+++ b/docs/DATASETS.md
@@ -0,0 +1,82 @@
+# AudioCraft datasets
+
+Our dataset manifest files consist in 1-json-per-line files, potentially gzipped,
+as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio
+file and associated metadata. The manifest files are then provided in the configuration,
+as `datasource` sub-configuration. A datasource contains the pointers to the paths of
+the manifest files for each AudioCraft stage (or split) along with additional information
+(eg. maximum sample rate to use against this dataset). All the datasources are under the
+`dset` group config, with a dedicated configuration file for each dataset.
+
+## Getting started
+
+### Example
+
+See the provided example in the directory that provides a manifest to use the example dataset
+provided under the [dataset folder](../dataset/example).
+
+The manifest files are stored in the [egs folder](../egs/example).
+
+```shell
+egs/
+ example/data.json.gz
+```
+
+A datasource is defined in the configuration folder, in the dset group config for this dataset
+at [config/dset/audio/example](../config/dset/audio/example.yaml):
+
+```shell
+# @package __global__
+
+datasource:
+ max_sample_rate: 44100
+ max_channels: 2
+
+ train: egs/example
+ valid: egs/example
+ evaluate: egs/example
+ generate: egs/example
+```
+
+For proper dataset, one should create manifest for each of the splits and specify the correct path
+to the given manifest in the datasource for each split.
+
+Then, using a dataset through the configuration can be done pointing to the
+corresponding dataset configuration:
+```shell
+dset= # should match the yaml file name
+
+# for example
+dset=audio/example
+```
+
+### Creating manifest files
+
+Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use
+the following command to create new manifest files from a given folder containing audio files:
+
+```shell
+python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz
+
+# For example to generate the manifest for dset=audio/example
+# note: we don't use any split and we don't compress the jsonl file for this dummy example
+python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl
+
+# More info with: python -m audiocraft.data.audio_dataset --help
+```
+
+## Additional information
+
+### MusicDataset and metadata
+
+The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects
+the additional metadata to be stored in a JSON file that has the same path as the corresponding
+audio file, but with a `.json` extension.
+
+### SoundDataset and metadata
+
+The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset,
+the SoundDataset expects the additional metadata to be stored in a JSON file that has the same
+path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset
+supports an additional parameter pointing to an extra folder `external_metadata_source` containing
+all the JSON metadata files given they have the same filename as the audio file.
diff --git a/docs/ENCODEC.md b/docs/ENCODEC.md
new file mode 100644
index 0000000000000000000000000000000000000000..efc2bcc7ec50190b907c887b920b70fd799c6953
--- /dev/null
+++ b/docs/ENCODEC.md
@@ -0,0 +1,179 @@
+# EnCodec: High Fidelity Neural Audio Compression
+
+AudioCraft provides the training code for EnCodec, a state-of-the-art deep learning
+based audio codec supporting both mono stereo audio, presented in the
+[High Fidelity Neural Audio Compression][arxiv] paper.
+Check out our [sample page][encodec_samples].
+
+## Original EnCodec models
+
+The EnCodec models presented in High Fidelity Neural Audio Compression can be accessed
+and used with the [EnCodec repository](https://github.com/facebookresearch/encodec).
+
+**Note**: We do not guarantee compatibility between the AudioCraft and EnCodec codebases
+and released checkpoints at this stage.
+
+
+## Installation
+
+Please follow the AudioCraft installation instructions from the [README](../README.md).
+
+
+## Training
+
+The [CompressionSolver](../audiocraft/solvers/compression.py) implements the audio reconstruction
+task to train an EnCodec model. Specifically, it trains an encoder-decoder with a quantization
+bottleneck - a SEANet encoder-decoder with Residual Vector Quantization bottleneck for EnCodec -
+using a combination of objective and perceptual losses in the forms of discriminators.
+
+The default configuration matches a causal EnCodec training with at a single bandwidth.
+
+### Example configuration and grids
+
+We provide sample configuration and grids for training EnCodec models.
+
+The compression configuration are defined in
+[config/solver/compression](../config/solver/compression).
+
+The example grids are available at
+[audiocraft/grids/compression](../audiocraft/grids/compression).
+
+```shell
+# base causal encodec on monophonic audio sampled at 24 khz
+dora grid compression.encodec_base_24khz
+# encodec model used for MusicGen on monophonic audio sampled at 32 khz
+dora grid compression.encodec_musicgen_32khz
+```
+
+### Training and valid stages
+
+The model is trained using a combination of objective and perceptual losses.
+More specifically, EnCodec is trained with the MS-STFT discriminator along with
+objective losses through the use of a loss balancer to effectively weight
+the different losses, in an intuitive manner.
+
+### Evaluation stage
+
+Evaluations metrics for audio generation:
+* SI-SNR: Scale-Invariant Signal-to-Noise Ratio.
+* ViSQOL: Virtual Speech Quality Objective Listener.
+
+Note: Path to the ViSQOL binary (compiled with bazel) needs to be provided in
+order to run the ViSQOL metric on the reference and degraded signals.
+The metric is disabled by default.
+Please refer to the [metrics documentation](../METRICS.md) to learn more.
+
+### Generation stage
+
+The generation stage consists in generating the reconstructed audio from samples
+with the current model. The number of samples generated and the batch size used are
+controlled by the `dataset.generate` configuration. The output path and audio formats
+are defined in the generate stage configuration.
+
+```shell
+# generate samples every 5 epoch
+dora run solver=compression/encodec_base_24khz generate.every=5
+# run with a different dset
+dora run solver=compression/encodec_base_24khz generate.path=
+# limit the number of samples or use a different batch size
+dora grid solver=compression/encodec_base_24khz dataset.generate.num_samples=10 dataset.generate.batch_size=4
+```
+
+### Playing with the model
+
+Once you have a model trained, it is possible to get the entire solver, or just
+the trained model with the following functions:
+
+```python
+from audiocraft.solvers import CompressionSolver
+
+# If you trained a custom model with signature SIG.
+model = CompressionSolver.model_from_checkpoint('//sig/SIG')
+# If you want to get one of the pretrained models with the `//pretrained/` prefix.
+model = CompressionSolver.model_from_checkpoint('//pretrained/facebook/encodec_32khz')
+# Or load from a custom checkpoint path
+model = CompressionSolver.model_from_checkpoint('/my_checkpoints/foo/bar/checkpoint.th')
+
+
+# If you only want to use a pretrained model, you can also directly get it
+# from the CompressionModel base model class.
+from audiocraft.models import CompressionModel
+
+# Here do not put the `//pretrained/` prefix!
+model = CompressionModel.get_pretrained('facebook/encodec_32khz')
+model = CompressionModel.get_pretrained('dac_44khz')
+
+# Finally, you can also retrieve the full Solver object, with its dataloader etc.
+from audiocraft import train
+from pathlib import Path
+import logging
+import os
+import sys
+
+# uncomment the following line if you want some detailed logs when loading a Solver.
+logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+# You must always run the following function from the root directory.
+os.chdir(Path(train.__file__).parent.parent)
+
+
+# You can also get the full solver (only for your own experiments).
+# You can provide some overrides to the parameters to make things more convenient.
+solver = train.get_solver_from_sig('SIG', {'device': 'cpu', 'dataset': {'batch_size': 8}})
+solver.model
+solver.dataloaders
+```
+
+### Importing / Exporting models
+
+At the moment we do not have a definitive workflow for exporting EnCodec models, for
+instance to Hugging Face (HF). We are working on supporting automatic convertion between
+AudioCraft and Hugging Face implementations.
+
+We still have some support for fine tuning an EnCodec model coming from HF in AudioCraft,
+using for instance `continue_from=//pretrained/facebook/encodec_32k`.
+
+An AudioCraft checkpoint can be exported in a more compact format (excluding the optimizer etc.)
+using `audiocraft.utils.export.export_encodec`. For instance, you could run
+
+```python
+from audiocraft.utils import export
+from audiocraft import train
+xp = train.main.get_xp_from_sig('SIG')
+export.export_encodec(
+ xp.folder / 'checkpoint.th',
+ '/checkpoints/my_audio_lm/compression_state_dict.bin')
+
+
+from audiocraft.models import CompressionModel
+model = CompressionModel.get_pretrained('/checkpoints/my_audio_lm/compression_state_dict.bin')
+
+from audiocraft.solvers import CompressionSolver
+# The two are strictly equivalent, but this function supports also loading from non already exported models.
+model = CompressionSolver.model_from_checkpoint('//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin')
+```
+
+We will see then how to use this model as a tokenizer for MusicGen/Audio gen in the
+[MusicGen documentation](./MUSICGEN.md).
+
+### Learn more
+
+Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
+
+
+## Citation
+```
+@article{defossez2022highfi,
+ title={High Fidelity Neural Audio Compression},
+ author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
+ journal={arXiv preprint arXiv:2210.13438},
+ year={2022}
+}
+```
+
+
+## License
+
+See license information in the [README](../README.md).
+
+[arxiv]: https://arxiv.org/abs/2210.13438
+[encodec_samples]: https://ai.honu.io/papers/encodec/samples.html
diff --git a/docs/MBD.md b/docs/MBD.md
new file mode 100644
index 0000000000000000000000000000000000000000..296d08407bac9155380a48bdc9faa5798db32bcb
--- /dev/null
+++ b/docs/MBD.md
@@ -0,0 +1,117 @@
+# MultiBand Diffusion
+
+AudioCraft provides the code and models for MultiBand Diffusion, [From Discrete Tokens to High Fidelity Audio using MultiBand Diffusion][arxiv].
+MultiBand diffusion is a collection of 4 models that can decode tokens from
+EnCodec tokenizer into waveform audio.
+
+
+
+
+
+
+
+## Installation
+
+Please follow the AudioCraft installation instructions from the [README](../README.md).
+
+
+## Usage
+
+We offer a number of way to use MultiBand Diffusion:
+1. The MusicGen demo includes a toggle to try diffusion decoder. You can use the demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py), or through the [MusicGen Colab](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing).
+2. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU).
+
+## API
+
+We provide a simple API and pre-trained models for MusicGen and for EnCodec at 24 khz for 3 bitrates (1.5 kbps, 3 kbps and 6 kbps).
+
+See after a quick example for using MultiBandDiffusion with the MusicGen API:
+
+```python
+import torchaudio
+from audiocraft.models import MusicGen, MultiBandDiffusion
+from audiocraft.data.audio import audio_write
+
+model = MusicGen.get_pretrained('facebook/musicgen-melody')
+mbd = MultiBandDiffusion.get_mbd_musicgen()
+model.set_generation_params(duration=8) # generate 8 seconds.
+wav, tokens = model.generate_unconditional(4, return_tokens=True) # generates 4 unconditional audio samples and keep the tokens for MBD generation
+descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
+wav_diffusion = mbd.tokens_to_wav(tokens)
+wav, tokens = model.generate(descriptions, return_tokens=True) # generates 3 samples and keep the tokens.
+wav_diffusion = mbd.tokens_to_wav(tokens)
+melody, sr = torchaudio.load('./assets/bach.mp3')
+# Generates using the melody from the given audio and the provided descriptions, returns audio and audio tokens.
+wav, tokens = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr, return_tokens=True)
+wav_diffusion = mbd.tokens_to_wav(tokens)
+
+for idx, one_wav in enumerate(wav):
+ # Will save under {idx}.wav and {idx}_diffusion.wav, with loudness normalization at -14 db LUFS for comparing the methods.
+ audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
+ audio_write(f'{idx}_diffusion', wav_diffusion[idx].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
+```
+
+For the compression task (and to compare with [EnCodec](https://github.com/facebookresearch/encodec)):
+
+```python
+import torch
+from audiocraft.models import MultiBandDiffusion
+from encodec import EncodecModel
+from audiocraft.data.audio import audio_read, audio_write
+
+bandwidth = 3.0 # 1.5, 3.0, 6.0
+mbd = MultiBandDiffusion.get_mbd_24khz(bw=bandwidth)
+encodec = EncodecModel.get_encodec_24khz()
+
+somepath = ''
+wav, sr = audio_read(somepath)
+with torch.no_grad():
+ compressed_encodec = encodec(wav)
+ compressed_diffusion = mbd.regenerate(wav, sample_rate=sr)
+
+audio_write('sample_encodec', compressed_encodec.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True)
+audio_write('sample_diffusion', compressed_diffusion.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True)
+```
+
+
+## Training
+
+The [DiffusionSolver](../audiocraft/solvers/diffusion.py) implements our diffusion training pipeline.
+It generates waveform audio conditioned on the embeddings extracted from a pre-trained EnCodec model
+(see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model).
+
+Note that **we do NOT provide any of the datasets** used for training our diffusion models.
+We provide a dummy dataset containing just a few examples for illustrative purposes.
+
+### Example configurations and grids
+
+One can train diffusion models as described in the paper by using this [dora grid](../audiocraft/grids/diffusion/4_bands_base_32khz.py).
+```shell
+# 4 bands MBD trainning
+dora grid diffusion.4_bands_base_32khz
+```
+
+### Learn more
+
+Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
+
+
+## Citation
+
+```
+@article{sanroman2023fromdi,
+ title={From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion},
+ author={San Roman, Robin and Adi, Yossi and Deleforge, Antoine and Serizel, Romain and Synnaeve, Gabriel and Défossez, Alexandre},
+ journal={arXiv preprint arXiv:},
+ year={2023}
+}
+```
+
+
+## License
+
+See license information in the [README](../README.md).
+
+
+[arxiv]: https://dl.fbaipublicfiles.com/encodec/Diffusion/paper.pdf
+[mbd_samples]: https://ai.honu.io/papers/mbd/
diff --git a/docs/METRICS.md b/docs/METRICS.md
new file mode 100644
index 0000000000000000000000000000000000000000..e2ae9a184cbccb8bfefb4ce77afa5ddab743a051
--- /dev/null
+++ b/docs/METRICS.md
@@ -0,0 +1,127 @@
+# AudioCraft objective metrics
+
+In addition to training losses, AudioCraft provides a set of objective metrics
+for audio synthesis and audio generation. As these metrics may require
+extra dependencies and can be costly to train, they are often disabled by default.
+This section provides guidance for setting up and using these metrics in
+the AudioCraft training pipelines.
+
+## Available metrics
+
+### Audio synthesis quality metrics
+
+#### SI-SNR
+
+We provide an implementation of the Scale-Invariant Signal-to-Noise Ratio in PyTorch.
+No specific requirement is needed for this metric. Please activate the metric at the
+evaluation stage with the appropriate flag:
+
+```shell
+dora run <...> evaluate.metrics.sisnr=true
+```
+
+#### ViSQOL
+
+We provide a Python wrapper around the ViSQOL [official implementation](https://github.com/google/visqol)
+to conveniently run ViSQOL within the training pipelines.
+
+One must specify the path to the ViSQOL installation through the configuration in order
+to enable ViSQOL computations in AudioCraft:
+
+```shell
+# the first parameter is used to activate visqol computation while the second specify
+# the path to visqol's library to be used by our python wrapper
+dora run <...> evaluate.metrics.visqol=true metrics.visqol.bin=
+```
+
+See an example grid: [Compression with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py)
+
+To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
+instructions available in the [open source repository](https://github.com/google/visqol).
+
+### Audio generation metrics
+
+#### Frechet Audio Distance
+
+Similarly to ViSQOL, we use a Python wrapper around the Frechet Audio Distance
+[official implementation](https://github.com/google-research/google-research/tree/master/frechet_audio_distance)
+in TensorFlow.
+
+Note that we had to make several changes to the actual code in order to make it work.
+Please refer to the [FrechetAudioDistanceMetric](../audiocraft/metrics/fad.py) class documentation
+for more details. We do not plan to provide further support in obtaining a working setup for the
+Frechet Audio Distance at this stage.
+
+```shell
+# the first parameter is used to activate FAD metric computation while the second specify
+# the path to FAD library to be used by our python wrapper
+dora run <...> evaluate.metrics.fad=true metrics.fad.bin=
+```
+
+See an example grid: [Evaluation with FAD](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py)
+
+#### Kullback-Leibler Divergence
+
+We provide a PyTorch implementation of the Kullback-Leibler Divergence computed over the probabilities
+of the labels obtained by a state-of-the-art audio classifier. We provide our implementation of the KLD
+using the [PaSST classifier](https://github.com/kkoutini/PaSST).
+
+In order to use the KLD metric over PaSST, you must install the PaSST library as an extra dependency:
+```shell
+pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
+```
+
+Then similarly, you can use the metric activating the corresponding flag:
+
+```shell
+# one could extend the kld metric with additional audio classifier models that can then be picked through the configuration
+dora run <...> evaluate.metrics.kld=true metrics.kld.model=passt
+```
+
+#### Text consistency
+
+We provide a text-consistency metric, similarly to the MuLan Cycle Consistency from
+[MusicLM](https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in
+[Make-An-Audio](https://arxiv.org/pdf/2301.12661v1.pdf).
+More specifically, we provide a PyTorch implementation of a Text consistency metric
+relying on a pre-trained [Contrastive Language-Audio Pretraining (CLAP)](https://github.com/LAION-AI/CLAP).
+
+Please install the CLAP library as an extra dependency prior to using the metric:
+```shell
+pip install laion_clap
+```
+
+Then similarly, you can use the metric activating the corresponding flag:
+
+```shell
+# one could extend the text consistency metric with additional audio classifier models that can then be picked through the configuration
+dora run ... evaluate.metrics.text_consistency=true metrics.text_consistency.model=clap
+```
+
+Note that the text consistency metric based on CLAP will require the CLAP checkpoint to be
+provided in the configuration.
+
+#### Chroma cosine similarity
+
+Finally, as introduced in MusicGen, we provide a Chroma Cosine Similarity metric in PyTorch.
+No specific requirement is needed for this metric. Please activate the metric at the
+evaluation stage with the appropriate flag:
+
+```shell
+dora run ... evaluate.metrics.chroma_cosine=true
+```
+
+#### Comparing against reconstructed audio
+
+For all the above audio generation metrics, we offer the option to compute the metric on the reconstructed audio
+fed in EnCodec instead of the generated sample using the flag `.use_gt=true`.
+
+## Example usage
+
+You will find example of configuration for the different metrics introduced above in:
+* The [musicgen's default solver](../config/solver/musicgen/default.yaml) for all audio generation metrics
+* The [compression's default solver](../config/solver/compression/default.yaml) for all audio synthesis metrics
+
+Similarly, we provide different examples in our grids:
+* [Evaluation with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py)
+* [Evaluation with FAD and others](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py)
diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md
new file mode 100644
index 0000000000000000000000000000000000000000..606ce85808a428432f4e77564fb97dcade3851a3
--- /dev/null
+++ b/docs/MUSICGEN.md
@@ -0,0 +1,362 @@
+# MusicGen: Simple and Controllable Music Generation
+
+AudioCraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv].
+MusicGen is a single stage auto-regressive Transformer model trained over a 32kHz
+EnCodec tokenizer with 4 codebooks sampled at 50 Hz.
+Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require
+a self-supervised semantic representation, and it generates all 4 codebooks in one pass. By introducing
+a small delay between the codebooks, we show we can predict them in parallel, thus having only 50 auto-regressive
+steps per second of audio.
+Check out our [sample page][musicgen_samples] or test the available demo!
+
+
+
+
+
+
+
+
+
+We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset
+of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
+
+
+## Model Card
+
+See [the model card](../model_cards/MUSICGEN_MODEL_CARD.md).
+
+
+## Installation
+
+Please follow the AudioCraft installation instructions from the [README](../README.md).
+
+AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters).
+
+## Usage
+
+We offer a number of way to interact with MusicGen:
+1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co/spaces/facebook/MusicGen)
+(huge thanks to all the HF team for their support).
+2. You can run the extended demo on a Colab:
+[colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing)
+3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py).
+4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU).
+5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab)
+which is regularly updated with contributions from @camenduru and the community.
+
+
+## API
+
+We provide a simple API and 4 pre-trained models. The pre trained models are:
+- `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
+- `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
+- `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
+- `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
+
+We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model.
+In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
+GPUs will be able to generate short sequences, or longer sequences with the `facebook/musicgen-small` model.
+
+See after a quick example for using the API.
+
+```python
+import torchaudio
+from audiocraft.models import MusicGen
+from audiocraft.data.audio import audio_write
+
+model = MusicGen.get_pretrained('facebook/musicgen-melody')
+model.set_generation_params(duration=8) # generate 8 seconds.
+wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
+descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
+wav = model.generate(descriptions) # generates 3 samples.
+
+melody, sr = torchaudio.load('./assets/bach.mp3')
+# generates using the melody from the given audio and the provided descriptions.
+wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
+
+for idx, one_wav in enumerate(wav):
+ # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
+ audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
+```
+
+## 🤗 Transformers Usage
+
+MusicGen is available in the 🤗 Transformers library from version 4.31.0 onwards, requiring minimal dependencies
+and additional packages. Steps to get started:
+
+1. First install the 🤗 [Transformers library](https://github.com/huggingface/transformers) from main:
+
+```shell
+pip install git+https://github.com/huggingface/transformers.git
+```
+
+2. Run the following Python code to generate text-conditional audio samples:
+
+```py
+from transformers import AutoProcessor, MusicgenForConditionalGeneration
+
+
+processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+inputs = processor(
+ text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+ padding=True,
+ return_tensors="pt",
+)
+
+audio_values = model.generate(**inputs, max_new_tokens=256)
+```
+
+3. Listen to the audio samples either in an ipynb notebook:
+
+```py
+from IPython.display import Audio
+
+sampling_rate = model.config.audio_encoder.sampling_rate
+Audio(audio_values[0].numpy(), rate=sampling_rate)
+```
+
+Or save them as a `.wav` file using a third-party library, e.g. `scipy`:
+
+```py
+import scipy
+
+sampling_rate = model.config.audio_encoder.sampling_rate
+scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
+```
+
+For more details on using the MusicGen model for inference using the 🤗 Transformers library, refer to the
+[MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on
+[Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb).
+
+
+## Training
+
+The [MusicGenSolver](../audiocraft/solvers/musicgen.py) implements MusicGen's training pipeline.
+It defines an autoregressive language modeling task over multiple streams of discrete tokens
+extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md)
+for more details on how to train such model).
+
+Note that **we do NOT provide any of the datasets** used for training MusicGen.
+We provide a dummy dataset containing just a few examples for illustrative purposes.
+
+Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section.
+
+### Example configurations and grids
+
+We provide configurations to reproduce the released models and our research.
+MusicGen solvers configuration are available in [config/solver/musicgen](../config/solver/musicgen),
+in particular:
+* MusicGen base model for text-to-music:
+[`solver=musicgen/musicgen_base_32khz`](../config/solver/musicgen/musicgen_base_32khz.yaml)
+* MusicGen model with chromagram-conditioning support:
+[`solver=musicgen/musicgen_melody_32khz`](../config/solver/musicgen/musicgen_melody_32khz.yaml)
+
+We provide 3 different scales, e.g. `model/lm/model_scale=small` (300M), or `medium` (1.5B), and `large` (3.3B).
+
+Please find some example grids to train MusicGen at
+[audiocraft/grids/musicgen](../audiocraft/grids/musicgen/).
+
+```shell
+# text-to-music
+dora grid musicgen.musicgen_base_32khz --dry_run --init
+# melody-guided music generation
+dora grid musicgen.musicgen_melody_base_32khz --dry_run --init
+# Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup.
+```
+
+### Music dataset and metadata
+
+MusicGen's underlying dataset is an AudioDataset augmented with music-specific metadata.
+The MusicGen dataset implementation expects the metadata to be available as `.json` files
+at the same location as the audio files. Learn more in the [datasets section](./DATASETS.md).
+
+
+### Audio tokenizers
+
+We support a number of audio tokenizers: either pretrained EnCodec models, [DAC](https://github.com/descriptinc/descript-audio-codec), or your own models.
+The tokenizer is controlled with the setting `compression_model_checkpoint`.
+For instance,
+
+```bash
+# Using the 32kHz EnCodec trained on music
+dora run solver=musicgen/debug \
+ compression_model_checkpoint=//pretrained/facebook/encodec_32khz \
+ transformer_lm.n_q=4 transformer_lm.card=2048
+
+# Using DAC
+dora run solver=musicgen/debug \
+ compression_model_checkpoint=//pretrained/dac_44khz \
+ transformer_lm.n_q=9 transformer_lm.card=1024 \
+ 'codebooks_pattern.delay.delays=[0,1,2,3,4,5,6,7,8]'
+
+# Using your own model after export (see ENCODEC.md)
+dora run solver=musicgen/debug \
+ compression_model_checkpoint=//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin \
+ transformer_lm.n_q=... transformer_lm.card=...
+
+# Using your own model from its training checkpoint.
+dora run solver=musicgen/debug \
+ compression_model_checkpoint=//sig/SIG \ # where SIG is the Dora signature of the EnCodec XP.
+ transformer_lm.n_q=... transformer_lm.card=...
+```
+
+**Warning:** you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. .
+
+
+### Fine tuning existing models
+
+You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular
+
+```bash
+# Using pretrained MusicGen model.
+dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//pretrained/facebook/musicgen-medium conditioner=text2music
+
+# Using another model you already trained with a Dora signature SIG.
+dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//sig/SIG conditioner=text2music
+
+# Or providing manually a path
+dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=/checkpoints/my_other_xp/checkpoint.th
+```
+
+**Warning:** You are responsible for selecting the other parameters accordingly, in a way that make it compatible
+ with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`.
+
+**Warning:** We currently do not support fine tuning a model with slightly different layers. If you decide
+ to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`.
+ If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict
+ `{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix.
+
+### Caching of EnCodec tokens
+
+It is possible to precompute the EnCodec tokens and other metadata.
+An example of generating and using this cache provided in the [musicgen.musicgen_base_cached_32khz grid](../audiocraft/grids/musicgen/musicgen_base_cached_32khz.py).
+
+### Evaluation stage
+
+By default, evaluation stage is also computing the cross-entropy and the perplexity over the
+evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run
+or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md)
+for more details on the requirements for each metric.
+
+We provide an off-the-shelf configuration to enable running the objective metrics
+for audio generation in
+[config/solver/musicgen/evaluation/objective_eval](../config/solver/musicgen/evaluation/objective_eval.yaml).
+
+One can then activate evaluation the following way:
+```shell
+# using the configuration
+dora run solver=musicgen/debug solver/musicgen/evaluation=objective_eval
+# specifying each of the fields, e.g. to activate KL computation
+dora run solver=musicgen/debug evaluate.metrics.kld=true
+```
+
+See [an example evaluation grid](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py).
+
+### Generation stage
+
+The generation stage allows to generate samples conditionally and/or unconditionally and to perform
+audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling
+from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples
+generated and the batch size used are controlled by the `dataset.generate` configuration
+while the other generation parameters are defined in `generate.lm`.
+
+```shell
+# control sampling parameters
+dora run solver=musicgen/debug generate.lm.gen_duration=10 generate.lm.use_sampling=true generate.lm.top_k=15
+```
+
+#### Listening to samples
+
+Note that generation happens automatically every 25 epochs. You can easily access and
+compare samples between models (as long as they are trained) on the same dataset using the
+MOS tool. For that first `pip install Flask gunicorn`. Then
+```
+gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile -
+```
+And access the tool at [https://127.0.0.1:8895](https://127.0.0.1:8895).
+
+### Playing with the model
+
+Once you have launched some experiments, you can easily get access
+to the Solver with the latest trained model using the following snippet.
+
+```python
+from audiocraft.solvers.musicgen import MusicGen
+
+solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8)
+solver.model
+solver.dataloaders
+```
+
+### Importing / Exporting models
+
+We do not support currently loading a model from the Hugging Face implementation or exporting to it.
+If you want to export your model in a way that is compatible with `audiocraft.models.MusicGen`
+API, you can run:
+
+```python
+from audiocraft.utils import export
+from audiocraft import train
+xp = train.main.get_xp_from_sig('SIG_OF_LM')
+export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin')
+# You also need to bundle the EnCodec model you used !!
+## Case 1) you trained your own
+xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC')
+export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin')
+## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix.
+## This will actually not dump the actual model, simply a pointer to the right model to download.
+export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin')
+```
+
+Now you can load your custom model with:
+```python
+import audiocraft.models
+musicgen = audiocraft.models.MusicGen.get_pretrained('/checkpoints/my_audio_lm/')
+```
+
+
+### Learn more
+
+Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
+
+## FAQ
+
+#### I need help on Windows
+
+@FurkanGozukara made a complete tutorial for [AudioCraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4)
+
+#### I need help for running the demo on Colab
+
+Check [@camenduru tutorial on YouTube](https://www.youtube.com/watch?v=EGfxuTy9Eeo).
+
+#### What are top-k, top-p, temperature and classifier-free guidance?
+
+Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt).
+
+#### Should I use FSDP or autocast ?
+
+The two are mutually exclusive (because FSDP does autocast on its own).
+You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU.
+FSDP makes everything more complex but will free up some memory for the actual
+activations by sharding the optimizer state.
+
+## Citation
+```
+@article{copet2023simple,
+ title={Simple and Controllable Music Generation},
+ author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
+ year={2023},
+ journal={arXiv preprint arXiv:2306.05284},
+}
+```
+
+
+## License
+
+See license information in the [model card](../model_cards/MUSICGEN_MODEL_CARD.md).
+
+
+[arxiv]: https://arxiv.org/abs/2306.05284
+[musicgen_samples]: https://ai.honu.io/papers/musicgen/
diff --git a/docs/TRAINING.md b/docs/TRAINING.md
new file mode 100644
index 0000000000000000000000000000000000000000..148de295f2ddfed2e4e893576bf31e1485038b8e
--- /dev/null
+++ b/docs/TRAINING.md
@@ -0,0 +1,312 @@
+# AudioCraft training pipelines
+
+AudioCraft training pipelines are built on top of PyTorch as our core deep learning library
+and [Flashy](https://github.com/facebookresearch/flashy) as our training pipeline design library,
+and [Dora](https://github.com/facebookresearch/dora) as our experiment manager.
+AudioCraft training pipelines are designed to be research and experiment-friendly.
+
+
+## Environment setup
+
+For the base installation, follow the instructions from the [README.md](../README.md).
+Below are some additional instructions for setting up environment to train new models.
+
+### Team and cluster configuration
+
+In order to support multiple teams and clusters, AudioCraft uses an environment configuration.
+The team configuration allows to specify cluster-specific configurations (e.g. SLURM configuration),
+or convenient mapping of paths between the supported environments.
+
+Each team can have a yaml file under the [configuration folder](../config). To select a team set the
+`AUDIOCRAFT_TEAM` environment variable to a valid team name (e.g. `labs` or `default`):
+```shell
+conda env config vars set AUDIOCRAFT_TEAM=default
+```
+
+Alternatively, you can add it to your `.bashrc`:
+```shell
+export AUDIOCRAFT_TEAM=default
+```
+
+If not defined, the environment will default to the `default` team.
+
+The cluster is automatically detected, but it is also possible to override it by setting
+the `AUDIOCRAFT_CLUSTER` environment variable.
+
+Based on this team and cluster, the environment is then configured with:
+* The dora experiment outputs directory.
+* The available slurm partitions: categorized by global and team.
+* A shared reference directory: In order to facilitate sharing research models while remaining
+agnostic to the used compute cluster, we created the `//reference` symbol that can be used in
+YAML config to point to a defined reference folder containing shared checkpoints
+(e.g. baselines, models for evaluation...).
+
+**Important:** The default output dir for trained models and checkpoints is under `/tmp/`. This is suitable
+only for quick testing. If you are doing anything serious you MUST edit the file `default.yaml` and
+properly set the `dora_dir` entries.
+
+#### Overriding environment configurations
+
+You can set the following environmet variables to bypass the team's environment configuration:
+* `AUDIOCRAFT_CONFIG`: absolute path to a team config yaml file.
+* `AUDIOCRAFT_DORA_DIR`: absolute path to a custom dora directory.
+* `AUDIOCRAFT_REFERENCE_DIR`: absolute path to the shared reference directory.
+
+## Training pipelines
+
+Each task supported in AudioCraft has its own training pipeline and dedicated solver.
+Learn more about solvers and key designs around AudioCraft training pipeline below.
+Please refer to the documentation of each task and model for specific information on a given task.
+
+
+### Solvers
+
+The core training component in AudioCraft is the solver. A solver holds the definition
+of how to solve a given task: It implements the training pipeline logic, combining the datasets,
+model, optimization criterion and components and the full training loop. We refer the reader
+to [Flashy](https://github.com/facebookresearch/flashy) for core principles around solvers.
+
+AudioCraft proposes an initial solver, the `StandardSolver` that is used as the base implementation
+for downstream solvers. This standard solver provides a nice base management of logging,
+checkpoints loading/saving, xp restoration, etc. on top of the base Flashy implementation.
+In AudioCraft, we made the assumption that all tasks are following the same set of stages:
+train, valid, evaluate and generation, each relying on a dedicated dataset.
+
+Each solver is responsible for defining the task to solve and the associated stages
+of the training loop in order to leave the full ownership of the training pipeline
+to the researchers. This includes loading the datasets, building the model and
+optimisation components, registering them and defining the execution of each stage.
+To create a new solver for a given task, one should extend the StandardSolver
+and define each stage of the training loop. One can further customise its own solver
+starting from scratch instead of inheriting from the standard solver.
+
+```python
+from . import base
+from .. import optim
+
+
+class MyNewSolver(base.StandardSolver):
+
+ def __init__(self, cfg: omegaconf.DictConfig):
+ super().__init__(cfg)
+ # one can add custom attributes to the solver
+ self.criterion = torch.nn.L1Loss()
+
+ def best_metric(self):
+ # here optionally specify which metric to use to keep track of best state
+ return 'loss'
+
+ def build_model(self):
+ # here you can instantiate your models and optimization related objects
+ # this method will be called by the StandardSolver init method
+ self.model = ...
+ # the self.cfg attribute contains the raw configuration
+ self.optimizer = optim.build_optimizer(self.model.parameters(), self.cfg.optim)
+ # don't forget to register the states you'd like to include in your checkpoints!
+ self.register_stateful('model', 'optimizer')
+ # keep the model best state based on the best value achieved at validation for the given best_metric
+ self.register_best('model')
+ # if you want to add EMA around the model
+ self.register_ema('model')
+
+ def build_dataloaders(self):
+ # here you can instantiate your dataloaders
+ # this method will be called by the StandardSolver init method
+ self.dataloaders = ...
+
+ ...
+
+ # For both train and valid stages, the StandardSolver relies on
+ # a share common_train_valid implementation that is in charge of
+ # accessing the appropriate loader, iterate over the data up to
+ # the specified number of updates_per_epoch, run the ``run_step``
+ # function that you need to implement to specify the behavior
+ # and finally update the EMA and collect the metrics properly.
+ @abstractmethod
+ def run_step(self, idx: int, batch: tp.Any, metrics: dict):
+ """Perform one training or valid step on a given batch.
+ """
+ ... # provide your implementation of the solver over a batch
+
+ def train(self):
+ """Train stage.
+ """
+ return self.common_train_valid('train')
+
+ def valid(self):
+ """Valid stage.
+ """
+ return self.common_train_valid('valid')
+
+ @abstractmethod
+ def evaluate(self):
+ """Evaluate stage.
+ """
+ ... # provide your implementation here!
+
+ @abstractmethod
+ def generate(self):
+ """Generate stage.
+ """
+ ... # provide your implementation here!
+```
+
+### About Epochs
+
+AudioCraft Solvers uses the concept of Epoch. One epoch doesn't necessarily mean one pass over the entire
+dataset, but instead represent the smallest amount of computation that we want to work with before checkpointing.
+Typically, we find that having an Epoch time around 30min is ideal both in terms of safety (checkpointing often enough)
+and getting updates often enough. One Epoch is at least a `train` stage that lasts for `optim.updates_per_epoch` (2000 by default),
+and a `valid` stage. You can control how long the valid stage takes with `dataset.valid.num_samples`.
+Other stages (`evaluate`, `generate`) will only happen every X epochs, as given by `evaluate.every` and `generate.every`).
+
+
+### Models
+
+In AudioCraft, a model is a container object that wraps one or more torch modules together
+with potential processing logic to use in a solver. For example, a model would wrap an encoder module,
+a quantisation bottleneck module, a decoder and some tensor processing logic. Each of the previous components
+can be considered as a small « model unit » on its own but the container model is a practical component
+to manipulate and train a set of modules together.
+
+### Datasets
+
+See the [dedicated documentation on datasets](./DATASETS.md).
+
+### Metrics
+
+See the [dedicated documentation on metrics](./METRICS.md).
+
+### Conditioners
+
+AudioCraft language models can be conditioned in various ways and the codebase offers a modular implementation
+of different conditioners that can be potentially combined together.
+Learn more in the [dedicated documentation on conditioning](./CONDITIONING.md).
+
+### Configuration
+
+AudioCraft's configuration is defined in yaml files and the framework relies on
+[hydra](https://hydra.cc/docs/intro/) and [omegaconf](https://omegaconf.readthedocs.io/) to parse
+and manipulate the configuration through Dora.
+
+##### :warning: Important considerations around configurations
+
+Our configuration management relies on Hydra and the concept of group configs to structure
+and compose configurations. Updating the root default configuration files will then have
+an impact on all solvers and tasks.
+**One should never change the default configuration files. Instead they should use Hydra config groups in order to store custom configuration.**
+Once this configuration is created and used for running experiments, you should not edit it anymore.
+
+Note that as we are using Dora as our experiment manager, all our experiment tracking is based on
+signatures computed from delta between configurations.
+**One must therefore ensure backward compatibilty of the configuration at all time.**
+See [Dora's README](https://github.com/facebookresearch/dora) and the
+[section below introduction Dora](#running-experiments-with-dora).
+
+##### Configuration structure
+
+The configuration is organized in config groups:
+* `conditioner`: default values for conditioning modules.
+* `dset`: contains all data source related information (paths to manifest files
+and metadata for a given dataset).
+* `model`: contains configuration for each model defined in AudioCraft and configurations
+for different variants of models.
+* `solver`: contains the default configuration for each solver as well as configuration
+for each solver task, combining all the above components.
+* `teams`: contains the cluster configuration per teams. See environment setup for more details.
+
+The `config.yaml` file is the main configuration that composes the above groups
+and contains default configuration for AudioCraft.
+
+##### Solver's core configuration structure
+
+The core configuration structure shared across solver is available in `solvers/default.yaml`.
+
+##### Other configuration modules
+
+AudioCraft configuration contains the different setups we used for our research and publications.
+
+## Running experiments with Dora
+
+### Launching jobs
+
+Try launching jobs for different tasks locally with dora run:
+
+```shell
+# run compression task with lightweight encodec
+dora run solver=compression/debug
+```
+
+Most of the time, the jobs are launched through dora grids, for example:
+
+```shell
+# run compression task through debug grid
+dora grid compression.debug
+```
+
+Learn more about running experiments with Dora below.
+
+### A small introduction to Dora
+
+[Dora](https://github.com/facebookresearch/dora) is the experiment manager tool used in AudioCraft.
+Check out the README to learn how Dora works. Here is a quick summary of what to know:
+* An XP is a unique set of hyper-parameters with a given signature. The signature is a hash
+of those hyper-parameters. We always refer to an XP with its signature, e.g. 9357e12e. We will see
+after that one can retrieve the hyper-params and re-rerun it in a single command.
+* In fact, the hash is defined as a delta between the base config and the one obtained
+with the config overrides you passed from the command line. This means you must never change
+the `conf/**.yaml` files directly., except for editing things like paths. Changing the default values
+in the config files means the XP signature won't reflect that change, and wrong checkpoints might be reused.
+I know, this is annoying, but the reason is that otherwise, any change to the config file would mean
+that all XPs ran so far would see their signature change.
+
+#### Dora commands
+
+```shell
+dora info -f 81de367c # this will show the hyper-parameter used by a specific XP.
+ # Be careful some overrides might present twice, and the right most one
+ # will give you the right value for it.
+
+dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c.
+ # `-d` is for distributed, it will use all available GPUs.
+
+dora run -d -f 81de367c dataset.batch_size=32 # start from the config of XP 81de367c but change some hyper-params.
+ # This will give you a new XP with a new signature (e.g. 3fe9c332).
+
+dora info -f SIG -t # will tail the log (if the XP has scheduled).
+# if you need to access the logs of the process for rank > 0, in particular because a crash didn't happen in the main
+# process, then use `dora info -f SIG` to get the main log name (finished into something like `/5037674_0_0_log.out`)
+# and worker K can accessed as `/5037674_0_{K}_log.out`.
+# This is only for scheduled jobs, for local distributed runs with `-d`, then you should go into the XP folder,
+# and look for `worker_{K}.log` logs.
+```
+
+An XP runs from a specific folder based on its signature, under the
+`//experiments/audiocraft/outputs/` folder.
+You can safely interrupt a training and resume it, it will reuse any existing checkpoint,
+as it will reuse the same folder. If you made some change to the code and need to ignore
+a previous checkpoint you can use `dora run --clear [RUN ARGS]`.
+
+If you have a Slurm cluster, you can also use the dora grid command, e.g.
+
+```shell
+# run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py`
+dora grid my_grid_folder.my_grid_name
+# Run the following will simply display the grid and also initialized the Dora experiments database.
+# You can then simply refer to a config using its signature (e.g. as `dora run -f SIG`).
+dora grid my_grid_folder.my_grid_name --dry_run --init
+```
+
+Please refer to the [Dora documentation](https://github.com/facebookresearch/dora) for more information.
+
+
+#### Clearing up past experiments
+
+```shell
+# This will cancel all the XPs and delete their folder and checkpoints.
+# It will then reschedule them starting from scratch.
+dora grid my_grid_folder.my_grid_name --clear
+# The following will delete the folder and checkpoint for a single XP,
+# and then run it afresh.
+dora run [-f BASE_SIG] [ARGS] --clear
+```
diff --git a/egs/example/data.jsonl b/egs/example/data.jsonl
new file mode 100644
index 0000000000000000000000000000000000000000..63c3c333daa3418f52f952f9d018ccedee017899
--- /dev/null
+++ b/egs/example/data.jsonl
@@ -0,0 +1,2 @@
+{"path": "dataset/example/electro_1.mp3", "duration": 15.024, "sample_rate": 48000, "amplitude": null, "weight": null, "info_path": null}
+{"path": "dataset/example/electro_2.mp3", "duration": 20.035918367346937, "sample_rate": 44100, "amplitude": null, "weight": null, "info_path": null}
diff --git a/model_cards/AUDIOGEN_MODEL_CARD.md b/model_cards/AUDIOGEN_MODEL_CARD.md
new file mode 100644
index 0000000000000000000000000000000000000000..92decf5e16e05ce0c2e72af8aa6728b5186c6882
--- /dev/null
+++ b/model_cards/AUDIOGEN_MODEL_CARD.md
@@ -0,0 +1,79 @@
+# AudioGen Model Card
+
+## Model details
+**Organization developing the model:** The FAIR team of Meta AI.
+
+**Model date:** This version of AudioGen was trained between July 2023 and August 2023.
+
+**Model version:** This is version 2 of the model, not to be confused with the original AudioGen model published in ["AudioGen: Textually Guided Audio Generation"][audiogen].
+In this version (v2), AudioGen was trained on the same data, but with some other differences:
+1. This model was trained on 10 seconds (vs. 5 seconds in v1).
+2. The discrete representation used under the hood is extracted using a retrained EnCodec model on the environmental sound data, following the EnCodec setup detailed in the ["Simple and Controllable Music Generation" paper][musicgen].
+3. No audio mixing augmentations.
+
+**Model type:** AudioGen consists of an EnCodec model for audio tokenization, and an auto-regressive language model based on the transformer architecture for audio modeling. The released model has 1.5B parameters.
+
+**Paper or resource for more information:** More information can be found in the paper [AudioGen: Textually Guided Audio Generation](https://arxiv.org/abs/2209.15352).
+
+**Citation details:** See [AudioGen paper][audiogen]
+
+**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
+
+**Where to send questions or comments about the model:** Questions and comments about AudioGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
+
+## Intended use
+**Primary intended use:** The primary use of AudioGen is research on AI-based audio generation, including:
+- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
+- Generation of sound guided by text to understand current abilities of generative AI models by machine learning amateurs
+
+**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
+
+**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate audio pieces that create hostile or alienating environments for people. This includes generating audio that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
+
+## Metrics
+
+**Models performance measures:** We used the following objective measure to evaluate the model on a standard audio benchmark:
+- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
+- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
+
+Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
+- Overall quality of the audio samples;
+- Text relevance to the provided text input;
+
+More details on performance measures and human studies can be found in the paper.
+
+**Decision thresholds:** Not applicable.
+
+## Evaluation datasets
+
+The model was evaluated on the [AudioCaps benchmark](https://audiocaps.github.io/).
+
+## Training datasets
+
+The model was trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects).
+
+## Evaluation results
+
+Below are the objective metrics obtained with the released model on AudioCaps (consisting of 10-second long samples). Note that the model differs from the original AudioGen model introduced in the paper, hence the difference in the metrics.
+
+| Model | Frechet Audio Distance | KLD | Text consistency |
+|---|---|---|---|
+| facebook/audiogen-medium | 1.77 | 1.41 | 0.299 |
+
+More information can be found in the paper [AudioGen: Textually Guided Audio Generation][audiogen], in the Experiments section.
+
+## Limitations and biases
+
+**Limitations:**
+- The model is not able to generate realistic vocals.
+- The model has been trained with English descriptions and will not perform as well in other languages.
+- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
+
+**Biases:** The datasets used for training may be lacking of diversity and are not representative of all possible sound events. The generated samples from the model will reflect the biases from the training data.
+
+**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
+
+**Use cases:** Users must be aware of the biases, limitations and risks of the model. AudioGen is a model developed for artificial intelligence research on audio generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
+
+[musicgen]: https://arxiv.org/abs/2306.05284
+[audiogen]: https://arxiv.org/abs/2209.15352
diff --git a/model_cards/MUSICGEN_MODEL_CARD.md b/model_cards/MUSICGEN_MODEL_CARD.md
new file mode 100644
index 0000000000000000000000000000000000000000..10ba9f9790841be06cd3e459cf667c1af6291343
--- /dev/null
+++ b/model_cards/MUSICGEN_MODEL_CARD.md
@@ -0,0 +1,90 @@
+# MusicGen Model Card
+
+## Model details
+
+**Organization developing the model:** The FAIR team of Meta AI.
+
+**Model date:** MusicGen was trained between April 2023 and May 2023.
+
+**Model version:** This is the version 1 of the model.
+
+**Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation.
+
+**Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
+
+**Citation details:** See [our paper][arxiv]
+
+**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
+
+**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
+
+## Intended use
+**Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
+
+- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
+- Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs
+
+**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
+
+**Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
+
+## Metrics
+
+**Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark:
+
+- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
+- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
+- CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model
+
+Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
+
+- Overall quality of the music samples;
+- Text relevance to the provided text input;
+- Adherence to the melody for melody-guided music generation.
+
+More details on performance measures and human studies can be found in the paper.
+
+**Decision thresholds:** Not applicable.
+
+## Evaluation datasets
+
+The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set.
+
+## Training datasets
+
+The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
+
+## Evaluation results
+
+Below are the objective metrics obtained on MusicCaps with the released model. Note that for the publicly released models, we had all the datasets go through a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs), in order to keep only the instrumental part. This explains the difference in objective metrics with the models used in the paper.
+
+| Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity |
+|---|---|---|---|---|
+| facebook/musicgen-small | 4.88 | 1.28 | 0.27 | - |
+| facebook/musicgen-medium | 5.14 | 1.24 | 0.28 | - |
+| facebook/musicgen-large | 5.48 | 1.22 | 0.28 | - |
+| facebook/musicgen-melody | 4.93 | 1.26 | 0.27 | 0.44 |
+
+More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section.
+
+## Limitations and biases
+
+**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
+
+**Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs).
+
+**Limitations:**
+
+- The model is not able to generate realistic vocals.
+- The model has been trained with English descriptions and will not perform as well in other languages.
+- The model does not perform equally well for all music styles and cultures.
+- The model sometimes generates end of songs, collapsing to silence.
+- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
+
+**Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive.
+
+**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
+
+**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
+
+[arxiv]: https://arxiv.org/abs/2306.05284
diff --git a/models/Put your models here.txt b/models/Put your models here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a0c20400cdb62f7975c0090bfc83a7259d114eb1
--- /dev/null
+++ b/models/Put your models here.txt
@@ -0,0 +1 @@
+nothing here
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000000000000000000000000000000000000..6ab60f2fd7545c803fca221614704a075b8f2188
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,4 @@
+[mypy]
+
+[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*]
+ignore_missing_imports = True
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..74c2820afd5c9749ce25892b07209b4032318b16
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+# please make sure you have already a pytorch install that is cuda enabled!
+av
+einops
+flashy>=0.0.1
+hydra-core>=1.1
+hydra_colorlog
+julius
+num2words
+numpy
+sentencepiece
+spacy==3.5.2
+torch>=2.0.0
+torchaudio>=2.0.0
+huggingface_hub
+tqdm
+transformers>=4.31.0 # need Encodec there.
+xformers
+demucs
+librosa
+gradio
+torchmetrics
+encodec
+pytaglib
\ No newline at end of file
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/scripts/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/scripts/mos.py b/scripts/mos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a711c9ece23e72ed3a07032c7834ef7c56ab4f11
--- /dev/null
+++ b/scripts/mos.py
@@ -0,0 +1,286 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+"""
+To run this script, from the root of the repo. Make sure to have Flask installed
+
+ FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567
+ # or if you have gunicorn
+ gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile -
+
+"""
+from collections import defaultdict
+from functools import wraps
+from hashlib import sha1
+import json
+import math
+from pathlib import Path
+import random
+import typing as tp
+
+from flask import Flask, redirect, render_template, request, session, url_for
+
+from audiocraft import train
+from audiocraft.utils.samples.manager import get_samples_for_xps
+
+
+SAMPLES_PER_PAGE = 8
+MAX_RATING = 5
+storage = Path(train.main.dora.dir / 'mos_storage')
+storage.mkdir(exist_ok=True)
+surveys = storage / 'surveys'
+surveys.mkdir(exist_ok=True)
+magma_root = Path(train.__file__).parent.parent
+app = Flask('mos', static_folder=str(magma_root / 'scripts/static'),
+ template_folder=str(magma_root / 'scripts/templates'))
+app.secret_key = b'audiocraft makes the best songs'
+
+
+def normalize_path(path: Path):
+ """Just to make path a bit nicer, make them relative to the Dora root dir.
+ """
+ path = path.resolve()
+ dora_dir = train.main.dora.dir.resolve() / 'xps'
+ return path.relative_to(dora_dir)
+
+
+def get_full_path(normalized_path: Path):
+ """Revert `normalize_path`.
+ """
+ return train.main.dora.dir.resolve() / 'xps' / normalized_path
+
+
+def get_signature(xps: tp.List[str]):
+ """Return a signature for a list of XP signatures.
+ """
+ return sha1(json.dumps(xps).encode()).hexdigest()[:10]
+
+
+def ensure_logged(func):
+ """Ensure user is logged in.
+ """
+ @wraps(func)
+ def _wrapped(*args, **kwargs):
+ user = session.get('user')
+ if user is None:
+ return redirect(url_for('login', redirect_to=request.url))
+ return func(*args, **kwargs)
+ return _wrapped
+
+
+@app.route('/login', methods=['GET', 'POST'])
+def login():
+ """Login user if not already, then redirect.
+ """
+ user = session.get('user')
+ if user is None:
+ error = None
+ if request.method == 'POST':
+ user = request.form['user']
+ if not user:
+ error = 'User cannot be empty'
+ if user is None or error:
+ return render_template('login.html', error=error)
+ assert user
+ session['user'] = user
+ redirect_to = request.args.get('redirect_to')
+ if redirect_to is None:
+ redirect_to = url_for('index')
+ return redirect(redirect_to)
+
+
+@app.route('/', methods=['GET', 'POST'])
+@ensure_logged
+def index():
+ """Offer to create a new study.
+ """
+ errors = []
+ if request.method == 'POST':
+ xps_or_grids = [part.strip() for part in request.form['xps'].split()]
+ xps = set()
+ for xp_or_grid in xps_or_grids:
+ xp_path = train.main.dora.dir / 'xps' / xp_or_grid
+ if xp_path.exists():
+ xps.add(xp_or_grid)
+ continue
+ grid_path = train.main.dora.dir / 'grids' / xp_or_grid
+ if grid_path.exists():
+ for child in grid_path.iterdir():
+ if child.is_symlink():
+ xps.add(child.name)
+ continue
+ errors.append(f'{xp_or_grid} is neither an XP nor a grid!')
+ assert xps or errors
+ blind = 'true' if request.form.get('blind') == 'on' else 'false'
+ xps = list(xps)
+ if not errors:
+ signature = get_signature(xps)
+ manifest = {
+ 'xps': xps,
+ }
+ survey_path = surveys / signature
+ survey_path.mkdir(exist_ok=True)
+ with open(survey_path / 'manifest.json', 'w') as f:
+ json.dump(manifest, f, indent=2)
+ return redirect(url_for('survey', blind=blind, signature=signature))
+ return render_template('index.html', errors=errors)
+
+
+@app.route('/survey/', methods=['GET', 'POST'])
+@ensure_logged
+def survey(signature):
+ success = request.args.get('success', False)
+ seed = int(request.args.get('seed', 4321))
+ blind = request.args.get('blind', 'false') in ['true', 'on', 'True']
+ exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True']
+ exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True']
+ max_epoch = int(request.args.get('max_epoch', '-1'))
+ survey_path = surveys / signature
+ assert survey_path.exists(), survey_path
+
+ user = session['user']
+ result_folder = survey_path / 'results'
+ result_folder.mkdir(exist_ok=True)
+ result_file = result_folder / f'{user}_{seed}.json'
+
+ with open(survey_path / 'manifest.json') as f:
+ manifest = json.load(f)
+
+ xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']]
+ names, ref_name = train.main.get_names(xps)
+
+ samples_kwargs = {
+ 'exclude_prompted': exclude_prompted,
+ 'exclude_unprompted': exclude_unprompted,
+ 'max_epoch': max_epoch,
+ }
+ matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) # fetch latest epoch
+ models_by_id = {
+ id: [{
+ 'xp': xps[idx],
+ 'xp_name': names[idx],
+ 'model_id': f'{xps[idx].sig}-{sample.id}',
+ 'sample': sample,
+ 'is_prompted': sample.prompt is not None,
+ 'errors': [],
+ } for idx, sample in enumerate(samples)]
+ for id, samples in matched_samples.items()
+ }
+ experiments = [
+ {'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch}
+ for idx, xp in enumerate(xps)
+ ]
+
+ keys = list(matched_samples.keys())
+ keys.sort()
+ rng = random.Random(seed)
+ rng.shuffle(keys)
+ model_ids = keys[:SAMPLES_PER_PAGE]
+
+ if blind:
+ for key in model_ids:
+ rng.shuffle(models_by_id[key])
+
+ ok = True
+ if request.method == 'POST':
+ all_samples_results = []
+ for id in model_ids:
+ models = models_by_id[id]
+ result = {
+ 'id': id,
+ 'is_prompted': models[0]['is_prompted'],
+ 'models': {}
+ }
+ all_samples_results.append(result)
+ for model in models:
+ rating = request.form[model['model_id']]
+ if rating:
+ rating = int(rating)
+ assert rating <= MAX_RATING and rating >= 1
+ result['models'][model['xp'].sig] = rating
+ model['rating'] = rating
+ else:
+ ok = False
+ model['errors'].append('Please rate this model.')
+ if ok:
+ result = {
+ 'results': all_samples_results,
+ 'seed': seed,
+ 'user': user,
+ 'blind': blind,
+ 'exclude_prompted': exclude_prompted,
+ 'exclude_unprompted': exclude_unprompted,
+ }
+ print(result)
+ with open(result_file, 'w') as f:
+ json.dump(result, f)
+ seed = seed + 1
+ return redirect(url_for(
+ 'survey', signature=signature, blind=blind, seed=seed,
+ exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted,
+ max_epoch=max_epoch, success=True))
+
+ ratings = list(range(1, MAX_RATING + 1))
+ return render_template(
+ 'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success,
+ exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch,
+ experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[],
+ ref_name=ref_name, already_filled=result_file.exists())
+
+
+@app.route('/audio/')
+def audio(path: str):
+ full_path = Path('/') / path
+ assert full_path.suffix in [".mp3", ".wav"]
+ return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'}
+
+
+def mean(x):
+ return sum(x) / len(x)
+
+
+def std(x):
+ m = mean(x)
+ return math.sqrt(sum((i - m)**2 for i in x) / len(x))
+
+
+@app.route('/results/')
+@ensure_logged
+def results(signature):
+
+ survey_path = surveys / signature
+ assert survey_path.exists(), survey_path
+ result_folder = survey_path / 'results'
+ result_folder.mkdir(exist_ok=True)
+
+ # ratings per model, then per user.
+ ratings_per_model = defaultdict(list)
+ users = []
+ for result_file in result_folder.iterdir():
+ if result_file.suffix != '.json':
+ continue
+ with open(result_file) as f:
+ results = json.load(f)
+ users.append(results['user'])
+ for result in results['results']:
+ for sig, rating in result['models'].items():
+ ratings_per_model[sig].append(rating)
+
+ fmt = '{:.2f}'
+ models = []
+ for model in sorted(ratings_per_model.keys()):
+ ratings = ratings_per_model[model]
+
+ models.append({
+ 'sig': model,
+ 'samples': len(ratings),
+ 'mean_rating': fmt.format(mean(ratings)),
+ # the value 1.96 was probably chosen to achieve some
+ # confidence interval assuming gaussianity.
+ 'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5),
+ })
+ return render_template('results.html', signature=signature, models=models, users=users)
diff --git a/scripts/resample_dataset.py b/scripts/resample_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..af5288712b8d2cde2d9814c747275e69f6e970c8
--- /dev/null
+++ b/scripts/resample_dataset.py
@@ -0,0 +1,207 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Resampling script.
+"""
+import argparse
+from pathlib import Path
+import shutil
+import typing as tp
+
+import submitit
+import tqdm
+
+from audiocraft.data.audio import audio_read, audio_write
+from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files
+from audiocraft.data.audio_utils import convert_audio
+from audiocraft.environment import AudioCraftEnvironment
+
+
+def read_txt_files(path: tp.Union[str, Path]):
+ with open(args.files_path) as f:
+ lines = [line.rstrip() for line in f]
+ print(f"Read {len(lines)} in .txt")
+ lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']]
+ print(f"Filtered and keep {len(lines)} from .txt")
+ return lines
+
+
+def read_egs_files(path: tp.Union[str, Path]):
+ path = Path(path)
+ if path.is_dir():
+ if (path / 'data.jsonl').exists():
+ path = path / 'data.jsonl'
+ elif (path / 'data.jsonl.gz').exists():
+ path = path / 'data.jsonl.gz'
+ else:
+ raise ValueError("Don't know where to read metadata from in the dir. "
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+ meta = load_audio_meta(path)
+ return [m.path for m in meta]
+
+
+def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None):
+ if task_index is None:
+ env = submitit.JobEnvironment()
+ task_index = env.global_rank
+ shard_index = node_index * args.tasks_per_node + task_index
+
+ if args.files_path is None:
+ lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)]
+ else:
+ files_path = Path(args.files_path)
+ if files_path.suffix == '.txt':
+ print(f"Reading file list from .txt file: {args.files_path}")
+ lines = read_txt_files(args.files_path)
+ else:
+ print(f"Reading file list from egs: {args.files_path}")
+ lines = read_egs_files(args.files_path)
+
+ total_files = len(lines)
+ print(
+ f"Total of {total_files} processed with {n_shards} shards. " +
+ f"Current idx = {shard_index} -> {total_files // n_shards} files to process"
+ )
+ for idx, line in tqdm.tqdm(enumerate(lines)):
+
+ # skip if not part of this shard
+ if idx % n_shards != shard_index:
+ continue
+
+ path = str(AudioCraftEnvironment.apply_dataset_mappers(line))
+ root_path = str(args.root_path)
+ if not root_path.endswith('/'):
+ root_path += '/'
+ assert path.startswith(str(root_path)), \
+ f"Mismatch between path and provided root: {path} VS {root_path}"
+
+ try:
+ metadata_path = Path(path).with_suffix('.json')
+ out_path = args.out_path / path[len(root_path):]
+ out_metadata_path = out_path.with_suffix('.json')
+ out_done_token = out_path.with_suffix('.done')
+
+ # don't reprocess existing files
+ if out_done_token.exists():
+ continue
+
+ print(idx, out_path, path)
+ mix, sr = audio_read(path)
+ mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0)
+ # enforce simple stereo
+ out_channels = mix_channels
+ if out_channels > 2:
+ print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels")
+ out_channels = 2
+ out_sr = args.sample_rate if args.sample_rate is not None else sr
+ out_wav = convert_audio(mix, sr, out_sr, out_channels)
+ audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr,
+ format=args.format, normalize=False, strategy='clip')
+ if metadata_path.exists():
+ shutil.copy(metadata_path, out_metadata_path)
+ else:
+ print(f"No metadata found at {str(metadata_path)}")
+ out_done_token.touch()
+ except Exception as e:
+ print(f"Error processing file line: {line}, {e}")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="Resample dataset with SLURM.")
+ parser.add_argument(
+ "--log_root",
+ type=Path,
+ default=Path.home() / 'tmp' / 'resample_logs',
+ )
+ parser.add_argument(
+ "--files_path",
+ type=Path,
+ help="List of files to process, either .txt (one file per line) or a jsonl[.gz].",
+ )
+ parser.add_argument(
+ "--root_path",
+ type=Path,
+ required=True,
+ help="When rewriting paths, this will be the prefix to remove.",
+ )
+ parser.add_argument(
+ "--out_path",
+ type=Path,
+ required=True,
+ help="When rewriting paths, `root_path` will be replaced by this.",
+ )
+ parser.add_argument("--xp_name", type=str, default="shutterstock")
+ parser.add_argument(
+ "--nodes",
+ type=int,
+ default=4,
+ )
+ parser.add_argument(
+ "--tasks_per_node",
+ type=int,
+ default=20,
+ )
+ parser.add_argument(
+ "--cpus_per_task",
+ type=int,
+ default=4,
+ )
+ parser.add_argument(
+ "--memory_gb",
+ type=int,
+ help="Memory in GB."
+ )
+ parser.add_argument(
+ "--format",
+ type=str,
+ default="wav",
+ )
+ parser.add_argument(
+ "--sample_rate",
+ type=int,
+ default=32000,
+ )
+ parser.add_argument(
+ "--channels",
+ type=int,
+ )
+ parser.add_argument(
+ "--partition",
+ default='learnfair',
+ )
+ parser.add_argument("--qos")
+ parser.add_argument("--account")
+ parser.add_argument("--timeout", type=int, default=4320)
+ parser.add_argument('--debug', action='store_true', help='debug mode (local run)')
+ args = parser.parse_args()
+ n_shards = args.tasks_per_node * args.nodes
+ if args.files_path is None:
+ print("Warning: --files_path not provided, not recommended when processing more than 10k files.")
+ if args.debug:
+ print("Debugging mode")
+ process_dataset(args, n_shards=n_shards, node_index=0, task_index=0)
+ else:
+
+ log_folder = Path(args.log_root) / args.xp_name / '%j'
+ print(f"Logging to: {log_folder}")
+ log_folder.parent.mkdir(parents=True, exist_ok=True)
+ executor = submitit.AutoExecutor(folder=str(log_folder))
+ if args.qos:
+ executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account)
+ else:
+ executor.update_parameters(slurm_partition=args.partition)
+ executor.update_parameters(
+ slurm_job_name=args.xp_name, timeout_min=args.timeout,
+ cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1)
+ if args.memory_gb:
+ executor.update_parameters(mem=f'{args.memory_gb}GB')
+ jobs = []
+ with executor.batch():
+ for node_index in range(args.nodes):
+ job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index)
+ jobs.append(job)
+ for job in jobs:
+ print(f"Waiting on job {job.job_id}")
+ job.results()
diff --git a/scripts/static/style.css b/scripts/static/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..a0df7c63a0d2dd9a79f33f5d869ca31c9da87e8d
--- /dev/null
+++ b/scripts/static/style.css
@@ -0,0 +1,113 @@
+body {
+ background-color: #fbfbfb;
+ margin: 0;
+}
+
+select, input {
+ font-size: 1em;
+ max-width: 100%;
+}
+
+.xp_name {
+ font-family: monospace;
+}
+
+.simple_form {
+ background-color: #dddddd;
+ padding: 1em;
+ margin: 0.5em;
+}
+
+textarea {
+ margin-top: 0.5em;
+ margin-bottom: 0.5em;
+}
+
+.rating {
+ background-color: grey;
+ padding-top: 5px;
+ padding-bottom: 5px;
+ padding-left: 8px;
+ padding-right: 8px;
+ margin-right: 2px;
+ cursor:pointer;
+}
+
+.rating_selected {
+ background-color: purple;
+}
+
+.content {
+ font-family: sans-serif;
+ background-color: #f6f6f6;
+ padding: 40px;
+ margin: 0 auto;
+ max-width: 1000px;
+}
+
+.track label {
+ padding-top: 10px;
+ padding-bottom: 10px;
+}
+.track {
+ padding: 15px;
+ margin: 5px;
+ background-color: #c8c8c8;
+}
+
+.submit-big {
+ width:400px;
+ height:30px;
+ font-size: 20px;
+}
+
+.error {
+ color: red;
+}
+
+.ratings {
+ margin-left: 10px;
+}
+
+.important {
+ font-weight: bold;
+}
+
+.survey {
+ margin-bottom: 100px;
+}
+
+.success {
+ color: #25901b;
+ font-weight: bold;
+}
+.warning {
+ color: #8a1f19;
+ font-weight: bold;
+}
+.track>section {
+ display: flex;
+ align-items: center;
+}
+
+.prompt {
+ display: flex;
+ align-items: center;
+}
+
+.track>section>div {
+ padding-left: 10px;
+}
+
+audio {
+ max-width: 280px;
+ max-height: 40px;
+ margin-left: 10px;
+ margin-right: 10px;
+}
+
+.special {
+ font-weight: bold;
+ color: #2c2c2c;
+}
+
diff --git a/scripts/templates/base.html b/scripts/templates/base.html
new file mode 100644
index 0000000000000000000000000000000000000000..f74668c19ecb83090a8a2d82c026bf417190ec6d
--- /dev/null
+++ b/scripts/templates/base.html
@@ -0,0 +1,16 @@
+
+
+
+ {% block head %}
+
+
+ AudioCraft — MOS
+ {% endblock %}
+
+
+
+ Welcome {{session['user']}} to the internal MOS assistant for AudioCraft.
+ You can create custom surveys between your models, that you can
+ evaluate yourself, or with the help of your teammates, by simply
+ sharing a link!
+
+
+{% for error in errors %}
+
{{error}}
+{% endfor %}
+
+
+
Samples
+
+
+
+{% endblock %}
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..a00890009a88752714357210a73709a83b395849
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,14 @@
+[pep8]
+max-line-length = 120
+
+[flake8]
+max-line-length = 120
+
+[coverage:report]
+include = audiocraft/*
+omit =
+ audiocraft/environment.py
+ audiocraft/solvers/*
+ audiocraft/utils/*
+ audiocraft/*/loaders.py
+ audiocraft/*/builders.py
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..64e7d6fcb1092748f8151f6d3ed1767d3be1b34b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,62 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+
+from setuptools import setup, find_packages
+
+
+NAME = 'audiocraft'
+DESCRIPTION = 'Audio generation research library for PyTorch'
+
+URL = 'https://github.com/facebookresearch/audiocraft'
+AUTHOR = 'FAIR Speech & Audio'
+EMAIL = 'defossez@meta.com, jadecopet@meta.com'
+REQUIRES_PYTHON = '>=3.8.0'
+
+for line in open('audiocraft/__init__.py'):
+ line = line.strip()
+ if '__version__' in line:
+ context = {}
+ exec(line, context)
+ VERSION = context['__version__']
+
+HERE = Path(__file__).parent
+
+try:
+ with open(HERE / "README.md", encoding='utf-8') as f:
+ long_description = '\n' + f.read()
+except FileNotFoundError:
+ long_description = DESCRIPTION
+
+REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')]
+
+setup(
+ name=NAME,
+ version=VERSION,
+ description=DESCRIPTION,
+ author_email=EMAIL,
+ long_description=long_description,
+ long_description_content_type='text/markdown',
+ author=AUTHOR,
+ url=URL,
+ python_requires=REQUIRES_PYTHON,
+ install_requires=REQUIRED,
+ extras_require={
+ 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'],
+ },
+ packages=find_packages(),
+ package_data={'audiocraft': ['py.typed']},
+ include_package_data=True,
+ license='MIT License',
+ classifiers=[
+ # Trove classifiers
+ # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
+ 'License :: OSI Approved :: MIT License',
+ 'Topic :: Multimedia :: Sound/Audio',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ ],
+)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/tests/adversarial/__init__.py b/tests/adversarial/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/tests/adversarial/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/tests/adversarial/test_discriminators.py b/tests/adversarial/test_discriminators.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad89a0ae4534dc7967b6ccda194b9fd1dedbffe
--- /dev/null
+++ b/tests/adversarial/test_discriminators.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+
+import torch
+
+from audiocraft.adversarial.discriminators import (
+ MultiPeriodDiscriminator,
+ MultiScaleDiscriminator,
+ MultiScaleSTFTDiscriminator
+)
+
+
+class TestMultiPeriodDiscriminator:
+
+ def test_mpd_discriminator(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+ periods = [1, 2, 3]
+ mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C)
+ logits, fmaps = mpd(t0)
+
+ assert len(logits) == len(periods)
+ assert len(fmaps) == len(periods)
+ assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
+ assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
+
+
+class TestMultiScaleDiscriminator:
+
+ def test_msd_discriminator(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+
+ scale_norms = ['weight_norm', 'weight_norm']
+ msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C)
+ logits, fmaps = msd(t0)
+
+ assert len(logits) == len(scale_norms)
+ assert len(fmaps) == len(scale_norms)
+ assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits])
+ assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
+
+
+class TestMultiScaleStftDiscriminator:
+
+ def test_msstftd_discriminator(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+
+ n_filters = 4
+ n_ffts = [128, 256, 64]
+ hop_lengths = [32, 64, 16]
+ win_lengths = [128, 256, 64]
+
+ msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths,
+ win_lengths=win_lengths, in_channels=C)
+ logits, fmaps = msstftd(t0)
+
+ assert len(logits) == len(n_ffts)
+ assert len(fmaps) == len(n_ffts)
+ assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
+ assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
diff --git a/tests/adversarial/test_losses.py b/tests/adversarial/test_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e30bc3a6dde00003e13c00f15e977e39425063c
--- /dev/null
+++ b/tests/adversarial/test_losses.py
@@ -0,0 +1,159 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+import random
+
+import torch
+
+from audiocraft.adversarial import (
+ AdversarialLoss,
+ get_adv_criterion,
+ get_real_criterion,
+ get_fake_criterion,
+ FeatureMatchingLoss,
+ MultiScaleDiscriminator,
+)
+
+
+class TestAdversarialLoss:
+
+ def test_adversarial_single_multidiscriminator(self):
+ adv = MultiScaleDiscriminator()
+ optimizer = torch.optim.Adam(
+ adv.parameters(),
+ lr=1e-4,
+ )
+ loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse')
+ adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake)
+
+ B, C, T = 4, 1, random.randint(1000, 5000)
+ real = torch.randn(B, C, T)
+ fake = torch.randn(B, C, T)
+
+ disc_loss = adv_loss.train_adv(fake, real)
+ assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float)
+
+ loss, loss_feat = adv_loss(fake, real)
+ assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float)
+ # we did not specify feature loss
+ assert loss_feat.item() == 0.
+
+ def test_adversarial_feat_loss(self):
+ adv = MultiScaleDiscriminator()
+ optimizer = torch.optim.Adam(
+ adv.parameters(),
+ lr=1e-4,
+ )
+ loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse')
+ feat_loss = FeatureMatchingLoss()
+ adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss)
+
+ B, C, T = 4, 1, random.randint(1000, 5000)
+ real = torch.randn(B, C, T)
+ fake = torch.randn(B, C, T)
+
+ loss, loss_feat = adv_loss(fake, real)
+
+ assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float)
+ assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float)
+
+
+class TestGeneratorAdversarialLoss:
+
+ def test_hinge_generator_adv_loss(self):
+ adv_loss = get_adv_criterion(loss_type='hinge')
+
+ t0 = torch.randn(1, 2, 0)
+ t1 = torch.FloatTensor([1.0, 2.0, 3.0])
+
+ assert adv_loss(t0).item() == 0.0
+ assert adv_loss(t1).item() == -2.0
+
+ def test_mse_generator_adv_loss(self):
+ adv_loss = get_adv_criterion(loss_type='mse')
+
+ t0 = torch.randn(1, 2, 0)
+ t1 = torch.FloatTensor([1.0, 1.0, 1.0])
+ t2 = torch.FloatTensor([2.0, 5.0, 5.0])
+
+ assert adv_loss(t0).item() == 0.0
+ assert adv_loss(t1).item() == 0.0
+ assert adv_loss(t2).item() == 11.0
+
+
+class TestDiscriminatorAdversarialLoss:
+
+ def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor):
+ disc_loss_real = get_real_criterion(loss_type)
+ disc_loss_fake = get_fake_criterion(loss_type)
+
+ loss = disc_loss_fake(fake) + disc_loss_real(real)
+ return loss
+
+ def test_hinge_discriminator_adv_loss(self):
+ loss_type = 'hinge'
+ t0 = torch.FloatTensor([0.0, 0.0, 0.0])
+ t1 = torch.FloatTensor([1.0, 2.0, 3.0])
+
+ assert self._disc_loss(loss_type, t0, t0).item() == 2.0
+ assert self._disc_loss(loss_type, t1, t1).item() == 3.0
+
+ def test_mse_discriminator_adv_loss(self):
+ loss_type = 'mse'
+
+ t0 = torch.FloatTensor([0.0, 0.0, 0.0])
+ t1 = torch.FloatTensor([1.0, 1.0, 1.0])
+
+ assert self._disc_loss(loss_type, t0, t0).item() == 1.0
+ assert self._disc_loss(loss_type, t1, t0).item() == 2.0
+
+
+class TestFeatureMatchingLoss:
+
+ def test_features_matching_loss_base(self):
+ ft_matching_loss = FeatureMatchingLoss()
+ length = random.randrange(1, 100_000)
+ t1 = torch.randn(1, 2, length)
+
+ loss = ft_matching_loss([t1], [t1])
+ assert isinstance(loss, torch.Tensor)
+ assert loss.item() == 0.0
+
+ def test_features_matching_loss_raises_exception(self):
+ ft_matching_loss = FeatureMatchingLoss()
+ length = random.randrange(1, 100_000)
+ t1 = torch.randn(1, 2, length)
+ t2 = torch.randn(1, 2, length + 1)
+
+ with pytest.raises(AssertionError):
+ ft_matching_loss([], [])
+
+ with pytest.raises(AssertionError):
+ ft_matching_loss([t1], [t1, t1])
+
+ with pytest.raises(AssertionError):
+ ft_matching_loss([t1], [t2])
+
+ def test_features_matching_loss_output(self):
+ loss_nonorm = FeatureMatchingLoss(normalize=False)
+ loss_layer_normed = FeatureMatchingLoss(normalize=True)
+
+ length = random.randrange(1, 100_000)
+ t1 = torch.randn(1, 2, length)
+ t2 = torch.randn(1, 2, length)
+
+ assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0
+ assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0
+
+ t3 = torch.FloatTensor([1.0, 2.0, 3.0])
+ t4 = torch.FloatTensor([2.0, 10.0, 3.0])
+
+ assert loss_nonorm([t3], [t4]).item() == 3.0
+ assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0
+
+ assert loss_layer_normed([t3], [t4]).item() == 3.0
+ assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0
diff --git a/tests/common_utils/__init__.py b/tests/common_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ffcfef96fec35c99b2a1a053a61f44f7a8bbe9
--- /dev/null
+++ b/tests/common_utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .temp_utils import TempDirMixin
+from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
diff --git a/tests/common_utils/temp_utils.py b/tests/common_utils/temp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45d896836799edcf1fee271409b390b3b6e4127
--- /dev/null
+++ b/tests/common_utils/temp_utils.py
@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import tempfile
+
+
+class TempDirMixin:
+ """Mixin to provide easy access to temp dir.
+ """
+
+ temp_dir_ = None
+
+ @classmethod
+ def get_base_temp_dir(cls):
+ # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
+ # this is handy for debugging.
+ key = "AUDIOCRAFT_TEST_DIR"
+ if key in os.environ:
+ return os.environ[key]
+ if cls.temp_dir_ is None:
+ cls.temp_dir_ = tempfile.TemporaryDirectory()
+ return cls.temp_dir_.name
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls.temp_dir_ is not None:
+ try:
+ cls.temp_dir_.cleanup()
+ cls.temp_dir_ = None
+ except PermissionError:
+ # On Windows there is a know issue with `shutil.rmtree`,
+ # which fails intermittently.
+ # https://github.com/python/cpython/issues/74168
+ # Following the above thread, we ignore it.
+ pass
+ super().tearDownClass()
+
+ @property
+ def id(self):
+ return self.__class__.__name__
+
+ def get_temp_path(self, *paths):
+ temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
+ path = os.path.join(temp_dir, *paths)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ return path
+
+ def get_temp_dir(self, *paths):
+ temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
+ path = os.path.join(temp_dir, *paths)
+ os.makedirs(path, exist_ok=True)
+ return path
diff --git a/tests/common_utils/wav_utils.py b/tests/common_utils/wav_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a563ee1749a58217ece55c9a08b8d93c0fc386
--- /dev/null
+++ b/tests/common_utils/wav_utils.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import typing as tp
+
+import torch
+import torchaudio
+
+
+def get_white_noise(chs: int = 1, num_frames: int = 1):
+ wav = torch.randn(chs, num_frames)
+ return wav
+
+
+def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
+ wav = torch.randn(bs, chs, num_frames)
+ return wav
+
+
+def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
+ fp = Path(path)
+ kwargs: tp.Dict[str, tp.Any] = {}
+ if fp.suffix == '.wav':
+ kwargs['encoding'] = 'PCM_S'
+ kwargs['bits_per_sample'] = 16
+ elif fp.suffix == '.mp3':
+ kwargs['compression'] = 320
+ torchaudio.save(str(fp), wav, sample_rate, **kwargs)
diff --git a/tests/data/__init__.py b/tests/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/tests/data/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/tests/data/test_audio.py b/tests/data/test_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..40c0d5ed69eff92a766dc6d176e532f0df6c2b5e
--- /dev/null
+++ b/tests/data/test_audio.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from itertools import product
+import random
+
+import numpy as np
+import torch
+import torchaudio
+
+from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
+
+from ..common_utils import TempDirMixin, get_white_noise, save_wav
+
+
+class TestInfo(TempDirMixin):
+
+ def test_info_mp3(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ wav = get_white_noise(ch, int(sample_rate * duration))
+ path = self.get_temp_path('sample_wav.mp3')
+ save_wav(path, wav, sample_rate)
+ info = audio_info(path)
+ assert info.sample_rate == sample_rate
+ assert info.channels == ch
+ # we cannot trust torchaudio for num_frames, so we don't check
+
+ def _test_info_format(self, ext: str):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ wav = get_white_noise(ch, n_frames)
+ path = self.get_temp_path(f'sample_wav{ext}')
+ save_wav(path, wav, sample_rate)
+ info = audio_info(path)
+ assert info.sample_rate == sample_rate
+ assert info.channels == ch
+ assert np.isclose(info.duration, duration, atol=1e-5)
+
+ def test_info_wav(self):
+ self._test_info_format('.wav')
+
+ def test_info_flac(self):
+ self._test_info_format('.flac')
+
+ def test_info_ogg(self):
+ self._test_info_format('.ogg')
+
+ def test_info_m4a(self):
+ # TODO: generate m4a file programmatically
+ # self._test_info_format('.m4a')
+ pass
+
+
+class TestRead(TempDirMixin):
+
+ def test_read_full_wav(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
+ path = self.get_temp_path('sample_wav.wav')
+ save_wav(path, wav, sample_rate)
+ read_wav, read_sr = audio_read(path)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[1] == wav.shape[1]
+ assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
+
+ def test_read_partial_wav(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ read_duration = torch.rand(1).item()
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ read_frames = int(sample_rate * read_duration)
+ wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
+ path = self.get_temp_path('sample_wav.wav')
+ save_wav(path, wav, sample_rate)
+ read_wav, read_sr = audio_read(path, 0, read_duration)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[1] == read_frames
+ assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
+
+ def test_read_seek_time_wav(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ read_duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
+ path = self.get_temp_path('sample_wav.wav')
+ save_wav(path, wav, sample_rate)
+ seek_time = torch.rand(1).item()
+ read_wav, read_sr = audio_read(path, seek_time, read_duration)
+ seek_frames = int(sample_rate * seek_time)
+ expected_frames = n_frames - seek_frames
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[1] == expected_frames
+ assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
+
+ def test_read_seek_time_wav_padded(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ read_duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ read_frames = int(sample_rate * read_duration)
+ wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
+ path = self.get_temp_path('sample_wav.wav')
+ save_wav(path, wav, sample_rate)
+ seek_time = torch.rand(1).item()
+ seek_frames = int(sample_rate * seek_time)
+ expected_frames = n_frames - seek_frames
+ read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
+ expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[1] == read_frames
+ assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
+ assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
+
+
+class TestAvRead(TempDirMixin):
+
+ def test_avread_seek_base(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 2.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ wav = get_white_noise(ch, n_frames)
+ path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
+ save_wav(path, wav, sample_rate)
+ for _ in range(100):
+ # seek will always load a full duration segment in the file
+ seek_time = random.uniform(0.0, 1.0)
+ seek_duration = random.uniform(0.001, 1.0)
+ read_wav, read_sr = _av_read(path, seek_time, seek_duration)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[-1] == int(seek_duration * sample_rate)
+
+ def test_avread_seek_partial(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ wav = get_white_noise(ch, n_frames)
+ path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
+ save_wav(path, wav, sample_rate)
+ for _ in range(100):
+ # seek will always load a partial segment
+ seek_time = random.uniform(0.5, 1.)
+ seek_duration = 1.
+ expected_num_frames = n_frames - int(seek_time * sample_rate)
+ read_wav, read_sr = _av_read(path, seek_time, seek_duration)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[-1] == expected_num_frames
+
+ def test_avread_seek_outofbound(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(sample_rate * duration)
+ wav = get_white_noise(ch, n_frames)
+ path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
+ save_wav(path, wav, sample_rate)
+ seek_time = 1.5
+ read_wav, read_sr = _av_read(path, seek_time, 1.)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[-1] == 0
+
+ def test_avread_seek_edge(self):
+ sample_rates = [8000, 16_000]
+ # some of these values will have
+ # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
+ n_frames = [1000, 1001, 1002]
+ channels = [1, 2]
+ for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
+ duration = frames / sample_rate
+ wav = get_white_noise(ch, frames)
+ path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
+ save_wav(path, wav, sample_rate)
+ seek_time = (frames - 1) / sample_rate
+ seek_frames = int(seek_time * sample_rate)
+ read_wav, read_sr = _av_read(path, seek_time, duration)
+ assert read_sr == sample_rate
+ assert read_wav.shape[0] == wav.shape[0]
+ assert read_wav.shape[-1] == (frames - seek_frames)
+
+
+class TestAudioWrite(TempDirMixin):
+
+ def test_audio_write_wav(self):
+ torch.manual_seed(1234)
+ sample_rates = [8000, 16_000]
+ n_frames = [1000, 1001, 1002]
+ channels = [1, 2]
+ strategies = ["peak", "clip", "rms"]
+ formats = ["wav", "mp3"]
+ for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
+ for format_, strategy in product(formats, strategies):
+ wav = get_white_noise(ch, frames)
+ path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
+ audio_write(path, wav, sample_rate, format_, strategy=strategy)
+ read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
+ if format_ == "wav":
+ assert read_wav.shape == wav.shape
+
+ if format_ == "wav" and strategy in ["peak", "rms"]:
+ rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
+ # for a Gaussian, the typical max scale will be less than ~5x the std.
+ # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
+ # For RMS target, rescaling leaves more headroom by default, leading
+ # to a 20x rescaling typically
+ atol = (5 if strategy == "peak" else 20) / 2**15
+ delta = (rescaled_read_wav - wav).abs().max()
+ assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
+ formats = ["wav"] # faster unit tests
diff --git a/tests/data/test_audio_dataset.py b/tests/data/test_audio_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b591ea6137f48d0d97fcd1243c5f5d258670a474
--- /dev/null
+++ b/tests/data/test_audio_dataset.py
@@ -0,0 +1,352 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+from itertools import product
+import json
+import math
+import os
+import random
+import typing as tp
+
+import pytest
+import torch
+from torch.utils.data import DataLoader
+
+from audiocraft.data.audio_dataset import (
+ AudioDataset,
+ AudioMeta,
+ _get_audio_meta,
+ load_audio_meta,
+ save_audio_meta
+)
+from audiocraft.data.zip import PathInZip
+
+from ..common_utils import TempDirMixin, get_white_noise, save_wav
+
+
+class TestAudioMeta(TempDirMixin):
+
+ def test_get_audio_meta(self):
+ sample_rates = [8000, 16_000]
+ channels = [1, 2]
+ duration = 1.
+ for sample_rate, ch in product(sample_rates, channels):
+ n_frames = int(duration * sample_rate)
+ wav = get_white_noise(ch, n_frames)
+ path = self.get_temp_path('sample.wav')
+ save_wav(path, wav, sample_rate)
+ m = _get_audio_meta(path, minimal=True)
+ assert m.path == path, 'path does not match'
+ assert m.sample_rate == sample_rate, 'sample rate does not match'
+ assert m.duration == duration, 'duration does not match'
+ assert m.amplitude is None
+ assert m.info_path is None
+
+ def test_save_audio_meta(self):
+ audio_meta = [
+ AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
+ AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
+ ]
+ empty_audio_meta = []
+ for idx, meta in enumerate([audio_meta, empty_audio_meta]):
+ path = self.get_temp_path(f'data_{idx}_save.jsonl')
+ save_audio_meta(path, meta)
+ with open(path, 'r') as f:
+ lines = f.readlines()
+ read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
+ assert len(read_meta) == len(meta)
+ for m, read_m in zip(meta, read_meta):
+ assert m == read_m
+
+ def test_load_audio_meta(self):
+ try:
+ import dora
+ except ImportError:
+ dora = None # type: ignore
+
+ audio_meta = [
+ AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
+ AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
+ ]
+ empty_meta = []
+ for idx, meta in enumerate([audio_meta, empty_meta]):
+ path = self.get_temp_path(f'data_{idx}_load.jsonl')
+ with open(path, 'w') as f:
+ for m in meta:
+ json_str = json.dumps(m.to_dict()) + '\n'
+ f.write(json_str)
+ read_meta = load_audio_meta(path)
+ assert len(read_meta) == len(meta)
+ for m, read_m in zip(meta, read_meta):
+ if dora:
+ m.path = dora.git_save.to_absolute_path(m.path)
+ assert m == read_m, f'original={m}, read={read_m}'
+
+
+class TestAudioDataset(TempDirMixin):
+
+ def _create_audio_files(self,
+ root_name: str,
+ num_examples: int,
+ durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
+ sample_rate: int = 16_000,
+ channels: int = 1):
+ root_dir = self.get_temp_dir(root_name)
+ for i in range(num_examples):
+ if isinstance(durations, float):
+ duration = durations
+ elif isinstance(durations, tuple) and len(durations) == 1:
+ duration = durations[0]
+ elif isinstance(durations, tuple) and len(durations) == 2:
+ duration = random.uniform(durations[0], durations[1])
+ else:
+ assert False
+ n_frames = int(duration * sample_rate)
+ wav = get_white_noise(channels, n_frames)
+ path = os.path.join(root_dir, f'example_{i}.wav')
+ save_wav(path, wav, sample_rate)
+ return root_dir
+
+ def _create_audio_dataset(self,
+ root_name: str,
+ total_num_examples: int,
+ durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
+ sample_rate: int = 16_000,
+ channels: int = 1,
+ segment_duration: tp.Optional[float] = None,
+ num_examples: int = 10,
+ shuffle: bool = True,
+ return_info: bool = False):
+ root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
+ dataset = AudioDataset.from_path(root_dir,
+ minimal_meta=True,
+ segment_duration=segment_duration,
+ num_samples=num_examples,
+ sample_rate=sample_rate,
+ channels=channels,
+ shuffle=shuffle,
+ return_info=return_info)
+ return dataset
+
+ def test_dataset_full(self):
+ total_examples = 10
+ min_duration, max_duration = 1., 4.
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=(min_duration, max_duration),
+ sample_rate=sample_rate, channels=channels, segment_duration=None)
+ assert len(dataset) == total_examples
+ assert dataset.sample_rate == sample_rate
+ assert dataset.channels == channels
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ assert sample.shape[0] == channels
+ assert sample.shape[1] <= int(max_duration * sample_rate)
+ assert sample.shape[1] >= int(min_duration * sample_rate)
+
+ def test_dataset_segment(self):
+ total_examples = 10
+ num_samples = 20
+ min_duration, max_duration = 1., 4.
+ segment_duration = 1.
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples)
+ assert len(dataset) == num_samples
+ assert dataset.sample_rate == sample_rate
+ assert dataset.channels == channels
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ assert sample.shape[0] == channels
+ assert sample.shape[1] == int(segment_duration * sample_rate)
+
+ def test_dataset_equal_audio_and_segment_durations(self):
+ total_examples = 1
+ num_samples = 2
+ audio_duration = 1.
+ segment_duration = 1.
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples)
+ assert len(dataset) == num_samples
+ assert dataset.sample_rate == sample_rate
+ assert dataset.channels == channels
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ assert sample.shape[0] == channels
+ assert sample.shape[1] == int(segment_duration * sample_rate)
+ # the random seek_time adds variability on audio read
+ sample_1 = dataset[0]
+ sample_2 = dataset[1]
+ assert not torch.allclose(sample_1, sample_2)
+
+ def test_dataset_samples(self):
+ total_examples = 1
+ num_samples = 2
+ audio_duration = 1.
+ segment_duration = 1.
+ sample_rate = 16_000
+ channels = 1
+
+ create_dataset = partial(
+ self._create_audio_dataset,
+ 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples,
+ )
+
+ dataset = create_dataset(shuffle=True)
+ # when shuffle = True, we have different inputs for the same index across epoch
+ sample_1 = dataset[0]
+ sample_2 = dataset[0]
+ assert not torch.allclose(sample_1, sample_2)
+
+ dataset_noshuffle = create_dataset(shuffle=False)
+ # when shuffle = False, we have same inputs for the same index across epoch
+ sample_1 = dataset_noshuffle[0]
+ sample_2 = dataset_noshuffle[0]
+ assert torch.allclose(sample_1, sample_2)
+
+ def test_dataset_return_info(self):
+ total_examples = 10
+ num_samples = 20
+ min_duration, max_duration = 1., 4.
+ segment_duration = 1.
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
+ assert len(dataset) == num_samples
+ assert dataset.sample_rate == sample_rate
+ assert dataset.channels == channels
+ for idx in range(len(dataset)):
+ sample, segment_info = dataset[idx]
+ assert sample.shape[0] == channels
+ assert sample.shape[1] == int(segment_duration * sample_rate)
+ assert segment_info.sample_rate == sample_rate
+ assert segment_info.total_frames == int(segment_duration * sample_rate)
+ assert segment_info.n_frames <= int(segment_duration * sample_rate)
+ assert segment_info.seek_time >= 0
+
+ def test_dataset_return_info_no_segment_duration(self):
+ total_examples = 10
+ num_samples = 20
+ min_duration, max_duration = 1., 4.
+ segment_duration = None
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
+ assert len(dataset) == total_examples
+ assert dataset.sample_rate == sample_rate
+ assert dataset.channels == channels
+ for idx in range(len(dataset)):
+ sample, segment_info = dataset[idx]
+ assert sample.shape[0] == channels
+ assert sample.shape[1] == segment_info.total_frames
+ assert segment_info.sample_rate == sample_rate
+ assert segment_info.n_frames <= segment_info.total_frames
+
+ def test_dataset_collate_fn(self):
+ total_examples = 10
+ num_samples = 20
+ min_duration, max_duration = 1., 4.
+ segment_duration = 1.
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
+ batch_size = 4
+ dataloader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=0
+ )
+ for idx, batch in enumerate(dataloader):
+ assert batch.shape[0] == batch_size
+
+ @pytest.mark.parametrize("segment_duration", [1.0, None])
+ def test_dataset_with_meta_collate_fn(self, segment_duration):
+ total_examples = 10
+ num_samples = 20
+ min_duration, max_duration = 1., 4.
+ segment_duration = 1.
+ sample_rate = 16_000
+ channels = 1
+ dataset = self._create_audio_dataset(
+ 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
+ channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
+ batch_size = 4
+ dataloader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ collate_fn=dataset.collater,
+ num_workers=0
+ )
+ for idx, batch in enumerate(dataloader):
+ wav, infos = batch
+ assert wav.shape[0] == batch_size
+ assert len(infos) == batch_size
+
+ @pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
+ [1, True, True, 0.5, 0.5, 0.0],
+ [1, False, True, 0.25, 0.5, 0.25],
+ [1, True, False, 0.666, 0.333, 0.0],
+ [1, False, False, 0.333, 0.333, 0.333],
+ [None, False, False, 0.333, 0.333, 0.333]])
+ def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
+ random.seed(1234)
+ rng = torch.Generator()
+ rng.manual_seed(1234)
+
+ def _get_histogram(dataset, repetitions=20_000):
+ counts = {file_meta.path: 0. for file_meta in meta}
+ for _ in range(repetitions):
+ file_meta = dataset.sample_file(0, rng)
+ counts[file_meta.path] += 1
+ return {name: count / repetitions for name, count in counts.items()}
+
+ meta = [
+ AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
+ AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
+ AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
+ ]
+ dataset = AudioDataset(
+ meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
+ sample_on_duration=sample_on_duration)
+ hist = _get_histogram(dataset)
+ assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
+ assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
+ assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
+
+ def test_meta_duration_filter_all(self):
+ meta = [
+ AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
+ AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
+ AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
+ ]
+ try:
+ AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
+ assert False
+ except AssertionError:
+ assert True
+
+ def test_meta_duration_filter_long(self):
+ meta = [
+ AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
+ AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
+ AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
+ ]
+ dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
+ assert len(dataset) == 2
diff --git a/tests/data/test_audio_utils.py b/tests/data/test_audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0480671bb17281d61ce02bce6373a5ccec89fece
--- /dev/null
+++ b/tests/data/test_audio_utils.py
@@ -0,0 +1,110 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import julius
+import torch
+import pytest
+
+from audiocraft.data.audio_utils import (
+ _clip_wav,
+ convert_audio_channels,
+ convert_audio,
+ normalize_audio
+)
+from ..common_utils import get_batch_white_noise
+
+
+class TestConvertAudioChannels:
+
+ def test_convert_audio_channels_downmix(self):
+ b, c, t = 2, 3, 100
+ audio = get_batch_white_noise(b, c, t)
+ mixed = convert_audio_channels(audio, channels=2)
+ assert list(mixed.shape) == [b, 2, t]
+
+ def test_convert_audio_channels_nochange(self):
+ b, c, t = 2, 3, 100
+ audio = get_batch_white_noise(b, c, t)
+ mixed = convert_audio_channels(audio, channels=c)
+ assert list(mixed.shape) == list(audio.shape)
+
+ def test_convert_audio_channels_upmix(self):
+ b, c, t = 2, 1, 100
+ audio = get_batch_white_noise(b, c, t)
+ mixed = convert_audio_channels(audio, channels=3)
+ assert list(mixed.shape) == [b, 3, t]
+
+ def test_convert_audio_channels_upmix_error(self):
+ b, c, t = 2, 2, 100
+ audio = get_batch_white_noise(b, c, t)
+ with pytest.raises(ValueError):
+ convert_audio_channels(audio, channels=3)
+
+
+class TestConvertAudio:
+
+ def test_convert_audio_channels_downmix(self):
+ b, c, dur = 2, 3, 4.
+ sr = 128
+ audio = get_batch_white_noise(b, c, int(sr * dur))
+ out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
+ assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
+
+ def test_convert_audio_channels_upmix(self):
+ b, c, dur = 2, 1, 4.
+ sr = 128
+ audio = get_batch_white_noise(b, c, int(sr * dur))
+ out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
+ assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
+
+ def test_convert_audio_upsample(self):
+ b, c, dur = 2, 1, 4.
+ sr = 2
+ new_sr = 3
+ audio = get_batch_white_noise(b, c, int(sr * dur))
+ out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
+ out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
+ assert torch.allclose(out, out_j)
+
+ def test_convert_audio_resample(self):
+ b, c, dur = 2, 1, 4.
+ sr = 3
+ new_sr = 2
+ audio = get_batch_white_noise(b, c, int(sr * dur))
+ out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
+ out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
+ assert torch.allclose(out, out_j)
+
+
+class TestNormalizeAudio:
+
+ def test_clip_wav(self):
+ b, c, dur = 2, 1, 4.
+ sr = 3
+ audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
+ _clip_wav(audio)
+ assert audio.abs().max() <= 1
+
+ def test_normalize_audio_clip(self):
+ b, c, dur = 2, 1, 4.
+ sr = 3
+ audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
+ norm_audio = normalize_audio(audio, strategy='clip')
+ assert norm_audio.abs().max() <= 1
+
+ def test_normalize_audio_rms(self):
+ b, c, dur = 2, 1, 4.
+ sr = 3
+ audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
+ norm_audio = normalize_audio(audio, strategy='rms')
+ assert norm_audio.abs().max() <= 1
+
+ def test_normalize_audio_peak(self):
+ b, c, dur = 2, 1, 4.
+ sr = 3
+ audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
+ norm_audio = normalize_audio(audio, strategy='peak')
+ assert norm_audio.abs().max() <= 1
diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/tests/losses/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6681e12c453dea5aeba738ab252d1923b7e0941
--- /dev/null
+++ b/tests/losses/test_losses.py
@@ -0,0 +1,78 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+
+import torch
+
+from audiocraft.losses import (
+ MelSpectrogramL1Loss,
+ MultiScaleMelSpectrogramLoss,
+ MRSTFTLoss,
+ SISNR,
+ STFTLoss,
+)
+
+
+def test_mel_l1_loss():
+ N, C, T = 2, 2, random.randrange(1000, 100_000)
+ t1 = torch.randn(N, C, T)
+ t2 = torch.randn(N, C, T)
+
+ mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050)
+ loss = mel_l1(t1, t2)
+ loss_same = mel_l1(t1, t1)
+
+ assert isinstance(loss, torch.Tensor)
+ assert isinstance(loss_same, torch.Tensor)
+ assert loss_same.item() == 0.0
+
+
+def test_msspec_loss():
+ N, C, T = 2, 2, random.randrange(1000, 100_000)
+ t1 = torch.randn(N, C, T)
+ t2 = torch.randn(N, C, T)
+
+ msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050)
+ loss = msspec(t1, t2)
+ loss_same = msspec(t1, t1)
+
+ assert isinstance(loss, torch.Tensor)
+ assert isinstance(loss_same, torch.Tensor)
+ assert loss_same.item() == 0.0
+
+
+def test_mrstft_loss():
+ N, C, T = 2, 2, random.randrange(1000, 100_000)
+ t1 = torch.randn(N, C, T)
+ t2 = torch.randn(N, C, T)
+
+ mrstft = MRSTFTLoss()
+ loss = mrstft(t1, t2)
+
+ assert isinstance(loss, torch.Tensor)
+
+
+def test_sisnr_loss():
+ N, C, T = 2, 2, random.randrange(1000, 100_000)
+ t1 = torch.randn(N, C, T)
+ t2 = torch.randn(N, C, T)
+
+ sisnr = SISNR()
+ loss = sisnr(t1, t2)
+
+ assert isinstance(loss, torch.Tensor)
+
+
+def test_stft_loss():
+ N, C, T = 2, 2, random.randrange(1000, 100_000)
+ t1 = torch.randn(N, C, T)
+ t2 = torch.randn(N, C, T)
+
+ mrstft = STFTLoss()
+ loss = mrstft(t1, t2)
+
+ assert isinstance(loss, torch.Tensor)
diff --git a/tests/models/test_audiogen.py b/tests/models/test_audiogen.py
new file mode 100644
index 0000000000000000000000000000000000000000..3850af066cedd5ea38bd9aead9634d6aaf938218
--- /dev/null
+++ b/tests/models/test_audiogen.py
@@ -0,0 +1,53 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+import torch
+
+from audiocraft.models import AudioGen
+
+
+class TestAudioGenModel:
+ def get_audiogen(self):
+ ag = AudioGen.get_pretrained(name='debug', device='cpu')
+ ag.set_generation_params(duration=2.0, extend_stride=2.)
+ return ag
+
+ def test_base(self):
+ ag = self.get_audiogen()
+ assert ag.frame_rate == 25
+ assert ag.sample_rate == 16000
+ assert ag.audio_channels == 1
+
+ def test_generate_continuation(self):
+ ag = self.get_audiogen()
+ prompt = torch.randn(3, 1, 16000)
+ wav = ag.generate_continuation(prompt, 16000)
+ assert list(wav.shape) == [3, 1, 32000]
+
+ prompt = torch.randn(2, 1, 16000)
+ wav = ag.generate_continuation(
+ prompt, 16000, ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 32000]
+
+ prompt = torch.randn(2, 1, 16000)
+ with pytest.raises(AssertionError):
+ wav = ag.generate_continuation(
+ prompt, 16000, ['youpi', 'lapin dort', 'one too many'])
+
+ def test_generate(self):
+ ag = self.get_audiogen()
+ wav = ag.generate(
+ ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 32000]
+
+ def test_generate_long(self):
+ ag = self.get_audiogen()
+ ag.max_duration = 3.
+ ag.set_generation_params(duration=4., extend_stride=2.)
+ wav = ag.generate(
+ ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 16000 * 4]
diff --git a/tests/models/test_encodec_model.py b/tests/models/test_encodec_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f9c1db3f69a45f02451b71da95f44356811acbb
--- /dev/null
+++ b/tests/models/test_encodec_model.py
@@ -0,0 +1,60 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+
+import numpy as np
+import torch
+
+from audiocraft.models import EncodecModel
+from audiocraft.modules import SEANetEncoder, SEANetDecoder
+from audiocraft.quantization import DummyQuantizer
+
+
+class TestEncodecModel:
+
+ def _create_encodec_model(self,
+ sample_rate: int,
+ channels: int,
+ dim: int = 5,
+ n_filters: int = 3,
+ n_residual_layers: int = 1,
+ ratios: list = [5, 4, 3, 2],
+ **kwargs):
+ frame_rate = np.prod(ratios)
+ encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
+ n_residual_layers=n_residual_layers, ratios=ratios)
+ decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
+ n_residual_layers=n_residual_layers, ratios=ratios)
+ quantizer = DummyQuantizer()
+ model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
+ sample_rate=sample_rate, channels=channels, **kwargs)
+ return model
+
+ def test_model(self):
+ random.seed(1234)
+ sample_rate = 24_000
+ channels = 1
+ model = self._create_encodec_model(sample_rate, channels)
+ for _ in range(10):
+ length = random.randrange(1, 10_000)
+ x = torch.randn(2, channels, length)
+ res = model(x)
+ assert res.x.shape == x.shape
+
+ def test_model_renorm(self):
+ random.seed(1234)
+ sample_rate = 24_000
+ channels = 1
+ model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
+ model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
+
+ for _ in range(10):
+ length = random.randrange(1, 10_000)
+ x = torch.randn(2, channels, length)
+ codes, scales = model_nonorm.encode(x)
+ codes, scales = model_renorm.encode(x)
+ assert scales is not None
diff --git a/tests/models/test_multibanddiffusion.py b/tests/models/test_multibanddiffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..2702a3cb5fe402bf96911dbc992d2749cb18a4c0
--- /dev/null
+++ b/tests/models/test_multibanddiffusion.py
@@ -0,0 +1,53 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+
+import numpy as np
+import torch
+from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess
+from audiocraft.models import EncodecModel, DiffusionUnet
+from audiocraft.modules import SEANetEncoder, SEANetDecoder
+from audiocraft.modules.diffusion_schedule import NoiseSchedule
+from audiocraft.quantization import DummyQuantizer
+
+
+class TestMBD:
+
+ def _create_mbd(self,
+ sample_rate: int,
+ channels: int,
+ n_filters: int = 3,
+ n_residual_layers: int = 1,
+ ratios: list = [5, 4, 3, 2],
+ num_steps: int = 1000,
+ codec_dim: int = 128,
+ **kwargs):
+ frame_rate = np.prod(ratios)
+ encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
+ n_residual_layers=n_residual_layers, ratios=ratios)
+ decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
+ n_residual_layers=n_residual_layers, ratios=ratios)
+ quantizer = DummyQuantizer()
+ compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
+ sample_rate=sample_rate, channels=channels, **kwargs)
+ diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim)
+ schedule = NoiseSchedule(device='cpu', num_steps=num_steps)
+ DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule)
+ mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model)
+ return mbd
+
+ def test_model(self):
+ random.seed(1234)
+ sample_rate = 24_000
+ channels = 1
+ codec_dim = 128
+ mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim)
+ for _ in range(10):
+ length = random.randrange(1, 10_000)
+ x = torch.randn(2, channels, length)
+ res = mbd.regenerate(x, sample_rate)
+ assert res.shape == x.shape
diff --git a/tests/models/test_musicgen.py b/tests/models/test_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..65618a9e2ef5bb382694b50b23dd50958d590d4e
--- /dev/null
+++ b/tests/models/test_musicgen.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+import torch
+
+from audiocraft.models import MusicGen
+
+
+class TestMusicGenModel:
+ def get_musicgen(self):
+ mg = MusicGen.get_pretrained(name='debug', device='cpu')
+ mg.set_generation_params(duration=2.0, extend_stride=2.)
+ return mg
+
+ def test_base(self):
+ mg = self.get_musicgen()
+ assert mg.frame_rate == 25
+ assert mg.sample_rate == 32000
+ assert mg.audio_channels == 1
+
+ def test_generate_unconditional(self):
+ mg = self.get_musicgen()
+ wav = mg.generate_unconditional(3)
+ assert list(wav.shape) == [3, 1, 64000]
+
+ def test_generate_continuation(self):
+ mg = self.get_musicgen()
+ prompt = torch.randn(3, 1, 32000)
+ wav = mg.generate_continuation(prompt, 32000)
+ assert list(wav.shape) == [3, 1, 64000]
+
+ prompt = torch.randn(2, 1, 32000)
+ wav = mg.generate_continuation(
+ prompt, 32000, ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 64000]
+
+ prompt = torch.randn(2, 1, 32000)
+ with pytest.raises(AssertionError):
+ wav = mg.generate_continuation(
+ prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
+
+ def test_generate(self):
+ mg = self.get_musicgen()
+ wav = mg.generate(
+ ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 64000]
+
+ def test_generate_long(self):
+ mg = self.get_musicgen()
+ mg.max_duration = 3.
+ mg.set_generation_params(duration=4., extend_stride=2.)
+ wav = mg.generate(
+ ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 32000 * 4]
diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/tests/modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/tests/modules/test_activations.py b/tests/modules/test_activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..24e30d4cd87683430488bfa442e098b34229a5ee
--- /dev/null
+++ b/tests/modules/test_activations.py
@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+
+from audiocraft.modules.activations import CustomGLU
+
+
+class TestActivations:
+ def test_custom_glu_calculation(self):
+
+ activation = CustomGLU(nn.Identity())
+
+ initial_shape = (4, 8, 8)
+
+ part_a = torch.ones(initial_shape) * 2
+ part_b = torch.ones(initial_shape) * -1
+ input = torch.cat((part_a, part_b), dim=-1)
+
+ output = activation(input)
+
+ # ensure all dimensions match initial shape
+ assert output.shape == initial_shape
+ # ensure the gating was calculated correctly a * f(b)
+ assert torch.all(output == -2).item()
diff --git a/tests/modules/test_codebooks_patterns.py b/tests/modules/test_codebooks_patterns.py
new file mode 100644
index 0000000000000000000000000000000000000000..b658f4779a369f9ec8dde692a61b7f0fe3485724
--- /dev/null
+++ b/tests/modules/test_codebooks_patterns.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+import torch
+
+from audiocraft.modules.codebooks_patterns import (
+ DelayedPatternProvider,
+ ParallelPatternProvider,
+ Pattern,
+ UnrolledPatternProvider,
+)
+
+
+class TestParallelPatternProvider:
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
+ def test_get_pattern(self, n_q: int, timesteps: int):
+ provider = ParallelPatternProvider(n_q)
+ pattern = provider.get_pattern(timesteps)
+ # + 1 to account for 1st step
+ assert len(pattern.layout) == timesteps + 1
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [8, 16, 100])
+ def test_pattern_content(self, n_q: int, timesteps: int):
+ provider = ParallelPatternProvider(n_q)
+ pattern = provider.get_pattern(timesteps)
+ for s, v in enumerate(pattern.layout):
+ for i, code in enumerate(v):
+ assert i == code.q
+ assert code.t == s - 1 # account for the 1st empty step
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [8, 16, 100])
+ def test_pattern_max_delay(self, n_q: int, timesteps: int):
+ provider = ParallelPatternProvider(n_q)
+ pattern = provider.get_pattern(timesteps)
+ assert pattern.max_delay == 0
+ assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
+
+
+class TestDelayedPatternProvider:
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
+ def test_get_pattern(self, n_q: int, timesteps: int):
+ delays = [
+ list(range(n_q)),
+ [0] + [1] * (n_q - 1),
+ [0] + [4] * (n_q - 1),
+ ]
+ for delay in delays:
+ provider = DelayedPatternProvider(n_q, delay)
+ pattern = provider.get_pattern(timesteps)
+ # + 1 to account for 1st step
+ assert len(pattern.layout) == timesteps + max(delay) + 1
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [8, 16, 100])
+ def test_pattern_content(self, n_q: int, timesteps: int):
+ provider = DelayedPatternProvider(n_q)
+ pattern = provider.get_pattern(timesteps)
+ for s, v in enumerate(pattern.layout):
+ for i, code in enumerate(v):
+ assert i == code.q
+ assert code.t == max(0, s - code.q - 1)
+
+ @pytest.mark.parametrize("timesteps", [8, 16, 100])
+ @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
+ def test_pattern_max_delay(self, timesteps: int, delay: list):
+ provider = DelayedPatternProvider(len(delay), delay)
+ pattern = provider.get_pattern(timesteps)
+ assert pattern.max_delay == max(delay)
+ assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
+
+
+class TestUnrolledPatternProvider:
+
+ @pytest.mark.parametrize("timesteps", [0, 1, 16])
+ @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
+ @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
+ def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
+ n_q = len(flattening)
+ max_delay = max(delays)
+ provider = UnrolledPatternProvider(n_q, flattening, delays)
+ pattern = provider.get_pattern(timesteps)
+ assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
+
+ @pytest.mark.parametrize("timesteps", [0, 1, 16])
+ @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
+ @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
+ def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
+ n_q = len(flattening)
+ max_delay = max(delays)
+ provider = UnrolledPatternProvider(n_q, flattening, delays)
+ pattern = provider.get_pattern(timesteps)
+ assert pattern.max_delay == max_delay
+
+
+class TestPattern:
+
+ def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
+ """Reference method to build the sequence from the pattern without using fancy scatter."""
+ bs, n_q, T = z.shape
+ z = z.cpu().numpy()
+ assert n_q == pattern.n_q
+ assert T <= pattern.timesteps
+ inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
+ inp[:] = special_token
+ for s, v in enumerate(pattern.layout):
+ for (t, q) in v:
+ if t < T:
+ inp[:, q, s] = z[:, q, t]
+ return torch.from_numpy(inp)
+
+ def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
+ """Reference method to revert the sequence from the pattern without using fancy scatter."""
+ z = z.cpu().numpy()
+ bs, n_q, S = z.shape
+ assert pattern.n_q == n_q
+ inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
+ inp[:] = special_token
+ for s, v in enumerate(pattern.layout):
+ for (t, q) in v:
+ if t < pattern.timesteps:
+ inp[:, q, t] = z[:, q, s]
+ return torch.from_numpy(inp)
+
+ def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
+ """Reference method to revert the logits from the pattern without using fancy scatter."""
+ z = z.cpu().numpy()
+ bs, card, n_q, S = z.shape
+ assert pattern.n_q == n_q
+ ref_layout = pattern.layout
+ inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
+ inp[:] = special_token
+ for s, v in enumerate(ref_layout[1:]):
+ if s < S:
+ for (t, q) in v:
+ if t < pattern.timesteps:
+ inp[:, :, q, t] = z[:, :, q, s]
+ return torch.from_numpy(inp)
+
+ def _get_pattern_providers(self, n_q: int):
+ pattern_provider_1 = ParallelPatternProvider(n_q)
+ pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
+ pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
+ pattern_provider_4 = UnrolledPatternProvider(
+ n_q, flattening=list(range(n_q)), delays=[0] * n_q
+ )
+ pattern_provider_5 = UnrolledPatternProvider(
+ n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
+ )
+ pattern_provider_6 = UnrolledPatternProvider(
+ n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
+ )
+ return [
+ pattern_provider_1,
+ pattern_provider_2,
+ pattern_provider_3,
+ pattern_provider_4,
+ pattern_provider_5,
+ pattern_provider_6,
+ ]
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [16, 72])
+ def test_build_pattern_sequence(self, n_q: int, timesteps: int):
+ bs = 2
+ card = 256
+ special_token = card
+
+ pattern_providers = self._get_pattern_providers(n_q)
+ for pattern_provider in pattern_providers:
+ pattern = pattern_provider.get_pattern(timesteps)
+ # we can correctly build the sequence from the pattern
+ z = torch.randint(0, card, (bs, n_q, timesteps))
+ ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
+ res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
+ assert (res == ref_res).float().mean() == 1.0
+
+ # expected assertion fails on the number of timesteps
+ invalid_timesteps = [timesteps + 1]
+ if pattern.num_sequence_steps != pattern.timesteps:
+ invalid_timesteps.append(pattern.num_sequence_steps)
+ for i_timesteps in invalid_timesteps:
+ z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
+ with pytest.raises(AssertionError):
+ pattern.build_pattern_sequence(z2, special_token)
+
+ # expected assertion fails on the number of codebooks
+ invalid_qs = [0, n_q - 1, n_q + 1]
+ for i_q in invalid_qs:
+ z3 = torch.randint(0, card, (bs, i_q, timesteps))
+ with pytest.raises(AssertionError):
+ pattern.build_pattern_sequence(z3, special_token)
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [16, 72])
+ def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
+ bs = 2
+ card = 256
+ special_token = card
+
+ pattern_providers = self._get_pattern_providers(n_q)
+ for pattern_provider in pattern_providers:
+ pattern = pattern_provider.get_pattern(timesteps)
+ # this works assuming previous tests are successful
+ z = torch.randint(0, card, (bs, n_q, timesteps))
+ s = self.ref_build_pattern_sequence(z, pattern, special_token)
+ ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
+ # ensure our reference script retrieve the original sequence
+ assert z.shape == ref_out.shape
+ assert (z == ref_out).float().mean() == 1.0
+ # now we can test the scatter version
+ out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
+ assert out.shape == ref_out.shape
+ assert (out == ref_out).float().mean() == 1.0
+
+ @pytest.mark.parametrize("n_q", [1, 4, 32])
+ @pytest.mark.parametrize("timesteps", [16, 72])
+ @pytest.mark.parametrize("card", [1, 2, 256, 1024])
+ def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
+ bs = 2
+ special_token = card
+ logits_special_token = float('nan')
+
+ pattern_providers = self._get_pattern_providers(n_q)
+ for pattern_provider in pattern_providers:
+ pattern = pattern_provider.get_pattern(timesteps)
+ # this works assuming previous tests are successful
+ z = torch.randint(0, card, (bs, n_q, timesteps))
+ s = self.ref_build_pattern_sequence(z, pattern, special_token)
+ logits = torch.randn((bs, card, n_q, s.shape[-1]))
+ ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
+ # ensure our reference script retrieve the original sequence
+ assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
+ # now we can test the scatter version
+ out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
+ assert out.shape == ref_out.shape
+ assert (out == ref_out).float().mean() == 1.0
diff --git a/tests/modules/test_conv.py b/tests/modules/test_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..28fbc4f1a0ebaf41b56947b767958ae696e75eec
--- /dev/null
+++ b/tests/modules/test_conv.py
@@ -0,0 +1,203 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from itertools import product
+import math
+import random
+
+import pytest
+import torch
+from torch import nn
+
+from audiocraft.modules import (
+ NormConv1d,
+ NormConvTranspose1d,
+ StreamableConv1d,
+ StreamableConvTranspose1d,
+ pad1d,
+ unpad1d,
+)
+
+
+def test_get_extra_padding_for_conv1d():
+ # TODO: Implement me!
+ pass
+
+
+def test_pad1d_zeros():
+ x = torch.randn(1, 1, 20)
+
+ xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
+ assert xp1.shape[-1] == 25
+ xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
+ assert xp2.shape[-1] == 30
+ xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
+ assert xp3.shape[-1] == 20
+ xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
+ assert xp4.shape[-1] == 60
+
+ with pytest.raises(AssertionError):
+ pad1d(x, (-1, 0), mode='constant', value=0.)
+
+ with pytest.raises(AssertionError):
+ pad1d(x, (0, -1), mode='constant', value=0.)
+
+ with pytest.raises(AssertionError):
+ pad1d(x, (-1, -1), mode='constant', value=0.)
+
+
+def test_pad1d_reflect():
+ x = torch.randn(1, 1, 20)
+
+ xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
+ assert xp1.shape[-1] == 25
+ xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
+ assert xp2.shape[-1] == 30
+ xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
+ assert xp3.shape[-1] == 20
+ xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
+ assert xp4.shape[-1] == 60
+
+ with pytest.raises(AssertionError):
+ pad1d(x, (-1, 0), mode='reflect', value=0.)
+
+ with pytest.raises(AssertionError):
+ pad1d(x, (0, -1), mode='reflect', value=0.)
+
+ with pytest.raises(AssertionError):
+ pad1d(x, (-1, -1), mode='reflect', value=0.)
+
+
+def test_unpad1d():
+ x = torch.randn(1, 1, 20)
+
+ u1 = unpad1d(x, (5, 5))
+ assert u1.shape[-1] == 10
+ u2 = unpad1d(x, (0, 5))
+ assert u2.shape[-1] == 15
+ u3 = unpad1d(x, (5, 0))
+ assert u3.shape[-1] == 15
+ u4 = unpad1d(x, (0, 0))
+ assert u4.shape[-1] == x.shape[-1]
+
+ with pytest.raises(AssertionError):
+ unpad1d(x, (-1, 0))
+
+ with pytest.raises(AssertionError):
+ unpad1d(x, (0, -1))
+
+ with pytest.raises(AssertionError):
+ unpad1d(x, (-1, -1))
+
+
+class TestNormConv1d:
+
+ def test_norm_conv1d_modules(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+
+ C_out, kernel_size, stride = 1, 4, 1
+ expected_out_length = int((T - kernel_size) / stride + 1)
+ wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
+ gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
+ nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
+
+ assert isinstance(wn_conv.norm, nn.Identity)
+ assert isinstance(wn_conv.conv, nn.Conv1d)
+
+ assert isinstance(gn_conv.norm, nn.GroupNorm)
+ assert isinstance(gn_conv.conv, nn.Conv1d)
+
+ assert isinstance(nn_conv.norm, nn.Identity)
+ assert isinstance(nn_conv.conv, nn.Conv1d)
+
+ for conv_layer in [wn_conv, gn_conv, nn_conv]:
+ out = conv_layer(t0)
+ assert isinstance(out, torch.Tensor)
+ assert list(out.shape) == [N, C_out, expected_out_length]
+
+
+class TestNormConvTranspose1d:
+
+ def test_normalizations(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+
+ C_out, kernel_size, stride = 1, 4, 1
+ expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
+
+ wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
+ gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
+ nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
+
+ assert isinstance(wn_convtr.norm, nn.Identity)
+ assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
+
+ assert isinstance(gn_convtr.norm, nn.GroupNorm)
+ assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
+
+ assert isinstance(nn_convtr.norm, nn.Identity)
+ assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
+
+ for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
+ out = convtr_layer(t0)
+ assert isinstance(out, torch.Tensor)
+ assert list(out.shape) == [N, C_out, expected_out_length]
+
+
+class TestStreamableConv1d:
+
+ def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
+ # StreamableConv1d internally pads to make sure that the last window is full
+ padding_total = (kernel_size - 1) * dilation - (stride - 1)
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length // stride
+
+ def test_streamable_conv1d(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+ C_out = 1
+
+ # conv params are [(kernel_size, stride, dilation)]
+ conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
+ for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
+ expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
+ sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
+ out = sconv(t0)
+ assert isinstance(out, torch.Tensor)
+ print(list(out.shape), [N, C_out, expected_out_length])
+ assert list(out.shape) == [N, C_out, expected_out_length]
+
+
+class TestStreamableConvTranspose1d:
+
+ def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
+ padding_total = (kernel_size - stride)
+ return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
+
+ def test_streamable_convtr1d(self):
+ N, C, T = 2, 2, random.randrange(1, 100_000)
+ t0 = torch.randn(N, C, T)
+
+ C_out = 1
+
+ with pytest.raises(AssertionError):
+ StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
+ StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
+ StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
+
+ # causal params are [(causal, trim_right)]
+ causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
+ # conv params are [(kernel_size, stride)]
+ conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
+ for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
+ expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
+ sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
+ causal=causal, trim_right_ratio=trim_right_ratio)
+ out = sconvtr(t0)
+ assert isinstance(out, torch.Tensor)
+ assert list(out.shape) == [N, C_out, expected_out_length]
diff --git a/tests/modules/test_lstm.py b/tests/modules/test_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1248964c8191e19f27661f0974bef9cc967eb015
--- /dev/null
+++ b/tests/modules/test_lstm.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+
+from audiocraft.modules.lstm import StreamableLSTM
+
+
+class TestStreamableLSTM:
+
+ def test_lstm(self):
+ B, C, T = 4, 2, random.randint(1, 100)
+
+ lstm = StreamableLSTM(C, 3, skip=False)
+ x = torch.randn(B, C, T)
+ y = lstm(x)
+
+ print(y.shape)
+ assert y.shape == torch.Size([B, C, T])
+
+ def test_lstm_skip(self):
+ B, C, T = 4, 2, random.randint(1, 100)
+
+ lstm = StreamableLSTM(C, 3, skip=True)
+ x = torch.randn(B, C, T)
+ y = lstm(x)
+
+ assert y.shape == torch.Size([B, C, T])
diff --git a/tests/modules/test_rope.py b/tests/modules/test_rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..067c6f067acbf27fb0fef5c2b812c22474c4fcd0
--- /dev/null
+++ b/tests/modules/test_rope.py
@@ -0,0 +1,168 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from audiocraft.modules.rope import RotaryEmbedding
+from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend
+
+
+def test_rope():
+ set_efficient_attention_backend('xformers')
+ B, T, H, C = 8, 75, 16, 128
+
+ rope = RotaryEmbedding(dim=C)
+ xq = torch.rand((B, T, H, C))
+ xk = torch.rand((B, T, H, C))
+ xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
+
+ assert list(xq_out.shape) == [B, T, H, C]
+ assert list(xk_out.shape) == [B, T, H, C]
+
+
+def test_rope_io_dtypes():
+ set_efficient_attention_backend('xformers')
+ B, T, H, C = 8, 75, 16, 128
+
+ rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
+ rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
+
+ # Test bfloat16 inputs w/ both 32 and 64 precision rope.
+ xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
+ xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
+ xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
+ assert xq_out.dtype == torch.bfloat16
+ xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
+ assert xq_out.dtype == torch.bfloat16
+
+ # Test float32 inputs w/ both 32 and 64 precision rope.
+ xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
+ xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
+ xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
+ assert xq_out.dtype == torch.float32
+ xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
+ assert xq_out.dtype == torch.float32
+
+
+def test_transformer_with_rope():
+ set_efficient_attention_backend('xformers')
+ torch.manual_seed(1234)
+ for pos in ['rope', 'sin_rope']:
+ tr = StreamingTransformer(
+ 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
+ positional_embedding=pos)
+ tr.eval()
+ steps = 12
+ x = torch.randn(3, steps, 16)
+
+ out = tr(x)
+ assert list(out.shape) == list(x.shape)
+
+
+@torch.no_grad()
+def test_rope_streaming():
+ set_efficient_attention_backend('xformers')
+ torch.manual_seed(1234)
+ tr = StreamingTransformer(
+ 16, 4, 2, causal=True, dropout=0.,
+ custom=True, positional_embedding='rope')
+ tr.eval()
+ steps = 12
+ x = torch.randn(3, steps, 16)
+
+ ref = tr(x)
+
+ with tr.streaming():
+ outs = []
+ frame_sizes = [1] * steps
+
+ for frame_size in frame_sizes:
+ frame = x[:, :frame_size]
+ x = x[:, frame_size:]
+ outs.append(tr(frame))
+
+ out = torch.cat(outs, dim=1)
+ assert list(out.shape) == [3, steps, 16]
+ delta = torch.norm(out - ref) / torch.norm(out)
+ assert delta < 1e-6, delta
+
+
+@torch.no_grad()
+def test_rope_streaming_past_context():
+ set_efficient_attention_backend('xformers')
+ torch.manual_seed(1234)
+
+ for context in [None, 10]:
+ tr = StreamingTransformer(
+ 16, 4, 1 if context else 2,
+ causal=True, past_context=context, custom=True,
+ dropout=0., positional_embedding='rope')
+ tr.eval()
+
+ steps = 20
+ x = torch.randn(3, steps, 16)
+ ref = tr(x)
+
+ with tr.streaming():
+ outs = []
+ frame_sizes = [1] * steps
+
+ for frame_size in frame_sizes:
+ frame = x[:, :frame_size]
+ x = x[:, frame_size:]
+ outs.append(tr(frame))
+
+ out = torch.cat(outs, dim=1)
+ assert list(out.shape) == [3, steps, 16]
+ delta = torch.norm(out - ref) / torch.norm(out)
+ assert delta < 1e-6, delta
+
+
+def test_rope_memory_efficient():
+ set_efficient_attention_backend('xformers')
+ torch.manual_seed(1234)
+ tr = StreamingTransformer(
+ 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
+ positional_embedding='rope')
+ tr_mem_efficient = StreamingTransformer(
+ 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
+ positional_embedding='rope')
+ tr_mem_efficient.load_state_dict(tr.state_dict())
+ tr.eval()
+ steps = 12
+ x = torch.randn(3, steps, 16)
+
+ with torch.no_grad():
+ y = tr(x)
+ y2 = tr_mem_efficient(x)
+ # Check at float precision b/c this is the rope default.
+ assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
+
+
+def test_rope_with_xpos():
+ set_efficient_attention_backend('xformers')
+ B, T, H, C = 8, 75, 16, 128
+
+ rope = RotaryEmbedding(dim=C, xpos=True)
+ xq = torch.rand((B, T, H, C))
+ xk = torch.rand((B, T, H, C))
+ xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
+
+ assert list(xq_out.shape) == [B, T, H, C]
+ assert list(xk_out.shape) == [B, T, H, C]
+
+
+def test_positional_scale():
+ set_efficient_attention_backend('xformers')
+ B, T, H, C = 8, 75, 16, 128
+
+ rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
+ xq = torch.rand((B, T, H, C))
+ xk = torch.rand((B, T, H, C))
+ xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
+
+ assert torch.allclose(xq, xq_out)
+ assert torch.allclose(xk, xk_out)
diff --git a/tests/modules/test_seanet.py b/tests/modules/test_seanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c51b340a2f94fb2828b14daf83d5fad645073d
--- /dev/null
+++ b/tests/modules/test_seanet.py
@@ -0,0 +1,115 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from itertools import product
+
+import pytest
+import torch
+
+from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
+from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
+
+
+class TestSEANetModel:
+
+ def test_base(self):
+ encoder = SEANetEncoder()
+ decoder = SEANetDecoder()
+
+ x = torch.randn(1, 1, 24000)
+ z = encoder(x)
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+ def test_causal(self):
+ encoder = SEANetEncoder(causal=True)
+ decoder = SEANetDecoder(causal=True)
+ x = torch.randn(1, 1, 24000)
+
+ z = encoder(x)
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+ def test_conv_skip_connection(self):
+ encoder = SEANetEncoder(true_skip=False)
+ decoder = SEANetDecoder(true_skip=False)
+
+ x = torch.randn(1, 1, 24000)
+ z = encoder(x)
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+ def test_seanet_encoder_decoder_final_act(self):
+ encoder = SEANetEncoder(true_skip=False)
+ decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
+
+ x = torch.randn(1, 1, 24000)
+ z = encoder(x)
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+ def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
+ n_blocks = 0
+ for layer in encoder.model:
+ if isinstance(layer, StreamableConv1d):
+ n_blocks += 1
+ assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
+ elif isinstance(layer, SEANetResnetBlock):
+ for resnet_layer in layer.block:
+ if isinstance(resnet_layer, StreamableConv1d):
+ # here we add + 1 to n_blocks as we increment n_blocks just after the block
+ assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
+
+ def test_encoder_disable_norm(self):
+ n_residuals = [0, 1, 3]
+ disable_blocks = [0, 1, 2, 3, 4, 5, 6]
+ norms = ['weight_norm', 'none']
+ for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
+ encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
+ disable_norm_outer_blocks=disable_blocks)
+ self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
+
+ def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
+ n_blocks = 0
+ for layer in decoder.model:
+ if isinstance(layer, StreamableConv1d):
+ n_blocks += 1
+ assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
+ elif isinstance(layer, StreamableConvTranspose1d):
+ n_blocks += 1
+ assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
+ elif isinstance(layer, SEANetResnetBlock):
+ for resnet_layer in layer.block:
+ if isinstance(resnet_layer, StreamableConv1d):
+ assert resnet_layer.conv.norm_type == 'none' \
+ if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
+
+ def test_decoder_disable_norm(self):
+ n_residuals = [0, 1, 3]
+ disable_blocks = [0, 1, 2, 3, 4, 5, 6]
+ norms = ['weight_norm', 'none']
+ for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
+ decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
+ disable_norm_outer_blocks=disable_blocks)
+ self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
+
+ def test_disable_norm_raises_exception(self):
+ # Invalid disable_norm_outer_blocks values raise exceptions
+ with pytest.raises(AssertionError):
+ SEANetEncoder(disable_norm_outer_blocks=-1)
+
+ with pytest.raises(AssertionError):
+ SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
+
+ with pytest.raises(AssertionError):
+ SEANetDecoder(disable_norm_outer_blocks=-1)
+
+ with pytest.raises(AssertionError):
+ SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
diff --git a/tests/modules/test_transformer.py b/tests/modules/test_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bb79bfd58d535469f9b3c56b8a5fe254db5d8ba
--- /dev/null
+++ b/tests/modules/test_transformer.py
@@ -0,0 +1,253 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from itertools import product
+
+import pytest
+import torch
+
+from audiocraft.modules.transformer import (
+ StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend)
+
+
+def test_transformer_causal_streaming():
+ torch.manual_seed(1234)
+
+ for context, custom in product([None, 10], [False, True]):
+ # Test that causality and receptive fields are properly handled.
+ # looking at the gradients
+ tr = StreamingTransformer(
+ 16, 4, 1 if context else 2,
+ causal=True, past_context=context, custom=custom,
+ dropout=0.)
+ steps = 20
+ for k in [0, 10, 15, 19]:
+ x = torch.randn(4, steps, 16, requires_grad=True)
+ y = tr(x)
+ y[:, k].abs().sum().backward()
+ if k + 1 < steps:
+ assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
+ assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
+ if context is not None and k > context:
+ limit = k - context - 1
+ assert torch.allclose(x.grad[:, :limit],
+ torch.tensor(0.)), x.grad[:, :limit].norm()
+
+ # Now check that streaming gives the same result at batch eval.
+ x = torch.randn(4, steps, 16)
+ y = tr(x)
+ ys = []
+ with tr.streaming():
+ for k in range(steps):
+ chunk = x[:, k:k + 1, :]
+ ys.append(tr(chunk))
+ y_stream = torch.cat(ys, dim=1)
+ delta = torch.norm(y_stream - y) / torch.norm(y)
+ assert delta < 1e-6, delta
+
+
+def test_transformer_vs_pytorch():
+ torch.manual_seed(1234)
+ # Check that in the non causal setting, we get the same result as
+ # PyTorch Transformer encoder.
+ for custom in [False, True]:
+ tr = StreamingTransformer(
+ 16, 4, 2,
+ causal=False, custom=custom, dropout=0., positional_scale=0.)
+ layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
+ tr_ref = torch.nn.TransformerEncoder(layer, 2)
+ tr.load_state_dict(tr_ref.state_dict())
+
+ x = torch.randn(4, 20, 16)
+ y = tr(x)
+ y2 = tr_ref(x)
+ delta = torch.norm(y2 - y) / torch.norm(y)
+ assert delta < 1e-6, delta
+
+
+def test_streaming_api():
+ tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
+ tr.eval()
+ steps = 12
+ x = torch.randn(1, steps, 16)
+
+ with torch.no_grad():
+ with tr.streaming():
+ _ = tr(x[:, :1])
+ state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
+ y = tr(x[:, 1:2])
+ tr.set_streaming_state(state)
+ y2 = tr(x[:, 1:2])
+ assert torch.allclose(y, y2), (y - y2).norm()
+ assert tr.flush() is None
+
+
+def test_memory_efficient():
+ for backend in ['torch', 'xformers']:
+ torch.manual_seed(1234)
+ set_efficient_attention_backend(backend)
+
+ tr = StreamingTransformer(
+ 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
+ tr_mem_efficient = StreamingTransformer(
+ 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
+ tr_mem_efficient.load_state_dict(tr.state_dict())
+ tr.eval()
+ steps = 12
+ x = torch.randn(3, steps, 16)
+
+ with torch.no_grad():
+ y = tr(x)
+ y2 = tr_mem_efficient(x)
+ assert torch.allclose(y, y2), ((y - y2).norm(), backend)
+
+
+def test_attention_as_float32():
+ torch.manual_seed(1234)
+ cases = [
+ {'custom': True},
+ {'custom': False},
+ ]
+ for case in cases:
+ tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
+ tr_float32 = StreamingTransformer(
+ 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
+ if not case['custom']:
+ # we are not using autocast here because it doesn't really
+ # work as expected on CPU, so we have to manually cast the weights of the MHA.
+ for layer in tr_float32.layers:
+ layer.self_attn.mha.to(torch.float32)
+ tr_float32.load_state_dict(tr.state_dict())
+ steps = 12
+ x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
+
+ with torch.no_grad():
+ y = tr(x)
+ y2 = tr_float32(x)
+ assert not torch.allclose(y, y2), (y - y2).norm()
+
+
+@torch.no_grad()
+def test_streaming_memory_efficient():
+ for backend in ['torch', 'xformers']:
+ torch.manual_seed(1234)
+ set_efficient_attention_backend(backend)
+ tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
+ tr_mem_efficient = StreamingTransformer(
+ 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
+ tr.load_state_dict(tr_mem_efficient.state_dict())
+ tr.eval()
+ tr_mem_efficient.eval()
+ steps = 12
+ x = torch.randn(3, steps, 16)
+
+ ref = tr(x)
+
+ with tr_mem_efficient.streaming():
+ outs = []
+ # frame_sizes = [2] + [1] * (steps - 2)
+ frame_sizes = [1] * steps
+
+ for frame_size in frame_sizes:
+ frame = x[:, :frame_size]
+ x = x[:, frame_size:]
+ outs.append(tr_mem_efficient(frame))
+
+ out = torch.cat(outs, dim=1)
+ delta = torch.norm(out - ref) / torch.norm(out)
+ assert delta < 1e-6, delta
+
+
+def test_cross_attention():
+ torch.manual_seed(1234)
+ for norm_first in [True, False]:
+ m = StreamingTransformer(
+ 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
+ m_cross = StreamingTransformer(
+ 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
+ m_cross.load_state_dict(m.state_dict(), strict=False)
+ x = torch.randn(2, 5, 16)
+ cross_x = torch.randn(2, 3, 16)
+ y_ref = m(x)
+ y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
+ # With norm_first, the two should be exactly the same,
+ # but with norm_first=False, we get 2 normalization in a row
+ # and the epsilon value leads to a tiny change.
+ atol = 0. if norm_first else 1e-6
+ print((y_ref - y_cross_zero).norm() / y_ref.norm())
+ assert torch.allclose(y_ref, y_cross_zero, atol=atol)
+
+ # We now expect a difference even with a generous atol of 1e-2.
+ y_cross = m_cross(x, cross_attention_src=cross_x)
+ assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
+
+ with pytest.raises(AssertionError):
+ _ = m_cross(x)
+ _ = m(x, cross_attention_src=cross_x)
+
+
+def test_cross_attention_compat():
+ torch.manual_seed(1234)
+ num_heads = 2
+ dim = num_heads * 64
+ with pytest.raises(AssertionError):
+ StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
+
+ cross_attn = StreamingMultiheadAttention(
+ dim, num_heads, dropout=0, cross_attention=True, custom=True)
+ ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
+
+ # We can load the regular attention state dict
+ # so we have compat when loading old checkpoints.
+ cross_attn.load_state_dict(ref_attn.state_dict())
+
+ queries = torch.randn(3, 7, dim)
+ keys = torch.randn(3, 9, dim)
+ values = torch.randn(3, 9, dim)
+
+ y = cross_attn(queries, keys, values)[0]
+ y_ref = ref_attn(queries, keys, values)[0]
+ assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm()
+
+ # Now let's check that streaming is working properly.
+ with cross_attn.streaming():
+ ys = []
+ for step in range(queries.shape[1]):
+ ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
+ y_streaming = torch.cat(ys, dim=1)
+ assert torch.allclose(y_streaming, y, atol=1e-7)
+
+
+def test_repeat_kv():
+ torch.manual_seed(1234)
+ num_heads = 8
+ kv_repeat = 4
+ dim = num_heads * 64
+ with pytest.raises(AssertionError):
+ mha = StreamingMultiheadAttention(
+ dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
+ mha = StreamingMultiheadAttention(
+ dim, num_heads, causal=True, kv_repeat=kv_repeat)
+ mha = StreamingMultiheadAttention(
+ dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
+ x = torch.randn(4, 18, dim)
+ y = mha(x, x, x)[0]
+ assert x.shape == y.shape
+
+
+def test_qk_layer_norm():
+ torch.manual_seed(1234)
+ tr = StreamingTransformer(
+ 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
+ steps = 12
+ x = torch.randn(3, steps, 16)
+ y = tr(x)
+
+ tr = StreamingTransformer(
+ 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
+ z = torch.randn(3, 21, 16)
+ y = tr(x, cross_attention_src=z)
+ assert y.shape == x.shape
diff --git a/tests/quantization/test_vq.py b/tests/quantization/test_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c215099fedacae35c6798fdd9b8420a447aa16bb
--- /dev/null
+++ b/tests/quantization/test_vq.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from audiocraft.quantization.vq import ResidualVectorQuantizer
+
+
+class TestResidualVectorQuantizer:
+
+ def test_rvq(self):
+ x = torch.randn(1, 16, 2048)
+ vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
+ res = vq(x, 1.)
+ assert res.x.shape == torch.Size([1, 16, 2048])
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/tests/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.