Update modeling_molmo.py
Browse files- modeling_molmo.py +184 -205
modeling_molmo.py
CHANGED
@@ -32,13 +32,13 @@ import einops
|
|
32 |
from transformers import PreTrainedModel
|
33 |
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
34 |
|
35 |
-
from olmo.util import resource_path
|
36 |
from .configuration_molmo import (
|
37 |
MolmoConfig,
|
38 |
VisionBackboneConfig,
|
39 |
VisionBackboneType,
|
40 |
ImagePooling2DType,
|
41 |
-
ImageProjectType,
|
42 |
AttentionType,
|
43 |
MolmoConfigurationError,
|
44 |
)
|
@@ -54,6 +54,20 @@ else:
|
|
54 |
log = logging.getLogger(__name__)
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
58 |
"""
|
59 |
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
@@ -106,7 +120,7 @@ class Embedding(nn.Module):
|
|
106 |
def reset_parameters(self):
|
107 |
nn.init.normal_(self.embedding, std=self.initializer_range)
|
108 |
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
109 |
-
|
110 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
111 |
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
|
112 |
|
@@ -131,7 +145,7 @@ class Dropout(nn.Dropout):
|
|
131 |
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
|
132 |
return input
|
133 |
else:
|
134 |
-
if self.mask_p > 0. and self.training:
|
135 |
assert drop_mask is not None
|
136 |
drop_mask = drop_mask.to(input.dtype)
|
137 |
keep_prob = 1.0 - self.p
|
@@ -143,7 +157,7 @@ class Dropout(nn.Dropout):
|
|
143 |
multiplier = input.new_empty(dropout_shape).bernoulli_(keep_prob)
|
144 |
multiplier.div_(keep_prob)
|
145 |
return input * multiplier
|
146 |
-
elif self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
|
147 |
keep_prob = 1.0 - self.p
|
148 |
dropout_shape = list(input.shape)
|
149 |
for dim in self.broadcast_dims:
|
@@ -212,7 +226,6 @@ class LayerNorm(LayerNormBase):
|
|
212 |
else:
|
213 |
return tensor
|
214 |
|
215 |
-
|
216 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
217 |
if self.low_precision:
|
218 |
module_device = x.device
|
@@ -227,7 +240,7 @@ class LayerNorm(LayerNormBase):
|
|
227 |
)
|
228 |
else:
|
229 |
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
230 |
-
|
231 |
def reset_parameters(self):
|
232 |
if self.weight is not None:
|
233 |
torch.nn.init.ones_(self.weight) # type: ignore
|
@@ -239,6 +252,7 @@ class RMSLayerNorm(LayerNormBase):
|
|
239 |
"""
|
240 |
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
241 |
"""
|
|
|
242 |
def __init__(
|
243 |
self,
|
244 |
config: MolmoConfig,
|
@@ -263,7 +277,7 @@ class RMSLayerNorm(LayerNormBase):
|
|
263 |
return self.weight * x
|
264 |
else:
|
265 |
return x
|
266 |
-
|
267 |
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
268 |
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
269 |
# `is_autocast_cpu_enabled()` for CPU autocast.
|
@@ -274,7 +288,7 @@ class RMSLayerNorm(LayerNormBase):
|
|
274 |
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
275 |
else:
|
276 |
return tensor
|
277 |
-
|
278 |
def reset_parameters(self):
|
279 |
if self.weight is not None:
|
280 |
torch.nn.init.ones_(self.weight) # type: ignore
|
@@ -293,8 +307,7 @@ class RotaryEmbedding(nn.Module):
|
|
293 |
self.__cache = cache
|
294 |
# Warm up cache.
|
295 |
self.get_rotary_embedding(
|
296 |
-
config.max_position_embeddings or config.max_sequence_length,
|
297 |
-
_non_meta_init_device(config)
|
298 |
)
|
299 |
|
300 |
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
@@ -313,8 +326,14 @@ class RotaryEmbedding(nn.Module):
|
|
313 |
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
314 |
|
315 |
with torch.autocast(device.type, enabled=False):
|
316 |
-
dim =
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
319 |
freqs = einsum("i , j -> i j", seq, inv_freq)
|
320 |
if self.config.rope_impl == "cockatoo":
|
@@ -346,10 +365,7 @@ class RotaryEmbedding(nn.Module):
|
|
346 |
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
347 |
|
348 |
def forward(
|
349 |
-
self,
|
350 |
-
q: torch.Tensor,
|
351 |
-
k: torch.Tensor,
|
352 |
-
position_ids: Optional[torch.Tensor] = None
|
353 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
354 |
if self.config.rope_full_precision:
|
355 |
q_, k_ = q.float(), k.float()
|
@@ -360,7 +376,7 @@ class RotaryEmbedding(nn.Module):
|
|
360 |
batch_size = q_.shape[0]
|
361 |
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
362 |
if position_ids is not None:
|
363 |
-
freqs_cis_len =
|
364 |
else:
|
365 |
freqs_cis_len = key_len
|
366 |
pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device)
|
@@ -368,12 +384,8 @@ class RotaryEmbedding(nn.Module):
|
|
368 |
pos_cos = pos_cos.type_as(q_)
|
369 |
if position_ids is not None:
|
370 |
assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
|
371 |
-
pos_sin = pos_sin[0, 0][position_ids].view(
|
372 |
-
|
373 |
-
)
|
374 |
-
pos_cos = pos_cos[0, 0][position_ids].view(
|
375 |
-
(batch_size, 1, key_len, pos_cos.shape[-1])
|
376 |
-
)
|
377 |
q_ = self.apply_rotary_pos_emb(
|
378 |
pos_sin[:, :, key_len - query_len : key_len, :],
|
379 |
pos_cos[:, :, key_len - query_len : key_len, :],
|
@@ -466,11 +478,7 @@ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.de
|
|
466 |
|
467 |
|
468 |
class MolmoAttention(nn.Module):
|
469 |
-
def __init__(
|
470 |
-
self,
|
471 |
-
config: MolmoConfig,
|
472 |
-
cache: BufferCache
|
473 |
-
):
|
474 |
super().__init__()
|
475 |
self.config = config
|
476 |
self.__cache = cache
|
@@ -478,8 +486,7 @@ class MolmoAttention(nn.Module):
|
|
478 |
self.k_norm: Optional[LayerNormBase] = None
|
479 |
self.q_norm: Optional[LayerNormBase] = None
|
480 |
self.hidden_size = (
|
481 |
-
config.mlp_hidden_size if config.mlp_hidden_size is not None
|
482 |
-
else config.mlp_ratio * config.d_model
|
483 |
)
|
484 |
|
485 |
if config.attention_layer_norm:
|
@@ -508,29 +515,25 @@ class MolmoAttention(nn.Module):
|
|
508 |
config.n_kv_heads * head_dim,
|
509 |
)
|
510 |
self.att_proj = nn.Linear(
|
511 |
-
config.d_model,
|
|
|
512 |
bias=config.include_bias or config.qkv_bias,
|
513 |
-
device=config.init_device
|
514 |
-
)
|
515 |
-
self.attn_out = nn.Linear(
|
516 |
-
input_dim, config.d_model,
|
517 |
-
bias=config.include_bias,
|
518 |
-
device=config.init_device
|
519 |
)
|
520 |
-
self.
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
self.flash_attn_func = None
|
526 |
if self.config.attention_type == AttentionType.flash:
|
527 |
try:
|
528 |
from flash_attn import flash_attn_func
|
|
|
529 |
self.flash_attn_func = flash_attn_func
|
530 |
except ModuleNotFoundError:
|
531 |
pass
|
532 |
|
533 |
-
def attention(
|
|
|
534 |
q: torch.Tensor,
|
535 |
k: torch.Tensor,
|
536 |
v: torch.Tensor,
|
@@ -541,7 +544,7 @@ class MolmoAttention(nn.Module):
|
|
541 |
use_cache: bool = False,
|
542 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
543 |
B, T, C = q.size() # batch size, sequence length, d_model
|
544 |
-
dtype = k.dtype
|
545 |
|
546 |
# Optionally apply layer norm to keys and queries.
|
547 |
if self.q_norm is not None and self.k_norm is not None:
|
@@ -658,15 +661,7 @@ class MolmoAttention(nn.Module):
|
|
658 |
is_causal=is_causal,
|
659 |
)
|
660 |
|
661 |
-
def forward(
|
662 |
-
self,
|
663 |
-
x,
|
664 |
-
attention_bias,
|
665 |
-
position_ids,
|
666 |
-
drop_mask,
|
667 |
-
layer_past,
|
668 |
-
use_cache
|
669 |
-
):
|
670 |
if not self.config.norm_after:
|
671 |
atten_in = self.attn_norm(x)
|
672 |
else:
|
@@ -678,54 +673,45 @@ class MolmoAttention(nn.Module):
|
|
678 |
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
679 |
|
680 |
q, k, v = qkv.split(self.fused_dims, dim=-1)
|
681 |
-
|
682 |
# Get attention scores.
|
683 |
att, cache = self.attention(
|
684 |
-
q,
|
|
|
|
|
685 |
attention_bias,
|
686 |
position_ids=position_ids,
|
687 |
drop_mask=drop_mask,
|
688 |
layer_past=layer_past,
|
689 |
-
use_cache=use_cache
|
690 |
)
|
691 |
-
|
692 |
if self.config.norm_after:
|
693 |
att = self.attn_norm(att)
|
694 |
-
|
695 |
return att, cache
|
696 |
|
697 |
|
698 |
class MolmoMLP(nn.Module):
|
699 |
-
def __init__(
|
700 |
-
self,
|
701 |
-
config: MolmoConfig
|
702 |
-
):
|
703 |
# Feed-forward input projection.
|
704 |
super().__init__()
|
705 |
self.config = config
|
706 |
self.hidden_size = (
|
707 |
-
config.mlp_hidden_size if config.mlp_hidden_size is not None
|
708 |
-
else config.mlp_ratio * config.d_model
|
709 |
)
|
710 |
self.act = SwiGLU(config)
|
711 |
self.ff_proj = nn.Linear(
|
712 |
-
config.d_model,
|
713 |
-
|
714 |
-
bias=config.include_bias,
|
715 |
-
device=config.init_device
|
716 |
-
)
|
717 |
self.ff_out = nn.Linear(
|
718 |
int(self.act.output_multiplier * self.hidden_size),
|
719 |
config.d_model,
|
720 |
bias=config.include_bias,
|
721 |
device=config.init_device,
|
722 |
)
|
723 |
-
self.ff_norm = RMSLayerNorm(
|
724 |
-
|
725 |
-
size=config.d_model,
|
726 |
-
eps=config.layer_norm_eps
|
727 |
-
)
|
728 |
-
|
729 |
def forward(self, x):
|
730 |
if not self.config.norm_after:
|
731 |
x = self.ff_norm(x)
|
@@ -744,12 +730,8 @@ class MolmoDecoderLayer(nn.Module):
|
|
744 |
"""
|
745 |
A base class for transformer block implementations.
|
746 |
"""
|
747 |
-
|
748 |
-
|
749 |
-
layer_id: int,
|
750 |
-
config: MolmoConfig,
|
751 |
-
cache: BufferCache
|
752 |
-
):
|
753 |
super().__init__()
|
754 |
self.self_attn = MolmoAttention(config, cache)
|
755 |
self.mlp = MolmoMLP(config)
|
@@ -763,10 +745,7 @@ class MolmoDecoderLayer(nn.Module):
|
|
763 |
assert config.d_model % config.n_heads == 0
|
764 |
|
765 |
# Dropout.
|
766 |
-
self.dropout = Dropout(
|
767 |
-
config.residual_dropout,
|
768 |
-
mask_p=config.response_residual_dropout
|
769 |
-
)
|
770 |
|
771 |
def forward(
|
772 |
self,
|
@@ -787,12 +766,12 @@ class MolmoDecoderLayer(nn.Module):
|
|
787 |
"""
|
788 |
|
789 |
att, cache = self.self_attn(
|
790 |
-
x,
|
791 |
attention_bias=attention_bias,
|
792 |
position_ids=position_ids,
|
793 |
drop_mask=drop_mask,
|
794 |
layer_past=layer_past,
|
795 |
-
use_cache=use_cache
|
796 |
)
|
797 |
x = x + self.dropout(att, drop_mask=drop_mask)
|
798 |
og_x = x
|
@@ -822,7 +801,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
|
822 |
super().__init__()
|
823 |
self.config = config
|
824 |
self.use_bias = use_bias
|
825 |
-
|
826 |
v_cfg = config.vision_backbone
|
827 |
self.embed_dim = v_cfg.image_emb_dim
|
828 |
self.num_heads = v_cfg.image_num_heads
|
@@ -862,7 +841,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
|
862 |
if v_cfg.attention_dropout > 0:
|
863 |
self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
|
864 |
self.residual_dropout = Dropout(v_cfg.residual_dropout)
|
865 |
-
|
866 |
def reset_parameters(self):
|
867 |
nn.init.normal_(self.wq.weight, std=self.initializer_range)
|
868 |
nn.init.normal_(self.wk.weight, std=self.initializer_range)
|
@@ -879,15 +858,15 @@ class MultiHeadDotProductAttention(nn.Module):
|
|
879 |
|
880 |
def _merge_heads(self, hidden_states) -> torch.Tensor:
|
881 |
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
882 |
-
|
883 |
-
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
884 |
if inputs_kv is not None:
|
885 |
inputs_k = inputs_kv
|
886 |
inputs_v = inputs_kv
|
887 |
else:
|
888 |
inputs_k = inputs_q
|
889 |
inputs_v = inputs_q
|
890 |
-
|
891 |
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
|
892 |
|
893 |
xq = self._split_heads(xq, self.num_heads)
|
@@ -918,7 +897,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
|
918 |
xk.transpose(1, 2).contiguous(),
|
919 |
xv.transpose(1, 2).contiguous(),
|
920 |
is_causal=False,
|
921 |
-
dropout_p=self.config.vision_backbone.attention_dropout
|
922 |
).transpose(1, 2)
|
923 |
else:
|
924 |
raise NotImplementedError(self.config.attention_type)
|
@@ -940,7 +919,7 @@ class MultiHeadAttentionPool(nn.Module):
|
|
940 |
output_layer: bool = True,
|
941 |
mean_residual: bool = False,
|
942 |
query: str = "mean",
|
943 |
-
is_vit_layer: Optional[bool] = True
|
944 |
):
|
945 |
super().__init__()
|
946 |
self.config = config
|
@@ -950,7 +929,7 @@ class MultiHeadAttentionPool(nn.Module):
|
|
950 |
self.output_layer = output_layer
|
951 |
self.mean_residual = mean_residual
|
952 |
self.query = query
|
953 |
-
|
954 |
v_cfg = config.vision_backbone
|
955 |
input_dim = v_cfg.image_emb_dim
|
956 |
self.embed_dim = v_cfg.image_emb_dim * factor
|
@@ -985,7 +964,9 @@ class MultiHeadAttentionPool(nn.Module):
|
|
985 |
if query == "vector":
|
986 |
self.attention_query = nn.Parameter(
|
987 |
torch.zeros(
|
988 |
-
1,
|
|
|
|
|
989 |
),
|
990 |
)
|
991 |
|
@@ -1024,7 +1005,6 @@ class MultiHeadAttentionPool(nn.Module):
|
|
1024 |
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
1025 |
|
1026 |
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
|
1027 |
-
|
1028 |
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
|
1029 |
|
1030 |
if self.query == "mean":
|
@@ -1093,14 +1073,14 @@ class ViTMLP(nn.Module):
|
|
1093 |
bias=True,
|
1094 |
device=config.init_device,
|
1095 |
)
|
1096 |
-
|
1097 |
def reset_parameters(self):
|
1098 |
v_cfg = self.config.vision_backbone
|
1099 |
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
1100 |
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
1101 |
nn.init.zeros_(self.w1.bias)
|
1102 |
nn.init.zeros_(self.w2.bias)
|
1103 |
-
|
1104 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1105 |
x = self.w1(x)
|
1106 |
x = self.act(x)
|
@@ -1111,7 +1091,7 @@ class ViTMLP(nn.Module):
|
|
1111 |
class MLP(nn.Module):
|
1112 |
def __init__(self, config: MolmoConfig, input_dim: int, dropout: float = 0.0):
|
1113 |
super().__init__()
|
1114 |
-
self.config = config
|
1115 |
self.hidden_size = (
|
1116 |
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
1117 |
)
|
@@ -1135,15 +1115,15 @@ class MLP(nn.Module):
|
|
1135 |
bias=False,
|
1136 |
device=config.init_device,
|
1137 |
)
|
1138 |
-
|
1139 |
self.act = LlamaSwiGLU(config)
|
1140 |
self.dropout = Dropout(dropout)
|
1141 |
-
|
1142 |
def reset_parameters(self):
|
1143 |
nn.init.normal_(self.w1.weight, std=self.initializer_range)
|
1144 |
nn.init.normal_(self.w2.weight, std=self.initializer_range)
|
1145 |
nn.init.normal_(self.w3.weight, std=self.initializer_range)
|
1146 |
-
|
1147 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1148 |
x = self.w2(self.act(self.w1(x), self.w3(x)))
|
1149 |
x = self.dropout(x)
|
@@ -1154,26 +1134,26 @@ class Residual(nn.Module):
|
|
1154 |
def __init__(self, submodule: nn.Module):
|
1155 |
super().__init__()
|
1156 |
self.submodule = submodule
|
1157 |
-
|
1158 |
def reset_parameters(self):
|
1159 |
self.submodule.reset_parameters()
|
1160 |
-
|
1161 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1162 |
return x + self.submodule(x)
|
1163 |
|
1164 |
|
1165 |
class LayerNormFp32(nn.LayerNorm):
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
-
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
|
1178 |
|
1179 |
class ResidualAttentionBlock(nn.Module):
|
@@ -1200,7 +1180,7 @@ class ResidualAttentionBlock(nn.Module):
|
|
1200 |
self.feed_forward.reset_parameters()
|
1201 |
self.attention_norm.reset_parameters()
|
1202 |
self.ffn_norm.reset_parameters()
|
1203 |
-
|
1204 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1205 |
x = x + self.attention(self.attention_norm(x))
|
1206 |
x = x + self.feed_forward(self.ffn_norm(x))
|
@@ -1213,10 +1193,8 @@ class BlockCollection(nn.Module):
|
|
1213 |
self.config = config
|
1214 |
|
1215 |
v_cfg = config.vision_backbone
|
1216 |
-
self.resblocks = nn.ModuleList([
|
1217 |
-
|
1218 |
-
])
|
1219 |
-
|
1220 |
def reset_parameters(self):
|
1221 |
for r in self.resblocks:
|
1222 |
r.reset_parameters()
|
@@ -1240,7 +1218,7 @@ class VisionTransformer(nn.Module):
|
|
1240 |
|
1241 |
v_cfg = config.vision_backbone
|
1242 |
# class embeddings and positional embeddings
|
1243 |
-
self.scale = v_cfg.image_emb_dim
|
1244 |
self.class_embedding = nn.Parameter(
|
1245 |
torch.zeros(v_cfg.image_emb_dim, device=config.init_device),
|
1246 |
)
|
@@ -1264,14 +1242,14 @@ class VisionTransformer(nn.Module):
|
|
1264 |
)
|
1265 |
|
1266 |
self.transformer = BlockCollection(config)
|
1267 |
-
|
1268 |
def reset_parameters(self):
|
1269 |
nn.init.normal_(self.class_embedding, std=self.scale)
|
1270 |
nn.init.normal_(self.positional_embedding, std=self.scale)
|
1271 |
nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
1272 |
self.pre_ln.reset_parameters()
|
1273 |
self.transformer.reset_parameters()
|
1274 |
-
|
1275 |
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
1276 |
cls_emb = self.positional_embedding[0:1]
|
1277 |
pos_emb = self.positional_embedding[1:]
|
@@ -1279,7 +1257,7 @@ class VisionTransformer(nn.Module):
|
|
1279 |
pos_emb = pos_emb.reshape(
|
1280 |
(int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
|
1281 |
)
|
1282 |
-
|
1283 |
(patch_num_0, patch_num_1) = patch_num
|
1284 |
|
1285 |
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
@@ -1287,7 +1265,11 @@ class VisionTransformer(nn.Module):
|
|
1287 |
# antialias: default True in jax.image.resize
|
1288 |
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
1289 |
pos_emb = F.interpolate(
|
1290 |
-
pos_emb,
|
|
|
|
|
|
|
|
|
1291 |
)
|
1292 |
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
1293 |
|
@@ -1355,7 +1337,7 @@ class MolmoVisionBackbone(nn.Module):
|
|
1355 |
input_dim = nlayers * config.vision_backbone.image_emb_dim
|
1356 |
else:
|
1357 |
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
1358 |
-
|
1359 |
self.input_dim = input_dim
|
1360 |
|
1361 |
self.image_projector = MLP(config, input_dim)
|
@@ -1380,9 +1362,11 @@ class MolmoVisionBackbone(nn.Module):
|
|
1380 |
self.image_projector.reset_parameters()
|
1381 |
|
1382 |
@abstractmethod
|
1383 |
-
def forward(
|
|
|
|
|
1384 |
raise NotImplementedError
|
1385 |
-
|
1386 |
|
1387 |
class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
1388 |
def __init__(self, config: MolmoConfig):
|
@@ -1408,13 +1392,11 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1408 |
|
1409 |
self.pad_embed = None
|
1410 |
if config.image_padding_embed:
|
1411 |
-
image_dim = v_cfg.image_emb_dim*len(self.config.vit_layers)
|
1412 |
if config.image_padding_embed in ["pad_embed", "regress"]:
|
1413 |
-
self.pad_embed = nn.Parameter(
|
1414 |
-
torch.zeros((image_dim,), device=config.init_device))
|
1415 |
elif config.image_padding_embed == "pad_and_partial_pad":
|
1416 |
-
self.pad_embed = nn.Parameter(
|
1417 |
-
torch.zeros((2, image_dim), device=config.init_device))
|
1418 |
else:
|
1419 |
raise ValueError(config.image_padding_embed)
|
1420 |
|
@@ -1423,7 +1405,8 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1423 |
if self.config.vit_load_path:
|
1424 |
vit_load_path = Path(self.config.vit_load_path)
|
1425 |
state_dict_path = resource_path(
|
1426 |
-
vit_load_path.parent,
|
|
|
1427 |
local_cache=vit_load_path.parent,
|
1428 |
)
|
1429 |
assert state_dict_path.is_file(), f"Model file {str(state_dict_path)} not found"
|
@@ -1441,7 +1424,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1441 |
self.image_vit.reset_parameters()
|
1442 |
if self.config.use_cls_feature:
|
1443 |
nn.init.xavier_uniform_(self.cls_projector.weight)
|
1444 |
-
|
1445 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
1446 |
"""
|
1447 |
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
@@ -1469,15 +1452,17 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1469 |
if self.num_prefix_tokens > 0:
|
1470 |
cls_embed = image_features[:, 0]
|
1471 |
image_features = image_features[:, 1:]
|
1472 |
-
|
1473 |
image_features = image_features * mask
|
1474 |
image_features = image_features.view(B, T, N, -1)
|
1475 |
|
1476 |
cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
|
1477 |
|
1478 |
return image_features, cls_embed
|
1479 |
-
|
1480 |
-
def forward(
|
|
|
|
|
1481 |
cfg = self.config
|
1482 |
|
1483 |
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
@@ -1493,12 +1478,16 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1493 |
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
1494 |
elif cfg.image_padding_embed == "regress":
|
1495 |
pad_embed = self.pad_embed[None, None, None, :]
|
1496 |
-
image_features = image_features + pad_embed * torch.unsqueeze(
|
|
|
|
|
1497 |
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
1498 |
og_dtype = image_features.dtype
|
1499 |
pad_embed = self.pad_embed[:, None, None, None, :]
|
1500 |
all_pad = image_masks == 0
|
1501 |
-
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
|
|
|
|
|
1502 |
all_pad = all_pad.to(dtype=torch.float32)
|
1503 |
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
|
1504 |
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
|
@@ -1509,7 +1498,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1509 |
image_features = self.image_feature_dropout(image_features)
|
1510 |
if cls_embed is not None:
|
1511 |
cls_embed = self.image_feature_dropout(cls_embed)
|
1512 |
-
|
1513 |
image_features = image_features.reshape(
|
1514 |
(batch_size, num_image) + cfg.vision_backbone.image_num_patch + (-1,),
|
1515 |
)
|
@@ -1520,11 +1509,11 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1520 |
image_features,
|
1521 |
(0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
|
1522 |
)
|
1523 |
-
|
1524 |
# image pooling
|
1525 |
image_features = einops.rearrange(
|
1526 |
image_features,
|
1527 |
-
|
1528 |
dh=cfg.image_pooling_h,
|
1529 |
dw=cfg.image_pooling_w,
|
1530 |
)
|
@@ -1546,7 +1535,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1546 |
image_features = module(image_features)
|
1547 |
else:
|
1548 |
image_features = self.image_projector(image_features)
|
1549 |
-
|
1550 |
if self.config.use_cls_feature:
|
1551 |
cls_embed = self.cls_projector(cls_embed)
|
1552 |
if cfg.image_projector == ImageProjectType.mlpx2:
|
@@ -1554,7 +1543,7 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
|
1554 |
cls_embed = module(cls_embed)
|
1555 |
else:
|
1556 |
cls_embed = self.image_projector(cls_embed)
|
1557 |
-
|
1558 |
# image_features: (batch_size, num_image, num_patch, d_model)
|
1559 |
# cls_embed: (batch_size, num_image, d_model)
|
1560 |
return image_features, cls_embed
|
@@ -1579,11 +1568,7 @@ class MolmoPretrainedModel(PreTrainedModel):
|
|
1579 |
|
1580 |
|
1581 |
class MolmoModel(MolmoPretrainedModel):
|
1582 |
-
def __init__(
|
1583 |
-
self,
|
1584 |
-
config: MolmoConfig,
|
1585 |
-
init_params: bool = True
|
1586 |
-
):
|
1587 |
super().__init__(config)
|
1588 |
self.config = config
|
1589 |
self.__cache = BufferCache()
|
@@ -1616,10 +1601,10 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1616 |
config.d_model,
|
1617 |
device=config.init_device,
|
1618 |
initializer_range=config.initializer_range,
|
1619 |
-
new_embed_initializer_range=config.new_embedding_init_range
|
1620 |
)
|
1621 |
else:
|
1622 |
-
wte=nn.Embedding(
|
1623 |
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
1624 |
)
|
1625 |
|
@@ -1627,26 +1612,20 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1627 |
dict(
|
1628 |
wte=wte,
|
1629 |
emb_drop=Dropout(config.embedding_dropout),
|
1630 |
-
ln_f=RMSLayerNorm(
|
1631 |
-
config,
|
1632 |
-
size=config.d_model,
|
1633 |
-
eps=config.layer_norm_eps),
|
1634 |
)
|
1635 |
)
|
1636 |
|
1637 |
-
layers = [
|
1638 |
-
MolmoDecoderLayer(i, config, self.__cache) \
|
1639 |
-
for i in range(config.n_layers)
|
1640 |
-
]
|
1641 |
self.transformer.update({"layers": nn.ModuleList(layers)})
|
1642 |
-
|
1643 |
self.vision_backbone: Optional[MolmoVisionBackbone] = None
|
1644 |
if config.vision_backbone is not None:
|
1645 |
self.vision_backbone = MolmoVisionBackbone.build(config)
|
1646 |
|
1647 |
if self.vision_backbone is not None:
|
1648 |
self.vision_backbone.reset_with_pretrained_weights()
|
1649 |
-
|
1650 |
@property
|
1651 |
def device(self) -> torch.device:
|
1652 |
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
@@ -1655,7 +1634,6 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1655 |
else:
|
1656 |
return device
|
1657 |
|
1658 |
-
|
1659 |
def forward(
|
1660 |
self,
|
1661 |
input_ids: torch.LongTensor,
|
@@ -1716,7 +1694,9 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1716 |
has_image = images is not None
|
1717 |
|
1718 |
assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
|
1719 |
-
assert not (
|
|
|
|
|
1720 |
|
1721 |
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
1722 |
if past_key_values is None:
|
@@ -1730,16 +1710,17 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1730 |
|
1731 |
if self.config.use_position_ids and attention_mask is None:
|
1732 |
attention_mask = input_ids != -1
|
1733 |
-
|
1734 |
if subsegment_ids is not None:
|
1735 |
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
1736 |
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
1737 |
attention_mask = (
|
1738 |
-
subsegment_mask.to(attention_mask.dtype)
|
1739 |
-
attention_mask.unsqueeze(2)
|
1740 |
-
attention_mask.unsqueeze(1)
|
|
|
1741 |
if position_ids is None:
|
1742 |
-
raise ValueError(
|
1743 |
else:
|
1744 |
if self.config.use_position_ids and position_ids is None:
|
1745 |
position_ids = torch.clamp(
|
@@ -1776,10 +1757,8 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1776 |
|
1777 |
if self.config.use_cls_feature:
|
1778 |
x = torch.cat([x[:, :1], cls_embed, x[:, 1:-num_image]], dim=1)
|
1779 |
-
|
1780 |
-
valid_images = torch.any(
|
1781 |
-
(image_input_idx >= 0).view(batch_size, num_image, num_patch), dim=-1
|
1782 |
-
)
|
1783 |
valid_images = valid_images.to(attention_mask.dtype)
|
1784 |
attention_mask = torch.cat(
|
1785 |
[attention_mask[:, :1], valid_images, attention_mask[:, 1:-num_image]],
|
@@ -1796,13 +1775,13 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1796 |
|
1797 |
# normalized
|
1798 |
if self.config.normalize_input_embeds:
|
1799 |
-
x = x * (self.config.d_model
|
1800 |
|
1801 |
# Transform the attention mask into what the blocks expect.
|
1802 |
if attention_mask is not None:
|
1803 |
# shape: (batch_size, 1, 1, seq_len)
|
1804 |
if len(attention_mask.shape) == 2:
|
1805 |
-
attention_mask = attention_mask[:, :past_length + seq_len]
|
1806 |
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
1807 |
else:
|
1808 |
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
|
@@ -1852,16 +1831,23 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1852 |
|
1853 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
1854 |
# shape: (batch_size, seq_len, d_model)
|
1855 |
-
x, cache = layer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1856 |
|
1857 |
if attn_key_values is not None:
|
1858 |
assert cache is not None
|
1859 |
attn_key_values.append(cache)
|
1860 |
-
|
1861 |
if images is not None and self.config.use_cls_feature:
|
1862 |
assert num_image is not None
|
1863 |
x = torch.cat(
|
1864 |
-
[x[:, :1], x[:, num_image+1:], torch.zeros_like(x[:, :num_image])],
|
1865 |
dim=1,
|
1866 |
)
|
1867 |
|
@@ -1869,7 +1855,8 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1869 |
# shape: (batch_size, 1, d_model)
|
1870 |
if append_last_valid_logits is not None:
|
1871 |
last_valid_output = x[
|
1872 |
-
torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)
|
|
|
1873 |
x = last_valid_output.unsqueeze(1)
|
1874 |
else:
|
1875 |
x = x[:, -1, :].unsqueeze(1)
|
@@ -1886,23 +1873,20 @@ class MolmoModel(MolmoPretrainedModel):
|
|
1886 |
return MolmoOutput(
|
1887 |
last_hidden_states=x,
|
1888 |
attn_key_values=attn_key_values,
|
1889 |
-
hidden_states=tuple(all_hidden_states)
|
1890 |
-
|
1891 |
-
)
|
1892 |
|
1893 |
|
1894 |
class MolmoForCausalLM(PreTrainedModel):
|
1895 |
"""
|
1896 |
Extremely barebones HF model wrapper.
|
1897 |
"""
|
|
|
1898 |
config_class = MolmoConfig
|
1899 |
base_model_prefix = "model"
|
1900 |
_no_split_modules = ["MolmoDecoderLayer"]
|
1901 |
|
1902 |
-
def __init__(
|
1903 |
-
self,
|
1904 |
-
config: MolmoConfig
|
1905 |
-
):
|
1906 |
super().__init__(config)
|
1907 |
# model_config = create_model_config_from_pretrained_config(config)
|
1908 |
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
@@ -1972,7 +1956,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
1972 |
output_hidden_states=output_hidden_states,
|
1973 |
append_last_valid_logits=append_last_valid_logits,
|
1974 |
)
|
1975 |
-
|
1976 |
x = outputs.last_hidden_states
|
1977 |
if self.config.weight_tying:
|
1978 |
logits = F.linear(x, self.model.transformer.wte.weight, None) # type: ignore
|
@@ -1981,15 +1965,16 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
1981 |
|
1982 |
if self.config.scale_logits:
|
1983 |
logits.mul_(1 / math.sqrt(self.config.d_model))
|
1984 |
-
|
1985 |
if self.config.final_logit_softcapping is not None:
|
1986 |
logits = logits / self.config.final_logit_softcapping
|
1987 |
logits = torch.tanh(logits)
|
1988 |
logits = logits * self.config.final_logit_softcapping
|
1989 |
-
|
1990 |
if not last_logits_only and append_last_valid_logits is not None:
|
1991 |
last_valid_logit = logits[
|
1992 |
-
torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits
|
|
|
1993 |
logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
|
1994 |
|
1995 |
loss = None
|
@@ -2001,7 +1986,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2001 |
labels.masked_fill_(~(loss_masks > 0), -100)
|
2002 |
labels = labels.view(-1)
|
2003 |
logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
|
2004 |
-
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction=
|
2005 |
loss = loss_fct(logits_for_loss, labels)
|
2006 |
loss = loss.view(input_ids.shape[0], -1)
|
2007 |
loss = loss * loss_masks
|
@@ -2063,10 +2048,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2063 |
append_last_valid_logits: Optional[torch.Tensor] = None
|
2064 |
if self.config.use_position_ids and attention_mask is None:
|
2065 |
attention_mask = input_ids != -1
|
2066 |
-
position_ids = torch.clamp(
|
2067 |
-
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
2068 |
-
min=0
|
2069 |
-
)
|
2070 |
append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
|
2071 |
attention_mask = torch.cat(
|
2072 |
[attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
|
@@ -2074,7 +2056,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2074 |
)
|
2075 |
if attention_mask is not None:
|
2076 |
assert attention_mask.shape == (batch_size, mask_len)
|
2077 |
-
|
2078 |
out = super().generate(
|
2079 |
input_ids,
|
2080 |
generation_config,
|
@@ -2088,7 +2070,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2088 |
)
|
2089 |
|
2090 |
return out
|
2091 |
-
|
2092 |
def prepare_inputs_for_generation(
|
2093 |
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
2094 |
):
|
@@ -2116,7 +2098,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2116 |
model_inputs["image_masks"] = image_masks
|
2117 |
model_inputs["image_input_idx"] = image_input_idx
|
2118 |
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
2119 |
-
else:
|
2120 |
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
2121 |
|
2122 |
model_inputs.update(kwargs)
|
@@ -2236,7 +2218,4 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2236 |
# Tie weights again if needed
|
2237 |
self.tie_weights()
|
2238 |
|
2239 |
-
return model_embeds
|
2240 |
-
|
2241 |
-
# Always register for multi-modal features
|
2242 |
-
AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM)
|
|
|
32 |
from transformers import PreTrainedModel
|
33 |
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
34 |
|
35 |
+
# from olmo.util import resource_path
|
36 |
from .configuration_molmo import (
|
37 |
MolmoConfig,
|
38 |
VisionBackboneConfig,
|
39 |
VisionBackboneType,
|
40 |
ImagePooling2DType,
|
41 |
+
ImageProjectType,
|
42 |
AttentionType,
|
43 |
MolmoConfigurationError,
|
44 |
)
|
|
|
54 |
log = logging.getLogger(__name__)
|
55 |
|
56 |
|
57 |
+
def resource_path(
|
58 |
+
folder: Union[str, Path],
|
59 |
+
fname: str,
|
60 |
+
local_cache: Optional[Union[str, Path]] = None,
|
61 |
+
) -> Path:
|
62 |
+
if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
|
63 |
+
log.info(f"Found local cache of {fname} at {local_path}")
|
64 |
+
return local_path
|
65 |
+
else:
|
66 |
+
from cached_path import cached_path
|
67 |
+
|
68 |
+
return cached_path(f"{str(folder).rstrip('/')}/{fname}")
|
69 |
+
|
70 |
+
|
71 |
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
72 |
"""
|
73 |
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
|
|
120 |
def reset_parameters(self):
|
121 |
nn.init.normal_(self.embedding, std=self.initializer_range)
|
122 |
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
123 |
+
|
124 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
125 |
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
|
126 |
|
|
|
145 |
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
|
146 |
return input
|
147 |
else:
|
148 |
+
if self.mask_p > 0.0 and self.training:
|
149 |
assert drop_mask is not None
|
150 |
drop_mask = drop_mask.to(input.dtype)
|
151 |
keep_prob = 1.0 - self.p
|
|
|
157 |
multiplier = input.new_empty(dropout_shape).bernoulli_(keep_prob)
|
158 |
multiplier.div_(keep_prob)
|
159 |
return input * multiplier
|
160 |
+
elif self.p > 0.0 and len(self.broadcast_dims) > 0 and self.training:
|
161 |
keep_prob = 1.0 - self.p
|
162 |
dropout_shape = list(input.shape)
|
163 |
for dim in self.broadcast_dims:
|
|
|
226 |
else:
|
227 |
return tensor
|
228 |
|
|
|
229 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
230 |
if self.low_precision:
|
231 |
module_device = x.device
|
|
|
240 |
)
|
241 |
else:
|
242 |
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
243 |
+
|
244 |
def reset_parameters(self):
|
245 |
if self.weight is not None:
|
246 |
torch.nn.init.ones_(self.weight) # type: ignore
|
|
|
252 |
"""
|
253 |
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
254 |
"""
|
255 |
+
|
256 |
def __init__(
|
257 |
self,
|
258 |
config: MolmoConfig,
|
|
|
277 |
return self.weight * x
|
278 |
else:
|
279 |
return x
|
280 |
+
|
281 |
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
282 |
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
283 |
# `is_autocast_cpu_enabled()` for CPU autocast.
|
|
|
288 |
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
289 |
else:
|
290 |
return tensor
|
291 |
+
|
292 |
def reset_parameters(self):
|
293 |
if self.weight is not None:
|
294 |
torch.nn.init.ones_(self.weight) # type: ignore
|
|
|
307 |
self.__cache = cache
|
308 |
# Warm up cache.
|
309 |
self.get_rotary_embedding(
|
310 |
+
config.max_position_embeddings or config.max_sequence_length, _non_meta_init_device(config)
|
|
|
311 |
)
|
312 |
|
313 |
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
326 |
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
327 |
|
328 |
with torch.autocast(device.type, enabled=False):
|
329 |
+
dim = (
|
330 |
+
self.config.head_dim
|
331 |
+
if self.config.head_dim is not None
|
332 |
+
else self.config.d_model // self.config.n_heads
|
333 |
+
)
|
334 |
+
inv_freq = 1.0 / (
|
335 |
+
self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
|
336 |
+
)
|
337 |
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
338 |
freqs = einsum("i , j -> i j", seq, inv_freq)
|
339 |
if self.config.rope_impl == "cockatoo":
|
|
|
365 |
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
366 |
|
367 |
def forward(
|
368 |
+
self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
369 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
370 |
if self.config.rope_full_precision:
|
371 |
q_, k_ = q.float(), k.float()
|
|
|
376 |
batch_size = q_.shape[0]
|
377 |
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
378 |
if position_ids is not None:
|
379 |
+
freqs_cis_len = self.config.max_position_embeddings or self.config.max_sequence_length
|
380 |
else:
|
381 |
freqs_cis_len = key_len
|
382 |
pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device)
|
|
|
384 |
pos_cos = pos_cos.type_as(q_)
|
385 |
if position_ids is not None:
|
386 |
assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
|
387 |
+
pos_sin = pos_sin[0, 0][position_ids].view((batch_size, 1, key_len, pos_sin.shape[-1]))
|
388 |
+
pos_cos = pos_cos[0, 0][position_ids].view((batch_size, 1, key_len, pos_cos.shape[-1]))
|
|
|
|
|
|
|
|
|
389 |
q_ = self.apply_rotary_pos_emb(
|
390 |
pos_sin[:, :, key_len - query_len : key_len, :],
|
391 |
pos_cos[:, :, key_len - query_len : key_len, :],
|
|
|
478 |
|
479 |
|
480 |
class MolmoAttention(nn.Module):
|
481 |
+
def __init__(self, config: MolmoConfig, cache: BufferCache):
|
|
|
|
|
|
|
|
|
482 |
super().__init__()
|
483 |
self.config = config
|
484 |
self.__cache = cache
|
|
|
486 |
self.k_norm: Optional[LayerNormBase] = None
|
487 |
self.q_norm: Optional[LayerNormBase] = None
|
488 |
self.hidden_size = (
|
489 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
|
|
490 |
)
|
491 |
|
492 |
if config.attention_layer_norm:
|
|
|
515 |
config.n_kv_heads * head_dim,
|
516 |
)
|
517 |
self.att_proj = nn.Linear(
|
518 |
+
config.d_model,
|
519 |
+
sum(self.fused_dims),
|
520 |
bias=config.include_bias or config.qkv_bias,
|
521 |
+
device=config.init_device,
|
|
|
|
|
|
|
|
|
|
|
522 |
)
|
523 |
+
self.attn_out = nn.Linear(input_dim, config.d_model, bias=config.include_bias, device=config.init_device)
|
524 |
+
self.attn_norm = RMSLayerNorm(config, size=config.d_model, eps=config.layer_norm_eps)
|
525 |
+
|
526 |
+
self.flash_attn_func = None
|
|
|
|
|
527 |
if self.config.attention_type == AttentionType.flash:
|
528 |
try:
|
529 |
from flash_attn import flash_attn_func
|
530 |
+
|
531 |
self.flash_attn_func = flash_attn_func
|
532 |
except ModuleNotFoundError:
|
533 |
pass
|
534 |
|
535 |
+
def attention(
|
536 |
+
self,
|
537 |
q: torch.Tensor,
|
538 |
k: torch.Tensor,
|
539 |
v: torch.Tensor,
|
|
|
544 |
use_cache: bool = False,
|
545 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
546 |
B, T, C = q.size() # batch size, sequence length, d_model
|
547 |
+
dtype = k.dtype
|
548 |
|
549 |
# Optionally apply layer norm to keys and queries.
|
550 |
if self.q_norm is not None and self.k_norm is not None:
|
|
|
661 |
is_causal=is_causal,
|
662 |
)
|
663 |
|
664 |
+
def forward(self, x, attention_bias, position_ids, drop_mask, layer_past, use_cache):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
if not self.config.norm_after:
|
666 |
atten_in = self.attn_norm(x)
|
667 |
else:
|
|
|
673 |
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
674 |
|
675 |
q, k, v = qkv.split(self.fused_dims, dim=-1)
|
676 |
+
|
677 |
# Get attention scores.
|
678 |
att, cache = self.attention(
|
679 |
+
q,
|
680 |
+
k,
|
681 |
+
v,
|
682 |
attention_bias,
|
683 |
position_ids=position_ids,
|
684 |
drop_mask=drop_mask,
|
685 |
layer_past=layer_past,
|
686 |
+
use_cache=use_cache,
|
687 |
)
|
688 |
+
|
689 |
if self.config.norm_after:
|
690 |
att = self.attn_norm(att)
|
691 |
+
|
692 |
return att, cache
|
693 |
|
694 |
|
695 |
class MolmoMLP(nn.Module):
|
696 |
+
def __init__(self, config: MolmoConfig):
|
|
|
|
|
|
|
697 |
# Feed-forward input projection.
|
698 |
super().__init__()
|
699 |
self.config = config
|
700 |
self.hidden_size = (
|
701 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
|
|
702 |
)
|
703 |
self.act = SwiGLU(config)
|
704 |
self.ff_proj = nn.Linear(
|
705 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
706 |
+
)
|
|
|
|
|
|
|
707 |
self.ff_out = nn.Linear(
|
708 |
int(self.act.output_multiplier * self.hidden_size),
|
709 |
config.d_model,
|
710 |
bias=config.include_bias,
|
711 |
device=config.init_device,
|
712 |
)
|
713 |
+
self.ff_norm = RMSLayerNorm(config, size=config.d_model, eps=config.layer_norm_eps)
|
714 |
+
|
|
|
|
|
|
|
|
|
715 |
def forward(self, x):
|
716 |
if not self.config.norm_after:
|
717 |
x = self.ff_norm(x)
|
|
|
730 |
"""
|
731 |
A base class for transformer block implementations.
|
732 |
"""
|
733 |
+
|
734 |
+
def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
|
|
|
|
|
|
|
|
735 |
super().__init__()
|
736 |
self.self_attn = MolmoAttention(config, cache)
|
737 |
self.mlp = MolmoMLP(config)
|
|
|
745 |
assert config.d_model % config.n_heads == 0
|
746 |
|
747 |
# Dropout.
|
748 |
+
self.dropout = Dropout(config.residual_dropout, mask_p=config.response_residual_dropout)
|
|
|
|
|
|
|
749 |
|
750 |
def forward(
|
751 |
self,
|
|
|
766 |
"""
|
767 |
|
768 |
att, cache = self.self_attn(
|
769 |
+
x,
|
770 |
attention_bias=attention_bias,
|
771 |
position_ids=position_ids,
|
772 |
drop_mask=drop_mask,
|
773 |
layer_past=layer_past,
|
774 |
+
use_cache=use_cache,
|
775 |
)
|
776 |
x = x + self.dropout(att, drop_mask=drop_mask)
|
777 |
og_x = x
|
|
|
801 |
super().__init__()
|
802 |
self.config = config
|
803 |
self.use_bias = use_bias
|
804 |
+
|
805 |
v_cfg = config.vision_backbone
|
806 |
self.embed_dim = v_cfg.image_emb_dim
|
807 |
self.num_heads = v_cfg.image_num_heads
|
|
|
841 |
if v_cfg.attention_dropout > 0:
|
842 |
self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
|
843 |
self.residual_dropout = Dropout(v_cfg.residual_dropout)
|
844 |
+
|
845 |
def reset_parameters(self):
|
846 |
nn.init.normal_(self.wq.weight, std=self.initializer_range)
|
847 |
nn.init.normal_(self.wk.weight, std=self.initializer_range)
|
|
|
858 |
|
859 |
def _merge_heads(self, hidden_states) -> torch.Tensor:
|
860 |
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
861 |
+
|
862 |
+
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
863 |
if inputs_kv is not None:
|
864 |
inputs_k = inputs_kv
|
865 |
inputs_v = inputs_kv
|
866 |
else:
|
867 |
inputs_k = inputs_q
|
868 |
inputs_v = inputs_q
|
869 |
+
|
870 |
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
|
871 |
|
872 |
xq = self._split_heads(xq, self.num_heads)
|
|
|
897 |
xk.transpose(1, 2).contiguous(),
|
898 |
xv.transpose(1, 2).contiguous(),
|
899 |
is_causal=False,
|
900 |
+
dropout_p=self.config.vision_backbone.attention_dropout,
|
901 |
).transpose(1, 2)
|
902 |
else:
|
903 |
raise NotImplementedError(self.config.attention_type)
|
|
|
919 |
output_layer: bool = True,
|
920 |
mean_residual: bool = False,
|
921 |
query: str = "mean",
|
922 |
+
is_vit_layer: Optional[bool] = True,
|
923 |
):
|
924 |
super().__init__()
|
925 |
self.config = config
|
|
|
929 |
self.output_layer = output_layer
|
930 |
self.mean_residual = mean_residual
|
931 |
self.query = query
|
932 |
+
|
933 |
v_cfg = config.vision_backbone
|
934 |
input_dim = v_cfg.image_emb_dim
|
935 |
self.embed_dim = v_cfg.image_emb_dim * factor
|
|
|
964 |
if query == "vector":
|
965 |
self.attention_query = nn.Parameter(
|
966 |
torch.zeros(
|
967 |
+
1,
|
968 |
+
self.num_key_value_heads * self.head_dim,
|
969 |
+
device=config.init_device,
|
970 |
),
|
971 |
)
|
972 |
|
|
|
1005 |
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
1006 |
|
1007 |
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
|
|
|
1008 |
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
|
1009 |
|
1010 |
if self.query == "mean":
|
|
|
1073 |
bias=True,
|
1074 |
device=config.init_device,
|
1075 |
)
|
1076 |
+
|
1077 |
def reset_parameters(self):
|
1078 |
v_cfg = self.config.vision_backbone
|
1079 |
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
1080 |
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
1081 |
nn.init.zeros_(self.w1.bias)
|
1082 |
nn.init.zeros_(self.w2.bias)
|
1083 |
+
|
1084 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1085 |
x = self.w1(x)
|
1086 |
x = self.act(x)
|
|
|
1091 |
class MLP(nn.Module):
|
1092 |
def __init__(self, config: MolmoConfig, input_dim: int, dropout: float = 0.0):
|
1093 |
super().__init__()
|
1094 |
+
self.config = config
|
1095 |
self.hidden_size = (
|
1096 |
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
1097 |
)
|
|
|
1115 |
bias=False,
|
1116 |
device=config.init_device,
|
1117 |
)
|
1118 |
+
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version.
|
1119 |
self.act = LlamaSwiGLU(config)
|
1120 |
self.dropout = Dropout(dropout)
|
1121 |
+
|
1122 |
def reset_parameters(self):
|
1123 |
nn.init.normal_(self.w1.weight, std=self.initializer_range)
|
1124 |
nn.init.normal_(self.w2.weight, std=self.initializer_range)
|
1125 |
nn.init.normal_(self.w3.weight, std=self.initializer_range)
|
1126 |
+
|
1127 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1128 |
x = self.w2(self.act(self.w1(x), self.w3(x)))
|
1129 |
x = self.dropout(x)
|
|
|
1134 |
def __init__(self, submodule: nn.Module):
|
1135 |
super().__init__()
|
1136 |
self.submodule = submodule
|
1137 |
+
|
1138 |
def reset_parameters(self):
|
1139 |
self.submodule.reset_parameters()
|
1140 |
+
|
1141 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1142 |
return x + self.submodule(x)
|
1143 |
|
1144 |
|
1145 |
class LayerNormFp32(nn.LayerNorm):
|
1146 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).
|
1147 |
+
Derived from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py.
|
1148 |
+
"""
|
1149 |
+
|
1150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1151 |
+
orig_type = x.dtype
|
1152 |
+
if self.training:
|
1153 |
+
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
1154 |
+
else:
|
1155 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
1156 |
+
return x.to(orig_type)
|
1157 |
|
1158 |
|
1159 |
class ResidualAttentionBlock(nn.Module):
|
|
|
1180 |
self.feed_forward.reset_parameters()
|
1181 |
self.attention_norm.reset_parameters()
|
1182 |
self.ffn_norm.reset_parameters()
|
1183 |
+
|
1184 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1185 |
x = x + self.attention(self.attention_norm(x))
|
1186 |
x = x + self.feed_forward(self.ffn_norm(x))
|
|
|
1193 |
self.config = config
|
1194 |
|
1195 |
v_cfg = config.vision_backbone
|
1196 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
|
1197 |
+
|
|
|
|
|
1198 |
def reset_parameters(self):
|
1199 |
for r in self.resblocks:
|
1200 |
r.reset_parameters()
|
|
|
1218 |
|
1219 |
v_cfg = config.vision_backbone
|
1220 |
# class embeddings and positional embeddings
|
1221 |
+
self.scale = v_cfg.image_emb_dim**-0.5
|
1222 |
self.class_embedding = nn.Parameter(
|
1223 |
torch.zeros(v_cfg.image_emb_dim, device=config.init_device),
|
1224 |
)
|
|
|
1242 |
)
|
1243 |
|
1244 |
self.transformer = BlockCollection(config)
|
1245 |
+
|
1246 |
def reset_parameters(self):
|
1247 |
nn.init.normal_(self.class_embedding, std=self.scale)
|
1248 |
nn.init.normal_(self.positional_embedding, std=self.scale)
|
1249 |
nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
1250 |
self.pre_ln.reset_parameters()
|
1251 |
self.transformer.reset_parameters()
|
1252 |
+
|
1253 |
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
1254 |
cls_emb = self.positional_embedding[0:1]
|
1255 |
pos_emb = self.positional_embedding[1:]
|
|
|
1257 |
pos_emb = pos_emb.reshape(
|
1258 |
(int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
|
1259 |
)
|
1260 |
+
|
1261 |
(patch_num_0, patch_num_1) = patch_num
|
1262 |
|
1263 |
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
|
|
1265 |
# antialias: default True in jax.image.resize
|
1266 |
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
1267 |
pos_emb = F.interpolate(
|
1268 |
+
pos_emb,
|
1269 |
+
size=(patch_num_0, patch_num_1),
|
1270 |
+
mode="bicubic",
|
1271 |
+
align_corners=False,
|
1272 |
+
antialias=True,
|
1273 |
)
|
1274 |
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
1275 |
|
|
|
1337 |
input_dim = nlayers * config.vision_backbone.image_emb_dim
|
1338 |
else:
|
1339 |
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
1340 |
+
|
1341 |
self.input_dim = input_dim
|
1342 |
|
1343 |
self.image_projector = MLP(config, input_dim)
|
|
|
1362 |
self.image_projector.reset_parameters()
|
1363 |
|
1364 |
@abstractmethod
|
1365 |
+
def forward(
|
1366 |
+
self, images: torch.Tensor, image_masks: torch.Tensor
|
1367 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
1368 |
raise NotImplementedError
|
1369 |
+
|
1370 |
|
1371 |
class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
|
1372 |
def __init__(self, config: MolmoConfig):
|
|
|
1392 |
|
1393 |
self.pad_embed = None
|
1394 |
if config.image_padding_embed:
|
1395 |
+
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers)
|
1396 |
if config.image_padding_embed in ["pad_embed", "regress"]:
|
1397 |
+
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
|
|
|
1398 |
elif config.image_padding_embed == "pad_and_partial_pad":
|
1399 |
+
self.pad_embed = nn.Parameter(torch.zeros((2, image_dim), device=config.init_device))
|
|
|
1400 |
else:
|
1401 |
raise ValueError(config.image_padding_embed)
|
1402 |
|
|
|
1405 |
if self.config.vit_load_path:
|
1406 |
vit_load_path = Path(self.config.vit_load_path)
|
1407 |
state_dict_path = resource_path(
|
1408 |
+
vit_load_path.parent,
|
1409 |
+
vit_load_path.name,
|
1410 |
local_cache=vit_load_path.parent,
|
1411 |
)
|
1412 |
assert state_dict_path.is_file(), f"Model file {str(state_dict_path)} not found"
|
|
|
1424 |
self.image_vit.reset_parameters()
|
1425 |
if self.config.use_cls_feature:
|
1426 |
nn.init.xavier_uniform_(self.cls_projector.weight)
|
1427 |
+
|
1428 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
1429 |
"""
|
1430 |
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
|
1452 |
if self.num_prefix_tokens > 0:
|
1453 |
cls_embed = image_features[:, 0]
|
1454 |
image_features = image_features[:, 1:]
|
1455 |
+
|
1456 |
image_features = image_features * mask
|
1457 |
image_features = image_features.view(B, T, N, -1)
|
1458 |
|
1459 |
cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
|
1460 |
|
1461 |
return image_features, cls_embed
|
1462 |
+
|
1463 |
+
def forward(
|
1464 |
+
self, images: torch.Tensor, image_masks: torch.Tensor
|
1465 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
1466 |
cfg = self.config
|
1467 |
|
1468 |
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
|
|
1478 |
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
1479 |
elif cfg.image_padding_embed == "regress":
|
1480 |
pad_embed = self.pad_embed[None, None, None, :]
|
1481 |
+
image_features = image_features + pad_embed * torch.unsqueeze(
|
1482 |
+
torch.maximum(image_masks, torch.zeros_like(image_masks)), -1
|
1483 |
+
)
|
1484 |
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
1485 |
og_dtype = image_features.dtype
|
1486 |
pad_embed = self.pad_embed[:, None, None, None, :]
|
1487 |
all_pad = image_masks == 0
|
1488 |
+
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
|
1489 |
+
dtype=torch.float32
|
1490 |
+
)
|
1491 |
all_pad = all_pad.to(dtype=torch.float32)
|
1492 |
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
|
1493 |
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
|
|
|
1498 |
image_features = self.image_feature_dropout(image_features)
|
1499 |
if cls_embed is not None:
|
1500 |
cls_embed = self.image_feature_dropout(cls_embed)
|
1501 |
+
|
1502 |
image_features = image_features.reshape(
|
1503 |
(batch_size, num_image) + cfg.vision_backbone.image_num_patch + (-1,),
|
1504 |
)
|
|
|
1509 |
image_features,
|
1510 |
(0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
|
1511 |
)
|
1512 |
+
|
1513 |
# image pooling
|
1514 |
image_features = einops.rearrange(
|
1515 |
image_features,
|
1516 |
+
"b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
|
1517 |
dh=cfg.image_pooling_h,
|
1518 |
dw=cfg.image_pooling_w,
|
1519 |
)
|
|
|
1535 |
image_features = module(image_features)
|
1536 |
else:
|
1537 |
image_features = self.image_projector(image_features)
|
1538 |
+
|
1539 |
if self.config.use_cls_feature:
|
1540 |
cls_embed = self.cls_projector(cls_embed)
|
1541 |
if cfg.image_projector == ImageProjectType.mlpx2:
|
|
|
1543 |
cls_embed = module(cls_embed)
|
1544 |
else:
|
1545 |
cls_embed = self.image_projector(cls_embed)
|
1546 |
+
|
1547 |
# image_features: (batch_size, num_image, num_patch, d_model)
|
1548 |
# cls_embed: (batch_size, num_image, d_model)
|
1549 |
return image_features, cls_embed
|
|
|
1568 |
|
1569 |
|
1570 |
class MolmoModel(MolmoPretrainedModel):
|
1571 |
+
def __init__(self, config: MolmoConfig, init_params: bool = True):
|
|
|
|
|
|
|
|
|
1572 |
super().__init__(config)
|
1573 |
self.config = config
|
1574 |
self.__cache = BufferCache()
|
|
|
1601 |
config.d_model,
|
1602 |
device=config.init_device,
|
1603 |
initializer_range=config.initializer_range,
|
1604 |
+
new_embed_initializer_range=config.new_embedding_init_range,
|
1605 |
)
|
1606 |
else:
|
1607 |
+
wte = nn.Embedding(
|
1608 |
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
1609 |
)
|
1610 |
|
|
|
1612 |
dict(
|
1613 |
wte=wte,
|
1614 |
emb_drop=Dropout(config.embedding_dropout),
|
1615 |
+
ln_f=RMSLayerNorm(config, size=config.d_model, eps=config.layer_norm_eps),
|
|
|
|
|
|
|
1616 |
)
|
1617 |
)
|
1618 |
|
1619 |
+
layers = [MolmoDecoderLayer(i, config, self.__cache) for i in range(config.n_layers)]
|
|
|
|
|
|
|
1620 |
self.transformer.update({"layers": nn.ModuleList(layers)})
|
1621 |
+
|
1622 |
self.vision_backbone: Optional[MolmoVisionBackbone] = None
|
1623 |
if config.vision_backbone is not None:
|
1624 |
self.vision_backbone = MolmoVisionBackbone.build(config)
|
1625 |
|
1626 |
if self.vision_backbone is not None:
|
1627 |
self.vision_backbone.reset_with_pretrained_weights()
|
1628 |
+
|
1629 |
@property
|
1630 |
def device(self) -> torch.device:
|
1631 |
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
|
|
1634 |
else:
|
1635 |
return device
|
1636 |
|
|
|
1637 |
def forward(
|
1638 |
self,
|
1639 |
input_ids: torch.LongTensor,
|
|
|
1694 |
has_image = images is not None
|
1695 |
|
1696 |
assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
|
1697 |
+
assert not (
|
1698 |
+
has_image and past_key_values is not None
|
1699 |
+
), "Cached key and values should not be used with images."
|
1700 |
|
1701 |
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
1702 |
if past_key_values is None:
|
|
|
1710 |
|
1711 |
if self.config.use_position_ids and attention_mask is None:
|
1712 |
attention_mask = input_ids != -1
|
1713 |
+
|
1714 |
if subsegment_ids is not None:
|
1715 |
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
1716 |
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
1717 |
attention_mask = (
|
1718 |
+
subsegment_mask.to(attention_mask.dtype)
|
1719 |
+
* attention_mask.unsqueeze(2)
|
1720 |
+
* attention_mask.unsqueeze(1)
|
1721 |
+
)
|
1722 |
if position_ids is None:
|
1723 |
+
raise ValueError("Positioned ids must be given if using subsegment_ids")
|
1724 |
else:
|
1725 |
if self.config.use_position_ids and position_ids is None:
|
1726 |
position_ids = torch.clamp(
|
|
|
1757 |
|
1758 |
if self.config.use_cls_feature:
|
1759 |
x = torch.cat([x[:, :1], cls_embed, x[:, 1:-num_image]], dim=1)
|
1760 |
+
|
1761 |
+
valid_images = torch.any((image_input_idx >= 0).view(batch_size, num_image, num_patch), dim=-1)
|
|
|
|
|
1762 |
valid_images = valid_images.to(attention_mask.dtype)
|
1763 |
attention_mask = torch.cat(
|
1764 |
[attention_mask[:, :1], valid_images, attention_mask[:, 1:-num_image]],
|
|
|
1775 |
|
1776 |
# normalized
|
1777 |
if self.config.normalize_input_embeds:
|
1778 |
+
x = x * (self.config.d_model**0.5)
|
1779 |
|
1780 |
# Transform the attention mask into what the blocks expect.
|
1781 |
if attention_mask is not None:
|
1782 |
# shape: (batch_size, 1, 1, seq_len)
|
1783 |
if len(attention_mask.shape) == 2:
|
1784 |
+
attention_mask = attention_mask[:, : past_length + seq_len]
|
1785 |
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
1786 |
else:
|
1787 |
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
|
|
|
1831 |
|
1832 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
1833 |
# shape: (batch_size, seq_len, d_model)
|
1834 |
+
x, cache = layer(
|
1835 |
+
x,
|
1836 |
+
attention_bias=attention_bias,
|
1837 |
+
position_ids=position_ids,
|
1838 |
+
drop_mask=response_mask,
|
1839 |
+
layer_past=layer_past,
|
1840 |
+
use_cache=use_cache,
|
1841 |
+
)
|
1842 |
|
1843 |
if attn_key_values is not None:
|
1844 |
assert cache is not None
|
1845 |
attn_key_values.append(cache)
|
1846 |
+
|
1847 |
if images is not None and self.config.use_cls_feature:
|
1848 |
assert num_image is not None
|
1849 |
x = torch.cat(
|
1850 |
+
[x[:, :1], x[:, num_image + 1 :], torch.zeros_like(x[:, :num_image])],
|
1851 |
dim=1,
|
1852 |
)
|
1853 |
|
|
|
1855 |
# shape: (batch_size, 1, d_model)
|
1856 |
if append_last_valid_logits is not None:
|
1857 |
last_valid_output = x[
|
1858 |
+
torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)
|
1859 |
+
]
|
1860 |
x = last_valid_output.unsqueeze(1)
|
1861 |
else:
|
1862 |
x = x[:, -1, :].unsqueeze(1)
|
|
|
1873 |
return MolmoOutput(
|
1874 |
last_hidden_states=x,
|
1875 |
attn_key_values=attn_key_values,
|
1876 |
+
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
|
1877 |
+
)
|
|
|
1878 |
|
1879 |
|
1880 |
class MolmoForCausalLM(PreTrainedModel):
|
1881 |
"""
|
1882 |
Extremely barebones HF model wrapper.
|
1883 |
"""
|
1884 |
+
|
1885 |
config_class = MolmoConfig
|
1886 |
base_model_prefix = "model"
|
1887 |
_no_split_modules = ["MolmoDecoderLayer"]
|
1888 |
|
1889 |
+
def __init__(self, config: MolmoConfig):
|
|
|
|
|
|
|
1890 |
super().__init__(config)
|
1891 |
# model_config = create_model_config_from_pretrained_config(config)
|
1892 |
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
|
|
1956 |
output_hidden_states=output_hidden_states,
|
1957 |
append_last_valid_logits=append_last_valid_logits,
|
1958 |
)
|
1959 |
+
|
1960 |
x = outputs.last_hidden_states
|
1961 |
if self.config.weight_tying:
|
1962 |
logits = F.linear(x, self.model.transformer.wte.weight, None) # type: ignore
|
|
|
1965 |
|
1966 |
if self.config.scale_logits:
|
1967 |
logits.mul_(1 / math.sqrt(self.config.d_model))
|
1968 |
+
|
1969 |
if self.config.final_logit_softcapping is not None:
|
1970 |
logits = logits / self.config.final_logit_softcapping
|
1971 |
logits = torch.tanh(logits)
|
1972 |
logits = logits * self.config.final_logit_softcapping
|
1973 |
+
|
1974 |
if not last_logits_only and append_last_valid_logits is not None:
|
1975 |
last_valid_logit = logits[
|
1976 |
+
torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits
|
1977 |
+
]
|
1978 |
logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
|
1979 |
|
1980 |
loss = None
|
|
|
1986 |
labels.masked_fill_(~(loss_masks > 0), -100)
|
1987 |
labels = labels.view(-1)
|
1988 |
logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
|
1989 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
|
1990 |
loss = loss_fct(logits_for_loss, labels)
|
1991 |
loss = loss.view(input_ids.shape[0], -1)
|
1992 |
loss = loss * loss_masks
|
|
|
2048 |
append_last_valid_logits: Optional[torch.Tensor] = None
|
2049 |
if self.config.use_position_ids and attention_mask is None:
|
2050 |
attention_mask = input_ids != -1
|
2051 |
+
position_ids = torch.clamp(torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0)
|
|
|
|
|
|
|
2052 |
append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
|
2053 |
attention_mask = torch.cat(
|
2054 |
[attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
|
|
|
2056 |
)
|
2057 |
if attention_mask is not None:
|
2058 |
assert attention_mask.shape == (batch_size, mask_len)
|
2059 |
+
|
2060 |
out = super().generate(
|
2061 |
input_ids,
|
2062 |
generation_config,
|
|
|
2070 |
)
|
2071 |
|
2072 |
return out
|
2073 |
+
|
2074 |
def prepare_inputs_for_generation(
|
2075 |
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
2076 |
):
|
|
|
2098 |
model_inputs["image_masks"] = image_masks
|
2099 |
model_inputs["image_input_idx"] = image_input_idx
|
2100 |
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
2101 |
+
else:
|
2102 |
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
2103 |
|
2104 |
model_inputs.update(kwargs)
|
|
|
2218 |
# Tie weights again if needed
|
2219 |
self.tie_weights()
|
2220 |
|
2221 |
+
return model_embeds
|
|
|
|
|
|