File size: 7,295 Bytes
2a41a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import argparse
from glob import glob
import prettytable as pt

from evaluation.evaluate import evaluator
from config import Config


config = Config()


def do_eval(args):
    # evaluation for whole dataset
    # dataset first in evaluation
    for _data_name in args.data_lst.split('+'):
        pred_data_dir =  sorted(glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)))
        if not pred_data_dir:
            print('Skip dataset {}.'.format(_data_name))
            continue
        gt_src = os.path.join(args.gt_root, _data_name)
        gt_paths = sorted(glob(os.path.join(gt_src, 'gt', '*')))
        print('#' * 20, _data_name, '#' * 20)
        filename = os.path.join(args.save_dir, '{}_eval.txt'.format(_data_name))
        tb = pt.PrettyTable()
        tb.vertical_char = '&'
        if config.task == 'DIS5K':
            tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm"]
        elif config.task == 'COD':
            tb.field_names = ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE"]
        elif config.task == 'HRSOD':
            tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
        elif config.task == 'DIS5K+HRSOD+HRS10K':
            tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm"]
        elif config.task == 'P3M-10k':
            tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
        else:
            tb.field_names = ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
        for _model_name in args.model_lst[:]:
            print('\t', 'Evaluating model: {}...'.format(_model_name))
            pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, _model_name)).replace('/gt/', '/') for p in gt_paths]
            # print(pred_paths[:1], gt_paths[:1])
            em, sm, fm, mae, wfm, hce = evaluator(
                gt_paths=gt_paths,
                pred_paths=pred_paths,
                metrics=args.metrics.split('+'),
                verbose=config.verbose_eval
            )
            if config.task == 'DIS5K':
                scores = [
                    fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), 
                    em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
                ]
            elif config.task == 'COD':
                scores = [
                    sm.round(3), wfm.round(3), fm['curve'].mean().round(3), em['curve'].mean().round(3), em['curve'].max().round(3), mae.round(3),
                    fm['curve'].max().round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
                ]
            elif config.task == 'HRSOD':
                scores = [
                    sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
                    em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
                ]
            elif config.task == 'DIS5K+HRSOD+HRS10K':
                scores = [
                    fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), 
                    em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
                ]
            elif config.task == 'P3M-10k':
                scores = [
                    sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
                    em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
                ]
            else:
                scores = [
                    sm.round(3), mae.round(3), em['curve'].max().round(3), em['curve'].mean().round(3),
                    fm['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3),
                    em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
                ]
            
            for idx_score, score in enumerate(scores):
                scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1  else format(score, '<4')
            records = [_data_name, _model_name] + scores
            tb.add_row(records)
            # Write results after every check.
            with open(filename, 'w+') as file_to_write:
                file_to_write.write(str(tb)+'\n')
        print(tb)


if __name__ == '__main__':
    # set parameters
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--gt_root', type=str, help='ground-truth root',
        default=os.path.join(config.data_root_dir, config.task))
    parser.add_argument(
        '--pred_root', type=str, help='prediction root',
        default='./e_preds')
    parser.add_argument(
        '--data_lst', type=str, help='test dataset',
        default={
            'DIS5K': '+'.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:]),
            'COD': '+'.join(['TE-COD10K', 'NC4K', 'TE-CAMO', 'CHAMELEON'][:]),
            'HRSOD': '+'.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'TE-DUTS', 'DUT-OMRON'][:]),
            'DIS5K+HRSOD+HRS10K': '+'.join(['DIS-VD'][:]),
            'P3M-10k': '+'.join(['TE-P3M-500-P', 'TE-P3M-500-NP'][:]),
        }[config.task])
    parser.add_argument(
        '--save_dir', type=str, help='candidate competitors',
        default='e_results')
    parser.add_argument(
        '--check_integrity', type=bool, help='whether to check the file integrity',
        default=False)
    parser.add_argument(
        '--metrics', type=str, help='candidate competitors',
        default='+'.join(['S', 'MAE', 'E', 'F', 'WF', 'HCE'][:100 if 'DIS5K' in config.task else -1]))
    args = parser.parse_args()

    os.makedirs(args.save_dir, exist_ok=True)
    try:
        args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1]), reverse=True) if int(m.split('epoch_')[-1]) % 1 == 0]
    except:
        args.model_lst = [m for m in sorted(os.listdir(args.pred_root))]

    # check the integrity of each candidates
    if args.check_integrity:
        for _data_name in args.data_lst.split('+'):
            for _model_name in args.model_lst:
                gt_pth = os.path.join(args.gt_root, _data_name)
                pred_pth = os.path.join(args.pred_root, _model_name, _data_name)
                if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)):
                    print(len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth))))
                    print('The {} Dataset of {} Model is not matching to the ground-truth'.format(_data_name, _model_name))
    else:
        print('>>> skip check the integrity of each candidates')

    # start engine
    do_eval(args)