Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
def window_partition(x, window_size): | |
_, height, width, channels = x.shape | |
patch_num_y = height // window_size | |
patch_num_x = width // window_size | |
x = tf.reshape( | |
x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels) | |
) | |
x = tf.transpose(x, (0, 1, 3, 2, 4, 5)) | |
windows = tf.reshape(x, shape=(-1, window_size, window_size, channels)) | |
return windows | |
def window_reverse(windows, window_size, height, width, channels): | |
patch_num_y = height // window_size | |
patch_num_x = width // window_size | |
x = tf.reshape( | |
windows, | |
shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels), | |
) | |
x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5)) | |
x = tf.reshape(x, shape=(-1, height, width, channels)) | |
return x | |