Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append('..') | |
sys.path.append('.') | |
from aac_metrics import evaluate | |
from inference import AudioBartInference | |
from tqdm import tqdm | |
import os | |
import pandas as pd | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
metric_list = ["bleu_1", "bleu_4", "rouge_l", "meteor", "spider_fl"] | |
if __name__ == "__main__": | |
dataset = "AudioCaps" | |
# dataset = "clotho" | |
ckpt_path = "/data/jyk/aac_results/bart_base/audiocaps_35e5_2000/checkpoints/epoch_8" | |
# ckpt_path = "/data/jyk/aac_results/masking/linear_scalinEg/checkpoints/epoch_14" | |
max_encodec_length = 1022 | |
infer_module = AudioBartInference(ckpt_path, max_encodec_length) | |
from_encodec = True | |
csv_path = f"/workspace/audiobart/csv/{dataset}/test.csv" | |
base_path = f"/data/jyk/aac_dataset/{dataset}/encodec_16" | |
clap_name = "clap_audio_fused" | |
df = pd.read_csv(csv_path) | |
generation_config = { | |
"_from_model_config": True, | |
"bos_token_id": 0, | |
"decoder_start_token_id": 2, | |
"early_stopping": True, | |
"eos_token_id": 2, | |
"forced_bos_token_id": 0, | |
"forced_eos_token_id": 2, | |
"no_repeat_ngram_size": 3, | |
"num_beams": 4, | |
"pad_token_id": 1, | |
"max_length": 50 | |
} | |
print(f"> Making Predictions for model {ckpt_path}...") | |
predictions = [] | |
references = [] | |
for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="BLUE"): | |
if not from_encodec: | |
wav_path = df.loc[idx]['file_name'] | |
else: | |
wav_path = df.loc[idx]['file_path'] | |
wav_path = os.path.join(base_path,wav_path) | |
if not os.path.exists(wav_path): | |
pass | |
if not from_encodec: | |
prediction = infer_module.infer(wav_path) | |
else: | |
prediction = infer_module.infer_from_encodec(wav_path, clap_name, generation_config) | |
predictions.append(prediction[0]) | |
reference = [df.loc[idx]['caption_1'],df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5'] ] | |
references.append(reference) | |
print("> Evaluating predictions...") | |
result = evaluate(predictions, references, metrics=metric_list) | |
result = {k: round(v.item(),4) for k, v in result[0].items()} | |
keys = list(result.keys()) | |
for key in keys: | |
if "fluerr" in key: | |
del result[key] | |
print(result) |