jamesdon commited on
Commit
296a9ec
1 Parent(s): 7c9edea

change to AudioGen

Browse files
Files changed (2) hide show
  1. handler.py +16 -21
  2. requirements.txt +1 -1
handler.py CHANGED
@@ -1,13 +1,19 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
- import torch
 
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # load model and processor from path
8
- path = "jamesdon/audiogen-medium-endpoint"
9
- self.processor = AutoProcessor.from_pretrained(path)
10
- self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
 
 
11
 
12
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
13
  """
@@ -16,22 +22,11 @@ class EndpointHandler:
16
  The payload with the text prompt and generation parameters.
17
  """
18
  # process input
19
- inputs = data.pop("inputs", data)
20
- parameters = data.pop("parameters", None)
21
-
22
- # preprocess
23
- inputs = self.processor(
24
- text=[inputs],
25
- padding=True,
26
- return_tensors="pt",).to("cuda")
27
-
28
- # pass inputs with all kwargs in data
29
- if parameters is not None:
30
- outputs = self.model.generate(**inputs, **parameters)
31
- else:
32
- outputs = self.model.generate(**inputs)
33
-
34
- # postprocess the prediction
35
  prediction = outputs[0].cpu().numpy()
36
 
37
  return [{"generated_audio": prediction}]
 
1
  from typing import Dict, List, Any
2
+ # from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
+ # import torch
4
+
5
+ # import torchaudio
6
+ from audiocraft.models import AudioGen
7
+ from audiocraft.data.audio import audio_write
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
  # load model and processor from path
12
+ # path = "jamesdon/audiogen-medium-endpoint"
13
+ # self.processor = AutoProcessor.from_pretrained(path)
14
+ # self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
15
+ self.model = AudioGen.get_pretrained(path)
16
+
17
 
18
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
19
  """
 
22
  The payload with the text prompt and generation parameters.
23
  """
24
  # process input
25
+ inputs = data.pop("inputs", data) # list of string
26
+ duration = data.pop("duration", 5) # seconds to generate
27
+
28
+ self.model.set_generation_params(duration=duration)
29
+ outputs = self.model.generate(inputs)
 
 
 
 
 
 
 
 
 
 
 
30
  prediction = outputs[0].cpu().numpy()
31
 
32
  return [{"generated_audio": prediction}]
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  transformers==4.31.0
2
  accelerate>=0.20.3
3
- # audiocraft
 
1
  transformers==4.31.0
2
  accelerate>=0.20.3
3
+ audiocraft