mskov commited on
Commit
b4725a8
β€’
1 Parent(s): 42fc1b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -4,7 +4,7 @@ os.system("pip install transformers==4.27.0")
4
  os.system("pip install torch")
5
  os.system("pip install openai")
6
  os.system("pip install accelerate")
7
- from transformers import pipeline, WhisperModel, WhisperTokenizer, WhisperFeatureExtractor, AutoFeatureExtractor, AutoProcessor
8
  os.system("pip install evaluate")
9
  #import evaluate
10
  #os.system("pip install evaluate[evaluator]")
@@ -27,17 +27,10 @@ huggingface_token = os.environ["huggingface_token"]
27
  model = WhisperModel.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
28
  feature_extractor = AutoFeatureExtractor.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
29
 
 
 
 
 
30
 
31
- ds = load_dataset("mskov/miso_test", split="test")
32
- print("testing testing")
33
- ds = ds.cast_column("audio", Audio(sampling_rate=16000))
34
-
35
-
36
- inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
37
- print("check check")
38
- print(inputs)
39
- input_features = inputs.input_features
40
- decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
41
- last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
42
- list(last_hidden_state.shape)
43
- print(list(last_hidden_state.shape))
 
4
  os.system("pip install torch")
5
  os.system("pip install openai")
6
  os.system("pip install accelerate")
7
+ from transformers import pipeline, WhisperModel, WhisperTokenizer, WhisperFeatureExtractor, AutoFeatureExtractor, AutoProcessor, WhisperConfig
8
  os.system("pip install evaluate")
9
  #import evaluate
10
  #os.system("pip install evaluate[evaluator]")
 
27
  model = WhisperModel.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
28
  feature_extractor = AutoFeatureExtractor.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
29
 
30
+ model_config = WhisperConfig.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
31
+ model = WhisperModel(config=model_config)
32
+ model.load_state_dict(torch.load("mskov/whisper_miso/checkpoint-4000/pytorch_model.bin"))
33
+ model.eval()
34
 
35
+ dataset = load_dataset("mskov/miso_test", split="test").cast_column("audio", Audio())
36
+ print(dataset)