from typing import Dict, List, Any # from transformers import AutoProcessor, MusicgenForConditionalGeneration # import torch # import torchaudio from audiocraft.models import AudioGen from audiocraft.data.audio import audio_write class EndpointHandler: def __init__(self, path=""): # load model and processor from path # path = "jamesdon/audiogen-medium-endpoint" # self.processor = AutoProcessor.from_pretrained(path) # self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda") self.model = AudioGen.get_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) # list of string duration = data.pop("duration", 5) # seconds to generate self.model.set_generation_params(duration=duration) outputs = self.model.generate(inputs) prediction = outputs[0].cpu().numpy() return [{"generated_audio": prediction}]