import gradio as gr from transformers import pipeline from transformers import AutoFeatureExtractor, AutoModelForImageClassification import openpyxl #Function to predict the food from the image using the pre-trained model "nateraw/food" def predict(image): extractor = AutoFeatureExtractor.from_pretrained("nateraw/food") model = AutoModelForImageClassification.from_pretrained("nateraw/food") input = extractor(images=image, return_tensors='pt') output = model(**input) logits = output.logits pred_class = logits.argmax(-1).item() return(model.config.id2label[pred_class]) #Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA def check_food(food, counter): path = './database.xlsx' wb_obj = openpyxl.load_workbook(path) sheet_obj = wb_obj.active foodPred, cal, carb, prot, fat = None, None, None, None, None #Filter to prioritize the most probable match between the prediction and the entries in the database for i in range(3, sheet_obj.max_row+1): cell_obj = sheet_obj.cell(row = i, column = 2) if counter == 0: if len(food) >= 3: foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + "," elif len(food) == 2: foodName = food[0].capitalize() + " " + food[1] + "," elif len(food) == 1: foodName = food[0].capitalize() + "," condition = foodName == cell_obj.value[0:len(foodName):] elif counter == 1: if len(food) >= 3: foodName = food[0].capitalize() + " " + food[1] + " " + food[2] elif len(food) == 2: foodName = food[0].capitalize() + " " + food[1] elif len(food) == 1: foodName = food[0].capitalize() condition = foodName == cell_obj.value[0:len(foodName):] elif counter == 2: if len(food) >= 3: foodName = food[0] + " " + food[1] + " " + food[2] elif len(food) == 2: foodName = food[0] + " " + food[1] elif len(food) == 1: foodName = food[0] condition = foodName in cell_obj.value elif (counter == 3) & (len(food) > 1): condition = food[0] in cell_obj.value else: break #Update values if conditions are met if condition: foodPred = cell_obj.value cal = sheet_obj.cell(row = i, column = 5).value carb = sheet_obj.cell(row = i, column = 7).value prot = sheet_obj.cell(row = i, column = 6).value fat = sheet_obj.cell(row = i, column = 10).value break return foodPred, cal, carb, prot, fat #Function to prepare the output def get_cc(food, weight): #Configure the food string to match the entries in the database food = food.split("_") if food[-1][-1] == "s": food[-1] = food[-1][:-1] foodPred, cal, carb, prot, fat = None, None, None, None, None counter = 0 #Try for the most probable match between the prediction and the entries in the database while (not foodPred) & (counter <= 3): foodPred, cal, carb, prot, fat = check_food(food,counter) counter += 1 #Check if there is a match if food: output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g" elif not food: output = "No data for food" return(output) #Main function def CC(image, weight): pred = predict(image) cc = get_cc(pred, weight) return(pred, cc) interface = gr.Interface( fn = CC, inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")], outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')], examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]]) interface.launch()