Model card
Try our model here
Model description
This is an image categorization model that uses restnet-50 as the base model to classify diabetic retinopathy
Intended uses & limitations
Given an image taken using fundus photography this model will identify diabetic retinopathy on a scale of 0 to 4:
0 - No DR
1 - Mild
2 - Moderate
3 - Severe
4 - Proliferative DR
Training
- We trained our model with retina images taken using fundus photography under a variety of imaging conditions.
- The training data was gathered for a Kaggle completion by the Asia Pacific Tele-Ophthalmology Society (APTOS) in 2019
- Training data
- Training Process
Evaluation
Training accuracy - trained for 50 epochs, reaching 83% accuracy within our training data
Epoch | Train Loss | Valid Loss | Accuracy | Error Rate | Time |
---|---|---|---|---|---|
0 | 1.271288 | 1.351223 | 0.665301 | 0.334699 | 03:47 |
1 | 1.013268 | 0.742499 | 0.741803 | 0.258197 | 04:12 |
2 | 0.806825 | 0.687152 | 0.754098 | 0.245902 | 03:42 |
0 | 0.631816 | 0.533298 | 0.789617 | 0.210383 | 04:22 |
1 | 0.537469 | 0.457713 | 0.829235 | 0.170765 | 04:23 |
2 | 0.498419 | 0.515875 | 0.810109 | 0.189891 | 04:20 |
3 | 0.478353 | 0.511856 | 0.815574 | 0.184426 | 04:13 |
4 | 0.459457 | 0.475843 | 0.801913 | 0.198087 | 04:17 |
... | |||||
48 | 0.024947 | 0.800241 | 0.840164 | 0.159836 | 03:21 |
49 | 0.027916 | 0.803851 | 0.838798 | 0.161202 | 03:26 |
We submitted our model for validation to the APTOS 2019 Blindness Detection Competition, achieving a private score of 0.869345
Trying the model
Note: You can easily try our model here
This application uses a trained model to detect the severity of diabetic retinopathy from a given retina image taken using fundus photography. The severity levels are:
- 0 - No DR
- 1 - Mild
- 2 - Moderate
- 3 - Severe
- 4 - Proliferative DR
How to Use the Model
To use the model, you need to provide an image of the retina taken using fundus photography. The model will then predict the severity of diabetic retinopathy and return a dictionary where the keys are the severity levels and the values are the corresponding probabilities.
Breakdown of the app.py
File
Here's a breakdown of what the app.py
file is doing:
Import necessary libraries: The file starts by importing the necessary libraries. This includes
gradio
for creating the UI,fastai.vision.all
for loading the trained model, andskimage
for image processing.Define helper functions: The
get_x
andget_y
functions are defined. These functions are used to get the x and y values from the input dictionary. In this case, the x value is the image and the y value is the diagnosis.Load the trained model: The trained model is loaded from the
model.pkl
file using theload_learner
function fromfastai
.Define label descriptions: A dictionary is defined to map label numbers to descriptions. This is used to return descriptions instead of numbers in the prediction result.
Define the prediction function: The
predict
function is defined. This function takes an image as input, makes a prediction using the trained model, and returns a dictionary where the keys are the severity levels and the values are the corresponding probabilities.Define title and description: The title and description of the application are defined. These will be displayed in the Gradio UI.
To run the application, you need to create a Gradio interface with the predict
function as the prediction function, an image as the input, and a label as the output. You can then launch the interface to start the application.
from fastai.vision.all import *
import skimage
# Define the functions to get the x and y values from the input dictionary - in this case, the x value is the image and the y value is the diagnosis
# needed to load the model since we defined them during training
def get_x(r): return ""
def get_y(r): return r['diagnosis']
learn = load_learner('model.pkl')
labels = learn.dls.vocab
# Define the mapping from label numbers to descriptions
label_descriptions = {
0: "No DR",
1: "Mild",
2: "Moderate",
3: "Severe",
4: "Proliferative DR"
}
def predict(img):
img = PILImage.create(img)
pred, pred_idx, probs = learn.predict(img)
# Use the label_descriptions dictionary to return descriptions instead of numbers
return {label_descriptions[labels[i]]: float(probs[i]) for i in range(len(labels))}
title = "Diabetic Retinopathy Detection"
description = """Detects severity of diabetic retinopathy from a given retina image taken using fundus photography -
0 - No DR
1 - Mild
2 - Moderate
3 - Severe
4 - Proliferative DR
"""
article = "<p style='text-align: center'><a href='https://www.kaggle.com/code/josemauriciodelgado/proliferative-retinopathy' target='_blank'>Notebook</a></p>"
# Get a list of all image paths in the test folder
test_folder = "test" # replace with the actual path to your test folder
image_paths = [os.path.join(test_folder, img) for img in os.listdir(test_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]
gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=5),
examples=image_paths, # set the examples parameter to the list of image paths
article=article,
title=title,
description=description,
).launch()
Model tree for jdelgado2002/diabetic_retinopathy_detection
Base model
microsoft/resnet-50