NoteDance commited on
Commit
c86eba0
1 Parent(s): beed2af

Update CLIP.py

Browse files
Files changed (1) hide show
  1. CLIP.py +44 -14
CLIP.py CHANGED
@@ -54,9 +54,14 @@ class Bottleneck(tf.keras.layers.Layer):
54
  return out
55
 
56
 
57
- class AttentionPool2d:
58
  def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59
- self.positional_embedding = tf.Variable(tf.random.normal([spacial_dim ** 2 + 1, embed_dim]) / embed_dim ** 0.5)
 
 
 
 
 
60
  self.k_proj = Dense(embed_dim)
61
  self.q_proj = Dense(embed_dim)
62
  self.v_proj = Dense(embed_dim)
@@ -213,15 +218,25 @@ class Transformer:
213
  return self.resblocks(x)
214
 
215
 
216
- class VisionTransformer:
217
  def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
218
  self.input_resolution = input_resolution
219
  self.output_dim = output_dim
220
  self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
221
 
222
  scale = width ** -0.5
223
- self.class_embedding = tf.Variable(scale * tf.random.normal([width]))
224
- self.positional_embedding = tf.Variable(scale * tf.random.normal((input_resolution // patch_size) ** 2 + 1, width))
 
 
 
 
 
 
 
 
 
 
225
  self.ln_pre = LayerNorm(width)
226
 
227
  self.transformer = Transformer(width, layers, heads)
@@ -296,17 +311,32 @@ class CLIP(Model):
296
  )
297
 
298
  self.vocab_size = vocab_size
299
- self.token_embedding = tf.Variable(tf.random.normal((vocab_size, transformer_width),
300
- stddev=0.02))
301
- self.positional_embedding = tf.Variable(tf.random.normal((self.context_length, transformer_width),
302
- stddev=0.01
303
- ))
 
 
 
 
 
 
 
304
  self.ln_final = LayerNorm(transformer_width)
305
 
306
- self.text_projection = tf.Variable(tf.random.normal((transformer_width, embed_dim),
307
- stddev=self.transformer.width ** -0.5,
308
- ))
309
- self.logit_scale = tf.Variable(tf.ones([]) * np.log(1 / 0.07))
 
 
 
 
 
 
 
 
310
 
311
  def build_attention_mask(self):
312
  mask = tf.ones((self.context_length, self.context_length))
 
54
  return out
55
 
56
 
57
+ class AttentionPool2d(tf.keras.layers.Layer):
58
  def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59
+ self.positional_embedding = self.add_weight(
60
+ name='positional_embedding',
61
+ shape=[self.spacial_dim ** 2 + 1, self.embed_dim],
62
+ initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5),
63
+ trainable=True
64
+ )
65
  self.k_proj = Dense(embed_dim)
66
  self.q_proj = Dense(embed_dim)
67
  self.v_proj = Dense(embed_dim)
 
218
  return self.resblocks(x)
219
 
220
 
221
+ class VisionTransformer(tf.keras.layers.Layer):
222
  def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
223
  self.input_resolution = input_resolution
224
  self.output_dim = output_dim
225
  self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
226
 
227
  scale = width ** -0.5
228
+ self.class_embedding = self.add_weight(
229
+ name='class_embedding',
230
+ shape=[self.width],
231
+ initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
232
+ trainable=True
233
+ )
234
+ self.positional_embedding = self.add_weight(
235
+ name='positional_embedding',
236
+ shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width],
237
+ initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
238
+ trainable=True
239
+ )
240
  self.ln_pre = LayerNorm(width)
241
 
242
  self.transformer = Transformer(width, layers, heads)
 
311
  )
312
 
313
  self.vocab_size = vocab_size
314
+ self.token_embedding = self.add_weight(
315
+ name='token_embedding',
316
+ shape=(vocab_size, transformer_width),
317
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
318
+ trainable=True
319
+ )
320
+ self.positional_embedding = self.add_weight(
321
+ name='positional_embedding',
322
+ shape=(self.context_length, transformer_width),
323
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
324
+ trainable=True
325
+ )
326
  self.ln_final = LayerNorm(transformer_width)
327
 
328
+ self.text_projection = self.add_weight(
329
+ name='text_projection',
330
+ shape=(transformer_width, embed_dim),
331
+ initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5),
332
+ trainable=True
333
+ )
334
+ self.logit_scale = self.add_weight(
335
+ name='logit_scale',
336
+ shape=[],
337
+ initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)),
338
+ trainable=True
339
+ )
340
 
341
  def build_attention_mask(self):
342
  mask = tf.ones((self.context_length, self.context_length))