yanranxiaoxi's picture
First commit
7d6aa54 verified
raw
history blame
6.58 kB
import pandas as pd
import numpy as np
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
from streamlit_image_select import image_select
from tqdm import tqdm
import os
import shutil
from PIL import Image
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForMaskGeneration
def show_mask(image, mask, ax=None):
fig, axes = plt.subplots()
axes.imshow(np.array(image))
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
axes.imshow(mask_image)
canvas = FigureCanvasAgg(fig)
canvas.draw()
pil_image = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())
plt.close(fig)
return pil_image
def get_bounding_box(ground_truth_map):
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
def get_output(image,prompt):
inputs = processor(image,input_boxes=[[prompt]],return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
outputs = model(**inputs,multimask_output=False)
output_proba = torch.sigmoid(outputs.pred_masks.squeeze(1))
output_proba = output_proba.cpu().numpy().squeeze()
output = (output_proba > 0.5).astype(np.uint8)
return output
def generate_image(np_array):
return Image.fromarray((np_array*255).astype('uint8'),mode='L')
def iou_calculation(result1, result2):
intersection = np.logical_and(result1, result2)
union = np.logical_or(result1, result2)
iou_score = np.sum(intersection) / np.sum(union)
iou_score = "{:.4f}".format(iou_score)
return float(iou_score)
def calculate_pixel_accuracy(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
total_pixels = width * height
image1 = image1.convert("RGB")
image2 = image2.convert("RGB")
pixels1 = image1.load()
pixels2 = image2.load()
num_correct_pixels = 0
for y in range(height):
for x in range(width):
if pixels1[x, y] == pixels2[x, y]:
num_correct_pixels += 1
accuracy = num_correct_pixels / total_pixels
return accuracy
def calculate_f1_score(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
image1 = image1.convert("L")
image2 = image2.convert("L")
np_image1 = np.array(image1)
np_image2 = np.array(image2)
np_image1_flat = np_image1.flatten()
np_image2_flat = np_image2.flatten()
true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
precision = true_positives / (true_positives + false_positives + 1e-7)
recall = true_positives / (true_positives + false_negatives + 1e-7)
f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
return f1_score
def calculate_dice_coefficient(image1, image2):
if image1.size != image2.size or image1.mode != image2.mode:
image1 = image1.resize(image2.size, Image.BILINEAR)
if image1.mode != image2.mode:
image1 = image1.convert(image2.mode)
width, height = image1.size
image1 = image1.convert("L")
image2 = image2.convert("L")
np_image1 = np.array(image1)
np_image2 = np.array(image2)
np_image1_flat = np_image1.flatten()
np_image2_flat = np_image2.flatten()
true_positives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat == 255))
false_positives = np.sum(np.logical_and(np_image1_flat != 255, np_image2_flat == 255))
false_negatives = np.sum(np.logical_and(np_image1_flat == 255, np_image2_flat != 255))
dice_coefficient = (2 * true_positives) / (2 * true_positives + false_positives + false_negatives)
return dice_coefficient
device = "cuda" if torch.cuda.is_available() else "cpu"
st.set_page_config(layout='wide')
ds = load_dataset('yanranxiaoxi/skin-lesion-mask',split='train')
s1 = ds[7]['image']
s2 = ds[8]['image']
s3 = ds[9]['image']
s4 = ds[10]['image']
s1_label = ds[7]['label']
s2_label = ds[8]['label']
s3_label = ds[9]['label']
s4_label = ds[10]['label']
image_arr = [s1,s2,s3,s4]
label_arr = [s1_label,s2_label,s3_label,s4_label]
img = image_select(
label="选择一个皮肤病变图像",
images=[
s1,s2,s3,s4
],
captions=["例 1","例 2","例 3","例 4"],
return_value='index'
)
processor = AutoProcessor.from_pretrained('yanranxiaoxi/skin-lesion-base')
model = AutoModelForMaskGeneration.from_pretrained('yanranxiaoxi/skin-lesion-focalloss-base-combined')
model.to(device)
p = get_bounding_box(np.array(label_arr[img]))
predicted_mask_array = get_output(image_arr[img],p)
predicted_mask = generate_image(predicted_mask_array)
result_image = show_mask(image_arr[img],predicted_mask_array)
with st.container():
tab1, tab2 = st.tabs(['可视化','指标'])
with tab1:
col1, col2 = st.columns(2)
with col1:
st.image(image_arr[img],caption='原始皮肤病变图像',use_column_width=True)
with col2:
st.image(result_image,caption='叠加标注遮罩区域',use_column_width=True)
with tab2:
st.write(f'IOU 得分:{iou_calculation(label_arr[img],predicted_mask)}')
st.write(f'像素精确度:{calculate_pixel_accuracy(label_arr[img],predicted_mask)}')
st.write(f'骰子系数(DC)得分:{calculate_dice_coefficient(label_arr[img],predicted_mask)}')