Spaces:
Sleeping
Sleeping
DanielXu0208
commited on
Commit
•
9e56473
1
Parent(s):
ca68817
Update run_gradio.py
Browse files- run_gradio.py +215 -233
run_gradio.py
CHANGED
@@ -1,233 +1,215 @@
|
|
1 |
-
import
|
2 |
-
import torch
|
3 |
-
import
|
4 |
-
|
5 |
-
import
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
#
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
inputs=[correct_class_selector, other_class_input, mode_selector, image_input, modality_selector,
|
217 |
-
predicted_output, yes_no_choice],
|
218 |
-
outputs=gr.Textbox(label="Flagging Status")
|
219 |
-
)
|
220 |
-
|
221 |
-
return interface
|
222 |
-
|
223 |
-
|
224 |
-
if __name__ == "__main__":
|
225 |
-
interface = create_interface()
|
226 |
-
interface.launch(share=True)
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import torch
|
3 |
+
from torchvision.transforms.functional import resize, normalize, to_pil_image
|
4 |
+
from torchvision.io.image import read_image
|
5 |
+
from torchvision.models import resnet50
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
from torchcam.utils import overlay_mask
|
11 |
+
from PIL import Image
|
12 |
+
from collections import Counter
|
13 |
+
from scipy.spatial.distance import cdist
|
14 |
+
|
15 |
+
# Initialize logger to monitor progress
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
|
18 |
+
# Path to dataset and model
|
19 |
+
dataset_path = "archive"
|
20 |
+
model_path = "resnet50_finetuned_miniimagenet.pth" # Update to your fine-tuned MiniImagenet weights
|
21 |
+
n_clusters = 100 # Number of clusters for feature channels (can be adjusted)
|
22 |
+
top_k_prototypes = 5 # Top k most similar examples to select as prototypes for each cluster
|
23 |
+
batch_size = 8 # Reduce batch size to limit memory usage
|
24 |
+
output_folder = "examples" # Folder to save the images and heatmaps
|
25 |
+
|
26 |
+
# Limit to 100 images per class
|
27 |
+
images_per_class = 100
|
28 |
+
|
29 |
+
# Create the output folder if it doesn't exist
|
30 |
+
if not os.path.exists(output_folder):
|
31 |
+
os.makedirs(output_folder)
|
32 |
+
|
33 |
+
# Set device to GPU if available, otherwise CPU
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
|
36 |
+
# Load the fine-tuned model weights from MiniImagenet
|
37 |
+
model = resnet50(pretrained=False)
|
38 |
+
|
39 |
+
# Modify the fully connected layer to match the number of classes in MiniImagenet (100 classes)
|
40 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 100)
|
41 |
+
|
42 |
+
# Load the fine-tuned state_dict
|
43 |
+
checkpoint = torch.load(model_path, weights_only=True) # Updated to MiniImagenet weights
|
44 |
+
model.load_state_dict(checkpoint)
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
# Move the model to GPU
|
48 |
+
model.to(device)
|
49 |
+
|
50 |
+
# Hook to capture activations from 'layer4'
|
51 |
+
activation = {}
|
52 |
+
|
53 |
+
|
54 |
+
def get_activation(name):
|
55 |
+
def hook(model, input, output):
|
56 |
+
activation[name] = output.detach()
|
57 |
+
|
58 |
+
return hook
|
59 |
+
|
60 |
+
|
61 |
+
model.layer4.register_forward_hook(get_activation('layer4'))
|
62 |
+
|
63 |
+
# Collecting activations of all feature channels across multiple images
|
64 |
+
all_activations = []
|
65 |
+
image_paths = [] # To keep track of the image paths
|
66 |
+
image_labels = [] # To store the class labels of each image
|
67 |
+
|
68 |
+
# Traverse through the dataset, accessing each class folder and collecting images and labels
|
69 |
+
for class_folder in os.listdir(dataset_path):
|
70 |
+
class_folder_path = os.path.join(dataset_path, class_folder)
|
71 |
+
|
72 |
+
# Ensure we are looking at a directory (class folder)
|
73 |
+
if os.path.isdir(class_folder_path):
|
74 |
+
class_label = class_folder # Use the folder name as the class label
|
75 |
+
|
76 |
+
# Get only up to 'images_per_class' images from each class folder
|
77 |
+
class_images = os.listdir(class_folder_path)[:images_per_class]
|
78 |
+
|
79 |
+
for img_name in class_images:
|
80 |
+
img_path = os.path.join(class_folder_path, img_name)
|
81 |
+
image_paths.append(img_path) # Store image path for later use
|
82 |
+
image_labels.append(class_label) # Store the corresponding class label
|
83 |
+
|
84 |
+
# Log how many images we collected
|
85 |
+
logging.info(f"Collected {len(image_paths)} images across {len(set(image_labels))} classes.")
|
86 |
+
|
87 |
+
# Process the images for clustering and Grad-CAM calculation in batches
|
88 |
+
for batch_idx in range(0, len(image_paths), batch_size):
|
89 |
+
batch_image_paths = image_paths[batch_idx: batch_idx + batch_size]
|
90 |
+
|
91 |
+
with torch.no_grad(): # Disable gradient calculations
|
92 |
+
for img_path in batch_image_paths:
|
93 |
+
# Read and preprocess the image
|
94 |
+
img = read_image(img_path)
|
95 |
+
|
96 |
+
# Ensure the image has 3 channels (convert grayscale or 4-channel images to RGB)
|
97 |
+
if img.shape[0] == 1: # If the image is grayscale (1 channel), repeat the single channel to make it RGB
|
98 |
+
img = img.repeat(3, 1, 1) # Convert it to 3-channel by repeating the single channel
|
99 |
+
elif img.shape[0] == 4: # If the image has 4 channels (e.g., RGBA), drop the alpha channel
|
100 |
+
img = img[:3, :, :] # Keep only the first 3 channels (RGB)
|
101 |
+
|
102 |
+
# Resize and normalize the image
|
103 |
+
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
104 |
+
|
105 |
+
# Move the input tensor to the GPU
|
106 |
+
input_tensor = input_tensor.to(device)
|
107 |
+
|
108 |
+
# Feed the preprocessed image to the model
|
109 |
+
out = model(input_tensor.unsqueeze(0))
|
110 |
+
|
111 |
+
# Get the activations from layer4 (with 2,048 feature channels)
|
112 |
+
layer4_activations = activation['layer4'].cpu().numpy()
|
113 |
+
|
114 |
+
# For each image, store the activation values across all channels (2048 channels)
|
115 |
+
all_activations.append(layer4_activations.squeeze())
|
116 |
+
|
117 |
+
# Log progress
|
118 |
+
logging.info(f"Processed batch {batch_idx // batch_size + 1}/{len(image_paths) // batch_size + 1}")
|
119 |
+
|
120 |
+
# Convert the collected activations into a numpy array of shape (n_images, 2048, H * W)
|
121 |
+
all_activations = np.array(all_activations)
|
122 |
+
|
123 |
+
# Now we average the spatial dimensions (H*W) to get the activation vector for each channel
|
124 |
+
# This gives us an array of shape (n_images, 2048), where each value is the averaged activation for that channel
|
125 |
+
avg_activations_per_image = np.mean(all_activations, axis=(-2, -1)) # Average over spatial dimensions
|
126 |
+
|
127 |
+
# Now we want to transpose the array to get activations for each channel across all images
|
128 |
+
# Shape will be (2048, n_images), where each row is the activation of a channel across all images
|
129 |
+
channel_activation_vectors = avg_activations_per_image.T
|
130 |
+
|
131 |
+
# Perform KMeans clustering on the feature channels
|
132 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(channel_activation_vectors)
|
133 |
+
|
134 |
+
# Get cluster assignments for feature channels
|
135 |
+
channel_clusters = kmeans.labels_
|
136 |
+
|
137 |
+
# Find prototypes for each cluster based on channel activation similarities
|
138 |
+
prototypes = {}
|
139 |
+
|
140 |
+
for cluster_id in range(n_clusters):
|
141 |
+
cluster_indices = np.where(channel_clusters == cluster_id)[0] # Get the feature channels in this cluster
|
142 |
+
|
143 |
+
if len(cluster_indices) == 0:
|
144 |
+
continue # Skip empty clusters
|
145 |
+
|
146 |
+
# Find the majority class for the images that activate the feature channels in this cluster
|
147 |
+
cluster_activation_vectors = channel_activation_vectors[cluster_indices] # Activation vectors for this cluster
|
148 |
+
|
149 |
+
# Use the majority class of the images for this cluster
|
150 |
+
majority_class = Counter([image_labels[i] for i in range(len(image_paths))]).most_common(1)[0][0]
|
151 |
+
|
152 |
+
# Filter the images by the majority class before selecting prototypes
|
153 |
+
majority_class_indices = [i for i, label in enumerate(image_labels) if label == majority_class]
|
154 |
+
filtered_cluster_activation_vectors = cluster_activation_vectors[:, majority_class_indices] # Filtered activations
|
155 |
+
|
156 |
+
# Compute pairwise distances between the activation vectors of the feature channels
|
157 |
+
distances = cdist(filtered_cluster_activation_vectors.T, filtered_cluster_activation_vectors.T, 'euclidean')
|
158 |
+
|
159 |
+
# Sum the distances for each image (to find the closest/most representative sample)
|
160 |
+
distance_sums = distances.sum(axis=1)
|
161 |
+
|
162 |
+
# Get the indices of the top-5 closest images from the filtered list
|
163 |
+
top_k_indices = np.argsort(distance_sums)[:top_k_prototypes]
|
164 |
+
|
165 |
+
# Store the prototypes for this cluster (top-k most representative images of the majority class)
|
166 |
+
prototypes[cluster_id] = [image_paths[majority_class_indices[i]] for i in top_k_indices]
|
167 |
+
|
168 |
+
# Print the top 5 image paths for this cluster
|
169 |
+
logging.info(f"Cluster {cluster_id} Prototypes: {prototypes[cluster_id]}")
|
170 |
+
|
171 |
+
# Now download and save the images and their corresponding Grad-CAM heatmaps
|
172 |
+
for idx, img_path in enumerate(prototypes[cluster_id]):
|
173 |
+
# Read and preprocess the image
|
174 |
+
img = read_image(img_path)
|
175 |
+
|
176 |
+
# Ensure the image has 3 channels
|
177 |
+
if img.shape[0] == 1:
|
178 |
+
img = img.repeat(3, 1, 1)
|
179 |
+
elif img.shape[0] == 4:
|
180 |
+
img = img[:3, :, :]
|
181 |
+
|
182 |
+
# Resize and normalize the image
|
183 |
+
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
184 |
+
|
185 |
+
# Move the input tensor to the GPU
|
186 |
+
input_tensor = input_tensor.to(device)
|
187 |
+
|
188 |
+
# Feed the preprocessed image to the model
|
189 |
+
out = model(input_tensor.unsqueeze(0))
|
190 |
+
|
191 |
+
# Manually calculate Grad-CAM by averaging only the channels in the cluster
|
192 |
+
cam_activations = activation['layer4'].squeeze().cpu().numpy()
|
193 |
+
cluster_activations = cam_activations[cluster_indices]
|
194 |
+
averaged_cluster_activation = np.mean(cluster_activations, axis=0)
|
195 |
+
|
196 |
+
# Normalize the activation map
|
197 |
+
averaged_cluster_activation = (averaged_cluster_activation - averaged_cluster_activation.min()) / (
|
198 |
+
averaged_cluster_activation.max() - averaged_cluster_activation.min())
|
199 |
+
|
200 |
+
# Overlay the CAM on the original image
|
201 |
+
overlayed_img = overlay_mask(to_pil_image(img.cpu()), to_pil_image(averaged_cluster_activation, mode='F'),
|
202 |
+
alpha=0.5)
|
203 |
+
|
204 |
+
# Save the original image and the heatmap overlay
|
205 |
+
img_name = f"cluster_{cluster_id}_prototype_{idx + 1}.png"
|
206 |
+
heatmap_name = f"cluster_{cluster_id}_prototype_{idx + 1}_heatmap.png"
|
207 |
+
|
208 |
+
# Save original image
|
209 |
+
Image.fromarray(img.permute(1, 2, 0).cpu().numpy().astype(np.uint8)).save(os.path.join(output_folder, img_name))
|
210 |
+
|
211 |
+
# Save the heatmap overlay
|
212 |
+
overlayed_img.save(os.path.join(output_folder, heatmap_name))
|
213 |
+
|
214 |
+
# Done
|
215 |
+
logging.info("Saved all representative images and their corresponding heatmaps.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|