Spaces:
Runtime error
Runtime error
update laion
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ for k, v in ckpt.items():
|
|
17 |
k = k[len('image_encoder.model.'):]
|
18 |
new_dict.update({k: v})
|
19 |
|
20 |
-
model = mae_vit_base_patch16(uni_dim=768, less_u=True)
|
21 |
|
22 |
msg = model.load_state_dict(new_dict, strict=False)
|
23 |
print(msg)
|
|
|
17 |
k = k[len('image_encoder.model.'):]
|
18 |
new_dict.update({k: v})
|
19 |
|
20 |
+
model = mae_vit_base_patch16(uni_dim=768, uni_heads=12, less_u=True)
|
21 |
|
22 |
msg = model.load_state_dict(new_dict, strict=False)
|
23 |
print(msg)
|
model.py
CHANGED
@@ -143,6 +143,7 @@ class ParallelTransformerBlock(nn.Module):
|
|
143 |
|
144 |
attn_inner_dim = dim_head * heads
|
145 |
ff_inner_dim = dim * ff_mult
|
|
|
146 |
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
147 |
|
148 |
self.heads = heads
|
@@ -431,7 +432,7 @@ class MaskedAutoencoderViT(nn.Module):
|
|
431 |
# NOTE: +1 for mask token used by MLM objective
|
432 |
# self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
|
433 |
|
434 |
-
self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim)
|
435 |
self.text_cls_token = nn.Parameter(torch.randn(uni_dim))
|
436 |
|
437 |
self.embed_dim = embed_dim
|
@@ -528,7 +529,7 @@ class MaskedAutoencoderViT(nn.Module):
|
|
528 |
# self.text_mask_token = nn.Parameter(torch.randn(embed_dim))
|
529 |
self.mask_token_id = len(self.tokenizer.vocab)
|
530 |
|
531 |
-
|
532 |
self.text_length = text_length
|
533 |
|
534 |
self.latent_projector_layer = projector_layer
|
|
|
143 |
|
144 |
attn_inner_dim = dim_head * heads
|
145 |
ff_inner_dim = dim * ff_mult
|
146 |
+
# import ipdb; ipdb.set_trace()
|
147 |
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
|
148 |
|
149 |
self.heads = heads
|
|
|
432 |
# NOTE: +1 for mask token used by MLM objective
|
433 |
# self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
|
434 |
|
435 |
+
self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
|
436 |
self.text_cls_token = nn.Parameter(torch.randn(uni_dim))
|
437 |
|
438 |
self.embed_dim = embed_dim
|
|
|
529 |
# self.text_mask_token = nn.Parameter(torch.randn(embed_dim))
|
530 |
self.mask_token_id = len(self.tokenizer.vocab)
|
531 |
|
532 |
+
self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False)
|
533 |
self.text_length = text_length
|
534 |
|
535 |
self.latent_projector_layer = projector_layer
|