Spaces:
Build error
Build error
import tensorflow as tf | |
import tensorflow.python.keras.backend as K | |
from tensorflow.python.eager import context | |
from tensorflow.python.ops import ( | |
gen_math_ops, | |
math_ops, | |
sparse_ops, | |
standard_ops, | |
) | |
def l2normalize(v, eps=1e-12): | |
return v / (tf.norm(v) + eps) | |
class ConvSN2D(tf.keras.layers.Conv2D): | |
def __init__(self, filters, kernel_size, power_iterations=1, datatype=tf.float32, **kwargs): | |
super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs) | |
self.power_iterations = power_iterations | |
self.datatype = datatype | |
def build(self, input_shape): | |
super(ConvSN2D, self).build(input_shape) | |
if self.data_format == "channels_first": | |
channel_axis = 1 | |
else: | |
channel_axis = -1 | |
self.u = self.add_weight( | |
self.name + "_u", | |
shape=tuple([1, self.kernel.shape.as_list()[-1]]), | |
initializer=tf.initializers.RandomNormal(0, 1), | |
trainable=False, | |
dtype=self.dtype, | |
) | |
def compute_spectral_norm(self, W, new_u, W_shape): | |
for _ in range(self.power_iterations): | |
new_v = l2normalize(tf.matmul(new_u, tf.transpose(W))) | |
new_u = l2normalize(tf.matmul(new_v, W)) | |
sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u)) | |
W_bar = W / sigma | |
with tf.control_dependencies([self.u.assign(new_u)]): | |
W_bar = tf.reshape(W_bar, W_shape) | |
return W_bar | |
def call(self, inputs): | |
W_shape = self.kernel.shape.as_list() | |
W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1])) | |
new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape) | |
outputs = self._convolution_op(inputs, new_kernel) | |
if self.use_bias: | |
if self.data_format == "channels_first": | |
outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW") | |
else: | |
outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC") | |
if self.activation is not None: | |
return self.activation(outputs) | |
return outputs | |
class DenseSN(tf.keras.layers.Dense): | |
def __init__(self, datatype=tf.float32, **kwargs): | |
super(DenseSN, self).__init__(**kwargs) | |
self.datatype = datatype | |
def build(self, input_shape): | |
super(DenseSN, self).build(input_shape) | |
self.u = self.add_weight( | |
self.name + "_u", | |
shape=tuple([1, self.kernel.shape.as_list()[-1]]), | |
initializer=tf.initializers.RandomNormal(0, 1), | |
trainable=False, | |
dtype=self.datatype, | |
) | |
def compute_spectral_norm(self, W, new_u, W_shape): | |
new_v = l2normalize(tf.matmul(new_u, tf.transpose(W))) | |
new_u = l2normalize(tf.matmul(new_v, W)) | |
sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u)) | |
W_bar = W / sigma | |
with tf.control_dependencies([self.u.assign(new_u)]): | |
W_bar = tf.reshape(W_bar, W_shape) | |
return W_bar | |
def call(self, inputs): | |
W_shape = self.kernel.shape.as_list() | |
W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1])) | |
new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape) | |
rank = len(inputs.shape) | |
if rank > 2: | |
outputs = standard_ops.tensordot(inputs, new_kernel, [[rank - 1], [0]]) | |
if not context.executing_eagerly(): | |
shape = inputs.shape.as_list() | |
output_shape = shape[:-1] + [self.units] | |
outputs.set_shape(output_shape) | |
else: | |
inputs = math_ops.cast(inputs, self._compute_dtype) | |
if K.is_sparse(inputs): | |
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, new_kernel) | |
else: | |
outputs = gen_math_ops.mat_mul(inputs, new_kernel) | |
if self.use_bias: | |
outputs = tf.nn.bias_add(outputs, self.bias) | |
if self.activation is not None: | |
return self.activation(outputs) | |
return outputs | |
class AddNoise(tf.keras.layers.Layer): | |
def __init__(self, datatype=tf.float32, **kwargs): | |
super(AddNoise, self).__init__(**kwargs) | |
self.datatype = datatype | |
def build(self, input_shape): | |
self.b = self.add_weight( | |
shape=[ | |
1, | |
], | |
initializer=tf.keras.initializers.zeros(), | |
trainable=True, | |
name="noise_weight", | |
) | |
def call(self, inputs): | |
rand = tf.random.normal( | |
[tf.shape(inputs)[0], inputs.shape[1], inputs.shape[2], 1], | |
mean=0.0, | |
stddev=1.0, | |
dtype=self.datatype, | |
) | |
output = inputs + self.b * rand | |
return output | |
class PosEnc(tf.keras.layers.Layer): | |
def __init__(self, datatype=tf.float32, **kwargs): | |
super(PosEnc, self).__init__(**kwargs) | |
self.datatype = datatype | |
def call(self, inputs): | |
pos = tf.repeat( | |
tf.reshape(tf.range(inputs.shape[-3], dtype=tf.int32), [1, -1, 1, 1]), | |
inputs.shape[-2], | |
-2, | |
) | |
pos = tf.cast(tf.repeat(pos, tf.shape(inputs)[0], 0), self.dtype) / tf.cast(inputs.shape[-3], self.datatype) | |
return tf.concat([inputs, pos], -1) # [bs,1,hop,2] | |
def flatten_hw(x, data_format="channels_last"): | |
if data_format == "channels_last": | |
x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` | |
old_shape = tf.shape(x) | |