YourMT3 / amt /src /extras /swap_channel.py
mimbres's picture
.
a03c9b4
raw
history blame
3.82 kB
import numpy as np
a = np.arange(12).reshape(2, 3, 2) # (batch, channel, dim)
print(a)
array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]])
swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2))
# will swap channel 1 and 2 of batch 0 with channel 1 and 2 of batch 1
b = a @ swap_mat
print(b)
# expected output
array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]])
import torch
def swap_channels_between_batches(a_tensor, swap_channels):
# Copy the tensor to avoid modifying the original tensor
result_tensor = a_tensor.clone()
# Unpack the channels to be swapped
ch1, ch2 = swap_channels
# Swap the specified channels between batches
result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone()
result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone()
return result_tensor
# Define a sample tensor 'a_tensor'
a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)
# Define channels to swap
swap_channels = (1, 2) # Channels to swap between batches
# Swap the channels between batches
swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels)
# Print the original tensor and the tensor after swapping channels between batches
print("Original Tensor 'a_tensor':")
print(a_tensor)
print("\nTensor after swapping channels between batches:")
print(swapped_tensor)
#-------------------------------------------------
import torch
from einops import rearrange
def shift(arr, num, fill_value=np.nan):
result = np.empty_like(arr)
if num > 0:
result[:num] = fill_value
result[num:] = arr[:-num]
elif num < 0:
result[num:] = fill_value
result[:num] = arr[-num:]
else:
result[:] = arr
return result
def create_batch_swap_matrix(batch_size, channels, swap_channels):
swap_mat = np.eye(batch_size * channels)
for c in swap_channels:
idx1 = c # 첫 번째 배치의 κ΅ν™˜ν•  채널 인덱슀
idx2 = c + channels # 두 번째 배치의 κ΅ν™˜ν•  채널 인덱슀
swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0 # λŒ€κ°μ„  값을 0으둜 μ„€μ •
swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1 # ν•΄λ‹Ή 채널을 κ΅ν™˜
return swap_mat
def create_batch_swap_matrix(batch_size, channels, swap_channels):
swap_mat = np.eye(batch_size * channels)
# λͺ¨λ“  채널에 λŒ€ν•΄ κ΅ν™˜ μˆ˜ν–‰
for c in swap_channels:
idx1 = np.arange(c, batch_size * channels, channels) # ν˜„μž¬ μ±„λ„μ˜ λͺ¨λ“  배치 인덱슀
idx2 = (idx1 + channels) % (batch_size * channels) # μˆœν™˜μ„ μœ„ν•΄ modulo μ‚¬μš©
swap_mat[idx1, idx1] = 0
swap_mat[idx2, idx2] = 0
swap_mat[idx1, idx2] = 1
swap_mat[idx2, idx1] = 1
return swap_mat
def swap_channels_between_batches(input_tensor, swap_matrix):
reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d')
swapped_tensor = swap_matrix @ reshaped_tensor
return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0])
# 예제 νŒŒλΌλ―Έν„°
batch_size = 2
channels = 3
# swap_info = {
# : [1, 2] # batch_index: [channel_indices]
# }
swap_channels = [1, 2] # κ΅ν™˜ν•  채널
# 예제 ν…μ„œ 생성
input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)
# swap matrix 생성
swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels)
swap_matrix = torch.Tensor(swap_matrix)
# 채널 κ΅ν™˜ μˆ˜ν–‰
swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix)
# κ²°κ³Ό 좜λ ₯
print("Original Tensor:")
print(input_tensor)
print("\nSwapped Tensor:")
print(swapped_tensor)