Spaces:
Sleeping
Sleeping
import os | |
import tensorflow as tf | |
from tensorflow.keras import backend as K | |
import tf_keras | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('agg') | |
def dice_coefficients(y_true, y_pred, smooth=100): | |
y_true_flatten = K.flatten(y_true) | |
y_pred_flatten = K.flatten(y_pred) | |
intersection = K.sum(y_true_flatten * y_pred_flatten) | |
union = K.sum(y_true_flatten) + K.sum(y_pred_flatten) | |
return (2 * intersection + smooth) / (union + smooth) | |
def dice_coefficients_loss(y_true, y_pred, smooth=100): | |
return 1.0 - dice_coefficients(y_true, y_pred, smooth) | |
def iou(y_true, y_pred, smooth=100): | |
intersection = K.sum(y_true * y_pred) | |
sum = K.sum(y_true + y_pred) | |
iou = (intersection + smooth) / (sum - intersection + smooth) | |
return iou | |
def jaccard_distance(y_true, y_pred): | |
y_true_flatten = K.flatten(y_true) | |
y_pred_flatten = K.flatten(y_pred) | |
return -iou(y_true_flatten, y_pred_flatten) | |
segmodel = tf_keras.models.load_model("segment_model/V2", custom_objects={'dice_coefficients_loss': dice_coefficients_loss, 'iou': iou, 'dice_coefficients': dice_coefficients } ) | |
def load_image_for_pred(image_path): | |
img = tf.keras.utils.load_img( | |
image_path, | |
color_mode='rgb', | |
target_size=(256, 256), | |
interpolation='nearest', | |
keep_aspect_ratio=False | |
) | |
img = tf.keras.utils.img_to_array(img) / 255 | |
return np.array([img]) | |
def make_segmentation(image_path): | |
img = load_image_for_pred(image_path) | |
predicted_img = segmodel.predict(img) | |
plt.figure(figsize=(5, 3)) | |
plt.subplot(1, 3, 1) | |
plt.imshow(np.squeeze(img)) | |
plt.title('Original Image') | |
plt.axis(False) | |
plt.subplot(1, 3, 2) | |
plt.imshow(np.squeeze(predicted_img) > 0.5) | |
plt.title('Prediction') | |
plt.axis(False) | |
plt.subplot(1, 4, 4) | |
plt.imshow(np.squeeze(img)) | |
plt.imshow(np.squeeze(predicted_img) > 0.5, cmap='gray', alpha=0.5) | |
plt.title('Image w/h Mask') | |
plt.axis(False) | |
save_file_name = os.path.splitext(image_path)[0] + '_segmented.png' | |
plt.savefig(save_file_name) | |
return save_file_name | |