TomRB22 commited on
Commit
92ac48c
1 Parent(s): f605e64

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +153 -0
model.py CHANGED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ CAP = 3501 # Cap for the number of notes
5
+
6
+ class Encoder_Z(tf.keras.layers.Layer):
7
+
8
+ def __init__(self, dim_z, name="encoder", **kwargs):
9
+ super(Encoder_Z, self).__init__(name=name, **kwargs)
10
+ self.dim_x = (3, CAP, 1)
11
+ self.dim_z = dim_z
12
+
13
+ def build(self):
14
+ layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)]
15
+
16
+ layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2)))
17
+ layers.append(tf.keras.layers.ReLU())
18
+ layers.append(tf.keras.layers.Flatten())
19
+
20
+ layers.append(tf.keras.layers.Dense(2000))
21
+ layers.append(tf.keras.layers.ReLU())
22
+
23
+ layers.append(tf.keras.layers.Dense(500))
24
+ layers.append(tf.keras.layers.ReLU())
25
+
26
+ layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params"))
27
+
28
+ return tf.keras.Sequential(layers)
29
+
30
+
31
+ class Decoder_X(tf.keras.layers.Layer):
32
+
33
+ def __init__(self, dim_z, name="decoder", **kwargs):
34
+ super(Decoder_X, self).__init__(name=name, **kwargs)
35
+ self.dim_z = dim_z
36
+
37
+ def build(self):
38
+ # Build architecture
39
+
40
+ layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))]
41
+
42
+ layers.append(tf.keras.layers.Dense(500))
43
+ layers.append(tf.keras.layers.ReLU())
44
+
45
+ layers.append(tf.keras.layers.Dense(2000))
46
+ layers.append(tf.keras.layers.ReLU())
47
+
48
+ layers.append(tf.keras.layers.Dense((CAP - 1) / 2 * 32, activation=None))
49
+ layers.append(tf.keras.layers.Reshape((1, int((CAP - 1) / 2), 32)))
50
+
51
+ layers.append(tf.keras.layers.Conv2DTranspose(
52
+ filters=64, kernel_size=3, strides=2, padding='valid'))
53
+ layers.append(tf.keras.layers.ReLU())
54
+
55
+ layers.append(tf.keras.layers.Conv2DTranspose(
56
+ filters=1, kernel_size=3, strides=1, padding='same'))
57
+
58
+ return tf.keras.Sequential(layers)
59
+
60
+ kl_weight = tf.keras.backend.variable(0.125)
61
+
62
+
63
+
64
+ class VAECost:
65
+ # VAE cost with a schedule based on the Microsoft Research Blog's article
66
+ # "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
67
+ #
68
+ # The KL weight increases linearly, until it meets a certain threshold and keeps constant
69
+ # for the same number of epochs. After that, it decreases abruptly to zero again, and the
70
+ # cycle repeats.
71
+
72
+ def __init__(self, model):
73
+ self.model = model
74
+ self.kl_weight_increasing = True
75
+ self.epoch = 1
76
+
77
+
78
+ # The loss should have the form loss(y_true, y_pred), but in this
79
+ # case y_pred is computed in the cost function
80
+
81
+ @tf.function()
82
+ def __call__(self, x_true):
83
+ x_true = tf.cast(x_true, tf.float32)
84
+
85
+ # Encode "song map" to get its latent representation and the parameters
86
+ # of the distribution
87
+ z_sample, mu, sd = self.model.encode(x_true)
88
+
89
+ # Decode the latent representation. Due to the VAE architecture, we should
90
+ # ideally get a reconstructed song map similar to the input.
91
+ x_recons = self.model.decoder(z_sample)
92
+
93
+ # Compute mean squared error, where our ground truth is the song map
94
+ # we pass as input, so we "compare" the reconstruction to it.
95
+
96
+ recons_error = tf.cast(
97
+ tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]),
98
+ tf.float32)
99
+
100
+ # Compute reverse KL divergence
101
+ kl_divergence = -0.5 * tf.math.reduce_sum(
102
+ 1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
103
+ axis=1) # shape=(batch_size,)
104
+
105
+ # Return metrics
106
+ elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
107
+ mean_kl_divergence = tf.reduce_mean(kl_divergence)
108
+ mean_recons_error = tf.reduce_mean(recons_error)
109
+
110
+ return -elbo, mean_kl_divergence, mean_recons_error
111
+
112
+
113
+ class VAE(tf.keras.Model):
114
+
115
+ def __init__(self, dim_z, seed=2000, analytic_kl=True, name="autoencoder", **kwargs):
116
+ super(VAE, self).__init__(name=name, **kwargs)
117
+ self.dim_x = (3, CAP, 1)
118
+ self.dim_z = dim_z
119
+ self.seed = seed
120
+ self.analytic_kl = analytic_kl
121
+ self.encoder = Encoder_Z(dim_z=self.dim_z).build()
122
+ self.decoder = Decoder_X(dim_z=self.dim_z).build()
123
+ self.cost_func = VAECost(self)
124
+
125
+ @tf.function()
126
+ def train_step(self, data):
127
+ # Gradient descent
128
+
129
+ with tf.GradientTape() as tape:
130
+ neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data)
131
+
132
+ gradients = tape.gradient(neg_elbo, self.trainable_variables)
133
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
134
+
135
+ return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence,
136
+ "mean recons": mean_recons_error,
137
+ "kl weight": kl_weight}
138
+
139
+ def encode(self, x_input):
140
+ # Get a "song map" and make a forward pass through the encoder, in order
141
+ # to return the latent representation and the distribution's parameters
142
+
143
+ mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
144
+ sd = tf.math.log(1 + tf.math.exp(rho))
145
+ z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,))
146
+ return z_sample, mu, sd
147
+
148
+ def generate(self, z_sample=None):
149
+ # Decode a latent representation of a song, which is provided or sampled
150
+
151
+ if z_sample == None:
152
+ z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0)
153
+ return self.decoder(z_sample)