maskgct / evaluation /metrics /similarity /speaker_similarity.py
Hecheng0625's picture
Upload 167 files
8c92a11 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
from tqdm import tqdm
import librosa
from evaluation.metrics.similarity.models.RawNetModel import RawNet3
from evaluation.metrics.similarity.models.RawNetBasicBlock import Bottle2neck
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector
from resemblyzer import VoiceEncoder, preprocess_wav
def extract_rawnet_speaker_embd(
model, fn: str, n_samples: int, n_segments: int = 10, gpu: bool = False
) -> np.ndarray:
audio, sample_rate = sf.read(fn)
if len(audio.shape) > 1:
raise ValueError(
f"RawNet3 supports mono input only. Input data has a shape of {audio.shape}."
)
if sample_rate != 16000:
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
if len(audio) < n_samples:
shortage = n_samples - len(audio) + 1
audio = np.pad(audio, (0, shortage), "wrap")
audios = []
startframe = np.linspace(0, len(audio) - n_samples, num=n_segments)
for asf in startframe:
audios.append(audio[int(asf) : int(asf) + n_samples])
audios = torch.from_numpy(np.stack(audios, axis=0).astype(np.float32))
if gpu:
audios = audios.to("cuda")
with torch.no_grad():
output = model(audios)
return output
def extract_similarity(path_ref, path_deg, **kwargs):
kwargs = kwargs["kwargs"]
model_name = kwargs["model_name"]
ref_embds = []
deg_embds = []
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
if model_name == "rawnet":
model = RawNet3(
Bottle2neck,
model_scale=8,
context=True,
summed=True,
encoder_type="ECA",
nOut=256,
out_bn=False,
sinc_stride=10,
log_sinc=True,
norm_sinc="mean",
grad_mult=1,
)
model.load_state_dict(
torch.load(
"pretrained/rawnet3/model.pt",
map_location=lambda storage, loc: storage,
)["model"]
)
model.eval()
model = model.to(device)
for file in tqdm(os.listdir(path_ref)):
output = extract_rawnet_speaker_embd(
model,
fn=os.path.join(path_ref, file),
n_samples=48000,
n_segments=10,
gpu=torch.cuda.is_available(),
).mean(0)
ref_embds.append(output)
for file in tqdm(os.listdir(path_deg)):
output = extract_rawnet_speaker_embd(
model,
fn=os.path.join(path_deg, file),
n_samples=48000,
n_segments=10,
gpu=torch.cuda.is_available(),
).mean(0)
deg_embds.append(output)
elif model_name == "wavlm":
try:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus-sv"
)
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv")
except:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"pretrained/wavlm", sampling_rate=16000
)
model = WavLMForXVector.from_pretrained("pretrained/wavlm")
model = model.to(device)
for file in tqdm(os.listdir(path_ref)):
wav_path = os.path.join(path_ref, file)
wav, _ = librosa.load(wav_path, sr=16000)
inputs = feature_extractor(
[wav], padding=True, return_tensors="pt", sampling_rate=16000
)
if torch.cuda.is_available():
for key in inputs.keys():
inputs[key] = inputs[key].to(device)
with torch.no_grad():
embds = model(**inputs).embeddings
embds = embds
ref_embds.append(embds[0])
for file in tqdm(os.listdir(path_deg)):
wav_path = os.path.join(path_deg, file)
wav, _ = librosa.load(wav_path, sr=16000)
inputs = feature_extractor(
[wav], padding=True, return_tensors="pt", sampling_rate=16000
)
if torch.cuda.is_available():
for key in inputs.keys():
inputs[key] = inputs[key].to(device)
with torch.no_grad():
embds = model(**inputs).embeddings
embds = embds
deg_embds.append(embds[0])
elif model_name == "resemblyzer":
encoder = VoiceEncoder().to(device)
for file in tqdm(os.listdir(path_ref)):
wav_path = os.path.join(path_ref, file)
wav = preprocess_wav(wav_path)
output = encoder.embed_utterance(wav)
ref_embds.append(torch.from_numpy(output).to(device))
for file in tqdm(os.listdir(path_deg)):
wav_path = os.path.join(path_deg, file)
wav = preprocess_wav(wav_path)
output = encoder.embed_utterance(wav)
deg_embds.append(torch.from_numpy(output).to(device))
similarity_mode = kwargs["similarity_mode"]
scores = []
if similarity_mode == "pairwith":
for ref_embd, deg_embd in zip(ref_embds, deg_embds):
scores.append(
F.cosine_similarity(ref_embd, deg_embd, dim=-1).detach().cpu().numpy()
)
elif similarity_mode == "overall":
for ref_embd in ref_embds:
for deg_embd in deg_embds:
scores.append(
F.cosine_similarity(ref_embd, deg_embd, dim=-1)
.detach()
.cpu()
.numpy()
)
return np.mean(scores)