Spaces:
Running
Running
Update app.py
Browse filesRussian Translation Added
app.py
CHANGED
@@ -9,7 +9,7 @@ import numpy as np
|
|
9 |
from PIL import Image
|
10 |
model_name = "./best_model"
|
11 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
12 |
-
labels = ['
|
13 |
|
14 |
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
15 |
def forward(self, pixel_values):
|
@@ -44,14 +44,14 @@ def classify_image(image):
|
|
44 |
# Create a bar chart
|
45 |
fig_bar = go.Figure(
|
46 |
data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
|
47 |
-
fig_bar.update_layout(title="
|
48 |
-
yaxis_title="
|
49 |
|
50 |
# Create a heatmap
|
51 |
if attention is not None:
|
52 |
fig_heatmap = go.Figure(
|
53 |
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
|
54 |
-
fig_heatmap.update_layout(title="
|
55 |
else:
|
56 |
fig_heatmap = go.Figure() # Return an empty plot
|
57 |
|
@@ -66,7 +66,7 @@ def classify_image(image):
|
|
66 |
heatmap_color[:, :, 0] = heatmap # Red channel
|
67 |
heatmap_color[:, :, 1] = heatmap # Green channel
|
68 |
heatmap_color[:, :, 2] = 0 # Blue channel
|
69 |
-
attention_overlay = (img_array * 0.
|
70 |
attention_overlay = Image.fromarray(attention_overlay)
|
71 |
attention_overlay.save("attention_overlay.png")
|
72 |
attention_overlay = gr.Image("attention_overlay.png")
|
@@ -102,7 +102,7 @@ def update_model(image, label):
|
|
102 |
# Save the updated model
|
103 |
torch.save(model.state_dict(), "best_model.pth")
|
104 |
|
105 |
-
return "
|
106 |
|
107 |
|
108 |
demo = gr.TabbedInterface(
|
@@ -113,11 +113,11 @@ demo = gr.TabbedInterface(
|
|
113 |
gr.Image(type="pil", label="Image")
|
114 |
],
|
115 |
outputs=[
|
116 |
-
gr.Label(label="
|
117 |
-
gr.Plot(label="
|
118 |
],
|
119 |
-
title="
|
120 |
-
description="
|
121 |
allow_flagging=False
|
122 |
),
|
123 |
gr.Interface(
|
@@ -132,14 +132,14 @@ demo = gr.TabbedInterface(
|
|
132 |
)
|
133 |
],
|
134 |
outputs=[
|
135 |
-
gr.Textbox(label="
|
136 |
],
|
137 |
-
title="
|
138 |
-
description="
|
139 |
allow_flagging=False
|
140 |
)
|
141 |
],
|
142 |
-
title="
|
143 |
)
|
144 |
|
145 |
if __name__ == "__main__":
|
|
|
9 |
from PIL import Image
|
10 |
model_name = "./best_model"
|
11 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
12 |
+
labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции']
|
13 |
|
14 |
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
15 |
def forward(self, pixel_values):
|
|
|
44 |
# Create a bar chart
|
45 |
fig_bar = go.Figure(
|
46 |
data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
|
47 |
+
fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз",
|
48 |
+
yaxis_title="Вероятность")
|
49 |
|
50 |
# Create a heatmap
|
51 |
if attention is not None:
|
52 |
fig_heatmap = go.Figure(
|
53 |
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
|
54 |
+
fig_heatmap.update_layout(title="Карта внимания системы")
|
55 |
else:
|
56 |
fig_heatmap = go.Figure() # Return an empty plot
|
57 |
|
|
|
66 |
heatmap_color[:, :, 0] = heatmap # Red channel
|
67 |
heatmap_color[:, :, 1] = heatmap # Green channel
|
68 |
heatmap_color[:, :, 2] = 0 # Blue channel
|
69 |
+
attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8)
|
70 |
attention_overlay = Image.fromarray(attention_overlay)
|
71 |
attention_overlay.save("attention_overlay.png")
|
72 |
attention_overlay = gr.Image("attention_overlay.png")
|
|
|
102 |
# Save the updated model
|
103 |
torch.save(model.state_dict(), "best_model.pth")
|
104 |
|
105 |
+
return "Модель успешно обновлена"
|
106 |
|
107 |
|
108 |
demo = gr.TabbedInterface(
|
|
|
113 |
gr.Image(type="pil", label="Image")
|
114 |
],
|
115 |
outputs=[
|
116 |
+
gr.Label(label="Предсказанный диагноз"),
|
117 |
+
gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности")
|
118 |
],
|
119 |
+
title="DermaScan Demo",
|
120 |
+
description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.",
|
121 |
allow_flagging=False
|
122 |
),
|
123 |
gr.Interface(
|
|
|
132 |
)
|
133 |
],
|
134 |
outputs=[
|
135 |
+
gr.Textbox(label="Обновление модели")
|
136 |
],
|
137 |
+
title="Обучить модель",
|
138 |
+
description="Загрузите изображение и метку для обновления модели.",
|
139 |
allow_flagging=False
|
140 |
)
|
141 |
],
|
142 |
+
title="DermaScan Demo"
|
143 |
)
|
144 |
|
145 |
if __name__ == "__main__":
|