jimbozhang
commited on
Commit
•
0f8bf48
1
Parent(s):
945a279
Update README.md
Browse files
README.md
CHANGED
@@ -21,7 +21,6 @@ CED are simple ViT-Transformer-based models for audio tagging. Notable differenc
|
|
21 |
- **Demo:** https://huggingface.co/spaces/mispeech/ced-base
|
22 |
|
23 |
## Install
|
24 |
-
|
25 |
```bash
|
26 |
git clone https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
|
27 |
pip install -r requirements.txt
|
@@ -32,20 +31,21 @@ pip install -r requirements.txt
|
|
32 |
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
|
33 |
>>> from ced_model.modeling_ced import CedForAudioClassification
|
34 |
|
35 |
-
>>>
|
36 |
-
>>> feature_extractor = CedFeatureExtractor.from_pretrained(
|
37 |
-
>>> model = CedForAudioClassification.from_pretrained(
|
38 |
|
39 |
>>> import torchaudio
|
40 |
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
|
41 |
-
|
42 |
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
|
|
|
|
43 |
>>> with torch.no_grad():
|
44 |
... logits = model(**inputs).logits
|
45 |
|
46 |
-
>>>
|
47 |
-
>>>
|
48 |
-
>>> model.config.id2label[predicted_class_ids]
|
49 |
'Finger snapping'
|
50 |
```
|
51 |
|
|
|
21 |
- **Demo:** https://huggingface.co/spaces/mispeech/ced-base
|
22 |
|
23 |
## Install
|
|
|
24 |
```bash
|
25 |
git clone https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
|
26 |
pip install -r requirements.txt
|
|
|
31 |
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
|
32 |
>>> from ced_model.modeling_ced import CedForAudioClassification
|
33 |
|
34 |
+
>>> model_name = "mispeech/ced-tiny"
|
35 |
+
>>> feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
|
36 |
+
>>> model = CedForAudioClassification.from_pretrained(model_name)
|
37 |
|
38 |
>>> import torchaudio
|
39 |
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
|
40 |
+
>>> assert sampling_rate == 16000
|
41 |
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
42 |
+
|
43 |
+
>>> import torch
|
44 |
>>> with torch.no_grad():
|
45 |
... logits = model(**inputs).logits
|
46 |
|
47 |
+
>>> predicted_class_id = torch.argmax(logits, dim=-1).item()
|
48 |
+
>>> model.config.id2label[predicted_class_id]
|
|
|
49 |
'Finger snapping'
|
50 |
```
|
51 |
|