andythetechnerd03 commited on
Commit
5a73567
β€’
1 Parent(s): f6a046f

change some stuff

Browse files
Files changed (1) hide show
  1. app.py +46 -34
app.py CHANGED
@@ -1,65 +1,77 @@
 
1
  import gradio as gr
2
  import os
3
  import torch
4
 
5
  from model import create_effnetb2_model
6
  from timeit import default_timer as timer
7
- from typing import Tuple, TypedDict
8
 
9
  # Setup class names
10
  class_names = ["pizza", "steak", "sushi"]
11
 
12
- # Create EffNetB2 model instance and transform
13
- effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))
14
 
15
- # Load model weights
 
 
 
 
 
16
  effnetb2.load_state_dict(
17
  torch.load(
18
- os.path.join("models", "09_pretrained_effnetb2_feature_extractor_pizza_steak_sushi_20_percent.pth"),
19
- map_location=torch.device("cpu")
20
- )
21
  )
 
22
 
23
- # Predict function
24
- def predict(img) -> Tuple[Dict, float]:
25
- # Start a timer
26
- start_time = timer()
27
-
28
- # Transform the input image for use with EffNetB2
29
- img = effnetb2_transforms(img).unsqueeze(0)
30
 
31
- # put model into eval mode, make prediction
32
- effnetb2.eval()
33
- with torch.inference_mode():
34
- pred_probs = torch.softmax(effnetb2(img), dim=-1)
35
-
36
- # Create a prediction label and predcition probability
37
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(class_names)}
38
-
39
- # Calculate pred time and pred dict
40
- pred_time = round(timer() - start_time, 5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return pred_labels_and_probs, pred_time
43
 
44
- # Gradio app
45
  # Create title, description and article strings
46
  title = "FoodVision Mini πŸ•πŸ₯©πŸ£"
47
  description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi."
48
  article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
49
 
50
- # Create an example list
51
  example_list = [["examples/" + example] for example in os.listdir("examples")]
52
 
53
  # Create the Gradio demo
54
- demo = gr.Interface(fn=predict,
55
- inputs=gr.inputs.Image(type="pil"),
56
- outputs=[gr.outputs.Label(num_top_classes=3, label="Predictions"),
57
- gr.outputs.Number(label="Prediction time (s)")],
58
- examples=example_list,
 
59
  title=title,
60
  description=description,
61
  article=article)
62
 
63
  # Launch the demo!
64
- demo.launch(debug=False,
65
- share=True)
 
1
+ ### 1. Imports and class names setup ###
2
  import gradio as gr
3
  import os
4
  import torch
5
 
6
  from model import create_effnetb2_model
7
  from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
 
10
  # Setup class names
11
  class_names = ["pizza", "steak", "sushi"]
12
 
13
+ ### 2. Model and transforms preparation ###
 
14
 
15
+ # Create EffNetB2 model
16
+ effnetb2, effnetb2_transforms = create_effnetb2_model(
17
+ num_classes=3, # len(class_names) would also work
18
+ )
19
+
20
+ # Load saved weights
21
  effnetb2.load_state_dict(
22
  torch.load(
23
+ f="09_pretrained_effnetb2_feature_extractor_pizza_steak_sushi_20_percent.pth",
24
+ map_location=torch.device("cpu"), # load to CPU
 
25
  )
26
+ )
27
 
28
+ ### 3. Predict function ###
 
 
 
 
 
 
29
 
30
+ # Create predict function
31
+ def predict(img) -> Tuple[Dict, float]:
32
+ """Transforms and performs a prediction on img and returns prediction and time taken.
33
+ """
34
+ # Start the timer
35
+ start_time = timer()
36
+
37
+ # Transform the target image and add a batch dimension
38
+ img = effnetb2_transforms(img).unsqueeze(0)
39
+
40
+ # Put model into evaluation mode and turn on inference mode
41
+ effnetb2.eval()
42
+ with torch.inference_mode():
43
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
44
+ pred_probs = torch.softmax(effnetb2(img), dim=1)
45
+
46
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
47
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
48
+
49
+ # Calculate the prediction time
50
+ pred_time = round(timer() - start_time, 5)
51
+
52
+ # Return the prediction dictionary and prediction time
53
+ return pred_labels_and_probs, pred_time
54
 
55
+ ### 4. Gradio app ###
56
 
 
57
  # Create title, description and article strings
58
  title = "FoodVision Mini πŸ•πŸ₯©πŸ£"
59
  description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi."
60
  article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
61
 
62
+ # Create examples list from "examples/" directory
63
  example_list = [["examples/" + example] for example in os.listdir("examples")]
64
 
65
  # Create the Gradio demo
66
+ demo = gr.Interface(fn=predict, # mapping function from input to output
67
+ inputs=gr.Image(type="pil"), # what are the inputs?
68
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
69
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
70
+ # Create examples list from "examples/" directory
71
+ examples=example_list,
72
  title=title,
73
  description=description,
74
  article=article)
75
 
76
  # Launch the demo!
77
+ demo.launch()