wanchichen commited on
Commit
f1f16d2
1 Parent(s): 1db6e2f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -4091,10 +4091,10 @@ xeus_model, xeus_train_args = SSLTask.build_model_from_file(
4091
 
4092
  wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
4093
  wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
4094
- wavs = pad_sequence([wavs], batch_first=True).to(device)
4095
 
4096
  # we recommend use_mask=True during fine-tuning
4097
- feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
4098
  ```
4099
 
4100
  With Flash Attention:
 
4091
 
4092
  wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
4093
  wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
4094
+ wavs = pad_sequence(torch.Tensor([wavs]), batch_first=True).to(device)
4095
 
4096
  # we recommend use_mask=True during fine-tuning
4097
+ feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer -> batch_size x seq_len x hdim
4098
  ```
4099
 
4100
  With Flash Attention: