jimbozhang commited on
Commit
0f8bf48
1 Parent(s): 945a279

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -8
README.md CHANGED
@@ -21,7 +21,6 @@ CED are simple ViT-Transformer-based models for audio tagging. Notable differenc
21
  - **Demo:** https://huggingface.co/spaces/mispeech/ced-base
22
 
23
  ## Install
24
-
25
  ```bash
26
  git clone https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
27
  pip install -r requirements.txt
@@ -32,20 +31,21 @@ pip install -r requirements.txt
32
  >>> from ced_model.feature_extraction_ced import CedFeatureExtractor
33
  >>> from ced_model.modeling_ced import CedForAudioClassification
34
 
35
- >>> model_id = "mispeech/ced-tiny"
36
- >>> feature_extractor = CedFeatureExtractor.from_pretrained(model_id)
37
- >>> model = CedForAudioClassification.from_pretrained(model_id)
38
 
39
  >>> import torchaudio
40
  >>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
41
-
42
  >>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
 
 
43
  >>> with torch.no_grad():
44
  ... logits = model(**inputs).logits
45
 
46
- >>> import torch
47
- >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
48
- >>> model.config.id2label[predicted_class_ids]
49
  'Finger snapping'
50
  ```
51
 
 
21
  - **Demo:** https://huggingface.co/spaces/mispeech/ced-base
22
 
23
  ## Install
 
24
  ```bash
25
  git clone https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
26
  pip install -r requirements.txt
 
31
  >>> from ced_model.feature_extraction_ced import CedFeatureExtractor
32
  >>> from ced_model.modeling_ced import CedForAudioClassification
33
 
34
+ >>> model_name = "mispeech/ced-tiny"
35
+ >>> feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
36
+ >>> model = CedForAudioClassification.from_pretrained(model_name)
37
 
38
  >>> import torchaudio
39
  >>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
40
+ >>> assert sampling_rate == 16000
41
  >>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
42
+
43
+ >>> import torch
44
  >>> with torch.no_grad():
45
  ... logits = model(**inputs).logits
46
 
47
+ >>> predicted_class_id = torch.argmax(logits, dim=-1).item()
48
+ >>> model.config.id2label[predicted_class_id]
 
49
  'Finger snapping'
50
  ```
51