|
--- |
|
tags: |
|
- fastai |
|
- vision |
|
- image-classification |
|
license: mit |
|
language: |
|
- en |
|
library_name: fastai |
|
base_model: microsoft/resnet-50 |
|
pipeline_tag: image-classification |
|
metrics: |
|
- accuracy |
|
--- |
|
|
|
# Model card |
|
|
|
Try our model [here](https://huggingface.co/spaces/jdelgado2002/proliferative_retinopathy_detection) |
|
|
|
## Model description |
|
This is an image categorization model that uses restnet-50 as the base model to classify diabetic retinopathy |
|
|
|
## Intended uses & limitations |
|
Given an image taken using fundus photography |
|
this model will identify diabetic retinopathy on a scale of 0 to 4: |
|
|
|
0 - No DR |
|
|
|
1 - Mild |
|
|
|
2 - Moderate |
|
|
|
3 - Severe |
|
|
|
4 - Proliferative DR |
|
|
|
|
|
## Training |
|
* We trained our model with retina images taken using fundus photography under a variety of imaging conditions. |
|
* The training data was gathered for a Kaggle completion by the Asia Pacific Tele-Ophthalmology Society (APTOS) in 2019 |
|
* [Training data](https://www.kaggle.com/competitions/aptos2019-blindness-detection/data) |
|
* [Training Process](https://www.kaggle.com/code/josemauriciodelgado/proliferative-retinopathy) |
|
|
|
## Evaluation |
|
Training accuracy - trained for 50 epochs, reaching 83% accuracy within our training data |
|
|
|
| Epoch | Train Loss | Valid Loss | Accuracy | Error Rate | Time | |
|
|-------|------------|------------|----------|------------|-------| |
|
| 0 | 1.271288 | 1.351223 | 0.665301 | 0.334699 | 03:47 | |
|
| 1 | 1.013268 | 0.742499 | 0.741803 | 0.258197 | 04:12 | |
|
| 2 | 0.806825 | 0.687152 | 0.754098 | 0.245902 | 03:42 | |
|
| 0 | 0.631816 | 0.533298 | 0.789617 | 0.210383 | 04:22 | |
|
| 1 | 0.537469 | 0.457713 | 0.829235 | 0.170765 | 04:23 | |
|
| 2 | 0.498419 | 0.515875 | 0.810109 | 0.189891 | 04:20 | |
|
| 3 | 0.478353 | 0.511856 | 0.815574 | 0.184426 | 04:13 | |
|
| 4 | 0.459457 | 0.475843 | 0.801913 | 0.198087 | 04:17 | |
|
... |
|
| 48 | 0.024947 | 0.800241 | 0.840164 | 0.159836 | 03:21 | |
|
| 49 | 0.027916 | 0.803851 | 0.838798 | 0.161202 | 03:26 | |
|
|
|
|
|
![confusion matrix](https://drive.google.com/file/d/1lI7pps03RXTFKYjY_iv4UPeSOhqQhxQB/view) |
|
|
|
We submitted our model for validation to the [APTOS 2019 Blindness Detection Competition](https://www.kaggle.com/competitions/aptos2019-blindness-detection/submissions#), |
|
achieving a private score of 0.869345 |
|
|
|
## Trying the model |
|
|
|
Note: You can easily try our model [here](https://huggingface.co/spaces/jdelgado2002/proliferative_retinopathy_detection) |
|
|
|
This application uses a trained model to detect the severity of diabetic retinopathy from a given retina image taken using fundus photography. The severity levels are: |
|
|
|
- 0 - No DR |
|
- 1 - Mild |
|
- 2 - Moderate |
|
- 3 - Severe |
|
- 4 - Proliferative DR |
|
|
|
### How to Use the Model |
|
|
|
To use the model, you need to provide an image of the retina taken using fundus photography. The model will then predict the severity of diabetic retinopathy and return a dictionary where the keys are the severity levels and the values are the corresponding probabilities. |
|
|
|
### Breakdown of the `app.py` File |
|
|
|
Here's a breakdown of what the `app.py` file is doing: |
|
|
|
1. **Import necessary libraries**: The file starts by importing the necessary libraries. This includes `gradio` for creating the UI, `fastai.vision.all` for loading the trained model, and `skimage` for image processing. |
|
|
|
2. **Define helper functions**: The `get_x` and `get_y` functions are defined. These functions are used to get the x and y values from the input dictionary. In this case, the x value is the image and the y value is the diagnosis. |
|
|
|
3. **Load the trained model**: The trained model is loaded from the `model.pkl` file using the `load_learner` function from `fastai`. |
|
|
|
4. **Define label descriptions**: A dictionary is defined to map label numbers to descriptions. This is used to return descriptions instead of numbers in the prediction result. |
|
|
|
5. **Define the prediction function**: The `predict` function is defined. This function takes an image as input, makes a prediction using the trained model, and returns a dictionary where the keys are the severity levels and the values are the corresponding probabilities. |
|
|
|
6. **Define title and description**: The title and description of the application are defined. These will be displayed in the Gradio UI. |
|
|
|
To run the application, you need to create a Gradio interface with the `predict` function as the prediction function, an image as the input, and a label as the output. You can then launch the interface to start the application. |
|
|
|
```import gradio as gr |
|
from fastai.vision.all import * |
|
import skimage |
|
|
|
# Define the functions to get the x and y values from the input dictionary - in this case, the x value is the image and the y value is the diagnosis |
|
# needed to load the model since we defined them during training |
|
def get_x(r): return "" |
|
|
|
def get_y(r): return r['diagnosis'] |
|
|
|
learn = load_learner('model.pkl') |
|
labels = learn.dls.vocab |
|
|
|
# Define the mapping from label numbers to descriptions |
|
label_descriptions = { |
|
0: "No DR", |
|
1: "Mild", |
|
2: "Moderate", |
|
3: "Severe", |
|
4: "Proliferative DR" |
|
} |
|
|
|
def predict(img): |
|
img = PILImage.create(img) |
|
pred, pred_idx, probs = learn.predict(img) |
|
# Use the label_descriptions dictionary to return descriptions instead of numbers |
|
return {label_descriptions[labels[i]]: float(probs[i]) for i in range(len(labels))} |
|
|
|
title = "Diabetic Retinopathy Detection" |
|
description = """Detects severity of diabetic retinopathy from a given retina image taken using fundus photography - |
|
|
|
0 - No DR |
|
|
|
1 - Mild |
|
|
|
2 - Moderate |
|
|
|
3 - Severe |
|
|
|
4 - Proliferative DR |
|
""" |
|
article = "<p style='text-align: center'><a href='https://www.kaggle.com/code/josemauriciodelgado/proliferative-retinopathy' target='_blank'>Notebook</a></p>" |
|
|
|
# Get a list of all image paths in the test folder |
|
test_folder = "test" # replace with the actual path to your test folder |
|
image_paths = [os.path.join(test_folder, img) for img in os.listdir(test_folder) if img.endswith(('.png', '.jpg', '.jpeg'))] |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(), |
|
outputs=gr.Label(num_top_classes=5), |
|
examples=image_paths, # set the examples parameter to the list of image paths |
|
article=article, |
|
title=title, |
|
description=description, |
|
).launch() |
|
``` |
|
[source code](https://huggingface.co/spaces/jdelgado2002/proliferative_retinopathy_detection/tree/main) |