DanielXu0208 commited on
Commit
e073b0b
1 Parent(s): 9e56473

Update run_gradio.py

Browse files
Files changed (1) hide show
  1. run_gradio.py +246 -212
run_gradio.py CHANGED
@@ -1,215 +1,249 @@
1
- import matplotlib.pyplot as plt
2
  import torch
3
- from torchvision.transforms.functional import resize, normalize, to_pil_image
4
- from torchvision.io.image import read_image
5
- from torchvision.models import resnet50
6
- from sklearn.cluster import KMeans
7
- import numpy as np
8
  import os
9
- import logging
10
- from torchcam.utils import overlay_mask
11
  from PIL import Image
12
- from collections import Counter
13
- from scipy.spatial.distance import cdist
14
-
15
- # Initialize logger to monitor progress
16
- logging.basicConfig(level=logging.INFO)
17
-
18
- # Path to dataset and model
19
- dataset_path = "archive"
20
- model_path = "resnet50_finetuned_miniimagenet.pth" # Update to your fine-tuned MiniImagenet weights
21
- n_clusters = 100 # Number of clusters for feature channels (can be adjusted)
22
- top_k_prototypes = 5 # Top k most similar examples to select as prototypes for each cluster
23
- batch_size = 8 # Reduce batch size to limit memory usage
24
- output_folder = "examples" # Folder to save the images and heatmaps
25
-
26
- # Limit to 100 images per class
27
- images_per_class = 100
28
-
29
- # Create the output folder if it doesn't exist
30
- if not os.path.exists(output_folder):
31
- os.makedirs(output_folder)
32
-
33
- # Set device to GPU if available, otherwise CPU
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
-
36
- # Load the fine-tuned model weights from MiniImagenet
37
- model = resnet50(pretrained=False)
38
-
39
- # Modify the fully connected layer to match the number of classes in MiniImagenet (100 classes)
40
- model.fc = torch.nn.Linear(model.fc.in_features, 100)
41
-
42
- # Load the fine-tuned state_dict
43
- checkpoint = torch.load(model_path, weights_only=True) # Updated to MiniImagenet weights
44
- model.load_state_dict(checkpoint)
45
- model.eval()
46
-
47
- # Move the model to GPU
48
- model.to(device)
49
-
50
- # Hook to capture activations from 'layer4'
51
- activation = {}
52
-
53
-
54
- def get_activation(name):
55
- def hook(model, input, output):
56
- activation[name] = output.detach()
57
-
58
- return hook
59
-
60
-
61
- model.layer4.register_forward_hook(get_activation('layer4'))
62
-
63
- # Collecting activations of all feature channels across multiple images
64
- all_activations = []
65
- image_paths = [] # To keep track of the image paths
66
- image_labels = [] # To store the class labels of each image
67
-
68
- # Traverse through the dataset, accessing each class folder and collecting images and labels
69
- for class_folder in os.listdir(dataset_path):
70
- class_folder_path = os.path.join(dataset_path, class_folder)
71
-
72
- # Ensure we are looking at a directory (class folder)
73
- if os.path.isdir(class_folder_path):
74
- class_label = class_folder # Use the folder name as the class label
75
-
76
- # Get only up to 'images_per_class' images from each class folder
77
- class_images = os.listdir(class_folder_path)[:images_per_class]
78
-
79
- for img_name in class_images:
80
- img_path = os.path.join(class_folder_path, img_name)
81
- image_paths.append(img_path) # Store image path for later use
82
- image_labels.append(class_label) # Store the corresponding class label
83
-
84
- # Log how many images we collected
85
- logging.info(f"Collected {len(image_paths)} images across {len(set(image_labels))} classes.")
86
-
87
- # Process the images for clustering and Grad-CAM calculation in batches
88
- for batch_idx in range(0, len(image_paths), batch_size):
89
- batch_image_paths = image_paths[batch_idx: batch_idx + batch_size]
90
-
91
- with torch.no_grad(): # Disable gradient calculations
92
- for img_path in batch_image_paths:
93
- # Read and preprocess the image
94
- img = read_image(img_path)
95
-
96
- # Ensure the image has 3 channels (convert grayscale or 4-channel images to RGB)
97
- if img.shape[0] == 1: # If the image is grayscale (1 channel), repeat the single channel to make it RGB
98
- img = img.repeat(3, 1, 1) # Convert it to 3-channel by repeating the single channel
99
- elif img.shape[0] == 4: # If the image has 4 channels (e.g., RGBA), drop the alpha channel
100
- img = img[:3, :, :] # Keep only the first 3 channels (RGB)
101
-
102
- # Resize and normalize the image
103
- input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
104
-
105
- # Move the input tensor to the GPU
106
- input_tensor = input_tensor.to(device)
107
-
108
- # Feed the preprocessed image to the model
109
- out = model(input_tensor.unsqueeze(0))
110
-
111
- # Get the activations from layer4 (with 2,048 feature channels)
112
- layer4_activations = activation['layer4'].cpu().numpy()
113
-
114
- # For each image, store the activation values across all channels (2048 channels)
115
- all_activations.append(layer4_activations.squeeze())
116
-
117
- # Log progress
118
- logging.info(f"Processed batch {batch_idx // batch_size + 1}/{len(image_paths) // batch_size + 1}")
119
-
120
- # Convert the collected activations into a numpy array of shape (n_images, 2048, H * W)
121
- all_activations = np.array(all_activations)
122
-
123
- # Now we average the spatial dimensions (H*W) to get the activation vector for each channel
124
- # This gives us an array of shape (n_images, 2048), where each value is the averaged activation for that channel
125
- avg_activations_per_image = np.mean(all_activations, axis=(-2, -1)) # Average over spatial dimensions
126
-
127
- # Now we want to transpose the array to get activations for each channel across all images
128
- # Shape will be (2048, n_images), where each row is the activation of a channel across all images
129
- channel_activation_vectors = avg_activations_per_image.T
130
-
131
- # Perform KMeans clustering on the feature channels
132
- kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(channel_activation_vectors)
133
-
134
- # Get cluster assignments for feature channels
135
- channel_clusters = kmeans.labels_
136
-
137
- # Find prototypes for each cluster based on channel activation similarities
138
- prototypes = {}
139
-
140
- for cluster_id in range(n_clusters):
141
- cluster_indices = np.where(channel_clusters == cluster_id)[0] # Get the feature channels in this cluster
142
-
143
- if len(cluster_indices) == 0:
144
- continue # Skip empty clusters
145
-
146
- # Find the majority class for the images that activate the feature channels in this cluster
147
- cluster_activation_vectors = channel_activation_vectors[cluster_indices] # Activation vectors for this cluster
148
-
149
- # Use the majority class of the images for this cluster
150
- majority_class = Counter([image_labels[i] for i in range(len(image_paths))]).most_common(1)[0][0]
151
-
152
- # Filter the images by the majority class before selecting prototypes
153
- majority_class_indices = [i for i, label in enumerate(image_labels) if label == majority_class]
154
- filtered_cluster_activation_vectors = cluster_activation_vectors[:, majority_class_indices] # Filtered activations
155
-
156
- # Compute pairwise distances between the activation vectors of the feature channels
157
- distances = cdist(filtered_cluster_activation_vectors.T, filtered_cluster_activation_vectors.T, 'euclidean')
158
-
159
- # Sum the distances for each image (to find the closest/most representative sample)
160
- distance_sums = distances.sum(axis=1)
161
-
162
- # Get the indices of the top-5 closest images from the filtered list
163
- top_k_indices = np.argsort(distance_sums)[:top_k_prototypes]
164
-
165
- # Store the prototypes for this cluster (top-k most representative images of the majority class)
166
- prototypes[cluster_id] = [image_paths[majority_class_indices[i]] for i in top_k_indices]
167
-
168
- # Print the top 5 image paths for this cluster
169
- logging.info(f"Cluster {cluster_id} Prototypes: {prototypes[cluster_id]}")
170
-
171
- # Now download and save the images and their corresponding Grad-CAM heatmaps
172
- for idx, img_path in enumerate(prototypes[cluster_id]):
173
- # Read and preprocess the image
174
- img = read_image(img_path)
175
-
176
- # Ensure the image has 3 channels
177
- if img.shape[0] == 1:
178
- img = img.repeat(3, 1, 1)
179
- elif img.shape[0] == 4:
180
- img = img[:3, :, :]
181
-
182
- # Resize and normalize the image
183
- input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
184
-
185
- # Move the input tensor to the GPU
186
- input_tensor = input_tensor.to(device)
187
-
188
- # Feed the preprocessed image to the model
189
- out = model(input_tensor.unsqueeze(0))
190
-
191
- # Manually calculate Grad-CAM by averaging only the channels in the cluster
192
- cam_activations = activation['layer4'].squeeze().cpu().numpy()
193
- cluster_activations = cam_activations[cluster_indices]
194
- averaged_cluster_activation = np.mean(cluster_activations, axis=0)
195
-
196
- # Normalize the activation map
197
- averaged_cluster_activation = (averaged_cluster_activation - averaged_cluster_activation.min()) / (
198
- averaged_cluster_activation.max() - averaged_cluster_activation.min())
199
-
200
- # Overlay the CAM on the original image
201
- overlayed_img = overlay_mask(to_pil_image(img.cpu()), to_pil_image(averaged_cluster_activation, mode='F'),
202
- alpha=0.5)
203
-
204
- # Save the original image and the heatmap overlay
205
- img_name = f"cluster_{cluster_id}_prototype_{idx + 1}.png"
206
- heatmap_name = f"cluster_{cluster_id}_prototype_{idx + 1}_heatmap.png"
207
-
208
- # Save original image
209
- Image.fromarray(img.permute(1, 2, 0).cpu().numpy().astype(np.uint8)).save(os.path.join(output_folder, img_name))
210
-
211
- # Save the heatmap overlay
212
- overlayed_img.save(os.path.join(output_folder, heatmap_name))
213
-
214
- # Done
215
- logging.info("Saved all representative images and their corresponding heatmaps.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ import torchvision
4
+ import pandas as pd
 
 
 
5
  import os
 
 
6
  from PIL import Image
7
+ from utils.experiment_utils import get_model
8
+
9
+ # File to store the visitor count
10
+ visitor_count_file = "visitor_count.txt"
11
+
12
+ # Function to update visitor count
13
+ def update_visitor_count():
14
+ if os.path.exists(visitor_count_file):
15
+ with open(visitor_count_file, "r") as file:
16
+ count = int(file.read())
17
+ else:
18
+ count = 0 # Start from zero if no file exists
19
+
20
+ # Increment visitor count
21
+ count += 1
22
+
23
+ # Save the updated count back to the file
24
+ with open(visitor_count_file, "w") as file:
25
+ file.write(str(count))
26
+
27
+ return count
28
+
29
+ # Custom flagging logic to save flagged data to a CSV file
30
+ class CustomFlagging(gr.FlaggingCallback):
31
+ def __init__(self, dir_name="flagged_data"):
32
+ self.dir = dir_name
33
+ self.image_dir = os.path.join(self.dir, "uploaded_images")
34
+ if not os.path.exists(self.dir):
35
+ os.makedirs(self.dir)
36
+ if not os.path.exists(self.image_dir):
37
+ os.makedirs(self.image_dir)
38
+
39
+ # Define setup as a no-op to fulfill abstract class requirement
40
+ def setup(self, *args, **kwargs):
41
+ pass
42
+
43
+ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
44
+ # Extract data
45
+ classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data
46
+
47
+ # Save the uploaded image in the "uploaded_images" folder
48
+ image_filename = os.path.join(self.image_dir,
49
+ f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
50
+ image.save(image_filename) # Save image in PNG format
51
+
52
+ # Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
53
+ data = {
54
+ "Classification Mode": classification_mode,
55
+ "Image Path": image_filename, # Save path to image in CSV
56
+ "Sensing Modality": sensing_modality,
57
+ "Predicted Class": predicted_class,
58
+ "Correct Class": correct_class,
59
+ }
60
+
61
+ df = pd.DataFrame([data])
62
+ csv_file = os.path.join(self.dir, "flagged_data.csv")
63
+
64
+ # Append to CSV, or create if it doesn't exist
65
+ if os.path.exists(csv_file):
66
+ df.to_csv(csv_file, mode='a', header=False, index=False)
67
+ else:
68
+ df.to_csv(csv_file, mode='w', header=True, index=False)
69
+
70
+
71
+ # Function to load the appropriate model based on the user's selection
72
+ def load_model(modality, mode):
73
+ # For Few-Shot classification, always use the DINOv2 model
74
+ if mode == "Few-Shot":
75
+ class Args:
76
+ model = 'DINOv2'
77
+ pretrained = 'pretrained'
78
+ frozen = 'unfrozen'
79
+
80
+ args = Args()
81
+ model = get_model(args) # Load DINOv2 model for Few-Shot classification
82
+ else:
83
+ # For Fully-Supervised classification, choose model based on the sensing modality
84
+ if modality == "Texture":
85
+ class Args:
86
+ model = 'DINOv2'
87
+ pretrained = 'pretrained'
88
+ frozen = 'unfrozen'
89
+
90
+ args = Args()
91
+ model = get_model(args) # Load DINOv2 model for Texture modality
92
+ elif modality == "Heightmap":
93
+ class Args:
94
+ model = 'ResNet152'
95
+ pretrained = 'pretrained'
96
+ frozen = 'unfrozen'
97
+
98
+ args = Args()
99
+ model = get_model(args) # Load ResNet152 model for Heightmap modality
100
+ else:
101
+ raise ValueError("Invalid modality selected!")
102
+
103
+ model.eval() # Set the model to evaluation mode
104
+ return model
105
+
106
+
107
+ # Prediction function that processes the image and returns the prediction results
108
+ def predict(image, modality, mode):
109
+ # Load the appropriate model based on the user's selections
110
+ model = load_model(modality, mode)
111
+
112
+ # Print the selected mode and modality for debugging purposes
113
+ print(f"User selected Mode: {mode}, Modality: {modality}")
114
+
115
+ # Preprocess the image
116
+ transform = torchvision.transforms.Compose([
117
+ torchvision.transforms.Resize((224, 224)),
118
+ torchvision.transforms.ToTensor(),
119
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
120
+ ])
121
+
122
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
123
+ with torch.no_grad():
124
+ output = model(image_tensor) # Get model predictions
125
+ probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
126
+
127
+ # Class names for the predictions
128
+ class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
129
+
130
+ # Pair class names with their corresponding probabilities
131
+ predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class
132
+ results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
133
+
134
+ return predicted_class, results # Return predicted class and probabilities
135
+
136
+
137
+ # Create the Gradio interface using gr.Blocks
138
+ def create_interface():
139
+ with gr.Blocks() as interface:
140
+ # Title at the top of the interface (centered and larger)
141
+ gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")
142
+
143
+ # Add description for the interface
144
+ description = """
145
+ ### Image Classification Options
146
+ - **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
147
+ - **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
148
+ ### **Don't forget to choose the Sensing Modality based on your uploaded images.**
149
+ ### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
150
+ """
151
+ gr.Markdown(description)
152
+
153
+ # Top-level selector for Fully-Supervised vs. Few-Shot classification
154
+ mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
155
+ value="Fully Supervised")
156
+
157
+ # Sensing modality selector
158
+ modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")
159
+
160
+ # Image upload input
161
+ image_input = gr.Image(type="pil", label="Image")
162
+
163
+ # Predicted classification output and class probabilities
164
+ with gr.Row():
165
+ predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
166
+ probabilities_output = gr.Label(label="Prediction Probabilities")
167
+
168
+ # Add the "Run Prediction" button under the Prediction Probabilities
169
+ predict_button = gr.Button("Run Prediction")
170
+
171
+ # Dropdown for user to select the correct class if the model prediction is wrong
172
+ correct_class_selector = gr.Radio(
173
+ choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
174
+ label="Select Correct Class"
175
+ )
176
+
177
+ # Text box for user to type the correct class if "Other" is selected
178
+ other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)
179
+
180
+ # Logic to dynamically update visibility of the "Other" class text box
181
+ def update_visibility(selected_class):
182
+ return gr.update(visible=selected_class == "Other")
183
+
184
+ correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)
185
+
186
+
187
+ # Create a flagging instance
188
+ flagging_instance = CustomFlagging(dir_name="flagged_data")
189
+
190
+ # Define function for the confirmation pop-up
191
+ def confirm_flag_selection(correct_class, other_class):
192
+ # Generate confirmation message
193
+ if correct_class == "Other":
194
+ message = f"Are you sure the class you selected is '{other_class}' for this picture?"
195
+ else:
196
+ message = f"Are you sure the class you selected is '{correct_class}' for this picture?"
197
+
198
+ return message, gr.update(visible=True), gr.update(visible=True)
199
+
200
+ # Final flag submission function
201
+ def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
202
+ if confirmed == "Yes":
203
+ # Save the flagged data
204
+ correct_class_final = correct_class if correct_class != "Other" else other_class
205
+ flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
206
+ return "Flagged successfully!"
207
+ else:
208
+ return "No flag submitted, please select again."
209
+
210
+ # Flagging button
211
+ flag_button = gr.Button("Flag")
212
+
213
+ # Confirmation box for user input and confirmation flag
214
+ confirmation_text = gr.Textbox(visible=False)
215
+ yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
216
+ confirmation_button = gr.Button("Confirm Flag", visible=False)
217
+
218
+ # Prediction action
219
+ predict_button.click(
220
+ fn=predict,
221
+ inputs=[image_input, modality_selector, mode_selector],
222
+ outputs=[predicted_output, probabilities_output]
223
+ )
224
+
225
+ # Flagging action with confirmation
226
+ flag_button.click(
227
+ fn=confirm_flag_selection,
228
+ inputs=[correct_class_selector, other_class_input],
229
+ outputs=[confirmation_text, yes_no_choice, confirmation_button]
230
+ )
231
+
232
+ # Final flag submission after confirmation
233
+ confirmation_button.click(
234
+ fn=flag_data_save,
235
+ inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
236
+ predicted_output, yes_no_choice],
237
+ outputs=gr.Textbox(label="Flagging Status")
238
+ )
239
+
240
+ # Visitor count displayed at the bottom
241
+ visitor_count = update_visitor_count() # Update the visitor count
242
+ gr.Markdown(f"### Number of Visitors: {visitor_count}") # Display visitor count
243
+
244
+ return interface
245
+
246
+
247
+ if __name__ == "__main__":
248
+ interface = create_interface()
249
+ interface.launch(share=True)