pivaenist / model.py
TomRB22's picture
Update model.py
523a819
raw
history blame
5.21 kB
import tensorflow as tf
_CAP = 3501 # Cap for the number of notes
class Encoder_Z(tf.keras.layers.Layer):
def __init__(self, dim_z, name="encoder", **kwargs):
super(Encoder_Z, self).__init__(name=name, **kwargs)
self.dim_x = (3, _CAP, 1)
self.dim_z = dim_z
def build(self):
layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)]
layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2)))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Flatten())
layers.append(tf.keras.layers.Dense(2000))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense(500))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params"))
return tf.keras.Sequential(layers)
class Decoder_X(tf.keras.layers.Layer):
def __init__(self, dim_z, name="decoder", **kwargs):
super(Decoder_X, self).__init__(name=name, **kwargs)
self.dim_z = dim_z
def build(self):
# Build architecture
layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))]
layers.append(tf.keras.layers.Dense(500))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense(2000))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Dense((_CAP - 1) / 2 * 32, activation=None))
layers.append(tf.keras.layers.Reshape((1, int((_CAP - 1) / 2), 32)))
layers.append(tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='valid'))
layers.append(tf.keras.layers.ReLU())
layers.append(tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=3, strides=1, padding='same'))
return tf.keras.Sequential(layers)
kl_weight = tf.keras.backend.variable(0.125)
class VAECost:
# VAE cost with a schedule based on the Microsoft Research Blog's article
# "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
#
# The KL weight increases linearly, until it meets a certain threshold and keeps constant
# for the same number of epochs. After that, it decreases abruptly to zero again, and the
# cycle repeats.
def __init__(self, model):
self.model = model
self.kl_weight_increasing = True
self.epoch = 1
# The loss should have the form loss(y_true, y_pred), but in this
# case y_pred is computed in the cost function
@tf.function()
def __call__(self, x_true):
x_true = tf.cast(x_true, tf.float32)
# Encode "song map" to get its latent representation and the parameters
# of the distribution
z_sample, mu, sd = self.model.encode(x_true)
# Decode the latent representation. Due to the VAE architecture, we should
# ideally get a reconstructed song map similar to the input.
x_recons = self.model.decoder(z_sample)
# Compute mean squared error, where our ground truth is the song map
# we pass as input, so we "compare" the reconstruction to it.
recons_error = tf.cast(
tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]),
tf.float32)
# Compute reverse KL divergence
kl_divergence = -0.5 * tf.math.reduce_sum(
1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
axis=1) # shape=(batch_size,)
# Return metrics
elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
mean_kl_divergence = tf.reduce_mean(kl_divergence)
mean_recons_error = tf.reduce_mean(recons_error)
return -elbo, mean_kl_divergence, mean_recons_error
class VAE(tf.keras.Model):
def __init__(self, dim_z, seed=2000, analytic_kl=True, name="autoencoder", **kwargs):
super(VAE, self).__init__(name=name, **kwargs)
self.dim_x = (3, CAP, 1)
self.dim_z = dim_z
self.seed = seed
self.analytic_kl = analytic_kl
self.encoder = Encoder_Z(dim_z=self.dim_z).build()
self.decoder = Decoder_X(dim_z=self.dim_z).build()
self.cost_func = VAECost(self)
@tf.function()
def train_step(self, data):
# Gradient descent
with tf.GradientTape() as tape:
neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data)
gradients = tape.gradient(neg_elbo, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence,
"mean recons": mean_recons_error,
"kl weight": kl_weight}
def encode(self, x_input):
# Get a "song map" and make a forward pass through the encoder, in order
# to return the latent representation and the distribution's parameters
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
sd = tf.math.log(1 + tf.math.exp(rho))
z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,))
return z_sample, mu, sd
def generate(self, z_sample=None):
# Decode a latent representation of a song, which is provided or sampled
if z_sample == None:
z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0)
return self.decoder(z_sample)