File size: 3,823 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import numpy as np
a = np.arange(12).reshape(2, 3, 2) # (batch, channel, dim)
print(a)
array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]])
swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2))
# will swap channel 1 and 2 of batch 0 with channel 1 and 2 of batch 1
b = a @ swap_mat
print(b)
# expected output
array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]])
import torch
def swap_channels_between_batches(a_tensor, swap_channels):
# Copy the tensor to avoid modifying the original tensor
result_tensor = a_tensor.clone()
# Unpack the channels to be swapped
ch1, ch2 = swap_channels
# Swap the specified channels between batches
result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone()
result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone()
return result_tensor
# Define a sample tensor 'a_tensor'
a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)
# Define channels to swap
swap_channels = (1, 2) # Channels to swap between batches
# Swap the channels between batches
swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels)
# Print the original tensor and the tensor after swapping channels between batches
print("Original Tensor 'a_tensor':")
print(a_tensor)
print("\nTensor after swapping channels between batches:")
print(swapped_tensor)
#-------------------------------------------------
import torch
from einops import rearrange
def shift(arr, num, fill_value=np.nan):
result = np.empty_like(arr)
if num > 0:
result[:num] = fill_value
result[num:] = arr[:-num]
elif num < 0:
result[num:] = fill_value
result[:num] = arr[-num:]
else:
result[:] = arr
return result
def create_batch_swap_matrix(batch_size, channels, swap_channels):
swap_mat = np.eye(batch_size * channels)
for c in swap_channels:
idx1 = c # 첫 λ²μ§Έ λ°°μΉμ κ΅νν μ±λ μΈλ±μ€
idx2 = c + channels # λ λ²μ§Έ λ°°μΉμ κ΅νν μ±λ μΈλ±μ€
swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0 # λκ°μ κ°μ 0μΌλ‘ μ€μ
swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1 # ν΄λΉ μ±λμ κ΅ν
return swap_mat
def create_batch_swap_matrix(batch_size, channels, swap_channels):
swap_mat = np.eye(batch_size * channels)
# λͺ¨λ μ±λμ λν΄ κ΅ν μν
for c in swap_channels:
idx1 = np.arange(c, batch_size * channels, channels) # νμ¬ μ±λμ λͺ¨λ λ°°μΉ μΈλ±μ€
idx2 = (idx1 + channels) % (batch_size * channels) # μνμ μν΄ modulo μ¬μ©
swap_mat[idx1, idx1] = 0
swap_mat[idx2, idx2] = 0
swap_mat[idx1, idx2] = 1
swap_mat[idx2, idx1] = 1
return swap_mat
def swap_channels_between_batches(input_tensor, swap_matrix):
reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d')
swapped_tensor = swap_matrix @ reshaped_tensor
return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0])
# μμ νλΌλ―Έν°
batch_size = 2
channels = 3
# swap_info = {
# : [1, 2] # batch_index: [channel_indices]
# }
swap_channels = [1, 2] # κ΅νν μ±λ
# μμ ν
μ μμ±
input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)
# swap matrix μμ±
swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels)
swap_matrix = torch.Tensor(swap_matrix)
# μ±λ κ΅ν μν
swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix)
# κ²°κ³Ό μΆλ ₯
print("Original Tensor:")
print(input_tensor)
print("\nSwapped Tensor:")
print(swapped_tensor)
|