Update README.md
Browse files
README.md
CHANGED
@@ -68,12 +68,24 @@ You can use this model with the `transformers` and `torch` libraries.
|
|
68 |
|
69 |
```python
|
70 |
import torch
|
|
|
71 |
from torchvision import transforms
|
72 |
from PIL import Image
|
73 |
import requests
|
74 |
|
75 |
-
#
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
model.eval()
|
78 |
|
79 |
# Image preprocessing
|
@@ -84,7 +96,7 @@ transform = transforms.Compose([
|
|
84 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
85 |
])
|
86 |
|
87 |
-
# Sample Image (replace with your own image)
|
88 |
url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c'
|
89 |
img = Image.open(requests.get(url, stream=True).raw)
|
90 |
|
@@ -95,6 +107,7 @@ input_img = transform(img).unsqueeze(0)
|
|
95 |
with torch.no_grad():
|
96 |
output = model(input_img)
|
97 |
_, predicted = torch.max(output, 1)
|
98 |
-
|
|
|
99 |
labels = {0: 'Pneumonia', 1: 'Normal'}
|
100 |
-
print(f'Predicted label: {labels[predicted.item()]}')
|
|
|
68 |
|
69 |
```python
|
70 |
import torch
|
71 |
+
from huggingface_hub import hf_hub_download
|
72 |
from torchvision import transforms
|
73 |
from PIL import Image
|
74 |
import requests
|
75 |
|
76 |
+
# Download the model weights from Hugging Face Hub
|
77 |
+
model_path = hf_hub_download(repo_id="izeeek/resnet18_pneumonia_classifier", filename="resnet18_pneumonia_classifier.pth")
|
78 |
+
|
79 |
+
# Load the model architecture (ResNet18)
|
80 |
+
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
|
81 |
+
|
82 |
+
# Adjust the final layer for binary classification (if necessary)
|
83 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 2)
|
84 |
+
|
85 |
+
# Load the downloaded weights
|
86 |
+
model.load_state_dict(torch.load(model_path))
|
87 |
+
|
88 |
+
# Set the model to evaluation mode
|
89 |
model.eval()
|
90 |
|
91 |
# Image preprocessing
|
|
|
96 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
97 |
])
|
98 |
|
99 |
+
# Sample Image (replace with your own image URL)
|
100 |
url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c'
|
101 |
img = Image.open(requests.get(url, stream=True).raw)
|
102 |
|
|
|
107 |
with torch.no_grad():
|
108 |
output = model(input_img)
|
109 |
_, predicted = torch.max(output, 1)
|
110 |
+
|
111 |
+
# Labels for classification
|
112 |
labels = {0: 'Pneumonia', 1: 'Normal'}
|
113 |
+
print(f'Predicted label: {labels[predicted.item()]}')
|