MoritzLaurer HF staff commited on
Commit
dff476f
1 Parent(s): fef91df

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -0
handler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from parler_tts import ParlerTTSForConditionalGeneration
3
+ from transformers import AutoTokenizer
4
+ import torch
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ # load model and processor from path
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = ParlerTTSForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
13
+ """
14
+ Args:
15
+ data (:dict:):
16
+ The payload with the text prompt and generation parameters.
17
+ """
18
+ # process input
19
+ inputs = data.pop("inputs", data)
20
+ voice_description = data.pop("voice_description", "data")
21
+ parameters = data.pop("parameters", None)
22
+
23
+ gen_kwargs = {"min_new_tokens": 10}
24
+ if parameters is not None:
25
+ gen_kwargs.update(parameters)
26
+
27
+ # preprocess
28
+ inputs = self.tokenizer(
29
+ text=[inputs],
30
+ padding=True,
31
+ return_tensors="pt",).to("cuda")
32
+ voice_description = self.tokenizer(
33
+ text=[voice_description],
34
+ padding=True,
35
+ return_tensors="pt",).to("cuda")
36
+
37
+ # pass inputs with all kwargs in data
38
+ with torch.autocast("cuda"):
39
+ outputs = self.model.generate(**voice_description, prompt_input_ids=inputs.input_ids, **gen_kwargs)
40
+
41
+ # postprocess the prediction
42
+ prediction = outputs[0].cpu().numpy().tolist()
43
+
44
+ return [{"generated_audio": prediction}]