Spaces:
Runtime error
Runtime error
"""Create a gif sampling from the posterior from an image. | |
The file includes routines to create gifs of posterior samples for image | |
explanations. To create the gif, we sample a number of draws from the posterior, | |
plot the explanation and the image, and repeat this to stitch together a gif. | |
The interpretation is that regions of the image that more frequency show up as | |
green are more likely to positively impact the prediction. Similarly, regions that | |
more frequently show up as red are more likey to negatively impact the prediction. | |
""" | |
import os | |
from os.path import exists, dirname | |
import sys | |
import imageio | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from skimage.segmentation import mark_boundaries | |
import tempfile | |
from tqdm import tqdm | |
import lime.lime_tabular as baseline_lime_tabular | |
import shap | |
# Make sure we can get bayes explanations | |
parent_dir = dirname(os.path.abspath(os.getcwd())) | |
sys.path.append(parent_dir) | |
from bayes.explanations import BayesLocalExplanations, explain_many | |
from bayes.data_routines import get_dataset_by_name | |
from bayes.models import * | |
def fill_segmentation(values, segmentation, image, n_max=5): | |
max_segs = np.argsort(abs(values))[-n_max:] | |
out = np.zeros((224, 224)) | |
c_image = np.zeros(image.shape) | |
for i in range(len(values)): | |
if i in max_segs: | |
out[segmentation == i] = 1 if values[i] > 0 else -1 | |
c = 1 if values[i] > 0 else 0 | |
c_image[segmentation == i, c] = np.max(image) | |
return c_image.astype(int), out.astype(int) | |
def create_gif(explanation_blr, segments, image, n_images=20, n_max=5): | |
"""Create the gif corresponding to the image explanation. | |
Arguments: | |
explanation_coefficients: The explanation blr object. | |
segments: The image segmentation. | |
image: The image for which to compute the explantion. | |
save_loc: The location to save the gif. | |
n_images: Number of images to create the gif with. | |
n_max: The number of superpixels to draw on the image. | |
""" | |
draws = explanation_blr.draw_posterior_samples(n_images) | |
# Setup temporary directory to store paths in | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
paths = [] | |
for i, d in tqdm(enumerate(draws)): | |
c_image, filled_segs = fill_segmentation(d, segments, image, n_max=n_max) | |
plt.cla() | |
plt.axis('off') | |
plt.imshow(mark_boundaries(image, filled_segs)) | |
plt.imshow(c_image, alpha=0.3) | |
paths.append(os.path.join(tmpdirname, f"{i}.png")) | |
plt.savefig(paths[-1]) | |
# Save to gif | |
# https://stackoverflow.com/questions/61716066/creating-an-animation-out-of-matplotlib-pngs | |
print(f"Saving gif to {save_loc}") | |
ims = [imageio.imread(f) for f in paths] | |
return imageio.mimwrite(imageio.RETURN_BYTES, ims) | |