File size: 2,654 Bytes
03f6039 5fe2021 03f6039 5fe2021 03f6039 5fe2021 03f6039 96319b7 385e91b 5fe2021 385e91b b12baee 5fe2021 385e91b 5fe2021 7ed1e49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
---
license: mit
datasets:
- numind/NuNER
language:
- en
pipeline_tag: zero-shot-classification
tags:
- asr
- Automatic Speech Recognition
- Whisper
- Named entity recognition
---
# Whisper-NER
- Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107).
- Code: https://github.com/aiola-lab/whisper-ner
We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition.
WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference.
---------
## Training Details
`aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging.
The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details.
---------
## Usage
Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).:
```python
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_path = "aiola/whisper-ner-v1"
audio_file_path = "path/to/audio/file"
prompt = "person, company, location" # comma separated entity tags
# load model and processor from pre-trained
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# load audio file: user is responsible for loading the audio files themselves
target_sample_rate = 16000
signal, sampling_rate = torchaudio.load(audio_file_path)
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
signal = resampler(signal)
# convert to mono or remove first dim if needed
if signal.ndim == 2:
signal = torch.mean(signal, dim=0)
# pre-process to get the input features
input_features = processor(
signal, sampling_rate=target_sample_rate, return_tensors="pt"
).input_features
input_features = input_features.to(device)
prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
prompt_ids = prompt_ids.to(device)
# generate token ids by running model forward sequentially
with torch.no_grad():
predicted_ids = model.generate(
input_features,
prompt_ids=prompt_ids,
generation_config=model.generation_config,
language="en",
)
# post-process token ids to text, remove prompt
transcription = processor.batch_decode(
predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True
)[0]
print(transcription)
``` |