File size: 4,220 Bytes
3898f71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from timeit import default_timer as timer
from typing import Tuple
from pathlib import Path
from PIL import Image

import gradio as gr
import torch
from torch import nn
from torchvision import transforms

from model import create_effnetb3_model

class_names = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
               'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
               'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
               'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla',
               'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
               'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',
               'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
               'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
               'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
               'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich',
               'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup',
               'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
               'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
               'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters',
               'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
               'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
               'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto',
               'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits',
               'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
               'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']

device = "cpu"

# Create model
effnetb3, effnetb3_transforms = create_effnetb3_model(num_classes=len(class_names))

# Load saved weights
effnetb3_state_dict = torch.load("effnetb3_full_food101.pth",
                                  map_location=torch.device(device))
effnetb3_state_dict['classifier.1.weight'] = effnetb3_state_dict.pop('classifier.weight')
effnetb3_state_dict['classifier.1.bias'] = effnetb3_state_dict.pop('classifier.bias')
effnetb3.load_state_dict(effnetb3_state_dict)
effnetb3.to(device);

# Define predict function
def predict(img: Image) -> Tuple[dict, float]:
    """Uses EffnetB3 model to transform and predict on img. Returns prediction
    probabilities and time taken.
    
    Args:
      img (PIL.Image): Image to predict on.
    
    Returns:
      A tuple (pred_labels_and_probs, pred_time), where pred_labels_and_probs
      is a dict mapping each class name to the probability the model assigns to
      it, and pred_time is the time taken to predict (in seconds).
    """
    start_time = timer()
    img = effnetb3_transforms(img).unsqueeze(0)
    effnetb3.eval()
    with torch.inference_mode():
        pred_probs = torch.softmax(effnetb3(img), dim=1)
    pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i])
                             for i in range(len(class_names))}
    pred_time = round(timer() - start_time, 4)
    return pred_labels_and_probs, pred_time

# Initialize Gradio app
title = "FoodVision"
description = "EfficientNetB3 feature extractor to classify images of food. Upload an image or click on one of the examples to try it out!"
article = """
From the [Zero to Mastery PyTorch tutorial](https://www.learnpytorch.io/09_pytorch_model_deployment/), using the
[Food-101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/).
"""
examples = [[example] for example in Path("examples").glob("*.jpg")]

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Label(num_top_classes=3, label="Predictions"),
             gr.Number(label="Prediction time (s)")], 
    examples=examples,
    title=title,
    description=description,
    article=article,
)

demo.launch()