aaronherrera commited on
Commit
06b1c65
1 Parent(s): 7ca342d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
4
+ import openpyxl
5
+
6
+ #Function to predict the food from the image using the pre-trained model "nateraw/food"
7
+ def predict(image):
8
+ extractor = AutoFeatureExtractor.from_pretrained("nateraw/food")
9
+ model = AutoModelForImageClassification.from_pretrained("nateraw/food")
10
+
11
+ input = extractor(images=image, return_tensors='pt')
12
+ output = model(**input)
13
+ logits = output.logits
14
+
15
+ pred_class = logits.argmax(-1).item()
16
+ return(model.config.id2label[pred_class])
17
+
18
+ #Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA
19
+ def check_food(food, counter):
20
+ path = './database.xlsx'
21
+ wb_obj = openpyxl.load_workbook(path)
22
+ sheet_obj = wb_obj.active
23
+
24
+ foodPred, cal, carb, prot, fat = None, None, None, None, None
25
+
26
+ #Filter to prioritize the most probable match between the prediction and the entries in the database
27
+ for i in range(3, sheet_obj.max_row+1):
28
+ cell_obj = sheet_obj.cell(row = i, column = 2)
29
+ if counter == 0:
30
+ if len(food) >= 3:
31
+ foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + ","
32
+ elif len(food) == 2:
33
+ foodName = food[0].capitalize() + " " + food[1] + ","
34
+ elif len(food) == 1:
35
+ foodName = food[0].capitalize() + ","
36
+ condition = foodName == cell_obj.value[0:len(foodName):]
37
+ elif counter == 1:
38
+ if len(food) >= 3:
39
+ foodName = food[0].capitalize() + " " + food[1] + " " + food[2]
40
+ elif len(food) == 2:
41
+ foodName = food[0].capitalize() + " " + food[1]
42
+ elif len(food) == 1:
43
+ foodName = food[0].capitalize()
44
+ condition = foodName == cell_obj.value[0:len(foodName):]
45
+ elif counter == 2:
46
+ if len(food) >= 3:
47
+ foodName = food[0] + " " + food[1] + " " + food[2]
48
+ elif len(food) == 2:
49
+ foodName = food[0] + " " + food[1]
50
+ elif len(food) == 1:
51
+ foodName = food[0]
52
+ condition = foodName in cell_obj.value
53
+ elif (counter == 3) & (len(food) > 1):
54
+ condition = food[0] in cell_obj.value
55
+ else:
56
+ break
57
+
58
+ #Update values if conditions are met
59
+ if condition:
60
+ foodPred = cell_obj.value
61
+ cal = sheet_obj.cell(row = i, column = 5).value
62
+ carb = sheet_obj.cell(row = i, column = 7).value
63
+ prot = sheet_obj.cell(row = i, column = 6).value
64
+ fat = sheet_obj.cell(row = i, column = 10).value
65
+ break
66
+
67
+ return foodPred, cal, carb, prot, fat
68
+
69
+ #Function to prepare the output
70
+ def get_cc(food, weight):
71
+
72
+ #Configure the food string to match the entries in the database
73
+ food = food.split("_")
74
+ if food[-1][-1] == "s":
75
+ food[-1] = food[-1][:-1]
76
+
77
+ foodPred, cal, carb, prot, fat = None, None, None, None, None
78
+ counter = 0
79
+
80
+ #Try for the most probable match between the prediction and the entries in the database
81
+ while (not foodPred) & (counter <= 3):
82
+ foodPred, cal, carb, prot, fat = check_food(food,counter)
83
+ counter += 1
84
+
85
+ #Check if there is a match
86
+ if food:
87
+ output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g"
88
+ elif not food:
89
+ output = "No data for food"
90
+
91
+ return(output)
92
+
93
+ #Main function
94
+ def CC(image, weight):
95
+ pred = predict(image)
96
+ cc = get_cc(pred, weight)
97
+ return(pred, cc)
98
+
99
+ interface = gr.Interface(
100
+ fn = CC,
101
+ inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")],
102
+ outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')],
103
+ examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]])
104
+
105
+ interface.launch()