foodvision / app.py
allispaul's picture
initial commit
3898f71
import os
from timeit import default_timer as timer
from typing import Tuple
from pathlib import Path
from PIL import Image
import gradio as gr
import torch
from torch import nn
from torchvision import transforms
from model import create_effnetb3_model
class_names = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla',
'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',
'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich',
'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup',
'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters',
'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto',
'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits',
'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
device = "cpu"
# Create model
effnetb3, effnetb3_transforms = create_effnetb3_model(num_classes=len(class_names))
# Load saved weights
effnetb3_state_dict = torch.load("effnetb3_full_food101.pth",
map_location=torch.device(device))
effnetb3_state_dict['classifier.1.weight'] = effnetb3_state_dict.pop('classifier.weight')
effnetb3_state_dict['classifier.1.bias'] = effnetb3_state_dict.pop('classifier.bias')
effnetb3.load_state_dict(effnetb3_state_dict)
effnetb3.to(device);
# Define predict function
def predict(img: Image) -> Tuple[dict, float]:
"""Uses EffnetB3 model to transform and predict on img. Returns prediction
probabilities and time taken.
Args:
img (PIL.Image): Image to predict on.
Returns:
A tuple (pred_labels_and_probs, pred_time), where pred_labels_and_probs
is a dict mapping each class name to the probability the model assigns to
it, and pred_time is the time taken to predict (in seconds).
"""
start_time = timer()
img = effnetb3_transforms(img).unsqueeze(0)
effnetb3.eval()
with torch.inference_mode():
pred_probs = torch.softmax(effnetb3(img), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i])
for i in range(len(class_names))}
pred_time = round(timer() - start_time, 4)
return pred_labels_and_probs, pred_time
# Initialize Gradio app
title = "FoodVision"
description = "EfficientNetB3 feature extractor to classify images of food. Upload an image or click on one of the examples to try it out!"
article = """
From the [Zero to Mastery PyTorch tutorial](https://www.learnpytorch.io/09_pytorch_model_deployment/), using the
[Food-101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/).
"""
examples = [[example] for example in Path("examples").glob("*.jpg")]
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction time (s)")],
examples=examples,
title=title,
description=description,
article=article,
)
demo.launch()