TomRB22 commited on
Commit
5a0bd26
1 Parent(s): 3d7d565

Fixed dim_z to 120, modified path to be dynamic and cleaned VAE constructor of unnecessary parameters

Browse files
Files changed (1) hide show
  1. model.py +19 -9
model.py CHANGED
@@ -1,4 +1,6 @@
1
  import tensorflow as tf
 
 
2
 
3
 
4
  _CAP = 3501 # Cap for the number of notes
@@ -112,16 +114,24 @@ class VAECost:
112
 
113
  class VAE(tf.keras.Model):
114
 
115
- def __init__(self, dim_z=120, seed=2000, analytic_kl=True, name="autoencoder", **kwargs):
116
  super(VAE, self).__init__(name=name, **kwargs)
117
  self.dim_x = (3, _CAP, 1)
118
- self.dim_z = dim_z
119
- self.seed = seed
120
- self.analytic_kl = analytic_kl
121
- self.encoder = Encoder_Z(dim_z=self.dim_z).build()
122
- self.decoder = Decoder_X(dim_z=self.dim_z).build()
123
  self.cost_func = VAECost(self)
124
- self.load_weights("./weights/")
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  @tf.function()
127
  def train_step(self, data):
@@ -143,12 +153,12 @@ class VAE(tf.keras.Model):
143
 
144
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
145
  sd = tf.math.log(1 + tf.math.exp(rho))
146
- z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,))
147
  return z_sample, mu, sd
148
 
149
  def generate(self, z_sample=None):
150
  # Decode a latent representation of a song, which is provided or sampled
151
 
152
  if z_sample == None:
153
- z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0)
154
  return self.decoder(z_sample)
 
1
  import tensorflow as tf
2
+ import os
3
+ import inspect
4
 
5
 
6
  _CAP = 3501 # Cap for the number of notes
 
114
 
115
  class VAE(tf.keras.Model):
116
 
117
+ def __init__(self, **kwargs):
118
  super(VAE, self).__init__(name=name, **kwargs)
119
  self.dim_x = (3, _CAP, 1)
120
+ self.encoder = Encoder_Z(dim_z=120).build()
121
+ self.decoder = Decoder_X(dim_z=120).build()
 
 
 
122
  self.cost_func = VAECost(self)
123
+
124
+ # Get the path of the script that defines this method
125
+ script_path = inspect.getfile(inspect.currentframe())
126
+
127
+ # Get the directory containing the script
128
+ script_dir = os.path.dirname(os.path.abspath(script_path))
129
+
130
+ # Construct the path to the weights folder
131
+ weights_dir = os.path.join(script_dir, 'weights') + os.sep
132
+
133
+ # Load pretrained weights
134
+ self.load_weights(weights_dir)
135
 
136
  @tf.function()
137
  def train_step(self, data):
 
153
 
154
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
155
  sd = tf.math.log(1 + tf.math.exp(rho))
156
+ z_sample = mu + sd * tf.random.normal(shape=(120,))
157
  return z_sample, mu, sd
158
 
159
  def generate(self, z_sample=None):
160
  # Decode a latent representation of a song, which is provided or sampled
161
 
162
  if z_sample == None:
163
+ z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
164
  return self.decoder(z_sample)