Make-An-Audio-3 / wav_evaluation /cal_clap_score.py
3v324v23's picture
Add code
a84a65c
import pathlib
import sys
import os
directory = pathlib.Path(os.getcwd())
sys.path.append(str(directory))
import torch
import numpy as np
from wav_evaluation.models.CLAPWrapper import CLAPWrapper
import torch.nn.functional as F
import argparse
import csv
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import json
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--csv_path',type=str,default='')
parser.add_argument('--wavsdir',type=str)
parser.add_argument('--mean',type=bool,default=True)
parser.add_argument('--ckpt_path', default="useful_ckpts/CLAP")
args = parser.parse_args()
return args
def add_audio_path(df):
df['audio_path'] = df.apply(lambda x:x['mel_path'].replace('.npy','.wav'),axis=1)
return df
def build_csv_from_wavs(root_dir):
with open('ldm/data/audiocaps_fn2cap.json','r') as f:
fn2cap = json.load(f)
wavs_root = os.path.join(root_dir,'fake_class')
wavfiles = os.listdir(wavs_root)
wavfiles = list(filter(lambda x:x.endswith('.wav') and x[-6:-4]!='gt',wavfiles))
print(len(wavfiles))
dict_list = []
for wavfile in wavfiles:
tmpd = {'audio_path':os.path.join(wavs_root,wavfile)}
key = wavfile.rsplit('_sample')[0] + wavfile.rsplit('_sample')[1][:2]
tmpd['caption'] = fn2cap[key]
dict_list.append(tmpd)
df = pd.DataFrame.from_dict(dict_list)
csv_path = f'{os.path.basename(root_dir)}.csv'
csv_path = os.path.join(wavs_root,csv_path)
df.to_csv(csv_path,sep='\t',index=False)
return csv_path
def cal_score_by_csv(csv_path,clap_model): # audiocaps val的gt音频的clap_score计算为0.479077
df = pd.read_csv(csv_path,sep='\t')
clap_scores = []
if not ('audio_path' in df.columns):
df = add_audio_path(df)
caption_list,audio_list = [],[]
with torch.no_grad():
for idx,t in enumerate(tqdm(df.itertuples()),start=1):
# text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding
# audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True)
# score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False)
# clap_scores.append(score.cpu().numpy())
caption_list.append(getattr(t,'caption'))
audio_list.append(getattr(t,'audio_path'))
if idx % 20 == 0:
text_embeddings = clap_model.get_text_embeddings(caption_list)# 经过了norm的embedding
audio_embeddings = clap_model.get_audio_embeddings(audio_list, resample=True)# 这一步比较耗时,读取音频并重采样到44100
score_mat = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False)
score = score_mat.diagonal()
clap_scores.append(score.cpu().numpy())
# print(caption_list)
# print(audio_list)
# print(score)
audio_list = []
caption_list = []
# print("mean:",np.mean(np.array(clap_scores).flatten()))
return np.mean(np.array(clap_scores).flatten())
def add_clap_score_to_csv(csv_path,clap_model):
df = pd.read_csv(csv_path,sep='\t')
clap_scores_dict = {}
with torch.no_grad():
for idx,t in enumerate(tqdm(df.itertuples()),start=1):
text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding
audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True)
score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False)
clap_scores_dict[idx] = score.cpu().numpy()
df['clap_score'] = clap_scores_dict
df.to_csv(csv_path[:-4]+'_clap.csv',sep='\t',index=False)
if __name__ == '__main__':
args = parse_args()
if args.csv_path:
csv_path = args.csv_path
else:
csv_path = os.path.join(args.wavsdir,'fake_class/result.csv')
if not os.path.exists(csv_path):
print("result csv not exist,build for it")
csv_path = build_csv_from_wavs(args.wavsdir)
clap_model = CLAPWrapper(os.path.join(args.ckpt_path,'CLAP_weights_2022.pth'),os.path.join(args.ckpt_path,'config.yml'), use_cuda=True)
clap_score = cal_score_by_csv(csv_path,clap_model)
out = args.wavsdir if args.wavsdir else args.csv_path
print(f"clap_score for {out} is:{clap_score}")
print(f"clap_score for {out} is:{clap_score}")
print(f"clap_score for {out} is:{clap_score}")
# os.remove(csv_path)