Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
from PIL import Image | |
from utils import * | |
from tqdm import tqdm | |
from gradcam import GradCAM, GradCAMpp | |
from overlay_image import overlay_numpy | |
DISEASES = [ 'Atelectasis', | |
'Cardiomegaly', | |
'Effusion', | |
'Infiltration', | |
'Mass', | |
'Nodule', | |
'Pneumonia', | |
'Pneumothorax', | |
'Consolidation', | |
'Edema', | |
'Emphysema', | |
'Fibrosis', | |
'Pleural Thickening', | |
'Hernia' ] | |
class GradCamGenerator: | |
def __init__(self, model_path, layer, overlay=False): | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.model = self.load_model(model_path) | |
self.layer = layer | |
self.overlay = overlay # Overlay GC heatmaps with image | |
self.layer_module = self.model.get_submodule(layer) | |
self.gc_model2 = GradCAM(self.model, self.layer_module) # , device_ids=self.device) | |
def load_model(self, model_path, print_net=False): | |
checkpoint = torch.load(model_path, map_location=self.device) | |
model = checkpoint['model'] | |
self.set_inplace_False(model) | |
if print_net: | |
print(model) | |
return model | |
def set_inplace_False(self, module): | |
for layer in module._modules.values(): | |
if isinstance(layer, nn.ReLU): | |
layer.inplace = False | |
self.set_inplace_False(layer) | |
def generate_grad_cam(self, path): | |
img = self.pil_loader(path, 3) | |
input_image = self.transform_pil_to_tensor(img) | |
tclass = self.target_from_path(path) | |
#tmp_pred = self.model(input_image) | |
grayscale_cams = self.gc_model2(input=input_image, class_idx=tclass) | |
attribution = 255*grayscale_cams[0].detach().cpu().numpy().squeeze() | |
attribution /= attribution.max() | |
if self.overlay: | |
overlay_numpy(img, attribution, path) | |
#print() | |
return attribution | |
def target_from_path(self, path): | |
disease = path.split('/')[-2] | |
indx = DISEASES.index(disease) if disease!='No Finding' else 0 | |
return torch.tensor(indx, device=self.device) | |
def save_img(self, image, input_path): | |
gc_filename = input_path[:-4]+'_gc'+input_path[-4:] | |
image_PIL = Image.fromarray(image).convert('L') | |
image_PIL.save(gc_filename) | |
def transform_pil_to_tensor(self, pil_image): | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
transform = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std) | |
]) | |
tensor = transform(pil_image).to(self.device) | |
return tensor.unsqueeze(0) | |
def pil_loader(self, path, n_channels): | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
if n_channels == 1: | |
return img.convert('L') | |
elif n_channels == 3: | |
return img.convert('RGB') | |
else: | |
raise ValueError('Invalid value for parameter n_channels!') | |
def create_GC_from_folder(path, classifier='checkpoint', layer_name='features.norm5', overlay=True, override_gc=True): | |
GC = GradCamGenerator(classifier, layer_name, overlay=overlay) | |
folds = ['data/' + i for i in os.listdir(path) if 'No Finding' not in i][10:] | |
for cf in folds: | |
files = [cf+'/'+f for f in os.listdir(cf) if ('_gc' not in f and 'overlay' not in f and | |
(not os.path.exists(cf+'/'+f[:-4]+'_gc.png') or override_gc) and | |
(not os.path.exists(cf+'/'+f[:-4]+'_overlay.png') or not overlay))] | |
for cfil in tqdm(files): | |
GC.generate_grad_cam(cfil) | |