|
|
|
import os |
|
import audresample |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import soundfile |
|
import json |
|
import audb |
|
from transformers import AutoModelForAudioClassification |
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel |
|
import types |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
import pandas as pd |
|
import json |
|
import numpy as np |
|
from pathlib import Path |
|
import transformers |
|
import torch |
|
import audmodel |
|
import audiofile |
|
import jiwer |
|
|
|
|
|
|
|
|
|
|
|
import msinference |
|
import os |
|
from random import shuffle |
|
|
|
config = transformers.Wav2Vec2Config() |
|
config.dev = torch.device('cuda:0') |
|
config.dev2 = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
|
LABELS = ['arousal', 'dominance', 'valence', |
|
'Angry', |
|
'Sad', |
|
'Happy', |
|
'Surprise', |
|
'Fear', |
|
'Disgust', |
|
'Contempt', |
|
'Neutral' |
|
] |
|
|
|
config = transformers.Wav2Vec2Config() |
|
config.dev = torch.device('cuda:0') |
|
config.dev2 = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _infer(self, x): |
|
'''x: (batch, audio-samples-16KHz)''' |
|
x = (x + self.config.mean) / self.config.std |
|
x = self.ssl_model(x, attention_mask=None).last_hidden_state |
|
|
|
h = self.pool_model.sap_linear(x).tanh() |
|
w = torch.matmul(h, self.pool_model.attention) |
|
w = w.softmax(1) |
|
mu = (x * w).sum(1) |
|
x = torch.cat( |
|
[ |
|
mu, |
|
((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt() |
|
], 1) |
|
return self.ser_model(x) |
|
|
|
teacher_cat = AutoModelForAudioClassification.from_pretrained( |
|
'3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes', |
|
trust_remote_code=True |
|
).to(config.dev2).eval() |
|
teacher_cat.forward = types.MethodType(_infer, teacher_cat) |
|
|
|
|
|
|
|
def _prenorm(x, attention_mask=None): |
|
'''mean/var''' |
|
if attention_mask is not None: |
|
N = attention_mask.sum(1, keepdim=True) |
|
x -= x.sum(1, keepdim=True) / N |
|
var = (x * x).sum(1, keepdim=True) / N |
|
|
|
else: |
|
x -= x.mean(1, keepdim=True) |
|
var = (x * x).mean(1, keepdim=True) |
|
return x / torch.sqrt(var + 1e-7) |
|
|
|
from torch import nn |
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model |
|
class RegressionHead(nn.Module): |
|
r"""Classification head.""" |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__() |
|
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.final_dropout) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, features, **kwargs): |
|
|
|
x = features |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
|
|
return x |
|
|
|
|
|
class Dawn(Wav2Vec2PreTrainedModel): |
|
r"""Speech emotion classifier.""" |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.wav2vec2 = Wav2Vec2Model(config) |
|
self.classifier = RegressionHead(config) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_values, |
|
attention_mask=None, |
|
): |
|
x = _prenorm(input_values, attention_mask=attention_mask) |
|
outputs = self.wav2vec2(x, attention_mask=attention_mask) |
|
hidden_states = outputs[0] |
|
hidden_states = torch.mean(hidden_states, dim=1) |
|
logits = self.classifier(hidden_states) |
|
return logits |
|
|
|
|
|
dawn = Dawn.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(config.dev).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_dtype = torch.float16 |
|
model_id = "openai/whisper-large-v3" |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
|
).to(config.dev) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
_pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
max_new_tokens=128, |
|
chunk_length_s=30, |
|
batch_size=16, |
|
return_timestamps=True, |
|
torch_dtype=torch_dtype, |
|
device=config.dev, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_function(x, sampling_rate, idx): |
|
|
|
|
|
|
|
|
|
logits_cat = teacher_cat(torch.from_numpy(x).to(config.dev)).softmax(1) |
|
logits_adv = dawn(torch.from_numpy(x).to(config.dev)) |
|
|
|
out = torch.cat([logits_adv, |
|
logits_cat], |
|
1).cpu().detach().numpy() |
|
|
|
return out[0, :] |
|
|
|
|
|
|
|
def load_speech(split=None): |
|
DB = [ |
|
|
|
|
|
|
|
['emodb', '1.2.0', 'emotion.categories.train.gold_standard', False], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
output_list = [] |
|
for database_name, ver, table, has_timedeltas in DB: |
|
|
|
a = audb.load(database_name, |
|
sampling_rate=16000, |
|
format='wav', |
|
mixdown=True, |
|
version=ver, |
|
cache_root='/cache/audb/') |
|
a = a[table].get() |
|
if has_timedeltas: |
|
print(f'{has_timedeltas=}') |
|
|
|
|
|
|
|
else: |
|
output_list += [f for f in a.index] |
|
return output_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
natural_wav_paths = load_speech() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open('harvard.json', 'r') as f: |
|
harvard_individual_sentences = json.load(f)['sentences'] |
|
|
|
|
|
|
|
synthetic_wav_paths = ['./enslow/' + i for i in |
|
os.listdir('./enslow/')] |
|
synthetic_wav_paths_4x = ['./style_vector_v2/' + i for i in |
|
os.listdir('./style_vector_v2/')] |
|
synthetic_wav_paths_foreign = ['./mimic3_foreign/' + i for i in os.listdir('./mimic3_foreign/') if 'en_U' not in i] |
|
synthetic_wav_paths_foreign_4x = ['./mimic3_foreign_4x/' + i for i in os.listdir('./mimic3_foreign_4x/') if 'en_U' not in i] |
|
|
|
|
|
synthetic_wav_paths_foreign = [i for i in synthetic_wav_paths_foreign if audiofile.duration(i) > 2] |
|
synthetic_wav_paths_foreign_4x = [i for i in synthetic_wav_paths_foreign_4x if audiofile.duration(i) > 2] |
|
synthetic_wav_paths = [i for i in synthetic_wav_paths if audiofile.duration(i) > 2] |
|
synthetic_wav_pathsn_4x = [i for i in synthetic_wav_paths_4x if audiofile.duration(i) > 2] |
|
|
|
shuffle(synthetic_wav_paths_foreign_4x) |
|
shuffle(synthetic_wav_paths_foreign) |
|
shuffle(synthetic_wav_paths) |
|
shuffle(synthetic_wav_paths_4x) |
|
print(len(synthetic_wav_paths_foreign_4x), len(synthetic_wav_paths_foreign), |
|
len(synthetic_wav_paths), len(synthetic_wav_paths_4x)) |
|
|
|
|
|
|
|
for audio_prompt in ['english', |
|
'english_4x', |
|
'human', |
|
'foreign', |
|
'foreign_4x']: |
|
|
|
data = np.zeros((770, len(LABELS)*2 + 2)) |
|
|
|
|
|
|
|
|
|
|
|
OUT_FILE = f'{audio_prompt}_analytic.pkl' |
|
if not os.path.isfile(OUT_FILE): |
|
ix = 0 |
|
for list_of_10 in harvard_individual_sentences[:10004]: |
|
|
|
|
|
for text in list_of_10['sentences']: |
|
if audio_prompt == 'english': |
|
_p = synthetic_wav_paths[ix % len(synthetic_wav_paths)] |
|
|
|
style_vec = msinference.compute_style(_p) |
|
elif audio_prompt == 'english_4x': |
|
_p = synthetic_wav_paths_4x[ix % len(synthetic_wav_paths_4x)] |
|
|
|
style_vec = msinference.compute_style(_p) |
|
elif audio_prompt == 'human': |
|
_p = natural_wav_paths[ix % len(natural_wav_paths)] |
|
|
|
style_vec = msinference.compute_style(_p) |
|
elif audio_prompt == 'foreign': |
|
_p = synthetic_wav_paths_foreign[ix % len(synthetic_wav_paths_foreign)] |
|
|
|
style_vec = msinference.compute_style(_p) |
|
elif audio_prompt == 'foreign_4x': |
|
_p = synthetic_wav_paths_foreign_4x[ix % len(synthetic_wav_paths_foreign_4x)] |
|
|
|
style_vec = msinference.compute_style(_p) |
|
else: |
|
print('unknonw list of style vector') |
|
|
|
x = msinference.inference(text, |
|
style_vec, |
|
alpha=0.3, |
|
beta=0.7, |
|
diffusion_steps=7, |
|
embedding_scale=1) |
|
x = audresample.resample(x, 24000, 16000) |
|
|
|
|
|
_st, fsr = audiofile.read(_p) |
|
_st = audresample.resample(_st, fsr, 16000) |
|
print(_st.shape, x.shape) |
|
|
|
emotion_of_prompt = process_function(_st, 16000, None) |
|
emotion_of_out = process_function(x, 16000, None) |
|
data[ix, :11] = emotion_of_prompt |
|
data[ix, 11:22] = emotion_of_out |
|
|
|
|
|
|
|
transcription_prompt = _pipe(_st[0]) |
|
transcription_styletts2 = _pipe(x[0]) |
|
|
|
print(transcription_prompt, transcription_styletts2) |
|
|
|
data[ix, 22] = jiwer.cer('Sweet dreams are made of this. I travel the world and the seven seas.', |
|
transcription_prompt['text']) |
|
|
|
data[ix, 23] = jiwer.cer(text, |
|
transcription_styletts2['text']) |
|
print(data[ix, :]) |
|
|
|
ix += 1 |
|
|
|
df = pd.DataFrame(data, columns=['prompt-' + i for i in LABELS] + ['styletts2-' + i for i in LABELS] + ['cer-prompt', 'cer-styletts2']) |
|
df.to_pickle(OUT_FILE) |
|
else: |
|
|
|
df = pd.read_pickle(OUT_FILE) |
|
print('\nALREADY EXISTS\n{df}') |
|
|
|
|