pivaenist / model.py
TomRB22's picture
Fixed indentation issue
3f36060
raw
history blame
6.29 kB
# Deep learning
import tensorflow as tf
# Methods for loading the weights into the model
import os
import inspect
_CAP = 3501 # Cap for the number of notes
class Encoder_Z(tf.keras.layers.Layer):
# Encoder part of the VAE
def __init__(self, dim_z, name="encoder", **kwargs):
super(Encoder_Z, self).__init__(name=name, **kwargs)
self.dim_x = (3, _CAP, 1)
self.dim_z = dim_z
def build(self):
layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)]
layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2)))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Flatten())
layers.append(tf.keras.layers.Dense(2000))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense(500))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params"))
return tf.keras.Sequential(layers)
class Decoder_X(tf.keras.layers.Layer):
# Decoder part of the VAE.
def __init__(self, dim_z, name="decoder", **kwargs):
super(Decoder_X, self).__init__(name=name, **kwargs)
self.dim_z = dim_z
def build(self):
# Build architecture
layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))]
layers.append(tf.keras.layers.Dense(500))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense(2000))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense((_CAP - 1) / 2 * 32, activation=None))
layers.append(tf.keras.layers.Reshape((1, int((_CAP - 1) / 2), 32)))
layers.append(tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='valid'))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=3, strides=1, padding='same'))
return tf.keras.Sequential(layers)
kl_weight = tf.keras.backend.variable(0.125)
class VAECost:
"""
VAE cost with a schedule based on the Microsoft Research Blog's article
"Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
The KL weight increases linearly, until it meets a certain threshold and keeps constant
for the same number of epochs. After that, it decreases abruptly to zero again, and the
cycle repeats.
"""
def __init__(self, model):
self.model = model
self.kl_weight_increasing = True
self.epoch = 1
# The loss should have the form loss(y_true, y_pred), but in this
# case y_pred is computed in the cost function
@tf.function()
def __call__(self, x_true):
x_true = tf.cast(x_true, tf.float32)
# Encode "song map" to get its latent representation and the parameters
# of the distribution
z_sample, mu, sd = self.model.encode(x_true)
# Decode the latent representation. Due to the VAE architecture, we should
# ideally get a reconstructed song map similar to the input.
x_recons = self.model.decoder(z_sample)
# Compute mean squared error, where our ground truth is the song map
# we pass as input, so we "compare" the reconstruction to it.
recons_error = tf.cast(
tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]),
tf.float32)
# Compute reverse KL divergence
kl_divergence = -0.5 * tf.math.reduce_sum(
1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
axis=1) # shape=(batch_size,)
# Return metrics
elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
mean_kl_divergence = tf.reduce_mean(kl_divergence)
mean_recons_error = tf.reduce_mean(recons_error)
return -elbo, mean_kl_divergence, mean_recons_error
class VAE(tf.keras.Model):
# Main architecture, which connects the encoder with the decoder.
def __init__(self, name="variational autoencoder", **kwargs):
super(VAE, self).__init__(name=name, **kwargs)
self.dim_x = (3, _CAP, 1)
self.encoder = Encoder_Z(dim_z=120).build()
self.decoder = Decoder_X(dim_z=120).build()
self.cost_func = VAECost(self)
# Get the path of the script that defines this method
script_path = inspect.getfile(inspect.currentframe())
# Get the directory containing the script
script_dir = os.path.dirname(os.path.abspath(script_path))
# Construct the path to the weights folder
weights_dir = os.path.join(script_dir, 'weights') + os.sep
# Load pretrained weights
self.load_weights(weights_dir)
@tf.function()
def train_step(self, data):
# Gradient descent
with tf.GradientTape() as tape:
neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data)
gradients = tape.gradient(neg_elbo, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence,
"mean recons": mean_recons_error,
"kl weight": kl_weight}
def encode(self, x_input: tf.Tensor) -> tuple[tf.Tensor]:
"""
Get a "song map" and make a forward pass through the encoder, in order
to return the latent representation and the distribution's parameters.
Parameters:
x_input (tf.Tensor): Song map to be encoded by the VAE.
Returns:
tf.Tensor: The parameters of the distribution which encode the song
(mu, sd) and a sampled latent representation from this
distribution (z_sample).
"""
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
sd = tf.math.log(1 + tf.math.exp(rho))
z_sample = mu + sd * tf.random.normal(shape=(120,))
return z_sample, mu, sd
def generate(self, z_sample: tf.Tensor=None) -> tf.Tensor:
"""
Decode a latent representation of a song.
Parameters:
z_sample (tf.Tensor): Song encoding outputed by the encoder. If
None, this sampling is done over an
unit Gaussian distribution.
Returns:
tf.Tensor: Song map corresponding to the encoding.
"""
if z_sample == None:
z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
return self.decoder(z_sample)