izeeek commited on
Commit
f100315
1 Parent(s): 42461e2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +18 -5
README.md CHANGED
@@ -68,12 +68,24 @@ You can use this model with the `transformers` and `torch` libraries.
68
 
69
  ```python
70
  import torch
 
71
  from torchvision import transforms
72
  from PIL import Image
73
  import requests
74
 
75
- # Load the model from Hugging Face Hub
76
- model = torch.hub.load('huggingface/pytorch', 'resnet18_pneumonia_classifier')
 
 
 
 
 
 
 
 
 
 
 
77
  model.eval()
78
 
79
  # Image preprocessing
@@ -84,7 +96,7 @@ transform = transforms.Compose([
84
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
85
  ])
86
 
87
- # Sample Image (replace with your own image)
88
  url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c'
89
  img = Image.open(requests.get(url, stream=True).raw)
90
 
@@ -95,6 +107,7 @@ input_img = transform(img).unsqueeze(0)
95
  with torch.no_grad():
96
  output = model(input_img)
97
  _, predicted = torch.max(output, 1)
98
-
 
99
  labels = {0: 'Pneumonia', 1: 'Normal'}
100
- print(f'Predicted label: {labels[predicted.item()]}')
 
68
 
69
  ```python
70
  import torch
71
+ from huggingface_hub import hf_hub_download
72
  from torchvision import transforms
73
  from PIL import Image
74
  import requests
75
 
76
+ # Download the model weights from Hugging Face Hub
77
+ model_path = hf_hub_download(repo_id="izeeek/resnet18_pneumonia_classifier", filename="resnet18_pneumonia_classifier.pth")
78
+
79
+ # Load the model architecture (ResNet18)
80
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
81
+
82
+ # Adjust the final layer for binary classification (if necessary)
83
+ model.fc = torch.nn.Linear(model.fc.in_features, 2)
84
+
85
+ # Load the downloaded weights
86
+ model.load_state_dict(torch.load(model_path))
87
+
88
+ # Set the model to evaluation mode
89
  model.eval()
90
 
91
  # Image preprocessing
 
96
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
97
  ])
98
 
99
+ # Sample Image (replace with your own image URL)
100
  url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c'
101
  img = Image.open(requests.get(url, stream=True).raw)
102
 
 
107
  with torch.no_grad():
108
  output = model(input_img)
109
  _, predicted = torch.max(output, 1)
110
+
111
+ # Labels for classification
112
  labels = {0: 'Pneumonia', 1: 'Normal'}
113
+ print(f'Predicted label: {labels[predicted.item()]}')