import numpy as np import pandas as pd import cv2 import torch from tensorflow import keras from PIL import Image from transformers import ViTForImageClassification, ViTFeatureExtractor from tensorflow.keras.models import load_model import streamlit as st st.title("Skin Cancer Classification App") # Load TensorFlow models models = { "Le_Net": load_model('LeNet_5.h5'), "Simple_CNN": load_model('Simple CNN.h5'), "Alex_Net": load_model('AlexNet.h5'), "Deeper_CNN": load_model('Deeper CNN.h5'), } # Load PyTorch ViT model vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=7) vit_model.load_state_dict(torch.load('./vit_skin_cancer_model.pth')) vit_model.eval() # Set the model to evaluation mode # Add the PyTorch model to the models dictionary models["ViT_Model"] = vit_model # Allow user to select model model_name = st.selectbox("Choose a model", list(models.keys())) model = models[model_name] # Upload Image file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) true_file = pd.read_csv("HAM10000_metadata.csv") classes = { 4: ('nv', 'melanocytic nevi'), 6: ('mel', 'melanoma'), 2: ('bkl', 'benign keratosis-like lesions'), 1: ('bcc', 'basal cell carcinoma'), 5: ('vasc', 'pyogenic granulomas and hemorrhage'), 0: ('akiec', 'Actinic keratoses and intraepithelial carcinomae'), 3: ('df', 'dermatofibroma') } classes_map = { 'nv': 'melanocytic nevi', 'mel': 'melanoma', 'bkl': 'benign keratosis-like lesions', 'bcc': 'basal cell carcinoma', 'vasc': 'pyogenic granulomas and hemorrhage', 'akiec': 'Actinic keratoses and intraepithelial carcinomae', 'df': 'dermatofibroma' } if file is not None: file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8) opencv_image = cv2.imdecode(file_bytes, 1) # Resize image for TensorFlow models img1 = cv2.resize(opencv_image, (32, 32)) if model_name == "ViT_Model": # PyTorch model inference feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') image = feature_extractor(images=opencv_image, return_tensors="pt")['pixel_values'] with torch.no_grad(): outputs = model(image) class_ind = outputs.logits.argmax(-1).item() class_name = classes[class_ind] else: # TensorFlow model inference result = model.predict(img1.reshape(1, 32, 32, 3)) max_prob = max(result[0]) class_ind = list(result[0]).index(max_prob) class_name = classes[class_ind] # Display image and result col1, col2 = st.columns(2) with col1: st.header("Input Image") st.image(opencv_image, channels="BGR") with col2: st.header("Results") if file: name = file.name.split(".")[0] if name in true_file['image_id'].values: st.write("True Label: ", classes_map[true_file.loc[true_file['image_id']==name, 'dx'].iloc[0]]) st.write("Prediction:", class_name[1]) else: st.write("No match")