m3hrdadfi commited on
Commit
94228a6
1 Parent(s): 15a5eff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -4
README.md CHANGED
@@ -54,12 +54,11 @@ def speech_file_to_array_fn(path, sampling_rate):
54
 
55
  def predict(path, sampling_rate):
56
  speech = speech_file_to_array_fn(path, sampling_rate)
57
- features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
58
-
59
- input_values = features.input_values.to(device)
60
 
61
  with torch.no_grad():
62
- logits = model(input_values).logits
63
 
64
  scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
65
  outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
 
54
 
55
  def predict(path, sampling_rate):
56
  speech = speech_file_to_array_fn(path, sampling_rate)
57
+ inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
58
+ inputs = {key: inputs[key].to(device) for key in inputs}
 
59
 
60
  with torch.no_grad():
61
+ logits = model(**inputs).logits
62
 
63
  scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
64
  outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]