--- license: mit datasets: - numind/NuNER language: - en pipeline_tag: zero-shot-classification tags: - asr - Automatic Speech Recognition - Whisper - Named entity recognition --- # Whisper-NER - Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107). - Code: https://github.com/aiola-lab/whisper-ner We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition. WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference. --------- ## Training Details `aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging. The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details. --------- ## Usage Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).: ```python import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration model_path = "aiola/whisper-ner-v1" audio_file_path = "path/to/audio/file" prompt = "person, company, location" # comma separated entity tags # load model and processor from pre-trained processor = WhisperProcessor.from_pretrained(model_path) model = WhisperForConditionalGeneration.from_pretrained(model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # load audio file: user is responsible for loading the audio files themselves target_sample_rate = 16000 signal, sampling_rate = torchaudio.load(audio_file_path) resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate) signal = resampler(signal) # convert to mono or remove first dim if needed if signal.ndim == 2: signal = torch.mean(signal, dim=0) # pre-process to get the input features input_features = processor( signal, sampling_rate=target_sample_rate, return_tensors="pt" ).input_features input_features = input_features.to(device) prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt") prompt_ids = prompt_ids.to(device) # generate token ids by running model forward sequentially with torch.no_grad(): predicted_ids = model.generate( input_features, prompt_ids=prompt_ids, generation_config=model.generation_config, language="en", ) # post-process token ids to text, remove prompt transcription = processor.batch_decode( predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True )[0] print(transcription) ```