File size: 6,343 Bytes
0c9da96
 
 
 
 
 
 
 
 
 
 
 
 
a05049f
 
 
 
 
 
 
 
0c9da96
 
 
 
 
 
 
 
a05049f
0c9da96
a05049f
 
 
0c9da96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a05049f
0c9da96
 
 
a05049f
0c9da96
 
 
 
 
 
 
 
 
a05049f
 
0c9da96
 
a05049f
0c9da96
 
a05049f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c9da96
a05049f
0c9da96
 
 
12d56d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from transformers import ViTForImageClassification, ViTImageProcessor

import matplotlib.pyplot as plt
import gradio as gr
import plotly.graph_objects as go
import torch
import numpy as np
from PIL import Image
model_name = "./best_model"
processor = ViTImageProcessor.from_pretrained(model_name)
labels = ['Acne or Rosacea', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis', 'Bullous Disease', 'Cellulitis Impetigo and other Bacterial Infections', 'Contact Dermatitis', 'Eczema', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis', 'Warts Molluscum and other Viral Infections']

class ViTForImageClassificationWithAttention(ViTForImageClassification):
    def forward(self, pixel_values):
        outputs = super().forward(pixel_values)
        attention = self.vit.encoder.layers[0].attention.attention_weights
        return outputs, attention

model = ViTForImageClassificationWithAttention.from_pretrained(model_name)

class ViTForImageClassificationWithAttention(ViTForImageClassification):
    def forward(self, pixel_values, output_attentions=True):
        outputs = super().forward(pixel_values, output_attentions=output_attentions)
        attention = outputs.attentions
        return outputs, attention

model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")

i_count = 0
def classify_image(image):
    model_name = "best_model.pth"
    model.load_state_dict(torch.load(model_name))
    inputs = processor(images=image, return_tensors="pt")
    outputs, attention = model(**inputs, output_attentions=True)
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=1)
    top_k_probs, top_k_indices = torch.topk(probs, k=5)  # show top 5 predicted labels
    predicted_class_idx = torch.argmax(logits)
    predicted_class_label = labels[predicted_class_idx]
    top_k_labels = [labels[idx] for idx in top_k_indices[0]]
    top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])]

    # Create a bar chart
    fig_bar = go.Figure(
        data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
    fig_bar.update_layout(title="Top 5 Predicted Labels with Probabilities", xaxis_title="Label",
                          yaxis_title="Probability")

    # Create a heatmap
    if attention is not None:
        fig_heatmap = go.Figure(
            data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
        fig_heatmap.update_layout(title="Attention Heatmap")
    else:
        fig_heatmap = go.Figure()  # Return an empty plot

    # Overlay the attention heatmap on the input image
    if attention is not None:
        img_array = np.array(image)
        heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
        heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
        heatmap = heatmap / heatmap.max() * 255  # Normalize heatmap to [0, 255]
        heatmap = heatmap.astype(np.uint8)
        heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8)
        heatmap_color[:, :, 0] = heatmap  # Red channel
        heatmap_color[:, :, 1] = heatmap  # Green channel
        heatmap_color[:, :, 2] = 0  # Blue channel
        attention_overlay = (img_array * 0.5 + heatmap_color * 0.5).astype(np.uint8)
        attention_overlay = Image.fromarray(attention_overlay)
        attention_overlay.save("attention_overlay.png")
        attention_overlay = gr.Image("attention_overlay.png")
    else:
        attention_overlay = gr.Image()  # Return an empty image

    # Return the predicted label, the bar chart, and the heatmap
    return predicted_class_label, fig_bar, fig_heatmap, attention_overlay


def update_model(image, label):
    # Convert the label to an integer
    label_idx = labels.index(label)
    labels_tensor = torch.tensor([label_idx])

    inputs = processor(images=image, return_tensors="pt")
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs, attention = model(**inputs)
    loss = loss_fn(outputs.logits, labels_tensor)

    # Backward pass
    loss.backward()

    # Update the model parameters
    optimizer.step()

    # Save the updated model
    torch.save(model.state_dict(), "best_model.pth")

    return "Model updated successfully"


demo = gr.TabbedInterface(
    [
        gr.Interface(
            fn=classify_image,
            inputs=[
                gr.Image(type="pil", label="Image")
            ],
            outputs=[
                gr.Label(label="Predicted Class Label"),
                gr.Plot(label="Top 5 Predicted Labels with Probabilities")
            ],
            title="Dermatological Image Classification Demo",
            description="Upload an image to see the predicted class label, top 5 predicted labels with probabilities, and attention heatmap",
            allow_flagging=False
        ),
        gr.Interface(
            fn=update_model,
            inputs=[
                gr.Image(type="pil", label="Image"),
                gr.Radio(
                    choices=labels,
                    type="value",
                    label="Label",
                    value=labels[0]
                )
            ],
            outputs=[
                gr.Textbox(label="Model Update Status")
            ],
            title="Train Model",
            description="Upload an image and label to update the model",
            allow_flagging=False
        )
    ],
    title="Dermatological Image Classification and Training"
)

if __name__ == "__main__":
    demo.launch()