File size: 2,654 Bytes
03f6039
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe2021
03f6039
5fe2021
 
 
 
03f6039
 
 
 
 
5fe2021
 
03f6039
 
 
 
96319b7
 
385e91b
 
 
 
 
5fe2021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385e91b
b12baee
 
 
5fe2021
385e91b
 
5fe2021
 
 
 
 
7ed1e49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
---
license: mit
datasets:
- numind/NuNER
language:
- en
pipeline_tag: zero-shot-classification
tags:
- asr
- Automatic Speech Recognition
- Whisper
- Named entity recognition
---

# Whisper-NER

- Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107).
- Code: https://github.com/aiola-lab/whisper-ner

We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition. 
WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference.

---------

## Training Details
`aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging. 
The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details.

---------

## Usage

Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).:

```python
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration

model_path = "aiola/whisper-ner-v1"
audio_file_path = "path/to/audio/file"
prompt = "person, company, location"  # comma separated entity tags
    
# load model and processor from pre-trained
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# load audio file: user is responsible for loading the audio files themselves
target_sample_rate = 16000
signal, sampling_rate = torchaudio.load(audio_file_path)
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
signal = resampler(signal)
# convert to mono or remove first dim if needed
if signal.ndim == 2:
    signal = torch.mean(signal, dim=0)
# pre-process to get the input features
input_features = processor(
    signal, sampling_rate=target_sample_rate, return_tensors="pt"
).input_features
input_features = input_features.to(device)

prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
prompt_ids = prompt_ids.to(device)

# generate token ids by running model forward sequentially
with torch.no_grad():
    predicted_ids = model.generate(
        input_features,
        prompt_ids=prompt_ids,
        generation_config=model.generation_config,
        language="en",
    )

# post-process token ids to text, remove prompt
transcription = processor.batch_decode(
    predicted_ids[:, prompt_ids.shape[0]:], skip_special_tokens=True
)[0]
print(transcription)
```