File size: 2,390 Bytes
73baeae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sys 
sys.path.append('..')
sys.path.append('.')

from aac_metrics import evaluate
from inference import AudioBartInference
from tqdm import tqdm
import os
import pandas as pd 

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
metric_list = ["bleu_1", "bleu_4", "rouge_l", "meteor", "spider_fl"]

if __name__ == "__main__":
    dataset = "AudioCaps"
    # dataset = "clotho"
    ckpt_path = "/data/jyk/aac_results/bart_base/audiocaps_35e5_2000/checkpoints/epoch_8"

    # ckpt_path = "/data/jyk/aac_results/masking/linear_scalinEg/checkpoints/epoch_14"
    max_encodec_length = 1022
    infer_module = AudioBartInference(ckpt_path, max_encodec_length)
    from_encodec = True
    csv_path = f"/workspace/audiobart/csv/{dataset}/test.csv"
    base_path = f"/data/jyk/aac_dataset/{dataset}/encodec_16"
    clap_name = "clap_audio_fused"
    df = pd.read_csv(csv_path)

    generation_config = {
        "_from_model_config": True,
        "bos_token_id": 0,
        "decoder_start_token_id": 2,
        "early_stopping": True,
        "eos_token_id": 2,
        "forced_bos_token_id": 0,
        "forced_eos_token_id": 2,
        "no_repeat_ngram_size": 3,
        "num_beams": 4,
        "pad_token_id": 1,
        "max_length": 50
    }

    print(f"> Making Predictions for model {ckpt_path}...")
    predictions = []
    references = []
    for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="BLUE"):
        if not from_encodec:
            wav_path = df.loc[idx]['file_name']
        else:
            wav_path = df.loc[idx]['file_path']
        wav_path = os.path.join(base_path,wav_path)
        if not os.path.exists(wav_path):
            pass
        
        if not from_encodec:
            prediction = infer_module.infer(wav_path)
        else:
            prediction = infer_module.infer_from_encodec(wav_path, clap_name, generation_config)

        predictions.append(prediction[0])
        reference = [df.loc[idx]['caption_1'],df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5'] ]
        references.append(reference)

    print("> Evaluating predictions...")
    result = evaluate(predictions, references, metrics=metric_list)
    result = {k: round(v.item(),4) for k, v in result[0].items()}
    keys = list(result.keys())
    for key in keys:
        if "fluerr" in key:
            del result[key]
    print(result)