import tensorflow as tf _CAP = 3501 # Cap for the number of notes class Encoder_Z(tf.keras.layers.Layer): def __init__(self, dim_z, name="encoder", **kwargs): super(Encoder_Z, self).__init__(name=name, **kwargs) self.dim_x = (3, _CAP, 1) self.dim_z = dim_z def build(self): layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)] layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2))) layers.append(tf.keras.layers.ReLU()) layers.append(tf.keras.layers.Flatten()) layers.append(tf.keras.layers.Dense(2000)) layers.append(tf.keras.layers.ReLU()) layers.append(tf.keras.layers.Dense(500)) layers.append(tf.keras.layers.ReLU()) layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params")) return tf.keras.Sequential(layers) class Decoder_X(tf.keras.layers.Layer): def __init__(self, dim_z, name="decoder", **kwargs): super(Decoder_X, self).__init__(name=name, **kwargs) self.dim_z = dim_z def build(self): # Build architecture layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))] layers.append(tf.keras.layers.Dense(500)) layers.append(tf.keras.layers.ReLU()) layers.append(tf.keras.layers.Dense(2000)) layers.append(tf.keras.layers.ReLU()) layers.append(tf.keras.layers.Dense((_CAP - 1) / 2 * 32, activation=None)) layers.append(tf.keras.layers.Reshape((1, int((_CAP - 1) / 2), 32))) layers.append(tf.keras.layers.Conv2DTranspose( filters=64, kernel_size=3, strides=2, padding='valid')) layers.append(tf.keras.layers.ReLU()) layers.append(tf.keras.layers.Conv2DTranspose( filters=1, kernel_size=3, strides=1, padding='same')) return tf.keras.Sequential(layers) kl_weight = tf.keras.backend.variable(0.125) class VAECost: # VAE cost with a schedule based on the Microsoft Research Blog's article # "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony" # # The KL weight increases linearly, until it meets a certain threshold and keeps constant # for the same number of epochs. After that, it decreases abruptly to zero again, and the # cycle repeats. def __init__(self, model): self.model = model self.kl_weight_increasing = True self.epoch = 1 # The loss should have the form loss(y_true, y_pred), but in this # case y_pred is computed in the cost function @tf.function() def __call__(self, x_true): x_true = tf.cast(x_true, tf.float32) # Encode "song map" to get its latent representation and the parameters # of the distribution z_sample, mu, sd = self.model.encode(x_true) # Decode the latent representation. Due to the VAE architecture, we should # ideally get a reconstructed song map similar to the input. x_recons = self.model.decoder(z_sample) # Compute mean squared error, where our ground truth is the song map # we pass as input, so we "compare" the reconstruction to it. recons_error = tf.cast( tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]), tf.float32) # Compute reverse KL divergence kl_divergence = -0.5 * tf.math.reduce_sum( 1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd), axis=1) # shape=(batch_size,) # Return metrics elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error) mean_kl_divergence = tf.reduce_mean(kl_divergence) mean_recons_error = tf.reduce_mean(recons_error) return -elbo, mean_kl_divergence, mean_recons_error class VAE(tf.keras.Model): def __init__(self, dim_z=120, seed=2000, analytic_kl=True, name="autoencoder", **kwargs): super(VAE, self).__init__(name=name, **kwargs) self.dim_x = (3, _CAP, 1) self.dim_z = dim_z self.seed = seed self.analytic_kl = analytic_kl self.encoder = Encoder_Z(dim_z=self.dim_z).build() self.decoder = Decoder_X(dim_z=self.dim_z).build() self.cost_func = VAECost(self) @tf.function() def train_step(self, data): # Gradient descent with tf.GradientTape() as tape: neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data) gradients = tape.gradient(neg_elbo, self.trainable_variables) self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence, "mean recons": mean_recons_error, "kl weight": kl_weight} def encode(self, x_input): # Get a "song map" and make a forward pass through the encoder, in order # to return the latent representation and the distribution's parameters mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1) sd = tf.math.log(1 + tf.math.exp(rho)) z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,)) return z_sample, mu, sd def generate(self, z_sample=None): # Decode a latent representation of a song, which is provided or sampled if z_sample == None: z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0) return self.decoder(z_sample)