jamesdon's picture
change to AudioGen
296a9ec
from typing import Dict, List, Any
# from transformers import AutoProcessor, MusicgenForConditionalGeneration
# import torch
# import torchaudio
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
# path = "jamesdon/audiogen-medium-endpoint"
# self.processor = AutoProcessor.from_pretrained(path)
# self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
self.model = AudioGen.get_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data) # list of string
duration = data.pop("duration", 5) # seconds to generate
self.model.set_generation_params(duration=duration)
outputs = self.model.generate(inputs)
prediction = outputs[0].cpu().numpy()
return [{"generated_audio": prediction}]