Update model.py
Browse files
model.py
CHANGED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
|
4 |
+
CAP = 3501 # Cap for the number of notes
|
5 |
+
|
6 |
+
class Encoder_Z(tf.keras.layers.Layer):
|
7 |
+
|
8 |
+
def __init__(self, dim_z, name="encoder", **kwargs):
|
9 |
+
super(Encoder_Z, self).__init__(name=name, **kwargs)
|
10 |
+
self.dim_x = (3, CAP, 1)
|
11 |
+
self.dim_z = dim_z
|
12 |
+
|
13 |
+
def build(self):
|
14 |
+
layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)]
|
15 |
+
|
16 |
+
layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2)))
|
17 |
+
layers.append(tf.keras.layers.ReLU())
|
18 |
+
layers.append(tf.keras.layers.Flatten())
|
19 |
+
|
20 |
+
layers.append(tf.keras.layers.Dense(2000))
|
21 |
+
layers.append(tf.keras.layers.ReLU())
|
22 |
+
|
23 |
+
layers.append(tf.keras.layers.Dense(500))
|
24 |
+
layers.append(tf.keras.layers.ReLU())
|
25 |
+
|
26 |
+
layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params"))
|
27 |
+
|
28 |
+
return tf.keras.Sequential(layers)
|
29 |
+
|
30 |
+
|
31 |
+
class Decoder_X(tf.keras.layers.Layer):
|
32 |
+
|
33 |
+
def __init__(self, dim_z, name="decoder", **kwargs):
|
34 |
+
super(Decoder_X, self).__init__(name=name, **kwargs)
|
35 |
+
self.dim_z = dim_z
|
36 |
+
|
37 |
+
def build(self):
|
38 |
+
# Build architecture
|
39 |
+
|
40 |
+
layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))]
|
41 |
+
|
42 |
+
layers.append(tf.keras.layers.Dense(500))
|
43 |
+
layers.append(tf.keras.layers.ReLU())
|
44 |
+
|
45 |
+
layers.append(tf.keras.layers.Dense(2000))
|
46 |
+
layers.append(tf.keras.layers.ReLU())
|
47 |
+
|
48 |
+
layers.append(tf.keras.layers.Dense((CAP - 1) / 2 * 32, activation=None))
|
49 |
+
layers.append(tf.keras.layers.Reshape((1, int((CAP - 1) / 2), 32)))
|
50 |
+
|
51 |
+
layers.append(tf.keras.layers.Conv2DTranspose(
|
52 |
+
filters=64, kernel_size=3, strides=2, padding='valid'))
|
53 |
+
layers.append(tf.keras.layers.ReLU())
|
54 |
+
|
55 |
+
layers.append(tf.keras.layers.Conv2DTranspose(
|
56 |
+
filters=1, kernel_size=3, strides=1, padding='same'))
|
57 |
+
|
58 |
+
return tf.keras.Sequential(layers)
|
59 |
+
|
60 |
+
kl_weight = tf.keras.backend.variable(0.125)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class VAECost:
|
65 |
+
# VAE cost with a schedule based on the Microsoft Research Blog's article
|
66 |
+
# "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
|
67 |
+
#
|
68 |
+
# The KL weight increases linearly, until it meets a certain threshold and keeps constant
|
69 |
+
# for the same number of epochs. After that, it decreases abruptly to zero again, and the
|
70 |
+
# cycle repeats.
|
71 |
+
|
72 |
+
def __init__(self, model):
|
73 |
+
self.model = model
|
74 |
+
self.kl_weight_increasing = True
|
75 |
+
self.epoch = 1
|
76 |
+
|
77 |
+
|
78 |
+
# The loss should have the form loss(y_true, y_pred), but in this
|
79 |
+
# case y_pred is computed in the cost function
|
80 |
+
|
81 |
+
@tf.function()
|
82 |
+
def __call__(self, x_true):
|
83 |
+
x_true = tf.cast(x_true, tf.float32)
|
84 |
+
|
85 |
+
# Encode "song map" to get its latent representation and the parameters
|
86 |
+
# of the distribution
|
87 |
+
z_sample, mu, sd = self.model.encode(x_true)
|
88 |
+
|
89 |
+
# Decode the latent representation. Due to the VAE architecture, we should
|
90 |
+
# ideally get a reconstructed song map similar to the input.
|
91 |
+
x_recons = self.model.decoder(z_sample)
|
92 |
+
|
93 |
+
# Compute mean squared error, where our ground truth is the song map
|
94 |
+
# we pass as input, so we "compare" the reconstruction to it.
|
95 |
+
|
96 |
+
recons_error = tf.cast(
|
97 |
+
tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]),
|
98 |
+
tf.float32)
|
99 |
+
|
100 |
+
# Compute reverse KL divergence
|
101 |
+
kl_divergence = -0.5 * tf.math.reduce_sum(
|
102 |
+
1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
|
103 |
+
axis=1) # shape=(batch_size,)
|
104 |
+
|
105 |
+
# Return metrics
|
106 |
+
elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
|
107 |
+
mean_kl_divergence = tf.reduce_mean(kl_divergence)
|
108 |
+
mean_recons_error = tf.reduce_mean(recons_error)
|
109 |
+
|
110 |
+
return -elbo, mean_kl_divergence, mean_recons_error
|
111 |
+
|
112 |
+
|
113 |
+
class VAE(tf.keras.Model):
|
114 |
+
|
115 |
+
def __init__(self, dim_z, seed=2000, analytic_kl=True, name="autoencoder", **kwargs):
|
116 |
+
super(VAE, self).__init__(name=name, **kwargs)
|
117 |
+
self.dim_x = (3, CAP, 1)
|
118 |
+
self.dim_z = dim_z
|
119 |
+
self.seed = seed
|
120 |
+
self.analytic_kl = analytic_kl
|
121 |
+
self.encoder = Encoder_Z(dim_z=self.dim_z).build()
|
122 |
+
self.decoder = Decoder_X(dim_z=self.dim_z).build()
|
123 |
+
self.cost_func = VAECost(self)
|
124 |
+
|
125 |
+
@tf.function()
|
126 |
+
def train_step(self, data):
|
127 |
+
# Gradient descent
|
128 |
+
|
129 |
+
with tf.GradientTape() as tape:
|
130 |
+
neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data)
|
131 |
+
|
132 |
+
gradients = tape.gradient(neg_elbo, self.trainable_variables)
|
133 |
+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
134 |
+
|
135 |
+
return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence,
|
136 |
+
"mean recons": mean_recons_error,
|
137 |
+
"kl weight": kl_weight}
|
138 |
+
|
139 |
+
def encode(self, x_input):
|
140 |
+
# Get a "song map" and make a forward pass through the encoder, in order
|
141 |
+
# to return the latent representation and the distribution's parameters
|
142 |
+
|
143 |
+
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
144 |
+
sd = tf.math.log(1 + tf.math.exp(rho))
|
145 |
+
z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,))
|
146 |
+
return z_sample, mu, sd
|
147 |
+
|
148 |
+
def generate(self, z_sample=None):
|
149 |
+
# Decode a latent representation of a song, which is provided or sampled
|
150 |
+
|
151 |
+
if z_sample == None:
|
152 |
+
z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0)
|
153 |
+
return self.decoder(z_sample)
|