Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
import openpyxl | |
#Function to predict the food from the image using the pre-trained model "nateraw/food" | |
def predict(image): | |
extractor = AutoFeatureExtractor.from_pretrained("nateraw/food") | |
model = AutoModelForImageClassification.from_pretrained("nateraw/food") | |
input = extractor(images=image, return_tensors='pt') | |
output = model(**input) | |
logits = output.logits | |
pred_class = logits.argmax(-1).item() | |
return(model.config.id2label[pred_class]) | |
#Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA | |
def check_food(food, counter): | |
path = './database.xlsx' | |
wb_obj = openpyxl.load_workbook(path) | |
sheet_obj = wb_obj.active | |
foodPred, cal, carb, prot, fat = None, None, None, None, None | |
#Filter to prioritize the most probable match between the prediction and the entries in the database | |
for i in range(3, sheet_obj.max_row+1): | |
cell_obj = sheet_obj.cell(row = i, column = 2) | |
if counter == 0: | |
if len(food) >= 3: | |
foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + "," | |
elif len(food) == 2: | |
foodName = food[0].capitalize() + " " + food[1] + "," | |
elif len(food) == 1: | |
foodName = food[0].capitalize() + "," | |
condition = foodName == cell_obj.value[0:len(foodName):] | |
elif counter == 1: | |
if len(food) >= 3: | |
foodName = food[0].capitalize() + " " + food[1] + " " + food[2] | |
elif len(food) == 2: | |
foodName = food[0].capitalize() + " " + food[1] | |
elif len(food) == 1: | |
foodName = food[0].capitalize() | |
condition = foodName == cell_obj.value[0:len(foodName):] | |
elif counter == 2: | |
if len(food) >= 3: | |
foodName = food[0] + " " + food[1] + " " + food[2] | |
elif len(food) == 2: | |
foodName = food[0] + " " + food[1] | |
elif len(food) == 1: | |
foodName = food[0] | |
condition = foodName in cell_obj.value | |
elif (counter == 3) & (len(food) > 1): | |
condition = food[0] in cell_obj.value | |
else: | |
break | |
#Update values if conditions are met | |
if condition: | |
foodPred = cell_obj.value | |
cal = sheet_obj.cell(row = i, column = 5).value | |
carb = sheet_obj.cell(row = i, column = 7).value | |
prot = sheet_obj.cell(row = i, column = 6).value | |
fat = sheet_obj.cell(row = i, column = 10).value | |
break | |
return foodPred, cal, carb, prot, fat | |
#Function to prepare the output | |
def get_cc(food, weight): | |
#Configure the food string to match the entries in the database | |
food = food.split("_") | |
if food[-1][-1] == "s": | |
food[-1] = food[-1][:-1] | |
foodPred, cal, carb, prot, fat = None, None, None, None, None | |
counter = 0 | |
#Try for the most probable match between the prediction and the entries in the database | |
while (not foodPred) & (counter <= 3): | |
foodPred, cal, carb, prot, fat = check_food(food,counter) | |
counter += 1 | |
#Check if there is a match | |
if food: | |
output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g" | |
elif not food: | |
output = "No data for food" | |
return(output) | |
#Main function | |
def CC(image, weight): | |
pred = predict(image) | |
cc = get_cc(pred, weight) | |
return(pred, cc) | |
interface = gr.Interface( | |
fn = CC, | |
inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")], | |
outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')], | |
examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]]) | |
interface.launch() |