Update CLIP.py
Browse files
CLIP.py
CHANGED
@@ -54,9 +54,14 @@ class Bottleneck(tf.keras.layers.Layer):
|
|
54 |
return out
|
55 |
|
56 |
|
57 |
-
class AttentionPool2d:
|
58 |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
59 |
-
self.positional_embedding =
|
|
|
|
|
|
|
|
|
|
|
60 |
self.k_proj = Dense(embed_dim)
|
61 |
self.q_proj = Dense(embed_dim)
|
62 |
self.v_proj = Dense(embed_dim)
|
@@ -213,15 +218,25 @@ class Transformer:
|
|
213 |
return self.resblocks(x)
|
214 |
|
215 |
|
216 |
-
class VisionTransformer:
|
217 |
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
218 |
self.input_resolution = input_resolution
|
219 |
self.output_dim = output_dim
|
220 |
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
|
221 |
|
222 |
scale = width ** -0.5
|
223 |
-
self.class_embedding =
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
self.ln_pre = LayerNorm(width)
|
226 |
|
227 |
self.transformer = Transformer(width, layers, heads)
|
@@ -296,17 +311,32 @@ class CLIP(Model):
|
|
296 |
)
|
297 |
|
298 |
self.vocab_size = vocab_size
|
299 |
-
self.token_embedding =
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
self.ln_final = LayerNorm(transformer_width)
|
305 |
|
306 |
-
self.text_projection =
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
def build_attention_mask(self):
|
312 |
mask = tf.ones((self.context_length, self.context_length))
|
|
|
54 |
return out
|
55 |
|
56 |
|
57 |
+
class AttentionPool2d(tf.keras.layers.Layer):
|
58 |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
59 |
+
self.positional_embedding = self.add_weight(
|
60 |
+
name='positional_embedding',
|
61 |
+
shape=[self.spacial_dim ** 2 + 1, self.embed_dim],
|
62 |
+
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5),
|
63 |
+
trainable=True
|
64 |
+
)
|
65 |
self.k_proj = Dense(embed_dim)
|
66 |
self.q_proj = Dense(embed_dim)
|
67 |
self.v_proj = Dense(embed_dim)
|
|
|
218 |
return self.resblocks(x)
|
219 |
|
220 |
|
221 |
+
class VisionTransformer(tf.keras.layers.Layer):
|
222 |
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
223 |
self.input_resolution = input_resolution
|
224 |
self.output_dim = output_dim
|
225 |
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
|
226 |
|
227 |
scale = width ** -0.5
|
228 |
+
self.class_embedding = self.add_weight(
|
229 |
+
name='class_embedding',
|
230 |
+
shape=[self.width],
|
231 |
+
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
|
232 |
+
trainable=True
|
233 |
+
)
|
234 |
+
self.positional_embedding = self.add_weight(
|
235 |
+
name='positional_embedding',
|
236 |
+
shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width],
|
237 |
+
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
|
238 |
+
trainable=True
|
239 |
+
)
|
240 |
self.ln_pre = LayerNorm(width)
|
241 |
|
242 |
self.transformer = Transformer(width, layers, heads)
|
|
|
311 |
)
|
312 |
|
313 |
self.vocab_size = vocab_size
|
314 |
+
self.token_embedding = self.add_weight(
|
315 |
+
name='token_embedding',
|
316 |
+
shape=(vocab_size, transformer_width),
|
317 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
|
318 |
+
trainable=True
|
319 |
+
)
|
320 |
+
self.positional_embedding = self.add_weight(
|
321 |
+
name='positional_embedding',
|
322 |
+
shape=(self.context_length, transformer_width),
|
323 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
|
324 |
+
trainable=True
|
325 |
+
)
|
326 |
self.ln_final = LayerNorm(transformer_width)
|
327 |
|
328 |
+
self.text_projection = self.add_weight(
|
329 |
+
name='text_projection',
|
330 |
+
shape=(transformer_width, embed_dim),
|
331 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5),
|
332 |
+
trainable=True
|
333 |
+
)
|
334 |
+
self.logit_scale = self.add_weight(
|
335 |
+
name='logit_scale',
|
336 |
+
shape=[],
|
337 |
+
initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)),
|
338 |
+
trainable=True
|
339 |
+
)
|
340 |
|
341 |
def build_attention_mask(self):
|
342 |
mask = tf.ones((self.context_length, self.context_length))
|