import pathlib import sys import os directory = pathlib.Path(os.getcwd()) sys.path.append(str(directory)) import torch import numpy as np from wav_evaluation.models.CLAPWrapper import CLAPWrapper import torch.nn.functional as F import argparse import csv from tqdm import tqdm from torch.utils.data import Dataset,DataLoader import pandas as pd import json def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--csv_path',type=str,default='') parser.add_argument('--wavsdir',type=str) parser.add_argument('--mean',type=bool,default=True) parser.add_argument('--ckpt_path', default="useful_ckpts/CLAP") args = parser.parse_args() return args def add_audio_path(df): df['audio_path'] = df.apply(lambda x:x['mel_path'].replace('.npy','.wav'),axis=1) return df def build_csv_from_wavs(root_dir): with open('ldm/data/audiocaps_fn2cap.json','r') as f: fn2cap = json.load(f) wavs_root = os.path.join(root_dir,'fake_class') wavfiles = os.listdir(wavs_root) wavfiles = list(filter(lambda x:x.endswith('.wav') and x[-6:-4]!='gt',wavfiles)) print(len(wavfiles)) dict_list = [] for wavfile in wavfiles: tmpd = {'audio_path':os.path.join(wavs_root,wavfile)} key = wavfile.rsplit('_sample')[0] + wavfile.rsplit('_sample')[1][:2] tmpd['caption'] = fn2cap[key] dict_list.append(tmpd) df = pd.DataFrame.from_dict(dict_list) csv_path = f'{os.path.basename(root_dir)}.csv' csv_path = os.path.join(wavs_root,csv_path) df.to_csv(csv_path,sep='\t',index=False) return csv_path def cal_score_by_csv(csv_path,clap_model): # audiocaps val的gt音频的clap_score计算为0.479077 df = pd.read_csv(csv_path,sep='\t') clap_scores = [] if not ('audio_path' in df.columns): df = add_audio_path(df) caption_list,audio_list = [],[] with torch.no_grad(): for idx,t in enumerate(tqdm(df.itertuples()),start=1): # text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding # audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) # score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) # clap_scores.append(score.cpu().numpy()) caption_list.append(getattr(t,'caption')) audio_list.append(getattr(t,'audio_path')) if idx % 20 == 0: text_embeddings = clap_model.get_text_embeddings(caption_list)# 经过了norm的embedding audio_embeddings = clap_model.get_audio_embeddings(audio_list, resample=True)# 这一步比较耗时,读取音频并重采样到44100 score_mat = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) score = score_mat.diagonal() clap_scores.append(score.cpu().numpy()) # print(caption_list) # print(audio_list) # print(score) audio_list = [] caption_list = [] # print("mean:",np.mean(np.array(clap_scores).flatten())) return np.mean(np.array(clap_scores).flatten()) def add_clap_score_to_csv(csv_path,clap_model): df = pd.read_csv(csv_path,sep='\t') clap_scores_dict = {} with torch.no_grad(): for idx,t in enumerate(tqdm(df.itertuples()),start=1): text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) clap_scores_dict[idx] = score.cpu().numpy() df['clap_score'] = clap_scores_dict df.to_csv(csv_path[:-4]+'_clap.csv',sep='\t',index=False) if __name__ == '__main__': args = parse_args() if args.csv_path: csv_path = args.csv_path else: csv_path = os.path.join(args.wavsdir,'fake_class/result.csv') if not os.path.exists(csv_path): print("result csv not exist,build for it") csv_path = build_csv_from_wavs(args.wavsdir) clap_model = CLAPWrapper(os.path.join(args.ckpt_path,'CLAP_weights_2022.pth'),os.path.join(args.ckpt_path,'config.yml'), use_cuda=True) clap_score = cal_score_by_csv(csv_path,clap_model) out = args.wavsdir if args.wavsdir else args.csv_path print(f"clap_score for {out} is:{clap_score}") print(f"clap_score for {out} is:{clap_score}") print(f"clap_score for {out} is:{clap_score}") # os.remove(csv_path)