artificial-styletts2 / visualize_tts_plesantness.py
Dionyssos's picture
fx dir
fda2aa0
raw
history blame
No virus
11.7 kB
# 1. engineer_style_foreign_style_vectors.py # for speed=1 & speed=4
# 2. tts_harvard.py # (call inside SHIFT repo - needs StyleTTS msinference.py)
# 3. visualize_tts_pleasantness.py # figures & audinterface
# Visualises timeseries 11 class for mimic3 human mimic3speed
#
#
# human_770.wav
# mimic3_770.wav
# mimic3_speedup_770.wav
import pandas as pd
import os
import json
import numpy as np
import audonnx
import audb
from pathlib import Path
import transformers
import torch
import audmodel
import audinterface
import matplotlib.pyplot as plt
import audiofile
LABELS = ['arousal', 'dominance', 'valence',
'speech_synthesizer', 'synthetic_singing',
'Angry',
'Sad',
'Happy',
'Surprise',
'Fear',
'Disgust',
'Contempt',
'Neutral'
]
args = transformers.Wav2Vec2Config() #finetuning_task='spef2feat_reg')
args.dev = torch.device('cuda:0')
args.dev2 = torch.device('cuda:0')
# def _softmax(x):
# '''x : (batch, num_class)'''
# x -= x.max(1, keepdims=True) # if all -400 then sum(exp(x)) = 0
# x = np.minimum(-100, x)
# x = np.exp(x)
# x /= x.sum(1, keepdims=True)
# return x
def _softmax(x):
'''x : (batch, num_class)'''
x -= x.max(1, keepdims=True) # if all -400 then sum(exp(x)) = 0
x = np.maximum(-100, x)
x = np.exp(x)
x /= x.sum(1, keepdims=True)
return x
def _sigmoid(x):
'''x : (batch, num_class)'''
return 1 / (1 + np.exp(-x))
# --
# ALL = anger, contempt, disgust, fear, happiness, neutral, no_agreement, other, sadness, surprise
# plot - unplesant emo 7x emo-categories [anger, contempt, disgust, fear, sadness] for artifical/sped-up/natural
# plot - pleasant emo [neutral, happiness, surprise]
# plot - Cubes Natural vs spedup 4x speed
# plot - synthesizer class audioset
# https://arxiv.org/pdf/2407.12229
# https://arxiv.org/pdf/2312.05187
# https://arxiv.org/abs/2407.05407
# https://arxiv.org/pdf/2408.06577
# https://arxiv.org/pdf/2309.07405
# wavs are generated concat and plot time-series?
# for mimic3/mimic3speed/human - concat all 77 and run timeseries with 7s hop 3s
for long_audio in [
'mimic3_english_767_5.wav',
'mimic3_english_4x_767_5.wav',
'human_767_5.wav',
'mimic3_foregin_767_5.wav',
'mimic3_foreign_4x_767_5.wav'
]:
file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl'
if not os.path.exists(file_interface):
print('_______________________________________\nProcessing\n', file_interface, '\n___________')
# CAT MSP
from transformers import AutoModelForAudioClassification
import types
def _infer(self, x):
'''x: (batch, audio-samples-16KHz)'''
x = (x + self.config.mean) / self.config.std # plus
x = self.ssl_model(x, attention_mask=None).last_hidden_state
# pool
h = self.pool_model.sap_linear(x).tanh()
w = torch.matmul(h, self.pool_model.attention)
w = stylesoftmax(1)
mu = (x * w).sum(1)
x = torch.cat(
[
mu,
((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
], 1)
return self.ser_model(x)
teacher_cat = AutoModelForAudioClassification.from_pretrained(
'3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes',
trust_remote_code=True # fun definitions see 3loi/SER-.. repo
).to(args.dev2).eval()
teacher_cat.forward = types.MethodType(_infer, teacher_cat)
# Audioset & ADV
audioset_model = audonnx.load(audmodel.load('17c240ec-1.0.0'), device='cuda:0')
adv_model = audonnx.load(audmodel.load('90398682-2.0.0'), device='cuda:0')
def process_function(x, sampling_rate, idx):
'''run audioset ct, adv
USE onnx teachers
return [synth-speech, synth-singing, 7x, 3x adv] = 11
'''
# x = x[None , :] ASaHSuFDCN
#{0: 'Angry', 1: 'Sad', 2: 'Happy', 3: 'Surprise',
#4: 'Fear', 5: 'Disgust', 6: 'Contempt', 7: 'Neutral'}
#tensor([[0.0015, 0.3651, 0.0593, 0.0315, 0.0600, 0.0125, 0.0319, 0.4382]])
logits_cat = teacher_cat(torch.from_numpy(x).to(args.dev)).cpu().detach().numpy()
# USE ALL CATEGORIES
# --
logits_audioset = audioset_model(x, 16000)['logits_sounds']
logits_audioset = logits_audioset[:, [7, 35]] # speech synthesizer synthetic singing
# --
logits_adv = adv_model(x, 16000)['logits']
cat = np.concatenate([logits_adv,
_sigmoid(logits_audioset),
_softmax(logits_cat)],
1)
print(cat)
return cat #logits_adv #model(signal, sampling_rate)['logits']
interface = audinterface.Feature(
feature_names=LABELS,
process_func=process_function,
# process_func_args={'outputs': 'logits_scene'},
process_func_applies_sliding_window=False,
win_dur=40.0,
hop_dur=10.0,
sampling_rate=16000,
resample=True,
verbose=True,
)
df_pred = interface.process_file(long_audio)
df_pred.to_pickle(file_interface)
else:
print(file_interface, 'FOUND')
# df_pred = pd.read_pickle(file_interface)
# ===============================================================================
# V I S U A L S by loading all 3 pkl - mimic3 - speedup - human pd
#
# ===============================================================================
preds = {}
SHORTEST_PD = 100000 # segments
for long_audio in [
# 'mimic3.wav',
# 'mimic3_speedup.wav',
'human_770.wav', # 'mimic3_all_77.wav', #
'mimic3_770.wav',
'mimic3_speed_770.wav'
]:
file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl'
y = pd.read_pickle(file_interface)
preds[long_audio] = y
SHORTEST_PD = min(SHORTEST_PD, len(y))
# clean indexes for plot
for k,v in preds.items():
p = v[:SHORTEST_PD] # TRuncate extra segments - human is slower than mimic3
# p = pd.read_pickle(student_file)
p.reset_index(inplace= True)
p.drop(columns=['file','start'], inplace=True)
p.set_index('end', inplace=True)
# p = p.filter(scene_classes) #['transport', 'indoor', 'outdoor'])
p.index = p.index.map(mapper = (lambda x: x.total_seconds()))
preds[k] = p
print(p, '\n\n\n\n \n')
# Show plots by 2
fig, ax = plt.subplots(nrows=10, ncols=2, figsize=(24, 24), gridspec_kw={'hspace': 0, 'wspace': .04})
# ADV
time_stamp = preds['human_770.wav'].index.to_numpy()
for j, dim in enumerate(['arousal',
'dominance',
'valence']):
# MIMIC3
ax[j, 0].plot(time_stamp, preds['mimic3_770.wav'][dim],
color=(0,104/255,139/255),
label='mean_1',
linewidth=2)
ax[j, 0].fill_between(time_stamp,
preds['mimic3_770.wav'][dim],
preds['human_770.wav'][dim],
color=(.2,.2,.2),
alpha=0.244)
if j == 0:
ax[j, 0].legend(['StyleTTS2 style mimic3',
'StyleTTS2 style crema-d'],
prop={'size': 10},
# loc='lower right'
)
ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=14)
# TICK
ax[j, 0].set_ylim([1e-7, .9999])
# ax[j, 0].set_yticks([.25, .5,.75])
# ax[j, 0].set_yticklabels(['0.25', '.5', '0.75'])
ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()])
ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]])
# MIMIC3 4x speed
ax[j, 1].plot(time_stamp, preds['mimic3_speed_770.wav'][dim],
color=(0,104/255,139/255),
label='mean_1',
linewidth=2)
ax[j, 1].fill_between(time_stamp,
preds['mimic3_speed_770.wav'][dim],
preds['human_770.wav'][dim],
color=(.2,.2,.2),
alpha=0.244)
if j == 0:
ax[j, 1].legend(['StyleTTS2 style mimic3 4x speed',
'StyleTTS2 style crema-d'],
prop={'size': 10},
# loc='lower right'
)
ax[j, 1].set_xlabel('767 Harvard Sentences (seconds)')
# TICK
ax[j, 1].set_ylim([1e-7, .9999])
# ax[j, 1].set_yticklabels(['' for _ in ax[j, 1].get_yticklabels()])
ax[j, 1].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()])
ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]])
ax[j, 0].grid()
ax[j, 1].grid()
# CATEGORIE
time_stamp = preds['human_770.wav'].index.to_numpy()
for j, dim in enumerate(['Angry',
'Sad',
'Happy',
'Surprise',
'Fear',
'Disgust',
'Contempt',
# 'Neutral'
]): # ASaHSuFDCN
j = j + 3 # skip A/D/V suplt
# MIMIC3
ax[j, 0].plot(time_stamp, preds['mimic3_770.wav'][dim],
color=(0,104/255,139/255),
label='mean_1',
linewidth=2)
ax[j, 0].fill_between(time_stamp,
preds['mimic3_770.wav'][dim],
preds['human_770.wav'][dim],
color=(.2,.2,.2),
alpha=0.244)
# ax[j, 0].legend(['StyleTTS2 style mimic3',
# 'StyleTTS2 style crema-d'],
# prop={'size': 10},
# # loc='upper left'
# )
ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=14)
# TICKS
ax[j, 0].set_ylim([1e-7, .9999])
ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]])
ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()])
ax[j, 0].set_xlabel('767 Harvard Sentences (seconds)', fontsize=16, color=(.4,.4,.4))
# MIMIC3 4x speed
ax[j, 1].plot(time_stamp, preds['mimic3_speed_770.wav'][dim],
color=(0,104/255,139/255),
label='mean_1',
linewidth=2)
ax[j, 1].fill_between(time_stamp,
preds['mimic3_speed_770.wav'][dim],
preds['human_770.wav'][dim],
color=(.2,.2,.2),
alpha=0.244)
# ax[j, 1].legend(['StyleTTS2 style mimic3 4x speed',
# 'StyleTTS2 style crema-d'],
# prop={'size': 10},
# # loc='upper left'
# )
ax[j, 1].set_xlabel('767 Harvard Sentences (seconds)', fontsize=16, color=(.4,.4,.4))
ax[j, 1].set_ylim([1e-7, .999])
# ax[j, 1].set_yticklabels(['' for _ in ax[j, 1].get_yticklabels()])
ax[j, 1].set_xticklabels(['' for _ in ax[j, 1].get_xticklabels()])
ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]])
ax[j, 0].grid()
ax[j, 1].grid()
plt.savefig(f'valence_tts.pdf', bbox_inches='tight')
plt.close()