|
import tensorflow as tf |
|
|
|
|
|
_CAP = 3501 |
|
|
|
class Encoder_Z(tf.keras.layers.Layer): |
|
|
|
def __init__(self, dim_z, name="encoder", **kwargs): |
|
super(Encoder_Z, self).__init__(name=name, **kwargs) |
|
self.dim_x = (3, _CAP, 1) |
|
self.dim_z = dim_z |
|
|
|
def build(self): |
|
layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)] |
|
|
|
layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2))) |
|
layers.append(tf.keras.layers.ReLU()) |
|
layers.append(tf.keras.layers.Flatten()) |
|
|
|
layers.append(tf.keras.layers.Dense(2000)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense(500)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params")) |
|
|
|
return tf.keras.Sequential(layers) |
|
|
|
|
|
class Decoder_X(tf.keras.layers.Layer): |
|
|
|
def __init__(self, dim_z, name="decoder", **kwargs): |
|
super(Decoder_X, self).__init__(name=name, **kwargs) |
|
self.dim_z = dim_z |
|
|
|
def build(self): |
|
|
|
|
|
layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))] |
|
|
|
layers.append(tf.keras.layers.Dense(500)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense(2000)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense((_CAP - 1) / 2 * 32, activation=None)) |
|
layers.append(tf.keras.layers.Reshape((1, int((_CAP - 1) / 2), 32))) |
|
|
|
layers.append(tf.keras.layers.Conv2DTranspose( |
|
filters=64, kernel_size=3, strides=2, padding='valid')) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Conv2DTranspose( |
|
filters=1, kernel_size=3, strides=1, padding='same')) |
|
|
|
return tf.keras.Sequential(layers) |
|
|
|
kl_weight = tf.keras.backend.variable(0.125) |
|
|
|
|
|
|
|
class VAECost: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model): |
|
self.model = model |
|
self.kl_weight_increasing = True |
|
self.epoch = 1 |
|
|
|
|
|
|
|
|
|
|
|
@tf.function() |
|
def __call__(self, x_true): |
|
x_true = tf.cast(x_true, tf.float32) |
|
|
|
|
|
|
|
z_sample, mu, sd = self.model.encode(x_true) |
|
|
|
|
|
|
|
x_recons = self.model.decoder(z_sample) |
|
|
|
|
|
|
|
|
|
recons_error = tf.cast( |
|
tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]), |
|
tf.float32) |
|
|
|
|
|
kl_divergence = -0.5 * tf.math.reduce_sum( |
|
1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd), |
|
axis=1) |
|
|
|
|
|
elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error) |
|
mean_kl_divergence = tf.reduce_mean(kl_divergence) |
|
mean_recons_error = tf.reduce_mean(recons_error) |
|
|
|
return -elbo, mean_kl_divergence, mean_recons_error |
|
|
|
|
|
class VAE(tf.keras.Model): |
|
|
|
def __init__(self, dim_z, seed=2000, analytic_kl=True, name="autoencoder", **kwargs): |
|
super(VAE, self).__init__(name=name, **kwargs) |
|
self.dim_x = (3, CAP, 1) |
|
self.dim_z = dim_z |
|
self.seed = seed |
|
self.analytic_kl = analytic_kl |
|
self.encoder = Encoder_Z(dim_z=self.dim_z).build() |
|
self.decoder = Decoder_X(dim_z=self.dim_z).build() |
|
self.cost_func = VAECost(self) |
|
|
|
@tf.function() |
|
def train_step(self, data): |
|
|
|
|
|
with tf.GradientTape() as tape: |
|
neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data) |
|
|
|
gradients = tape.gradient(neg_elbo, self.trainable_variables) |
|
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) |
|
|
|
return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence, |
|
"mean recons": mean_recons_error, |
|
"kl weight": kl_weight} |
|
|
|
def encode(self, x_input): |
|
|
|
|
|
|
|
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1) |
|
sd = tf.math.log(1 + tf.math.exp(rho)) |
|
z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,)) |
|
return z_sample, mu, sd |
|
|
|
def generate(self, z_sample=None): |
|
|
|
|
|
if z_sample == None: |
|
z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0) |
|
return self.decoder(z_sample) |
|
|