Spaces:
Runtime error
Runtime error
try: | |
from jax import numpy as jnp | |
except ModuleNotFoundError: | |
# jax doesn't support windows os yet. | |
import numpy as jnp | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from layers.window_attention import WindowAttention | |
from utils.drop_path import DropPath | |
from utils.swin_window import window_partition | |
from utils.swin_window import window_reverse | |
class SwinTransformer(layers.Layer): | |
def __init__( | |
self, | |
dim, | |
num_patch, | |
num_heads, | |
window_size=7, | |
shift_size=0, | |
num_mlp=1024, | |
qkv_bias=True, | |
dropout_rate=0.0, | |
**kwargs, | |
): | |
super(SwinTransformer, self).__init__(**kwargs) | |
self.dim = dim | |
self.num_patch = num_patch | |
self.num_heads = num_heads | |
self.window_size = window_size | |
self.shift_size = shift_size | |
self.num_mlp = num_mlp | |
self.norm1 = layers.LayerNormalization(epsilon=1e-5) | |
self.attn = WindowAttention( | |
dim, | |
window_size=(self.window_size, self.window_size), | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
dropout_rate=dropout_rate, | |
) | |
self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity | |
self.norm2 = layers.LayerNormalization(epsilon=1e-5) | |
self.mlp = keras.Sequential( | |
[ | |
layers.Dense(num_mlp), | |
layers.Activation(keras.activations.gelu), | |
layers.Dropout(dropout_rate), | |
layers.Dense(dim), | |
layers.Dropout(dropout_rate), | |
] | |
) | |
if min(self.num_patch) < self.window_size: | |
self.shift_size = 0 | |
self.window_size = min(self.num_patch) | |
def build(self, input_shape): | |
if self.shift_size == 0: | |
self.attn_mask = None | |
else: | |
height, width = self.num_patch | |
h_slices = ( | |
slice(0, -self.window_size), | |
slice(-self.window_size, -self.shift_size), | |
slice(-self.shift_size, None), | |
) | |
w_slices = ( | |
slice(0, -self.window_size), | |
slice(-self.window_size, -self.shift_size), | |
slice(-self.shift_size, None), | |
) | |
mask_array = jnp.zeros((1, height, width, 1)) | |
count = 0 | |
for h in h_slices: | |
for w in w_slices: | |
mask_array[:, h, w, :] = count | |
count += 1 | |
mask_array = tf.convert_to_tensor(mask_array) | |
# mask array to windows | |
mask_windows = window_partition(mask_array, self.window_size) | |
mask_windows = tf.reshape( | |
mask_windows, shape=[-1, self.window_size * self.window_size] | |
) | |
attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims( | |
mask_windows, axis=2 | |
) | |
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask) | |
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask) | |
self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False) | |
def call(self, x): | |
height, width = self.num_patch | |
_, num_patches_before, channels = x.shape | |
x_skip = x | |
x = self.norm1(x) | |
x = tf.reshape(x, shape=(-1, height, width, channels)) | |
if self.shift_size > 0: | |
shifted_x = tf.roll( | |
x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2] | |
) | |
else: | |
shifted_x = x | |
x_windows = window_partition(shifted_x, self.window_size) | |
x_windows = tf.reshape( | |
x_windows, shape=(-1, self.window_size * self.window_size, channels) | |
) | |
attn_windows = self.attn(x_windows, mask=self.attn_mask) | |
attn_windows = tf.reshape( | |
attn_windows, shape=(-1, self.window_size, self.window_size, channels) | |
) | |
shifted_x = window_reverse( | |
attn_windows, self.window_size, height, width, channels | |
) | |
if self.shift_size > 0: | |
x = tf.roll( | |
shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2] | |
) | |
else: | |
x = shifted_x | |
x = tf.reshape(x, shape=(-1, height * width, channels)) | |
x = self.drop_path(x) | |
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32) | |
x_skip = x | |
x = self.norm2(x) | |
x = self.mlp(x) | |
x = self.drop_path(x) | |
x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32) | |
return x | |