Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Main model for using MAGNeT. This will combine all the required components | |
and provide easy access to the generation API. | |
""" | |
import typing as tp | |
import torch | |
from .genmodel import BaseGenModel | |
from .loaders import load_compression_model, load_lm_model_magnet | |
class MAGNeT(BaseGenModel): | |
"""MAGNeT main model with convenient generation API. | |
Args: | |
See MusicGen class. | |
""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# MAGNeT operates over a fixed sequence length defined in it's config. | |
self.duration = self.lm.cfg.dataset.segment_duration | |
self.set_generation_params() | |
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None): | |
"""Return pretrained model, we provide six models: | |
- facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. | |
# see: https://huggingface.co/facebook/magnet-small-10secs | |
- facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. | |
# see: https://huggingface.co/facebook/magnet-medium-10secs | |
- facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. | |
# see: https://huggingface.co/facebook/magnet-small-30secs | |
- facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. | |
# see: https://huggingface.co/facebook/magnet-medium-30secs | |
- facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). | |
# see: https://huggingface.co/facebook/audio-magnet-small | |
- facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). | |
# see: https://huggingface.co/facebook/audio-magnet-medium | |
""" | |
if device is None: | |
if torch.cuda.device_count(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
compression_model = load_compression_model(name, device=device) | |
lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device) | |
if 'self_wav' in lm.condition_provider.conditioners: | |
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True | |
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} | |
return MAGNeT(**kwargs) | |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0, | |
top_p: float = 0.9, temperature: float = 3.0, | |
max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0, | |
decoding_steps: tp.List[int] = [20, 10, 10, 10], | |
span_arrangement: str = 'nonoverlap'): | |
"""Set the generation parameters for MAGNeT. | |
Args: | |
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. | |
top_k (int, optional): top_k used for sampling. Defaults to 0. | |
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9. | |
temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0. | |
max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0. | |
min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0. | |
decoding_steps (list of n_q ints, optional): The number of iterative decoding steps, | |
for each of the n_q RVQ codebooks. | |
span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap') | |
or overlapping spans ('stride1') in the masking scheme. | |
""" | |
self.generation_params = { | |
'use_sampling': use_sampling, | |
'temp': temperature, | |
'top_k': top_k, | |
'top_p': top_p, | |
'max_cfg_coef': max_cfg_coef, | |
'min_cfg_coef': min_cfg_coef, | |
'decoding_steps': [int(s) for s in decoding_steps], | |
'span_arrangement': span_arrangement | |
} | |