Towhidul's picture
Update app.py
582132f verified
raw
history blame
No virus
3.14 kB
import numpy as np
import pandas as pd
import cv2
import torch
from tensorflow import keras
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tensorflow.keras.models import load_model
import streamlit as st
st.title("Skin Cancer Classification App")
# Load TensorFlow models
models = {
"Le_Net": load_model('LeNet_5.h5'),
"Simple_CNN": load_model('Simple CNN.h5'),
"Alex_Net": load_model('AlexNet.h5'),
"Deeper_CNN": load_model('Deeper CNN.h5'),
}
# Load PyTorch ViT model
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=7)
vit_model.load_state_dict(torch.load('./vit_skin_cancer_model.pth'))
vit_model.eval() # Set the model to evaluation mode
# Add the PyTorch model to the models dictionary
models["ViT_Model"] = vit_model
# Allow user to select model
model_name = st.selectbox("Choose a model", list(models.keys()))
model = models[model_name]
# Upload Image
file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
true_file = pd.read_csv("HAM10000_metadata.csv")
classes = {
4: ('nv', 'melanocytic nevi'),
6: ('mel', 'melanoma'),
2: ('bkl', 'benign keratosis-like lesions'),
1: ('bcc', 'basal cell carcinoma'),
5: ('vasc', 'pyogenic granulomas and hemorrhage'),
0: ('akiec', 'Actinic keratoses and intraepithelial carcinomae'),
3: ('df', 'dermatofibroma')
}
classes_map = {
'nv': 'melanocytic nevi',
'mel': 'melanoma',
'bkl': 'benign keratosis-like lesions',
'bcc': 'basal cell carcinoma',
'vasc': 'pyogenic granulomas and hemorrhage',
'akiec': 'Actinic keratoses and intraepithelial carcinomae',
'df': 'dermatofibroma'
}
if file is not None:
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
# Resize image for TensorFlow models
img1 = cv2.resize(opencv_image, (32, 32))
if model_name == "ViT_Model":
# PyTorch model inference
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
image = feature_extractor(images=opencv_image, return_tensors="pt")['pixel_values']
with torch.no_grad():
outputs = model(image)
class_ind = outputs.logits.argmax(-1).item()
class_name = classes[class_ind]
else:
# TensorFlow model inference
result = model.predict(img1.reshape(1, 32, 32, 3))
max_prob = max(result[0])
class_ind = list(result[0]).index(max_prob)
class_name = classes[class_ind]
# Display image and result
col1, col2 = st.columns(2)
with col1:
st.header("Input Image")
st.image(opencv_image, channels="BGR")
with col2:
st.header("Results")
if file:
name = file.name.split(".")[0]
if name in true_file['image_id'].values:
st.write("True Label: ", classes_map[true_file.loc[true_file['image_id']==name, 'dx'].iloc[0]])
st.write("Prediction:", class_name[1])
else:
st.write("No match")