File size: 5,065 Bytes
4ec6f12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Based on the learnt CNN kernels, this script will aid in generating a learnt kernel pattern.
# Attempt 1, did not work well.
import matplotlib.pyplot as plt
# Here we should be able to determine what weighting each part of the image aids in the detection of the goal.
# And how these change over time.
# https://www.youtube.com/watch?v=ST9NjnKKvT8
# This video aims to solve this problem, by going over the heatmaps of CNNs.
from torchvision import transforms
from dataset_creation import normal_transforms
from model import MakiAlexNet
import numpy as np
import cv2, torch, os
from tqdm import tqdm
import time
TEST_IMAGE = "dataset/root/train/left1_frame_0.jpg"
MODEL_PARAMS = "alexnet_cognitive.pth"
all_processing_files = os.listdir(os.path.join(os.getcwd(), "./dataset/root/train"))
model = MakiAlexNet()
model.load_state_dict(torch.load(MODEL_PARAMS))
model.eval()
print("Model armed and ready for evaluation.")
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
for image_file in tqdm(all_processing_files):
# Showcase and load image from file.
abs_file_path = os.path.join(os.getcwd(), "./dataset/root/train", image_file)
image = cv2.imread(abs_file_path)
# print(image.shape)
# cv2.imshow("test", image)
# cv2.waitKey(5000)
print("Image input shape of the matrix before: ", image.shape)
image = torch.unsqueeze(torch.tensor(image.astype(np.float32)), 0) # Convert image to tensor with float32, and extended batch size dimension. (Batch, Channel, W,H)
image = torch.einsum("BWHC->BCWH", image)
print("Image input shape of the matrix after: ", image.shape)
conv1_output = model.conv1(image)
print("Output shape of the matrix: ", conv1_output.shape)
# Handling image convolutions
conv1_formatted = torch.einsum("BCWH->WHC", conv1_output)
print(f"Formatted shape of matrix is: {conv1_formatted.shape}")
# Assuming your 3D array is named 'data'
num_channels = conv1_formatted.shape[2] # Get the number of channels (96)
max_rows = 5 # Set a maximum number of rows (optional)
rows = min(max_rows, int(np.sqrt(num_channels))) # Limit rows to a maximum
cols = int(np.ceil(num_channels / rows))
fig, axes = plt.subplots(rows, cols, figsize=(12, 12)) # Create a grid of subplots
DATASET_OUTPUT_PATH = "./dataset/visualisation"
merged_frames = np.zeros((224,224))
image_file_dir = abs_file_path.split(".jpg")[0].split("/")[-1]
if not os.path.isdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)):
os.mkdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)) # make new directory.
for i in range(rows):
for j in range(cols):
channel_idx = i * cols + j # Calculate index based on row and column
if channel_idx < num_channels: # Check if within channel range
channel_data = conv1_formatted[:, :, channel_idx]
channel_data = channel_data.detach().numpy()
print(f"Channel Data shape dimension: {channel_data.shape}")
# channel_data = np.mean(channel_data, axis=2)
# Get the mean of each third dimension, so mean on channels, if H,W,C -> H,W
channel_data = cv2.resize(channel_data, (224, 224))
# Accumulate normalized channel data
# take threshold values of channel data to add to merged frames, if above a specific point.
# ret, channel_data = cv2.threshold(channel_data, 120, 255, cv2.THRESH_BINARY)
merged_frames += channel_data
# # Save the image data matrix.
# image_filename = f"{int(time.time())}_output_{channel_idx}.jpg"
# image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_filename)
# plt.imsave(image_path, channel_data)
# print(f"Image path saved at {image_path}")
# Ensure final merged_frames is also normalized
merged_frames /= (np.max(merged_frames) * .8)
# Thresholding the main images that causes this highlight.
merged_frames_gray = merged_frames.astype(np.uint8) # No conversion needed, use as-is
# 9merged_frames = cv2.adaptiveThreshold(merged_frames_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_mask.jpg")
plt.imsave(image_path, merged_frames_gray, cmap='gray')
# merged_frames = merged_frames.astype(np.uint8)
heatmap_color = cv2.applyColorMap(merged_frames_gray, cv2.COLORMAP_JET) # Apply a colormap
#
# cv2.imshow("merged", heatmap_color)
image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_heatmap.jpg")
plt.imsave(image_path, heatmap_color)
#
# # Merge all images into one, normalising based on highest value, and then increasing from 54,54, 1, to 224,224,1
# cv2.waitKey(5000)
plt.close()
exit()
#
# image_tensor = normal_transforms(torch.tensor(image))
# print(image_tensor.shape)
# plt.imshow(image_tensor.squeeze())
|