lfolle's picture
small fix to video
44a5c17
import os
import torch
import numpy as np
from PIL import Image
from utils import *
from tqdm import tqdm
from gradcam import GradCAM, GradCAMpp
from overlay_image import overlay_numpy
DISEASES = [ 'Atelectasis',
'Cardiomegaly',
'Effusion',
'Infiltration',
'Mass',
'Nodule',
'Pneumonia',
'Pneumothorax',
'Consolidation',
'Edema',
'Emphysema',
'Fibrosis',
'Pleural Thickening',
'Hernia' ]
class GradCamGenerator:
def __init__(self, model_path, layer, overlay=False):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = self.load_model(model_path)
self.layer = layer
self.overlay = overlay # Overlay GC heatmaps with image
self.layer_module = self.model.get_submodule(layer)
self.gc_model2 = GradCAM(self.model, self.layer_module) # , device_ids=self.device)
def load_model(self, model_path, print_net=False):
checkpoint = torch.load(model_path, map_location=self.device)
model = checkpoint['model']
self.set_inplace_False(model)
if print_net:
print(model)
return model
def set_inplace_False(self, module):
for layer in module._modules.values():
if isinstance(layer, nn.ReLU):
layer.inplace = False
self.set_inplace_False(layer)
def generate_grad_cam(self, path):
img = self.pil_loader(path, 3)
input_image = self.transform_pil_to_tensor(img)
tclass = self.target_from_path(path)
#tmp_pred = self.model(input_image)
grayscale_cams = self.gc_model2(input=input_image, class_idx=tclass)
attribution = 255*grayscale_cams[0].detach().cpu().numpy().squeeze()
attribution /= attribution.max()
if self.overlay:
overlay_numpy(img, attribution, path)
#print()
return attribution
def target_from_path(self, path):
disease = path.split('/')[-2]
indx = DISEASES.index(disease) if disease!='No Finding' else 0
return torch.tensor(indx, device=self.device)
def save_img(self, image, input_path):
gc_filename = input_path[:-4]+'_gc'+input_path[-4:]
image_PIL = Image.fromarray(image).convert('L')
image_PIL.save(gc_filename)
def transform_pil_to_tensor(self, pil_image):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
tensor = transform(pil_image).to(self.device)
return tensor.unsqueeze(0)
def pil_loader(self, path, n_channels):
with open(path, 'rb') as f:
img = Image.open(f)
if n_channels == 1:
return img.convert('L')
elif n_channels == 3:
return img.convert('RGB')
else:
raise ValueError('Invalid value for parameter n_channels!')
def create_GC_from_folder(path, classifier='checkpoint', layer_name='features.norm5', overlay=True, override_gc=True):
GC = GradCamGenerator(classifier, layer_name, overlay=overlay)
folds = ['data/' + i for i in os.listdir(path) if 'No Finding' not in i][10:]
for cf in folds:
files = [cf+'/'+f for f in os.listdir(cf) if ('_gc' not in f and 'overlay' not in f and
(not os.path.exists(cf+'/'+f[:-4]+'_gc.png') or override_gc) and
(not os.path.exists(cf+'/'+f[:-4]+'_overlay.png') or not overlay))]
for cfil in tqdm(files):
GC.generate_grad_cam(cfil)