import os import torch import argparse import numpy as np from transformers import AutoFeatureExtractor, WhisperModel import torchaudio import torch.nn as nn import torch.nn.functional as F import speechbrain import librosa from subprocess import CalledProcessError, run #openai whispers load audio SAMPLE_RATE=16000 def denorm(input_x): input_x = input_x*(5-0) + 0 return input_x def load_audio(file: str, sr: int = SAMPLE_RATE): """ Open an audio file and read as mono waveform, resampling as necessary Parameters ---------- file: str The audio file to open sr: int The sample rate to resample the audio if necessary Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ # This launches a subprocess to decode audio while down-mixing # and resampling as necessary. Requires the ffmpeg CLI in PATH. # fmt: off cmd = [ "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" ] # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 class MosPredictor(nn.Module): def __init__(self): super().__init__() self.mean_net_conv = nn.Sequential( nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = (3,3), padding = (1,1), stride=(1,3)), nn.Dropout(0.3), nn.BatchNorm2d(16), nn.ReLU(), nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), padding = (1,1), stride=(1,3)), nn.Dropout(0.3), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = (1,1), stride=(1,3)), nn.Dropout(0.3), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), padding = (1,1)), nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3,3), padding = (1,1), stride=(1,3)), nn.Dropout(0.3), nn.BatchNorm2d(128), nn.ReLU()) self.relu_ = nn.ReLU() self.sigmoid_ = nn.Sigmoid() self.ssl_features = 1280 self.dim_layer = nn.Linear(self.ssl_features, 512) self.mean_net_rnn = nn.LSTM(input_size = 512, hidden_size = 128, num_layers = 1, batch_first = True, bidirectional = True) self.mean_net_dnn = nn.Sequential( nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), ) self.sinc = speechbrain.nnet.CNN.SincConv(in_channels=1, out_channels=257, kernel_size=251, stride=256, sample_rate=16000) self.att_output_layer_quality = nn.MultiheadAttention(128, num_heads=8) self.output_layer_quality = nn.Linear(128, 1) self.qualaverage_score = nn.AdaptiveAvgPool1d(1) self.att_output_layer_intell = nn.MultiheadAttention(128, num_heads=8) self.output_layer_intell = nn.Linear(128, 1) self.intellaverage_score = nn.AdaptiveAvgPool1d(1) self.att_output_layer_stoi= nn.MultiheadAttention(128, num_heads=8) self.output_layer_stoi = nn.Linear(128, 1) self.stoiaverage_score = nn.AdaptiveAvgPool1d(1) def new_method(self): self.sin_conv def forward(self, wav, lps, whisper): #SSL Features wav_ = wav.squeeze(1) ## [batches, audio_len] ssl_feat_red = self.dim_layer(whisper.squeeze(1)) ssl_feat_red = self.relu_(ssl_feat_red) #PS Features sinc_feat=self.sinc(wav.squeeze(1)) unsq_sinc = torch.unsqueeze(sinc_feat, axis=1) concat_lps_sinc = torch.cat((lps,unsq_sinc), axis=2) cnn_out = self.mean_net_conv(concat_lps_sinc) batch = concat_lps_sinc.shape[0] time = concat_lps_sinc.shape[2] re_cnn = cnn_out.view((batch, time, 512)) concat_feat = torch.cat((re_cnn,ssl_feat_red), axis=1) out_lstm, (h, c) = self.mean_net_rnn(concat_feat) out_dense = self.mean_net_dnn(out_lstm) # (batch, seq, 1) quality_att, _ = self.att_output_layer_quality (out_dense, out_dense, out_dense) frame_quality = self.output_layer_quality(quality_att) frame_quality = self.sigmoid_(frame_quality) quality_utt = self.qualaverage_score(frame_quality.permute(0,2,1)) int_att, _ = self.att_output_layer_intell (out_dense, out_dense, out_dense) frame_int = self.output_layer_intell(int_att) frame_int = self.sigmoid_(frame_int) int_utt = self.intellaverage_score(frame_int.permute(0,2,1)) return quality_utt.squeeze(1), int_utt.squeeze(1), frame_quality.squeeze(2), frame_int.squeeze(2)