kamalkraj commited on
Commit
75cff75
1 Parent(s): 5a636d4

add app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ import torchvision.transforms
6
+
7
+ import torchxrayvision as xrv
8
+
9
+
10
+ def classify_image(img, model_name):
11
+
12
+ model = xrv.models.get_model(model_name, from_hf_hub=True)
13
+
14
+ img = xrv.datasets.normalize(img, 255)
15
+
16
+ # Check that images are 2D arrays
17
+ if len(img.shape) > 2:
18
+ img = img[:, :, 0]
19
+ if len(img.shape) < 2:
20
+ print("error, dimension lower than 2 for image")
21
+
22
+ # Add color channel
23
+ img = img[None, :, :]
24
+
25
+ transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
26
+
27
+ img = transform(img)
28
+
29
+ with torch.no_grad():
30
+ img = torch.from_numpy(img).unsqueeze(0)
31
+ preds = model(img).cpu()
32
+ output = {
33
+ k: float(v)
34
+ for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy())
35
+ }
36
+ return output
37
+
38
+
39
+ gr.Interface(
40
+ fn=classify_image,
41
+ inputs=[
42
+ gr.Image(shape=(224, 224), image_mode="L"),
43
+ gr.Dropdown(
44
+ [
45
+ "densenet121-res224-all",
46
+ "densenet121-res224-nih",
47
+ "densenet121-res224-pc",
48
+ "densenet121-res224-chex",
49
+ "densenet121-res224-rsna",
50
+ "densenet121-res224-mimic_nb",
51
+ "densenet121-res224-mimic_ch",
52
+ "resnet50-res512-all",
53
+ ],
54
+ value="densenet121-res224-all",
55
+ type="value",
56
+ label="Pre-trained model",
57
+ ),
58
+ ],
59
+ outputs=gr.outputs.Label(),
60
+ title="Classify chest x-ray image",
61
+ examples=[
62
+ ["16747_3_1.jpg", "densenet121-res224-all"],
63
+ ["00000001_000.png", "resnet50-res512-all"],
64
+ ],
65
+ ).launch()