Fixed dim_z to 120, modified path to be dynamic and cleaned VAE constructor of unnecessary parameters
Browse files
model.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import tensorflow as tf
|
|
|
|
|
2 |
|
3 |
|
4 |
_CAP = 3501 # Cap for the number of notes
|
@@ -112,16 +114,24 @@ class VAECost:
|
|
112 |
|
113 |
class VAE(tf.keras.Model):
|
114 |
|
115 |
-
def __init__(self,
|
116 |
super(VAE, self).__init__(name=name, **kwargs)
|
117 |
self.dim_x = (3, _CAP, 1)
|
118 |
-
self.
|
119 |
-
self.
|
120 |
-
self.analytic_kl = analytic_kl
|
121 |
-
self.encoder = Encoder_Z(dim_z=self.dim_z).build()
|
122 |
-
self.decoder = Decoder_X(dim_z=self.dim_z).build()
|
123 |
self.cost_func = VAECost(self)
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
@tf.function()
|
127 |
def train_step(self, data):
|
@@ -143,12 +153,12 @@ class VAE(tf.keras.Model):
|
|
143 |
|
144 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
145 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
146 |
-
z_sample = mu + sd * tf.random.normal(shape=(
|
147 |
return z_sample, mu, sd
|
148 |
|
149 |
def generate(self, z_sample=None):
|
150 |
# Decode a latent representation of a song, which is provided or sampled
|
151 |
|
152 |
if z_sample == None:
|
153 |
-
z_sample = tf.expand_dims(tf.random.normal(shape=(
|
154 |
return self.decoder(z_sample)
|
|
|
1 |
import tensorflow as tf
|
2 |
+
import os
|
3 |
+
import inspect
|
4 |
|
5 |
|
6 |
_CAP = 3501 # Cap for the number of notes
|
|
|
114 |
|
115 |
class VAE(tf.keras.Model):
|
116 |
|
117 |
+
def __init__(self, **kwargs):
|
118 |
super(VAE, self).__init__(name=name, **kwargs)
|
119 |
self.dim_x = (3, _CAP, 1)
|
120 |
+
self.encoder = Encoder_Z(dim_z=120).build()
|
121 |
+
self.decoder = Decoder_X(dim_z=120).build()
|
|
|
|
|
|
|
122 |
self.cost_func = VAECost(self)
|
123 |
+
|
124 |
+
# Get the path of the script that defines this method
|
125 |
+
script_path = inspect.getfile(inspect.currentframe())
|
126 |
+
|
127 |
+
# Get the directory containing the script
|
128 |
+
script_dir = os.path.dirname(os.path.abspath(script_path))
|
129 |
+
|
130 |
+
# Construct the path to the weights folder
|
131 |
+
weights_dir = os.path.join(script_dir, 'weights') + os.sep
|
132 |
+
|
133 |
+
# Load pretrained weights
|
134 |
+
self.load_weights(weights_dir)
|
135 |
|
136 |
@tf.function()
|
137 |
def train_step(self, data):
|
|
|
153 |
|
154 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
155 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
156 |
+
z_sample = mu + sd * tf.random.normal(shape=(120,))
|
157 |
return z_sample, mu, sd
|
158 |
|
159 |
def generate(self, z_sample=None):
|
160 |
# Decode a latent representation of a song, which is provided or sampled
|
161 |
|
162 |
if z_sample == None:
|
163 |
+
z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
|
164 |
return self.decoder(z_sample)
|