kritsg's picture
Added image_posterior (to create segmentation)
0ff7e03
raw
history blame
2.87 kB
"""Create a gif sampling from the posterior from an image.
The file includes routines to create gifs of posterior samples for image
explanations. To create the gif, we sample a number of draws from the posterior,
plot the explanation and the image, and repeat this to stitch together a gif.
The interpretation is that regions of the image that more frequency show up as
green are more likely to positively impact the prediction. Similarly, regions that
more frequently show up as red are more likey to negatively impact the prediction.
"""
import os
from os.path import exists, dirname
import sys
import imageio
import matplotlib.pyplot as plt
import numpy as np
from skimage.segmentation import mark_boundaries
import tempfile
from tqdm import tqdm
import lime.lime_tabular as baseline_lime_tabular
import shap
# Make sure we can get bayes explanations
parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)
from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
def fill_segmentation(values, segmentation, image, n_max=5):
max_segs = np.argsort(abs(values))[-n_max:]
out = np.zeros((224, 224))
c_image = np.zeros(image.shape)
for i in range(len(values)):
if i in max_segs:
out[segmentation == i] = 1 if values[i] > 0 else -1
c = 1 if values[i] > 0 else 0
c_image[segmentation == i, c] = np.max(image)
return c_image.astype(int), out.astype(int)
def create_gif(explanation_blr, segments, image, n_images=20, n_max=5):
"""Create the gif corresponding to the image explanation.
Arguments:
explanation_coefficients: The explanation blr object.
segments: The image segmentation.
image: The image for which to compute the explantion.
save_loc: The location to save the gif.
n_images: Number of images to create the gif with.
n_max: The number of superpixels to draw on the image.
"""
draws = explanation_blr.draw_posterior_samples(n_images)
# Setup temporary directory to store paths in
with tempfile.TemporaryDirectory() as tmpdirname:
paths = []
for i, d in tqdm(enumerate(draws)):
c_image, filled_segs = fill_segmentation(d, segments, image, n_max=n_max)
plt.cla()
plt.axis('off')
plt.imshow(mark_boundaries(image, filled_segs))
plt.imshow(c_image, alpha=0.3)
paths.append(os.path.join(tmpdirname, f"{i}.png"))
plt.savefig(paths[-1])
# Save to gif
# https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs
print(f"Saving gif to {save_loc}")
ims = [imageio.imread(f) for f in paths]
return imageio.mimwrite(imageio.RETURN_BYTES, ims)