|
import numpy as np |
|
|
|
a = np.arange(12).reshape(2, 3, 2) |
|
print(a) |
|
array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]) |
|
|
|
swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2)) |
|
|
|
|
|
b = a @ swap_mat |
|
print(b) |
|
|
|
array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]]) |
|
|
|
import torch |
|
|
|
|
|
def swap_channels_between_batches(a_tensor, swap_channels): |
|
|
|
result_tensor = a_tensor.clone() |
|
|
|
|
|
ch1, ch2 = swap_channels |
|
|
|
|
|
result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone() |
|
result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone() |
|
|
|
return result_tensor |
|
|
|
|
|
|
|
a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32) |
|
|
|
|
|
swap_channels = (1, 2) |
|
|
|
|
|
swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels) |
|
|
|
|
|
print("Original Tensor 'a_tensor':") |
|
print(a_tensor) |
|
print("\nTensor after swapping channels between batches:") |
|
print(swapped_tensor) |
|
|
|
|
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
|
|
def shift(arr, num, fill_value=np.nan): |
|
result = np.empty_like(arr) |
|
if num > 0: |
|
result[:num] = fill_value |
|
result[num:] = arr[:-num] |
|
elif num < 0: |
|
result[num:] = fill_value |
|
result[:num] = arr[-num:] |
|
else: |
|
result[:] = arr |
|
return result |
|
|
|
|
|
def create_batch_swap_matrix(batch_size, channels, swap_channels): |
|
swap_mat = np.eye(batch_size * channels) |
|
|
|
for c in swap_channels: |
|
idx1 = c |
|
idx2 = c + channels |
|
|
|
swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0 |
|
swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1 |
|
return swap_mat |
|
|
|
|
|
def create_batch_swap_matrix(batch_size, channels, swap_channels): |
|
swap_mat = np.eye(batch_size * channels) |
|
|
|
|
|
for c in swap_channels: |
|
idx1 = np.arange(c, batch_size * channels, channels) |
|
idx2 = (idx1 + channels) % (batch_size * channels) |
|
|
|
swap_mat[idx1, idx1] = 0 |
|
swap_mat[idx2, idx2] = 0 |
|
swap_mat[idx1, idx2] = 1 |
|
swap_mat[idx2, idx1] = 1 |
|
|
|
return swap_mat |
|
|
|
|
|
def swap_channels_between_batches(input_tensor, swap_matrix): |
|
reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d') |
|
swapped_tensor = swap_matrix @ reshaped_tensor |
|
return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0]) |
|
|
|
|
|
|
|
batch_size = 2 |
|
channels = 3 |
|
|
|
|
|
|
|
swap_channels = [1, 2] |
|
|
|
|
|
input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32) |
|
|
|
|
|
swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels) |
|
swap_matrix = torch.Tensor(swap_matrix) |
|
|
|
|
|
swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix) |
|
|
|
|
|
print("Original Tensor:") |
|
print(input_tensor) |
|
print("\nSwapped Tensor:") |
|
print(swapped_tensor) |
|
|