DanielXu0208 commited on
Commit
9e56473
1 Parent(s): ca68817

Update run_gradio.py

Browse files
Files changed (1) hide show
  1. run_gradio.py +215 -233
run_gradio.py CHANGED
@@ -1,233 +1,215 @@
1
- import gradio as gr
2
- import torch
3
- import torchvision
4
- import pandas as pd
5
- import os
6
- from PIL import Image
7
- from utils.experiment_utils import get_model
8
-
9
-
10
- # Custom flagging logic to save flagged data to a CSV file
11
- class CustomFlagging(gr.FlaggingCallback):
12
- def __init__(self, dir_name="flagged_data"):
13
- self.dir = dir_name
14
- self.image_dir = os.path.join(self.dir, "uploaded_images")
15
- if not os.path.exists(self.dir):
16
- os.makedirs(self.dir)
17
- if not os.path.exists(self.image_dir):
18
- os.makedirs(self.image_dir)
19
-
20
- # Define setup as a no-op to fulfill abstract class requirement
21
- def setup(self, *args, **kwargs):
22
- pass
23
-
24
- def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
25
- # Extract data
26
- classification_mode, image, sensing_modality, predicted_class, correct_class = flag_data
27
-
28
- # Save the uploaded image in the "uploaded_images" folder
29
- image_filename = os.path.join(self.image_dir,
30
- f"flagged_image_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.png")
31
- image.save(image_filename) # Save image in PNG format
32
-
33
- # Columns: Classification, Image Path, Sensing Modality, Predicted Class, Correct Class
34
- data = {
35
- "Classification Mode": classification_mode,
36
- "Image Path": image_filename, # Save path to image in CSV
37
- "Sensing Modality": sensing_modality,
38
- "Predicted Class": predicted_class,
39
- "Correct Class": correct_class,
40
- }
41
-
42
- df = pd.DataFrame([data])
43
- csv_file = os.path.join(self.dir, "flagged_data.csv")
44
-
45
- # Append to CSV, or create if it doesn't exist
46
- if os.path.exists(csv_file):
47
- df.to_csv(csv_file, mode='a', header=False, index=False)
48
- else:
49
- df.to_csv(csv_file, mode='w', header=True, index=False)
50
-
51
-
52
- # Function to load the appropriate model based on the user's selection
53
- def load_model(modality, mode):
54
- # For Few-Shot classification, always use the DINOv2 model
55
- if mode == "Few-Shot":
56
- class Args:
57
- model = 'DINOv2'
58
- pretrained = 'pretrained'
59
- frozen = 'unfrozen'
60
-
61
- args = Args()
62
- model = get_model(args) # Load DINOv2 model for Few-Shot classification
63
- else:
64
- # For Fully-Supervised classification, choose model based on the sensing modality
65
- if modality == "Texture":
66
- class Args:
67
- model = 'DINOv2'
68
- pretrained = 'pretrained'
69
- frozen = 'unfrozen'
70
-
71
- args = Args()
72
- model = get_model(args) # Load DINOv2 model for Texture modality
73
- elif modality == "Heightmap":
74
- class Args:
75
- model = 'ResNet152'
76
- pretrained = 'pretrained'
77
- frozen = 'unfrozen'
78
-
79
- args = Args()
80
- model = get_model(args) # Load ResNet152 model for Heightmap modality
81
- else:
82
- raise ValueError("Invalid modality selected!")
83
-
84
- model.eval() # Set the model to evaluation mode
85
- return model
86
-
87
-
88
- # Prediction function that processes the image and returns the prediction results
89
- def predict(image, modality, mode):
90
- # Load the appropriate model based on the user's selections
91
- model = load_model(modality, mode)
92
-
93
- # Print the selected mode and modality for debugging purposes
94
- print(f"User selected Mode: {mode}, Modality: {modality}")
95
-
96
- # Preprocess the image
97
- transform = torchvision.transforms.Compose([
98
- torchvision.transforms.Resize((224, 224)),
99
- torchvision.transforms.ToTensor(),
100
- torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
101
- ])
102
-
103
- image_tensor = transform(image).unsqueeze(0) # Add batch dimension
104
- with torch.no_grad():
105
- output = model(image_tensor) # Get model predictions
106
- probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
107
-
108
- # Class names for the predictions
109
- class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
110
-
111
- # Pair class names with their corresponding probabilities
112
- predicted_class = class_names[probabilities.index(max(probabilities))] # Get the predicted class
113
- results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
114
-
115
- return predicted_class, results # Return predicted class and probabilities
116
-
117
-
118
- # Create the Gradio interface using gr.Blocks
119
- def create_interface():
120
- with gr.Blocks() as interface:
121
- # Title at the top of the interface (centered and larger)
122
- gr.Markdown("<h1 style='text-align: center; font-size: 36px;'>LUWA Dataset Image Classification</h1>")
123
-
124
- # Add description for the interface
125
- description = """
126
- ### Image Classification Options
127
- - **Fully-Supervised Classification**: Choose this for common or well-known materials with plenty of data (e.g., bone, wood).
128
- - **Few-Shot Classification**: Choose this for rare or newly discovered materials where only a few examples exist.
129
- ### **Don't forget to choose the Sensing Modality based on your uploaded images.**
130
- ### **Please help us to flag the correct class for your uploaded image if you know it, it will help us to further develop our dataset. If you cannot find the correct class in the option, please click on the option 'Other' and type the correct class for us!**
131
- """
132
- gr.Markdown(description)
133
-
134
- # Top-level selector for Fully-Supervised vs. Few-Shot classification
135
- mode_selector = gr.Radio(choices=["Fully Supervised", "Few-Shot"], label="Classification Mode",
136
- value="Fully Supervised")
137
-
138
- # Sensing modality selector
139
- modality_selector = gr.Radio(choices=["Texture", "Heightmap"], label="Sensing Modality", value="Texture")
140
-
141
- # Image upload input
142
- image_input = gr.Image(type="pil", label="Image")
143
-
144
- # Predicted classification output and class probabilities
145
- with gr.Row():
146
- predicted_output = gr.Label(num_top_classes=1, label="Predicted Classification")
147
- probabilities_output = gr.Label(label="Prediction Probabilities")
148
-
149
- # Add the "Run Prediction" button under the Prediction Probabilities
150
- predict_button = gr.Button("Run Prediction")
151
-
152
- # Dropdown for user to select the correct class if the model prediction is wrong
153
- correct_class_selector = gr.Radio(
154
- choices=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD", "Other"],
155
- label="Select Correct Class"
156
- )
157
-
158
- # Text box for user to type the correct class if "Other" is selected
159
- other_class_input = gr.Textbox(label="If Other, enter the correct class", visible=False)
160
-
161
- # Logic to dynamically update visibility of the "Other" class text box
162
- def update_visibility(selected_class):
163
- return gr.update(visible=selected_class == "Other")
164
-
165
- correct_class_selector.change(fn=update_visibility, inputs=correct_class_selector, outputs=other_class_input)
166
-
167
-
168
- # Create a flagging instance
169
- flagging_instance = CustomFlagging(dir_name="flagged_data")
170
-
171
- # Define function for the confirmation pop-up
172
- def confirm_flag_selection(correct_class, other_class):
173
- # Generate confirmation message
174
- if correct_class == "Other":
175
- message = f"Are you sure the class you selected is '{other_class}' for this picture?"
176
- else:
177
- message = f"Are you sure the class you selected is '{correct_class}' for this picture?"
178
-
179
- return message, gr.update(visible=True), gr.update(visible=True)
180
-
181
- # Final flag submission function
182
- def flag_data_save(correct_class, other_class, mode, image, modality, predicted_class, confirmed):
183
- if confirmed == "Yes":
184
- # Save the flagged data
185
- correct_class_final = correct_class if correct_class != "Other" else other_class
186
- flagging_instance.flag([mode, image, modality, predicted_class, correct_class_final])
187
- return "Flagged successfully!"
188
- else:
189
- return "No flag submitted, please select again."
190
-
191
- # Flagging button
192
- flag_button = gr.Button("Flag")
193
-
194
- # Confirmation box for user input and confirmation flag
195
- confirmation_text = gr.Textbox(visible=False)
196
- yes_no_choice = gr.Radio(choices=["Yes", "No"], label="Are you sure?", visible=False)
197
- confirmation_button = gr.Button("Confirm Flag", visible=False)
198
-
199
- # Prediction action
200
- predict_button.click(
201
- fn=predict,
202
- inputs=[image_input, modality_selector, mode_selector],
203
- outputs=[predicted_output, probabilities_output]
204
- )
205
-
206
- # Flagging action with confirmation
207
- flag_button.click(
208
- fn=confirm_flag_selection,
209
- inputs=[correct_class_selector, other_class_input],
210
- outputs=[confirmation_text, yes_no_choice, confirmation_button]
211
- )
212
-
213
- # Final flag submission after confirmation
214
- confirmation_button.click(
215
- fn=flag_data_save,
216
- inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
217
- predicted_output, yes_no_choice],
218
- outputs=gr.Textbox(label="Flagging Status")
219
- )
220
-
221
- return interface
222
-
223
-
224
- if __name__ == "__main__":
225
- interface = create_interface()
226
- interface.launch(share=True)
227
-
228
-
229
-
230
-
231
-
232
-
233
-
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ from torchvision.transforms.functional import resize, normalize, to_pil_image
4
+ from torchvision.io.image import read_image
5
+ from torchvision.models import resnet50
6
+ from sklearn.cluster import KMeans
7
+ import numpy as np
8
+ import os
9
+ import logging
10
+ from torchcam.utils import overlay_mask
11
+ from PIL import Image
12
+ from collections import Counter
13
+ from scipy.spatial.distance import cdist
14
+
15
+ # Initialize logger to monitor progress
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+ # Path to dataset and model
19
+ dataset_path = "archive"
20
+ model_path = "resnet50_finetuned_miniimagenet.pth" # Update to your fine-tuned MiniImagenet weights
21
+ n_clusters = 100 # Number of clusters for feature channels (can be adjusted)
22
+ top_k_prototypes = 5 # Top k most similar examples to select as prototypes for each cluster
23
+ batch_size = 8 # Reduce batch size to limit memory usage
24
+ output_folder = "examples" # Folder to save the images and heatmaps
25
+
26
+ # Limit to 100 images per class
27
+ images_per_class = 100
28
+
29
+ # Create the output folder if it doesn't exist
30
+ if not os.path.exists(output_folder):
31
+ os.makedirs(output_folder)
32
+
33
+ # Set device to GPU if available, otherwise CPU
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # Load the fine-tuned model weights from MiniImagenet
37
+ model = resnet50(pretrained=False)
38
+
39
+ # Modify the fully connected layer to match the number of classes in MiniImagenet (100 classes)
40
+ model.fc = torch.nn.Linear(model.fc.in_features, 100)
41
+
42
+ # Load the fine-tuned state_dict
43
+ checkpoint = torch.load(model_path, weights_only=True) # Updated to MiniImagenet weights
44
+ model.load_state_dict(checkpoint)
45
+ model.eval()
46
+
47
+ # Move the model to GPU
48
+ model.to(device)
49
+
50
+ # Hook to capture activations from 'layer4'
51
+ activation = {}
52
+
53
+
54
+ def get_activation(name):
55
+ def hook(model, input, output):
56
+ activation[name] = output.detach()
57
+
58
+ return hook
59
+
60
+
61
+ model.layer4.register_forward_hook(get_activation('layer4'))
62
+
63
+ # Collecting activations of all feature channels across multiple images
64
+ all_activations = []
65
+ image_paths = [] # To keep track of the image paths
66
+ image_labels = [] # To store the class labels of each image
67
+
68
+ # Traverse through the dataset, accessing each class folder and collecting images and labels
69
+ for class_folder in os.listdir(dataset_path):
70
+ class_folder_path = os.path.join(dataset_path, class_folder)
71
+
72
+ # Ensure we are looking at a directory (class folder)
73
+ if os.path.isdir(class_folder_path):
74
+ class_label = class_folder # Use the folder name as the class label
75
+
76
+ # Get only up to 'images_per_class' images from each class folder
77
+ class_images = os.listdir(class_folder_path)[:images_per_class]
78
+
79
+ for img_name in class_images:
80
+ img_path = os.path.join(class_folder_path, img_name)
81
+ image_paths.append(img_path) # Store image path for later use
82
+ image_labels.append(class_label) # Store the corresponding class label
83
+
84
+ # Log how many images we collected
85
+ logging.info(f"Collected {len(image_paths)} images across {len(set(image_labels))} classes.")
86
+
87
+ # Process the images for clustering and Grad-CAM calculation in batches
88
+ for batch_idx in range(0, len(image_paths), batch_size):
89
+ batch_image_paths = image_paths[batch_idx: batch_idx + batch_size]
90
+
91
+ with torch.no_grad(): # Disable gradient calculations
92
+ for img_path in batch_image_paths:
93
+ # Read and preprocess the image
94
+ img = read_image(img_path)
95
+
96
+ # Ensure the image has 3 channels (convert grayscale or 4-channel images to RGB)
97
+ if img.shape[0] == 1: # If the image is grayscale (1 channel), repeat the single channel to make it RGB
98
+ img = img.repeat(3, 1, 1) # Convert it to 3-channel by repeating the single channel
99
+ elif img.shape[0] == 4: # If the image has 4 channels (e.g., RGBA), drop the alpha channel
100
+ img = img[:3, :, :] # Keep only the first 3 channels (RGB)
101
+
102
+ # Resize and normalize the image
103
+ input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
104
+
105
+ # Move the input tensor to the GPU
106
+ input_tensor = input_tensor.to(device)
107
+
108
+ # Feed the preprocessed image to the model
109
+ out = model(input_tensor.unsqueeze(0))
110
+
111
+ # Get the activations from layer4 (with 2,048 feature channels)
112
+ layer4_activations = activation['layer4'].cpu().numpy()
113
+
114
+ # For each image, store the activation values across all channels (2048 channels)
115
+ all_activations.append(layer4_activations.squeeze())
116
+
117
+ # Log progress
118
+ logging.info(f"Processed batch {batch_idx // batch_size + 1}/{len(image_paths) // batch_size + 1}")
119
+
120
+ # Convert the collected activations into a numpy array of shape (n_images, 2048, H * W)
121
+ all_activations = np.array(all_activations)
122
+
123
+ # Now we average the spatial dimensions (H*W) to get the activation vector for each channel
124
+ # This gives us an array of shape (n_images, 2048), where each value is the averaged activation for that channel
125
+ avg_activations_per_image = np.mean(all_activations, axis=(-2, -1)) # Average over spatial dimensions
126
+
127
+ # Now we want to transpose the array to get activations for each channel across all images
128
+ # Shape will be (2048, n_images), where each row is the activation of a channel across all images
129
+ channel_activation_vectors = avg_activations_per_image.T
130
+
131
+ # Perform KMeans clustering on the feature channels
132
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(channel_activation_vectors)
133
+
134
+ # Get cluster assignments for feature channels
135
+ channel_clusters = kmeans.labels_
136
+
137
+ # Find prototypes for each cluster based on channel activation similarities
138
+ prototypes = {}
139
+
140
+ for cluster_id in range(n_clusters):
141
+ cluster_indices = np.where(channel_clusters == cluster_id)[0] # Get the feature channels in this cluster
142
+
143
+ if len(cluster_indices) == 0:
144
+ continue # Skip empty clusters
145
+
146
+ # Find the majority class for the images that activate the feature channels in this cluster
147
+ cluster_activation_vectors = channel_activation_vectors[cluster_indices] # Activation vectors for this cluster
148
+
149
+ # Use the majority class of the images for this cluster
150
+ majority_class = Counter([image_labels[i] for i in range(len(image_paths))]).most_common(1)[0][0]
151
+
152
+ # Filter the images by the majority class before selecting prototypes
153
+ majority_class_indices = [i for i, label in enumerate(image_labels) if label == majority_class]
154
+ filtered_cluster_activation_vectors = cluster_activation_vectors[:, majority_class_indices] # Filtered activations
155
+
156
+ # Compute pairwise distances between the activation vectors of the feature channels
157
+ distances = cdist(filtered_cluster_activation_vectors.T, filtered_cluster_activation_vectors.T, 'euclidean')
158
+
159
+ # Sum the distances for each image (to find the closest/most representative sample)
160
+ distance_sums = distances.sum(axis=1)
161
+
162
+ # Get the indices of the top-5 closest images from the filtered list
163
+ top_k_indices = np.argsort(distance_sums)[:top_k_prototypes]
164
+
165
+ # Store the prototypes for this cluster (top-k most representative images of the majority class)
166
+ prototypes[cluster_id] = [image_paths[majority_class_indices[i]] for i in top_k_indices]
167
+
168
+ # Print the top 5 image paths for this cluster
169
+ logging.info(f"Cluster {cluster_id} Prototypes: {prototypes[cluster_id]}")
170
+
171
+ # Now download and save the images and their corresponding Grad-CAM heatmaps
172
+ for idx, img_path in enumerate(prototypes[cluster_id]):
173
+ # Read and preprocess the image
174
+ img = read_image(img_path)
175
+
176
+ # Ensure the image has 3 channels
177
+ if img.shape[0] == 1:
178
+ img = img.repeat(3, 1, 1)
179
+ elif img.shape[0] == 4:
180
+ img = img[:3, :, :]
181
+
182
+ # Resize and normalize the image
183
+ input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
184
+
185
+ # Move the input tensor to the GPU
186
+ input_tensor = input_tensor.to(device)
187
+
188
+ # Feed the preprocessed image to the model
189
+ out = model(input_tensor.unsqueeze(0))
190
+
191
+ # Manually calculate Grad-CAM by averaging only the channels in the cluster
192
+ cam_activations = activation['layer4'].squeeze().cpu().numpy()
193
+ cluster_activations = cam_activations[cluster_indices]
194
+ averaged_cluster_activation = np.mean(cluster_activations, axis=0)
195
+
196
+ # Normalize the activation map
197
+ averaged_cluster_activation = (averaged_cluster_activation - averaged_cluster_activation.min()) / (
198
+ averaged_cluster_activation.max() - averaged_cluster_activation.min())
199
+
200
+ # Overlay the CAM on the original image
201
+ overlayed_img = overlay_mask(to_pil_image(img.cpu()), to_pil_image(averaged_cluster_activation, mode='F'),
202
+ alpha=0.5)
203
+
204
+ # Save the original image and the heatmap overlay
205
+ img_name = f"cluster_{cluster_id}_prototype_{idx + 1}.png"
206
+ heatmap_name = f"cluster_{cluster_id}_prototype_{idx + 1}_heatmap.png"
207
+
208
+ # Save original image
209
+ Image.fromarray(img.permute(1, 2, 0).cpu().numpy().astype(np.uint8)).save(os.path.join(output_folder, img_name))
210
+
211
+ # Save the heatmap overlay
212
+ overlayed_img.save(os.path.join(output_folder, heatmap_name))
213
+
214
+ # Done
215
+ logging.info("Saved all representative images and their corresponding heatmaps.")