import spaces import os import torch from transformers import AutoFeatureExtractor, WhisperModel, AutoModelForSpeechSeq2Seq import numpy as np import torchaudio import librosa import gradio as gr from modules import load_audio, MosPredictor, denorm mos_checkpoint = "ckpt_mosa_net_plus" print('Loading MOSANET+ checkpoint...') device = torch.device("cpu") #torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = MosPredictor().to(device) model.eval() model.load_state_dict(torch.load(mos_checkpoint, map_location=device)) print('Loading Whisper checkpoint...') feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3") #model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3") model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa") #model_asli = model_asli.to(device) @spaces.GPU def predict_mos(wavefile:str): device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) if device != model_asli.device: model_asli.to(device) print('Starting prediction...') # STFT wav = torchaudio.load(wavefile)[0] lps = torch.from_numpy(np.expand_dims(np.abs(librosa.stft(wav[0].detach().numpy(), n_fft = 512, hop_length=256,win_length=512)).T, axis=0)) lps = lps.unsqueeze(1) # Whisper Feature audio = load_audio(wavefile) inputs = feature_extractor(audio, return_tensors="pt") input_features = inputs.input_features input_features = input_features.to(device) with torch.no_grad(): decoder_input_ids = torch.tensor([[1, 1]]) * model_asli.config.decoder_start_token_id decoder_input_ids = decoder_input_ids.to(device) last_hidden_state = model_asli(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state whisper_feat = last_hidden_state print('Model features shapes...') print(whisper_feat.shape) print(wav.shape) print(lps.shape) # prediction wav = wav.to(device) lps = lps.to(device) Quality_1, Intell_1, frame1, frame2 = model(wav ,lps, whisper_feat) quality_pred = Quality_1.cpu().detach().numpy()[0] intell_pred = Intell_1.cpu().detach().numpy()[0] print("predictions") qa_text = f"Quality: {denorm(quality_pred)[0]:.2f} Inteligibility: {intell_pred[0]:.2f}" print(qa_text) return qa_text title = """