|
import os |
|
from os.path import join |
|
import glob |
|
from sklearn.metrics import roc_auc_score, roc_curve |
|
import sys |
|
|
|
def file_to_score(file): |
|
try: |
|
content = open(file, 'r').readlines() |
|
st = 0 |
|
if len(content) == 0: |
|
|
|
return 0. |
|
if content[0].split()[0].isalpha(): |
|
st = 1 |
|
return max([float(line.split()[st]) for line in content]) |
|
except FileNotFoundError: |
|
print(f'No Corresponding Box Found for file {file}, using [] as preds') |
|
return [] |
|
except Exception as e: |
|
print('Some Error',e) |
|
return [] |
|
|
|
|
|
def generate_image_dict(preds_folder_name='preds_42', |
|
root_fol='/home/krithika_1/densebreeast_datasets/AIIMS_C1', |
|
mal_path=None, ben_path=None, gt_path=None, |
|
mal_img_path = None, ben_img_path = None |
|
): |
|
|
|
mal_path = join(root_fol, mal_path) if mal_path else join( |
|
root_fol, 'mal', preds_folder_name) |
|
ben_path = join(root_fol, ben_path) if ben_path else join( |
|
root_fol, 'ben', preds_folder_name) |
|
mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join( |
|
root_fol, 'mal', 'images') |
|
ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join( |
|
root_fol, 'ben', 'images') |
|
gt_path = join(root_fol, gt_path) if gt_path else join( |
|
root_fol, 'mal', 'gt') |
|
|
|
|
|
''' |
|
image_dict structure: |
|
'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : score} |
|
''' |
|
image_dict = dict() |
|
|
|
|
|
|
|
for file in os.listdir(mal_img_path): |
|
|
|
if not file.endswith('.png'): |
|
continue |
|
file = file[:-4] + '.txt' |
|
file = join(gt_path, file) |
|
key = os.path.split(file)[-1][:-4] |
|
image_dict[key] = dict() |
|
image_dict[key]['gt'] = 1. |
|
image_dict[key]['preds'] = 0. |
|
|
|
for file in glob.glob(join(mal_path, '*.txt')): |
|
key = os.path.split(file)[-1][:-4] |
|
assert key in image_dict |
|
image_dict[key]['preds'] = file_to_score(file) |
|
|
|
for file in os.listdir(ben_img_path): |
|
|
|
if not file.endswith('.png'): |
|
continue |
|
|
|
file = file[:-4] + '.txt' |
|
file = join(ben_path, file) |
|
key = os.path.split(file)[-1][:-4] |
|
|
|
|
|
if key in image_dict: |
|
print(key) |
|
print('SHIT') |
|
continue |
|
|
|
image_dict[key] = dict() |
|
image_dict[key]['preds'] = file_to_score(file) |
|
image_dict[key]['gt'] = 0. |
|
return image_dict |
|
|
|
def get_auc_score_from_imdict(image_dict): |
|
keys = list(image_dict.keys()) |
|
y = [image_dict[k]['gt']for k in keys] |
|
preds = [image_dict[k]['preds']for k in keys] |
|
return roc_auc_score(y, preds) |
|
|
|
def get_accuracy_from_imdict(image_dict, thresh = 0.3): |
|
keys = list(image_dict.keys()) |
|
ys = [image_dict[k]['gt']for k in keys] |
|
preds = [image_dict[k]['preds']for k in keys] |
|
acc = 0 |
|
for y,pred in zip(ys,preds): |
|
if pred < thresh and y == 0.: |
|
acc+=1 |
|
elif pred > thresh and y == 1.: |
|
acc+=1 |
|
return acc/len(preds) |
|
|
|
|
|
def get_auc_score(preds_image_folder, root_fol, retAcc = False, acc_thresh = 0.3): |
|
im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol) |
|
if retAcc: |
|
return get_auc_score_from_imdict(im_dict), get_accuracy_from_imdict(im_dict, acc_thresh) |
|
else: |
|
return get_auc_score_from_imdict(im_dict) |
|
|
|
if __name__ == '__main__': |
|
seed = '42' if len(sys.argv)== 1 else sys.argv[1] |
|
|
|
root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test' |
|
|
|
auc_score = get_auc_score(f'preds_{seed}',root_fol) |
|
print(f'ROC AUC Score: {auc_score}') |
|
|