wetdog commited on
Commit
bb22e1e
1 Parent(s): 2fa71fe

device fix

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -13,8 +13,8 @@ from modules import load_audio, MosPredictor, denorm
13
  mos_checkpoint = "ckpt_mosa_net_plus"
14
 
15
  print('Loading MOSANET+ checkpoint...')
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
  model = MosPredictor().to(device)
19
  model.eval()
20
  model.load_state_dict(torch.load(mos_checkpoint, map_location=device))
@@ -22,8 +22,8 @@ model.load_state_dict(torch.load(mos_checkpoint, map_location=device))
22
  print('Loading Whisper checkpoint...')
23
  feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
24
  #model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3")
25
- model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
26
- model_asli = model_asli.to(device)
27
 
28
  @spaces.GPU
29
  def predict_mos(wavefile:str):
 
13
  mos_checkpoint = "ckpt_mosa_net_plus"
14
 
15
  print('Loading MOSANET+ checkpoint...')
16
+ device = torch.device("cpu")
17
+ #torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
  model = MosPredictor().to(device)
19
  model.eval()
20
  model.load_state_dict(torch.load(mos_checkpoint, map_location=device))
 
22
  print('Loading Whisper checkpoint...')
23
  feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
24
  #model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3")
25
+ model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
26
+ #model_asli = model_asli.to(device)
27
 
28
  @spaces.GPU
29
  def predict_mos(wavefile:str):