|
import os |
|
import prettytable as pt |
|
|
|
from evaluation.metrics import evaluator |
|
from config import Config |
|
|
|
|
|
config = Config() |
|
|
|
def evaluate(pred_dir, method, testset, only_S_MAE=False, epoch=0): |
|
filename = os.path.join('evaluation', 'eval-{}.txt'.format(method)) |
|
if os.path.exists(filename): |
|
id_suffix = 1 |
|
filename = filename.rstrip('.txt') + '_{}.txt'.format(id_suffix) |
|
while os.path.exists(filename): |
|
id_suffix += 1 |
|
filename = filename.replace('_{}.txt'.format(id_suffix-1), '_{}.txt'.format(id_suffix)) |
|
gt_paths = sorted([ |
|
os.path.join(config.data_root_dir, config.task, testset, 'gt', p) |
|
for p in os.listdir(os.path.join(config.data_root_dir, config.task, testset, 'gt')) |
|
]) |
|
pred_paths = sorted([os.path.join(pred_dir, method, testset, p) for p in os.listdir(os.path.join(pred_dir, method, testset))]) |
|
with open(filename, 'a+') as file_to_write: |
|
tb = pt.PrettyTable() |
|
field_names = [ |
|
"Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "maxEm", "meanFm", |
|
"adpEm", "adpFm", 'HCE' |
|
] |
|
tb.field_names = [name for name in field_names if not only_S_MAE or all(metric not in name for metric in ['Em', 'Fm'])] |
|
em, sm, fm, mae, wfm, hce = evaluator( |
|
gt_paths=gt_paths[:], |
|
pred_paths=pred_paths[:], |
|
metrics=['S', 'MAE', 'E', 'F', 'HCE'][:10*(not only_S_MAE) + 2], |
|
verbose=config.verbose_eval, |
|
) |
|
e_max, e_mean, e_adp = em['curve'].max(), em['curve'].mean(), em['adp'].mean() |
|
f_max, f_mean, f_wfm, f_adp = fm['curve'].max(), fm['curve'].mean(), wfm, fm['adp'] |
|
tb.add_row( |
|
[ |
|
method+str(epoch), testset, f_max.round(3), f_wfm.round(3), mae.round(3), sm.round(3), |
|
e_mean.round(3), e_max.round(3), f_mean.round(3), em['adp'].round(3), f_adp.round(3), hce.round(3) |
|
] if not only_S_MAE else [method, testset, mae.round(3), sm.round(3)] |
|
) |
|
print(tb) |
|
file_to_write.write(str(tb).replace('+', '|')+'\n') |
|
file_to_write.close() |
|
return {'e_max': e_max, 'e_mean': e_mean, 'e_adp': e_adp, 'sm': sm, 'mae': mae, 'f_max': f_max, 'f_mean': f_mean, 'f_wfm': f_wfm, 'f_adp': f_adp, 'hce': hce} |
|
|
|
|
|
def main(): |
|
only_S_MAE = False |
|
pred_dir = '.' |
|
method = 'tmp_val' |
|
testsets = 'DIS-VD+DIS-TE1' |
|
for testset in testsets.split('+'): |
|
res_dct = evaluate(pred_dir, method, testset, only_S_MAE=only_S_MAE) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|