DanielXu0208 commited on
Commit
ca68817
1 Parent(s): 3ebb60a

Update Interface

Browse files
run_gradio.py CHANGED
@@ -1,55 +1,233 @@
1
  import gradio as gr
2
  import torch
3
  import torchvision
 
 
 
4
  from utils.experiment_utils import get_model
5
 
6
 
7
- # 加载DINOv2模型
8
- def load_model():
9
- class Args:
10
- model = 'DINOv2'
11
- pretrained = 'pretrained'
12
- frozen = 'unfrozen'
 
 
 
13
 
14
- args = Args()
15
- model = get_model(args)
16
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return model
18
 
19
 
20
- model = load_model()
 
 
 
21
 
 
 
22
 
23
- # 预测函数,返回每个类别的概率
24
- def predict(image):
25
  transform = torchvision.transforms.Compose([
26
  torchvision.transforms.Resize((224, 224)),
27
  torchvision.transforms.ToTensor(),
28
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
  ])
30
 
31
- image = transform(image).unsqueeze(0)
32
  with torch.no_grad():
33
- output = model(image)
34
  probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
35
 
36
- # 类别名称列表
37
  class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
38
 
39
- # 将类别和对应的概率配对
 
40
  results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
41
 
42
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # 创建Gradio界面
46
- interface = gr.Interface(
47
- fn=predict,
48
- inputs=gr.Image(type="pil"),
49
- outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])),
50
- title="LUWA DINOv2 Prediction",
51
- description="Upload an image to get the probabilities for each class using the DINOv2 model."
52
- )
53
 
54
  if __name__ == "__main__":
55
- interface.launch(share=True)
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import torchvision
4
+ import pandas as pd
5
+ import os
6
+ from PIL import Image
7
  from utils.experiment_utils import get_model
8
 
9
 
10
+ # Custom flagging logic to save flagged data to a CSV file
11
+ class CustomFlagging(gr.FlaggingCallback):
12
+ def __init__(self, dir_name="flagged_data"):
13
+ self.dir = dir_name
14
+ self.image_dir = os.path.join(self.dir, "uploaded_images")
15
+ if not os.path.exists(self.dir):
16
+ os.makedirs(self.dir)
17
+ if not os.path.exists(self.image_dir):
18
+ os.makedirs(self.image_dir)
19
 
20
+ # Define setup as a no-op to fulfill abstract class requirement
21
+ def setup(self, *args, **kwargs):
22
+ pass
23
+
24
+ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
25
+ # Extract data
26
+ classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data
27
+
28
+ # Save the uploaded image in the "uploaded_images" folder
29
+ image_filename = os.path.join(self.image_dir,
30
+ f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
31
+ image.save(image_filename) # Save image in PNG format
32
+
33
+ # Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
34
+ data = {
35
+ "Classification Mode": classification_mode,
36
+ "Image Path": image_filename, # Save path to image in CSV
37
+ "Sensing Modality": sensing_modality,
38
+ "Predicted Class": predicted_class,
39
+ "Correct Class": correct_class,
40
+ }
41
+
42
+ df = pd.DataFrame([data])
43
+ csv_file = os.path.join(self.dir, "flagged_data.csv")
44
+
45
+ # Append to CSV, or create if it doesn't exist
46
+ if os.path.exists(csv_file):
47
+ df.to_csv(csv_file, mode='a', header=False, index=False)
48
+ else:
49
+ df.to_csv(csv_file, mode='w', header=True, index=False)
50
+
51
+
52
+ # Function to load the appropriate model based on the user's selection
53
+ def load_model(modality, mode):
54
+ # For Few-Shot classification, always use the DINOv2 model
55
+ if mode == "Few-Shot":
56
+ class Args:
57
+ model = 'DINOv2'
58
+ pretrained = 'pretrained'
59
+ frozen = 'unfrozen'
60
+
61
+ args = Args()
62
+ model = get_model(args) # Load DINOv2 model for Few-Shot classification
63
+ else:
64
+ # For Fully-Supervised classification, choose model based on the sensing modality
65
+ if modality == "Texture":
66
+ class Args:
67
+ model = 'DINOv2'
68
+ pretrained = 'pretrained'
69
+ frozen = 'unfrozen'
70
+
71
+ args = Args()
72
+ model = get_model(args) # Load DINOv2 model for Texture modality
73
+ elif modality == "Heightmap":
74
+ class Args:
75
+ model = 'ResNet152'
76
+ pretrained = 'pretrained'
77
+ frozen = 'unfrozen'
78
+
79
+ args = Args()
80
+ model = get_model(args) # Load ResNet152 model for Heightmap modality
81
+ else:
82
+ raise ValueError("Invalid modality selected!")
83
+
84
+ model.eval() # Set the model to evaluation mode
85
  return model
86
 
87
 
88
+ # Prediction function that processes the image and returns the prediction results
89
+ def predict(image, modality, mode):
90
+ # Load the appropriate model based on the user's selections
91
+ model = load_model(modality, mode)
92
 
93
+ # Print the selected mode and modality for debugging purposes
94
+ print(f"User selected Mode: {mode}, Modality: {modality}")
95
 
96
+ # Preprocess the image
 
97
  transform = torchvision.transforms.Compose([
98
  torchvision.transforms.Resize((224, 224)),
99
  torchvision.transforms.ToTensor(),
100
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
101
  ])
102
 
103
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
104
  with torch.no_grad():
105
+ output = model(image_tensor) # Get model predictions
106
  probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
107
 
108
+ # Class names for the predictions
109
  class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
110
 
111
+ # Pair class names with their corresponding probabilities
112
+ predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class
113
  results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
114
 
115
+ return predicted_class, results # Return predicted class and probabilities
116
+
117
+
118
+ # Create the Gradio interface using gr.Blocks
119
+ def create_interface():
120
+ with gr.Blocks() as interface:
121
+ # Title at the top of the interface (centered and larger)
122
+ gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")
123
+
124
+ # Add description for the interface
125
+ description = """
126
+ ### Image Classification Options
127
+ - **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
128
+ - **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
129
+ ### **Don't forget to choose the Sensing Modality based on your uploaded images.**
130
+ ### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
131
+ """
132
+ gr.Markdown(description)
133
+
134
+ # Top-level selector for Fully-Supervised vs. Few-Shot classification
135
+ mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
136
+ value="Fully Supervised")
137
+
138
+ # Sensing modality selector
139
+ modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")
140
+
141
+ # Image upload input
142
+ image_input = gr.Image(type="pil", label="Image")
143
+
144
+ # Predicted classification output and class probabilities
145
+ with gr.Row():
146
+ predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
147
+ probabilities_output = gr.Label(label="Prediction Probabilities")
148
+
149
+ # Add the "Run Prediction" button under the Prediction Probabilities
150
+ predict_button = gr.Button("Run Prediction")
151
+
152
+ # Dropdown for user to select the correct class if the model prediction is wrong
153
+ correct_class_selector = gr.Radio(
154
+ choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
155
+ label="Select Correct Class"
156
+ )
157
 
158
+ # Text box for user to type the correct class if "Other" is selected
159
+ other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)
160
+
161
+ # Logic to dynamically update visibility of the "Other" class text box
162
+ def update_visibility(selected_class):
163
+ return gr.update(visible=selected_class == "Other")
164
+
165
+ correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)
166
+
167
+
168
+ # Create a flagging instance
169
+ flagging_instance = CustomFlagging(dir_name="flagged_data")
170
+
171
+ # Define function for the confirmation pop-up
172
+ def confirm_flag_selection(correct_class, other_class):
173
+ # Generate confirmation message
174
+ if correct_class == "Other":
175
+ message = f"Are you sure the class you selected is '{other_class}' for this picture?"
176
+ else:
177
+ message = f"Are you sure the class you selected is '{correct_class}' for this picture?"
178
+
179
+ return message, gr.update(visible=True), gr.update(visible=True)
180
+
181
+ # Final flag submission function
182
+ def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
183
+ if confirmed == "Yes":
184
+ # Save the flagged data
185
+ correct_class_final = correct_class if correct_class != "Other" else other_class
186
+ flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
187
+ return "Flagged successfully!"
188
+ else:
189
+ return "No flag submitted, please select again."
190
+
191
+ # Flagging button
192
+ flag_button = gr.Button("Flag")
193
+
194
+ # Confirmation box for user input and confirmation flag
195
+ confirmation_text = gr.Textbox(visible=False)
196
+ yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
197
+ confirmation_button = gr.Button("Confirm Flag", visible=False)
198
+
199
+ # Prediction action
200
+ predict_button.click(
201
+ fn=predict,
202
+ inputs=[image_input, modality_selector, mode_selector],
203
+ outputs=[predicted_output, probabilities_output]
204
+ )
205
+
206
+ # Flagging action with confirmation
207
+ flag_button.click(
208
+ fn=confirm_flag_selection,
209
+ inputs=[correct_class_selector, other_class_input],
210
+ outputs=[confirmation_text, yes_no_choice, confirmation_button]
211
+ )
212
+
213
+ # Final flag submission after confirmation
214
+ confirmation_button.click(
215
+ fn=flag_data_save,
216
+ inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
217
+ predicted_output, yes_no_choice],
218
+ outputs=gr.Textbox(label="Flagging Status")
219
+ )
220
+
221
+ return interface
222
 
 
 
 
 
 
 
 
 
223
 
224
  if __name__ == "__main__":
225
+ interface = create_interface()
226
+ interface.launch(share=True)
227
+
228
+
229
+
230
+
231
+
232
+
233
+
utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
 
 
 
 
utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier DELETED
@@ -1,3 +0,0 @@
1
- [ZoneTransfer]
2
- ZoneId=3
3
- ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip