|
from typing import Any, Dict, List |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline |
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
torch_dtype = torch.float16 |
|
device = "cuda" |
|
|
|
whisper_processor = WhisperProcessor.from_pretrained(path) |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained( |
|
path, |
|
torch_dtype=torch_dtype, |
|
).to("cuda") |
|
|
|
self.asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model=whisper_model, |
|
tokenizer=whisper_processor.tokenizer, |
|
feature_extractor=whisper_processor.feature_extractor, |
|
chunk_length_s=30, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
model_kwargs={"language": "pt"}, |
|
) |
|
|
|
def parse_audio(self, audio_bytes): |
|
audio_nparray = ffmpeg_read(audio_bytes, SAMPLE_RATE) |
|
|
|
if len(audio_nparray.shape) > 1 and audio_nparray.shape[1] == 2: |
|
return np.mean(audio_nparray, axis=1) |
|
|
|
return audio_nparray |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
parameters = data.get("parameters", {}) |
|
|
|
audio = self.parse_audio(data["inputs"]) |
|
|
|
return self.asr_pipeline(audio, **parameters) |
|
|