Spaces:
Sleeping
Sleeping
VinayHajare
commited on
Commit
•
5c05ddf
1
Parent(s):
4905c12
Update app.py
Browse files
app.py
CHANGED
@@ -11,12 +11,13 @@ Fruits = ['Acerola', 'Apple', 'Apricot', 'Avocado', 'Banana', 'Black Berry', 'Bl
|
|
11 |
'Passion Fruit', 'Peach', 'Pear', 'Pineapple', 'Plum', 'Pomegranate', 'Raspberry', 'Strawberry', 'Tomato',
|
12 |
'Watermelon']
|
13 |
|
14 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
|
16 |
repo_name = "VinayHajare/fruits30-resnet18"
|
17 |
file_name = "fruit_resnet18(99.40%).pt"
|
18 |
model_path = hf_hub_download(repo_id = repo_name, filename = file_name)
|
19 |
-
model = torch.load(model_path
|
|
|
20 |
model.eval()
|
21 |
|
22 |
print("Stage 1: Completed ")
|
@@ -29,7 +30,8 @@ transform = transforms.Compose([
|
|
29 |
])
|
30 |
|
31 |
def predict_image(image):
|
32 |
-
image_tensor = transform(Image.fromarray(image)).unsqueeze(0)
|
|
|
33 |
with torch.no_grad():
|
34 |
output = model(image_tensor)
|
35 |
predicted = torch.argmax(output).item()
|
|
|
11 |
'Passion Fruit', 'Peach', 'Pear', 'Pineapple', 'Plum', 'Pomegranate', 'Raspberry', 'Strawberry', 'Tomato',
|
12 |
'Watermelon']
|
13 |
|
14 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
15 |
|
16 |
repo_name = "VinayHajare/fruits30-resnet18"
|
17 |
file_name = "fruit_resnet18(99.40%).pt"
|
18 |
model_path = hf_hub_download(repo_id = repo_name, filename = file_name)
|
19 |
+
model = torch.load(model_path, map_location=torch.device('cpu'))
|
20 |
+
model.to(device)
|
21 |
model.eval()
|
22 |
|
23 |
print("Stage 1: Completed ")
|
|
|
30 |
])
|
31 |
|
32 |
def predict_image(image):
|
33 |
+
image_tensor = transform(Image.fromarray(image)).unsqueeze(0)
|
34 |
+
image_tensor = image_tensor.to(device)
|
35 |
with torch.no_grad():
|
36 |
output = model(image_tensor)
|
37 |
predicted = torch.argmax(output).item()
|