VinayHajare commited on
Commit
5c05ddf
1 Parent(s): 4905c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -11,12 +11,13 @@ Fruits = ['Acerola', 'Apple', 'Apricot', 'Avocado', 'Banana', 'Black Berry', 'Bl
11
  'Passion Fruit', 'Peach', 'Pear', 'Pineapple', 'Plum', 'Pomegranate', 'Raspberry', 'Strawberry', 'Tomato',
12
  'Watermelon']
13
 
14
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
 
16
  repo_name = "VinayHajare/fruits30-resnet18"
17
  file_name = "fruit_resnet18(99.40%).pt"
18
  model_path = hf_hub_download(repo_id = repo_name, filename = file_name)
19
- model = torch.load(model_path).to(device)
 
20
  model.eval()
21
 
22
  print("Stage 1: Completed ")
@@ -29,7 +30,8 @@ transform = transforms.Compose([
29
  ])
30
 
31
  def predict_image(image):
32
- image_tensor = transform(Image.fromarray(image)).unsqueeze(0).to(device)
 
33
  with torch.no_grad():
34
  output = model(image_tensor)
35
  predicted = torch.argmax(output).item()
 
11
  'Passion Fruit', 'Peach', 'Pear', 'Pineapple', 'Plum', 'Pomegranate', 'Raspberry', 'Strawberry', 'Tomato',
12
  'Watermelon']
13
 
14
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
15
 
16
  repo_name = "VinayHajare/fruits30-resnet18"
17
  file_name = "fruit_resnet18(99.40%).pt"
18
  model_path = hf_hub_download(repo_id = repo_name, filename = file_name)
19
+ model = torch.load(model_path, map_location=torch.device('cpu'))
20
+ model.to(device)
21
  model.eval()
22
 
23
  print("Stage 1: Completed ")
 
30
  ])
31
 
32
  def predict_image(image):
33
+ image_tensor = transform(Image.fromarray(image)).unsqueeze(0)
34
+ image_tensor = image_tensor.to(device)
35
  with torch.no_grad():
36
  output = model(image_tensor)
37
  predicted = torch.argmax(output).item()