File size: 7,295 Bytes
2a41a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import argparse
from glob import glob
import prettytable as pt
from evaluation.evaluate import evaluator
from config import Config
config = Config()
def do_eval(args):
# evaluation for whole dataset
# dataset first in evaluation
for _data_name in args.data_lst.split('+'):
pred_data_dir = sorted(glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)))
if not pred_data_dir:
print('Skip dataset {}.'.format(_data_name))
continue
gt_src = os.path.join(args.gt_root, _data_name)
gt_paths = sorted(glob(os.path.join(gt_src, 'gt', '*')))
print('#' * 20, _data_name, '#' * 20)
filename = os.path.join(args.save_dir, '{}_eval.txt'.format(_data_name))
tb = pt.PrettyTable()
tb.vertical_char = '&'
if config.task == 'DIS5K':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm"]
elif config.task == 'COD':
tb.field_names = ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE"]
elif config.task == 'HRSOD':
tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
elif config.task == 'DIS5K+HRSOD+HRS10K':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm"]
elif config.task == 'P3M-10k':
tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
else:
tb.field_names = ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE"]
for _model_name in args.model_lst[:]:
print('\t', 'Evaluating model: {}...'.format(_model_name))
pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, _model_name)).replace('/gt/', '/') for p in gt_paths]
# print(pred_paths[:1], gt_paths[:1])
em, sm, fm, mae, wfm, hce = evaluator(
gt_paths=gt_paths,
pred_paths=pred_paths,
metrics=args.metrics.split('+'),
verbose=config.verbose_eval
)
if config.task == 'DIS5K':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
]
elif config.task == 'COD':
scores = [
sm.round(3), wfm.round(3), fm['curve'].mean().round(3), em['curve'].mean().round(3), em['curve'].max().round(3), mae.round(3),
fm['curve'].max().round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
elif config.task == 'HRSOD':
scores = [
sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
elif config.task == 'DIS5K+HRSOD+HRS10K':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
]
elif config.task == 'P3M-10k':
scores = [
sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
else:
scores = [
sm.round(3), mae.round(3), em['curve'].max().round(3), em['curve'].mean().round(3),
fm['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3),
em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
]
for idx_score, score in enumerate(scores):
scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4')
records = [_data_name, _model_name] + scores
tb.add_row(records)
# Write results after every check.
with open(filename, 'w+') as file_to_write:
file_to_write.write(str(tb)+'\n')
print(tb)
if __name__ == '__main__':
# set parameters
parser = argparse.ArgumentParser()
parser.add_argument(
'--gt_root', type=str, help='ground-truth root',
default=os.path.join(config.data_root_dir, config.task))
parser.add_argument(
'--pred_root', type=str, help='prediction root',
default='./e_preds')
parser.add_argument(
'--data_lst', type=str, help='test dataset',
default={
'DIS5K': '+'.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:]),
'COD': '+'.join(['TE-COD10K', 'NC4K', 'TE-CAMO', 'CHAMELEON'][:]),
'HRSOD': '+'.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'TE-DUTS', 'DUT-OMRON'][:]),
'DIS5K+HRSOD+HRS10K': '+'.join(['DIS-VD'][:]),
'P3M-10k': '+'.join(['TE-P3M-500-P', 'TE-P3M-500-NP'][:]),
}[config.task])
parser.add_argument(
'--save_dir', type=str, help='candidate competitors',
default='e_results')
parser.add_argument(
'--check_integrity', type=bool, help='whether to check the file integrity',
default=False)
parser.add_argument(
'--metrics', type=str, help='candidate competitors',
default='+'.join(['S', 'MAE', 'E', 'F', 'WF', 'HCE'][:100 if 'DIS5K' in config.task else -1]))
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
try:
args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1]), reverse=True) if int(m.split('epoch_')[-1]) % 1 == 0]
except:
args.model_lst = [m for m in sorted(os.listdir(args.pred_root))]
# check the integrity of each candidates
if args.check_integrity:
for _data_name in args.data_lst.split('+'):
for _model_name in args.model_lst:
gt_pth = os.path.join(args.gt_root, _data_name)
pred_pth = os.path.join(args.pred_root, _model_name, _data_name)
if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)):
print(len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth))))
print('The {} Dataset of {} Model is not matching to the ground-truth'.format(_data_name, _model_name))
else:
print('>>> skip check the integrity of each candidates')
# start engine
do_eval(args)
|