# Michael Peres (c) 2024 # Inspiration from code tutorial mentioned here: https://tree.rocks/get-heatmap-from-cnn-convolution-neural-network-aka-grad-cam-222e08f57a34 import cv2, os, torch, re import matplotlib.pyplot as plt from scipy.ndimage import zoom import numpy as np from model_two import MakiAlexNet from tqdm import tqdm # from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions TOP_ACCURACY_PERCENTILE = 10 TEST_IMAGE = "dataset/root/train/left1_frame_10.jpg" MODEL_PARAMS = "alexnet_2.0.pth" GIF_STORE = "dataset/gifs2/" TRAIN_STORE = "dataset/root/train/" model = MakiAlexNet() model.load_state_dict(torch.load(MODEL_PARAMS)) model.eval() # Make model run on cuda if available. if torch.cuda.is_available(): model = model.cuda() print("Running on cuda") print(dir(model)) for name, module in model.named_modules(): # Print the layer name print(name) def extract_file_paths(filename): """With aid from https://regex101.com/, regex.""" extractor_reg = r"(left|right)([0-9]+)(_frame_)([0-9]+)" result = re.search(extractor_reg, filename) frame_no = result.group(4) frame_name = result.group(1) video_no = result.group(2) return frame_no, frame_name, video_no def create_mp4_from_frames(file_name, frames): """Generate MP4/GIF file with the collection of frames given with a duration of 2000 msec. """ print("Sorted frames: ", sorted(frames)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') height, width, _ = cv2.imread(frames[0]).shape fps = 20 # Adjust the frames per second (FPS) as needed video_path = os.path.join(os.getcwd(), "dataset", "gifs2", f"{file_name}.mp4") video = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) for frame_path in sorted(frames): # Convert BRG to RGB image = cv2.imread(frame_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # if image.dtype != np.uint8: # image = (image * 255).astype(np.uint8) # Convert to uint8 video.write(image) # Release the VideoWriter video.release() current_video_name = None selected_frames = [] # stores matrices for the GIF generation. for image_filename in tqdm(sorted(os.listdir(TRAIN_STORE)), desc="Running Images"): # : frame_no, frame_name, video_no = extract_file_paths(image_filename) obtained_video_name = video_no+"vid"+frame_name if current_video_name != obtained_video_name: # We have a new video sequence, so save current sequences and name if selected_frames: filename = f"{current_video_name}" # Create gif from the frames. if current_video_name: create_mp4_from_frames(filename, selected_frames) # Clear frames and hand off to new handle. selected_frames = [] current_video_name = obtained_video_name # With the number and name of the file paths, we can then determine which should be part of the specific GIF file. # f"frame_no,fileno,video_no.gif" img = cv2.imread(os.path.join(TRAIN_STORE, image_filename)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.unsqueeze(torch.tensor(img.astype(np.float32)), 0) # Convert image to tensor with float32, and extended batch size dimension. (Batch, Channel, W,H) X = torch.einsum("BWHC->BCWH", img) if torch.cuda.is_available(): X = X.cuda() output = model(X) # print(output) #print("Model layer outputs: ") #print(model.layer_outputs) conv = model.layer_outputs['Conv2d'] pred = model.layer_outputs["Linear"] pred_weights, pred_bias = model.f_linear.weight, model.f_linear.bias #print(pred_weights.shape) conv = torch.einsum("BCWH->BWHC", conv).cpu().detach().numpy() # print(conv.shape) # torch.Size([1, 256, 12, 12]) # conv = conv.squeeze(0) # print(conv.shape) # torch.Size([256, 12, 12]) target = np.argmax(pred.cpu().detach().numpy(), axis=1).squeeze() weights = pred_weights[target, :].cpu().detach().numpy() # print("wieghts", weights.shape, "conv", conv.squeeze(0).shape) heatmap = conv.squeeze(0) @ weights # print(conv.shape) # print(heatmap.shape) scale = 224 / 12 # 256x5x5 after this additional. plt.figure(figsize=(12, 12)) img = cv2.imread(os.path.join(TRAIN_STORE, image_filename)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) plt.imshow(img) plt.imshow(zoom(heatmap, zoom=(scale, scale)), cmap='jet', alpha=0.5) # if frameno is just 0-9, then add a 0 to the front. if len(frame_no) == 1: frame_no = "0"+frame_no filename = video_no+frame_name+frame_no+".jpg" file_path = os.path.join(os.getcwd(), "dataset/gifs2/raw/", filename) plt.savefig(file_path) selected_frames.append(file_path) plt.close() # wait for user to press a key # mat = zoom(conv[0, :, :, i], zoom=(scale, scale)) # threshold = np.percentile(mat.flatten(), TOP_ACCURACY_PERCENTILE) # # The Lower threshold is to zero, the more specific the look is shown. # # mask = mat > threshold # # OR: filter_map = np.where(filter_map <= threshold, 0, filter_map) # # # Rescale remaining values (adjust new_range if needed) # new_range = 1 # Adjust based on your desired final range # filter_map = np.where(mask, (mat - threshold) / (mat.max() - threshold) * new_range, 0) # # # I just add all the maps together, which is really noisy. # if type(total_mat) != type(None): # total_mat += filter_map # else: # total_mat = filter_map # # # Normalize based on largest value, # # Store this image in a collection, in which a GIF will be made, that lasts at least 2 seconds. # total_mat = total_mat / abs(np.max(total_mat)) # # # image = img.squeeze(0) # .detach().numpy().astype(np.float32) # # # plt.imshow(plt.imread(os.path.join(os.getcwd(), "dataset/root/train", image_filename))) # full path needed # plt.imshow(total_mat, cmap='jet', alpha=0.3) # # # selected_frames.append() # filename = frame_name+frame_no+video_no+".jpg" # file_path = os.path.join(os.getcwd(), "dataset/gifs/raw/", filename) # plt.savefig(file_path) # selected_frames.append(file_path) exit() # plt.figure(figsize=(16, 16)) # for i in range(36): # plt.subplot(6, 6, i + 1) # plt.imshow(cv2.imread(TEST_IMAGE)) # plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3) # # plt.show()