|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision import transforms |
|
from dataset_creation import normal_transforms |
|
from model import MakiAlexNet |
|
import numpy as np |
|
import cv2, torch, os |
|
from tqdm import tqdm |
|
import time |
|
|
|
TEST_IMAGE = "dataset/root/train/left1_frame_0.jpg" |
|
MODEL_PARAMS = "alexnet_cognitive.pth" |
|
all_processing_files = os.listdir(os.path.join(os.getcwd(), "./dataset/root/train")) |
|
|
|
model = MakiAlexNet() |
|
|
|
model.load_state_dict(torch.load(MODEL_PARAMS)) |
|
model.eval() |
|
print("Model armed and ready for evaluation.") |
|
|
|
|
|
print("Model's state_dict:") |
|
for param_tensor in model.state_dict(): |
|
print(param_tensor, "\t", model.state_dict()[param_tensor].size()) |
|
|
|
|
|
|
|
|
|
for image_file in tqdm(all_processing_files): |
|
|
|
|
|
abs_file_path = os.path.join(os.getcwd(), "./dataset/root/train", image_file) |
|
image = cv2.imread(abs_file_path) |
|
|
|
|
|
|
|
|
|
|
|
print("Image input shape of the matrix before: ", image.shape) |
|
image = torch.unsqueeze(torch.tensor(image.astype(np.float32)), 0) |
|
image = torch.einsum("BWHC->BCWH", image) |
|
print("Image input shape of the matrix after: ", image.shape) |
|
conv1_output = model.conv1(image) |
|
print("Output shape of the matrix: ", conv1_output.shape) |
|
|
|
|
|
|
|
|
|
conv1_formatted = torch.einsum("BCWH->WHC", conv1_output) |
|
print(f"Formatted shape of matrix is: {conv1_formatted.shape}") |
|
|
|
|
|
|
|
num_channels = conv1_formatted.shape[2] |
|
max_rows = 5 |
|
rows = min(max_rows, int(np.sqrt(num_channels))) |
|
cols = int(np.ceil(num_channels / rows)) |
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=(12, 12)) |
|
|
|
DATASET_OUTPUT_PATH = "./dataset/visualisation" |
|
merged_frames = np.zeros((224,224)) |
|
image_file_dir = abs_file_path.split(".jpg")[0].split("/")[-1] |
|
if not os.path.isdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)): |
|
os.mkdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)) |
|
|
|
|
|
for i in range(rows): |
|
for j in range(cols): |
|
channel_idx = i * cols + j |
|
if channel_idx < num_channels: |
|
channel_data = conv1_formatted[:, :, channel_idx] |
|
channel_data = channel_data.detach().numpy() |
|
print(f"Channel Data shape dimension: {channel_data.shape}") |
|
|
|
|
|
channel_data = cv2.resize(channel_data, (224, 224)) |
|
|
|
|
|
|
|
|
|
merged_frames += channel_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
merged_frames /= (np.max(merged_frames) * .8) |
|
|
|
|
|
|
|
merged_frames_gray = merged_frames.astype(np.uint8) |
|
|
|
|
|
|
|
image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_mask.jpg") |
|
|
|
plt.imsave(image_path, merged_frames_gray, cmap='gray') |
|
|
|
|
|
heatmap_color = cv2.applyColorMap(merged_frames_gray, cv2.COLORMAP_JET) |
|
|
|
|
|
image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_heatmap.jpg") |
|
plt.imsave(image_path, heatmap_color) |
|
|
|
|
|
|
|
plt.close() |
|
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|