vikhyatk commited on
Commit
6eb5bbc
1 Parent(s): a4becc9

Upload Moondream

Browse files
config.json CHANGED
@@ -2,7 +2,11 @@
2
  "architectures": [
3
  "Moondream"
4
  ],
5
- "model_type": "moondream",
 
 
 
 
6
  "phi_config": {
7
  "model_type": "phi-msft"
8
  },
 
2
  "architectures": [
3
  "Moondream"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_moondream.MoondreamConfig",
7
+ "AutoModelForCausalLM": "moondream.Moondream"
8
+ },
9
+ "model_type": "moondream1",
10
  "phi_config": {
11
  "model_type": "phi-msft"
12
  },
configuration_moondream.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ from typing import Optional
4
+ import math
5
+
6
+
7
+ class PhiConfig(PretrainedConfig):
8
+ model_type = "phi-msft"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size: int = 51200,
13
+ n_positions: int = 2048,
14
+ n_embd: int = 2048,
15
+ n_layer: int = 24,
16
+ n_inner: Optional[int] = None,
17
+ n_head: int = 32,
18
+ n_head_kv: Optional[int] = None,
19
+ rotary_dim: Optional[int] = 32,
20
+ activation_function: Optional[str] = "gelu_new",
21
+ flash_attn: bool = False,
22
+ flash_rotary: bool = False,
23
+ fused_dense: bool = False,
24
+ attn_pdrop: float = 0.0,
25
+ embd_pdrop: float = 0.0,
26
+ resid_pdrop: float = 0.0,
27
+ layer_norm_epsilon: float = 1e-5,
28
+ initializer_range: float = 0.02,
29
+ tie_word_embeddings: bool = False,
30
+ pad_vocab_size_multiple: int = 64,
31
+ gradient_checkpointing: bool = False,
32
+ **kwargs
33
+ ):
34
+ pad_vocab_size = (
35
+ math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
36
+ )
37
+ super().__init__(
38
+ vocab_size=pad_vocab_size,
39
+ n_positions=n_positions,
40
+ n_embd=n_embd,
41
+ n_layer=n_layer,
42
+ n_inner=n_inner,
43
+ n_head=n_head,
44
+ n_head_kv=n_head_kv,
45
+ activation_function=activation_function,
46
+ attn_pdrop=attn_pdrop,
47
+ embd_pdrop=embd_pdrop,
48
+ resid_pdrop=resid_pdrop,
49
+ layer_norm_epsilon=layer_norm_epsilon,
50
+ initializer_range=initializer_range,
51
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
52
+ tie_word_embeddings=tie_word_embeddings,
53
+ gradient_checkpointing=gradient_checkpointing,
54
+ **kwargs
55
+ )
56
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
57
+ self.flash_attn = flash_attn
58
+ self.flash_rotary = flash_rotary
59
+ self.fused_dense = fused_dense
60
+
61
+ attribute_map = {
62
+ "max_position_embeddings": "n_positions",
63
+ "hidden_size": "n_embd",
64
+ "num_attention_heads": "n_head",
65
+ "num_hidden_layers": "n_layer",
66
+ }
67
+
68
+
69
+ class MoondreamConfig(PretrainedConfig):
70
+ model_type = "moondream1"
71
+
72
+ def __init__(self, **kwargs):
73
+ self.phi_config = PhiConfig(**kwargs)
74
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.36.2"
4
+ }
modeling_phi.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ #
4
+ # Copyright (c) 2022, Tri Dao, [email protected].
5
+ # Licensed under the BSD 3-Clause License.
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, Optional, Union, Tuple
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange, repeat
14
+ from transformers import PretrainedConfig, PreTrainedModel
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+
18
+ from .configuration_moondream import PhiConfig
19
+
20
+ FusedDense = None
21
+
22
+
23
+ @dataclass
24
+ class InferenceParams:
25
+ max_seqlen: int
26
+ max_batch_size: int
27
+ seqlen_offset: int = 0
28
+ batch_size_offset: int = 0
29
+ key_value_memory_dict: Dict[str, Any] = field(default_factory=dict)
30
+ lengths_per_sample: torch.Tensor = None
31
+
32
+
33
+ class Embedding(nn.Module):
34
+ def __init__(self, config: PretrainedConfig):
35
+ super().__init__()
36
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
37
+ self.drop = nn.Dropout(config.embd_pdrop)
38
+
39
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
40
+ return self.drop(self.wte(input_ids.view(-1, input_ids.size(-1))))
41
+
42
+
43
+ def _apply_rotary_emb(x, cos, sin):
44
+ seqlen, rotary_dim = x.size(1), cos.size(1) * 2
45
+ x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
46
+ x1, x2 = x_rot.chunk(2, dim=-1)
47
+ c, s = cos[:seqlen].unsqueeze(1), sin[:seqlen].unsqueeze(1)
48
+ x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1)
49
+ return torch.cat([x_rot.to(x.dtype), x_pass], dim=-1)
50
+
51
+
52
+ def _apply_rotary_emb_kv(
53
+ kv: torch.FloatTensor, cos: torch.FloatTensor, sin: torch.FloatTensor
54
+ ) -> torch.FloatTensor:
55
+ seqlen, rotary_dim = kv.shape[1], cos.shape[-1] * 2
56
+ k_rot = kv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
57
+ k_pass = kv[:, :, 0, :, rotary_dim:]
58
+ c, s = cos[:seqlen].unsqueeze(1), sin[:seqlen].unsqueeze(1)
59
+ k_rot = torch.cat(
60
+ [k_rot[0] * c - k_rot[1] * s, k_rot[0] * s + k_rot[1] * c], dim=-1
61
+ )
62
+ return torch.cat(
63
+ [torch.cat([k_rot, k_pass], dim=-1).unsqueeze(2), kv[:, :, 1:2, :, :]], dim=2
64
+ )
65
+
66
+
67
+ def _apply_rotary_emb_qkv(
68
+ qkv: torch.FloatTensor, cos: torch.FloatTensor, sin: torch.FloatTensor
69
+ ) -> torch.FloatTensor:
70
+ seqlen, rotary_dim = qkv.shape[1], cos.shape[1] * 2
71
+
72
+ c = cos[:seqlen].unsqueeze(1)
73
+ s = sin[:seqlen].unsqueeze(1)
74
+
75
+ qkv_rot = torch.stack(
76
+ [
77
+ torch.cat(
78
+ [
79
+ qkv[:, :, i, :, : rotary_dim // 2] * c
80
+ - qkv[:, :, i, :, rotary_dim // 2 : rotary_dim] * s,
81
+ qkv[:, :, i, :, : rotary_dim // 2] * s
82
+ + qkv[:, :, i, :, rotary_dim // 2 : rotary_dim] * c,
83
+ ],
84
+ dim=-1,
85
+ ).to(qkv.dtype)
86
+ for i in range(2)
87
+ ],
88
+ dim=2,
89
+ )
90
+
91
+ qkv_pass = qkv[:, :, :2, :, rotary_dim:].unsqueeze(2)
92
+ qkv_v = qkv[:, :, 2:3, :, :]
93
+ return torch.cat([qkv_rot, qkv_pass, qkv_v], dim=2)
94
+
95
+
96
+ class RotaryEmbedding(nn.Module):
97
+ # Enhanced Transformer with Rotary Position Embedding (https://arxiv.org/pdf/2104.09864.pdf)
98
+ def __init__(
99
+ self,
100
+ dim: int,
101
+ base: int = 10000,
102
+ scale_base: Optional[float] = None,
103
+ pos_idx_in_fp32: bool = True,
104
+ max_position_embeddings: int = 2048,
105
+ device: Optional[str] = None,
106
+ ) -> None:
107
+ super().__init__()
108
+ # fp32 is preferred since the output of `torch.arange` can be quite large and bf16 would lose a lot of precision
109
+ self.dim, self.base, self.pos_idx_in_fp32, self.device = (
110
+ dim,
111
+ float(base),
112
+ pos_idx_in_fp32,
113
+ device,
114
+ )
115
+ self.max_position_embeddings = max_position_embeddings
116
+ if scale_base is not None:
117
+ raise NotImplementedError
118
+
119
+ # Generate and register the non-trainable buffers
120
+ self.register_buffer(
121
+ "inv_freq", self._compute_inv_freq(device), persistent=False
122
+ )
123
+ self.register_buffer(
124
+ "scale", self._calculate_scale(dim, scale_base, device), persistent=False
125
+ )
126
+ self._update_cos_sin_cache(
127
+ max_position_embeddings, device=device, dtype=torch.float32
128
+ )
129
+
130
+ def _calculate_scale(self, dim, scale_base, device):
131
+ return (
132
+ (
133
+ (
134
+ torch.arange(0, dim, 2, device=device, dtype=torch.float32)
135
+ + 0.4 * dim
136
+ )
137
+ / (1.4 * dim)
138
+ )
139
+ if scale_base is not None
140
+ else None
141
+ )
142
+
143
+ def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
144
+ return 1.0 / (
145
+ self.base
146
+ ** (
147
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
148
+ / self.dim
149
+ )
150
+ )
151
+
152
+ def _update_cos_sin_cache(
153
+ self,
154
+ seqlen: int,
155
+ device: Optional[str] = None,
156
+ dtype: Optional[torch.dtype] = None,
157
+ ) -> None:
158
+ self._seq_len_cached = seqlen
159
+ t = torch.arange(
160
+ seqlen,
161
+ device=device,
162
+ dtype=torch.float32 if self.pos_idx_in_fp32 else self.inv_freq.dtype,
163
+ )
164
+ inv_freq = (
165
+ self._compute_inv_freq(device=device)
166
+ if self.pos_idx_in_fp32 and self.inv_freq.dtype != torch.float32
167
+ else self.inv_freq
168
+ )
169
+
170
+ freqs = torch.outer(t, inv_freq)
171
+
172
+ def apply_scale(freqs, scale, operator, dtype):
173
+ result = operator(freqs)
174
+ return (result / scale).to(dtype) if scale is not None else result.to(dtype)
175
+
176
+ if scale := self.scale:
177
+ power = (
178
+ torch.arange(seqlen, dtype=scale.dtype, device=scale.device)
179
+ - seqlen // 2
180
+ ) / self.scale_base
181
+ scale = scale.to(device=power.device) ** power.unsqueeze(1)
182
+
183
+ self._cos_cached = apply_scale(
184
+ freqs, 1 / scale if scale is not None else None, torch.cos, dtype
185
+ )
186
+ self._sin_cached = apply_scale(
187
+ freqs, 1 / scale if scale is not None else None, torch.sin, dtype
188
+ )
189
+ if scale is not None:
190
+ self._cos_k_cached = apply_scale(freqs, scale, torch.cos, dtype)
191
+ self._sin_k_cached = apply_scale(freqs, scale, torch.sin, dtype)
192
+
193
+ def forward(
194
+ self,
195
+ qkv: torch.Tensor,
196
+ kv: Optional[torch.Tensor] = None,
197
+ seqlen_offset: int = 0,
198
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
199
+ should_update = (
200
+ self._seq_len_cached < qkv.shape[1] + seqlen_offset
201
+ or self._cos_cached.device != qkv.device
202
+ or self._cos_cached.dtype != qkv.dtype
203
+ or (self.training and self._cos_cached.is_inference())
204
+ )
205
+
206
+ if should_update:
207
+ self._update_cos_sin_cache(
208
+ qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
209
+ )
210
+
211
+ offset_cos = self._cos_cached[seqlen_offset:]
212
+ offset_sin = self._sin_cached[seqlen_offset:]
213
+
214
+ if kv is None:
215
+ return _apply_rotary_emb_qkv(qkv, offset_cos, offset_sin)
216
+ else:
217
+ return _apply_rotary_emb(qkv, offset_cos, offset_sin), _apply_rotary_emb_kv(
218
+ kv, offset_cos, offset_sin
219
+ )
220
+
221
+
222
+ class MLP(nn.Module):
223
+ def __init__(
224
+ self,
225
+ config: PretrainedConfig,
226
+ n_inner: Optional[int] = None,
227
+ act_fn: Optional[str] = None,
228
+ ) -> None:
229
+ super().__init__()
230
+ n_inner = n_inner or getattr(config, "n_inner", None) or 4 * config.n_embd
231
+ act_fn = act_fn or config.activation_function
232
+
233
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
234
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
235
+ self.act = ACT2FN[act_fn]
236
+
237
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
238
+ return self.fc2(self.act(self.fc1(hidden_states)))
239
+
240
+
241
+ # Flash Attention (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py)
242
+ class SelfAttention(nn.Module):
243
+ def __init__(
244
+ self,
245
+ causal: bool = True,
246
+ softmax_scale: Optional[float] = None,
247
+ attention_dropout: float = 0.0,
248
+ ):
249
+ super().__init__()
250
+ self.causal = causal
251
+ self.softmax_scale = softmax_scale
252
+ self.drop = nn.Dropout(attention_dropout)
253
+
254
+ @torch.autocast("cpu", enabled=False)
255
+ @torch.autocast("cuda", enabled=False)
256
+ def forward(
257
+ self,
258
+ qkv: torch.FloatTensor,
259
+ causal: Optional[bool] = None,
260
+ key_padding_mask: Optional[torch.BoolTensor] = None,
261
+ ):
262
+ q, k, v = qkv.chunk(3, dim=-1)
263
+ scale = self.softmax_scale or 1.0 / q.size(-1) ** 0.5
264
+
265
+ scores = (
266
+ torch.einsum("bthd,bshd->bhts", q.to(torch.float32), k.to(torch.float32))
267
+ * scale
268
+ )
269
+ if causal or self.causal:
270
+ scores.triu_(1).fill_(-10000.0)
271
+ if key_padding_mask is not None:
272
+ scores.masked_fill_(key_padding_mask[:, None, None, :], -10000.0)
273
+
274
+ attn = self.drop(torch.softmax(scores, dim=-1).to(v.dtype))
275
+ return torch.einsum("bhts,bshd->bthd", attn, v)
276
+
277
+
278
+ # Flash Attention (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py)
279
+ class CrossAttention(nn.Module):
280
+ def __init__(self, causal=True, softmax_scale=None, attention_dropout=0.0):
281
+ super().__init__()
282
+ self.causal = causal
283
+ self.softmax_scale = softmax_scale
284
+ self.drop = nn.Dropout(attention_dropout)
285
+
286
+ @torch.autocast("cpu", enabled=False)
287
+ @torch.autocast("cuda", enabled=False)
288
+ def forward(
289
+ self,
290
+ q: torch.FloatTensor,
291
+ kv: torch.FloatTensor,
292
+ causal: bool = None,
293
+ key_padding_mask: Optional[torch.BoolTensor] = None,
294
+ ) -> torch.FloatTensor:
295
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
296
+ seqlen_k = kv.shape[1]
297
+
298
+ if kv.shape[3] != q.shape[2]:
299
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
300
+ k, v = kv.unbind(dim=2)
301
+
302
+ q = q.to(torch.float32)
303
+ k = k.to(torch.float32)
304
+
305
+ causal = self.causal if causal is None else causal
306
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
307
+
308
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation using float16, which might lead to overflow
309
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
310
+
311
+ if key_padding_mask is not None:
312
+ padding_mask = torch.full(
313
+ (batch_size, seqlen_k),
314
+ -10000.0,
315
+ dtype=scores.dtype,
316
+ device=scores.device,
317
+ )
318
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
319
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
320
+
321
+ if causal:
322
+ rows = rearrange(
323
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
324
+ )
325
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
326
+ causal_mask = cols > rows + seqlen_k - seqlen_q
327
+ scores = scores.masked_fill(causal_mask, -10000.0)
328
+
329
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
330
+ attention = self.drop(attention)
331
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
332
+
333
+ return output
334
+
335
+
336
+ def _find_mha_dims(
337
+ config: PretrainedConfig,
338
+ n_head: Optional[int] = None,
339
+ n_head_kv: Optional[int] = None,
340
+ head_dim: Optional[int] = None,
341
+ ) -> Tuple[int, int]:
342
+ if n_head is None and head_dim is None:
343
+ head_dim = config.n_embd // config.n_head
344
+ n_head = config.n_head
345
+ elif n_head is None or head_dim is None:
346
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
347
+ if n_head_kv is None:
348
+ n_head_kv = getattr(config, "n_head_kv", None) or n_head
349
+ return n_head, n_head_kv, head_dim
350
+
351
+
352
+ def _update_kv_cache(
353
+ kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int
354
+ ) -> torch.FloatTensor:
355
+ num_heads, head_dim = kv.shape[-2:]
356
+ layer_memory = inference_params.key_value_memory_dict.setdefault(
357
+ layer_idx,
358
+ torch.empty(
359
+ inference_params.max_batch_size,
360
+ inference_params.max_seqlen,
361
+ 2,
362
+ num_heads,
363
+ head_dim,
364
+ dtype=kv.dtype,
365
+ device=kv.device,
366
+ ),
367
+ )
368
+
369
+ batch_slice = slice(
370
+ inference_params.batch_size_offset,
371
+ inference_params.batch_size_offset + kv.shape[0],
372
+ )
373
+ seqlen_slice = slice(
374
+ inference_params.seqlen_offset, inference_params.seqlen_offset + kv.shape[1]
375
+ )
376
+
377
+ if seqlen_slice.stop >= inference_params.max_seqlen:
378
+ layer_memory = torch.cat((layer_memory, kv), dim=1)
379
+ inference_params.key_value_memory_dict[layer_idx] = layer_memory
380
+
381
+ layer_memory[batch_slice, seqlen_slice, ...] = kv
382
+ return layer_memory[batch_slice, : seqlen_slice.stop, ...]
383
+
384
+
385
+ # Multi-head attention layer with rotary embeddings
386
+ class MHA(nn.Module):
387
+ def __init__(
388
+ self,
389
+ config,
390
+ dtype=None,
391
+ device=None,
392
+ rotary_dim=None,
393
+ rotary_base=10000.0,
394
+ rotary_scale_base=None,
395
+ n_head=None,
396
+ n_head_kv=None,
397
+ head_dim=None,
398
+ bias=True,
399
+ causal=True,
400
+ softmax_scale=None,
401
+ layer_idx=None,
402
+ return_residual=False,
403
+ checkpointing=False,
404
+ ):
405
+ super().__init__()
406
+
407
+ # Set rotary embedding if specified
408
+ self.rotary_dim = rotary_dim or getattr(config, "rotary_dim", 0)
409
+ if self.rotary_dim:
410
+ self.rotary_emb = RotaryEmbedding(
411
+ self.rotary_dim,
412
+ base=rotary_base,
413
+ scale_base=rotary_scale_base,
414
+ device=device,
415
+ max_position_embeddings=config.n_positions,
416
+ )
417
+
418
+ # Determine MHA dims from arguments or config
419
+ self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
420
+ config, n_head, n_head_kv, head_dim
421
+ )
422
+ op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
423
+ hidden_size = config.n_embd
424
+
425
+ # Choose Linear class based on config, FusedDense is optional
426
+ LinearClass = (
427
+ FusedDense if config.fused_dense and FusedDense is not None else nn.Linear
428
+ )
429
+ self.Wqkv = LinearClass(
430
+ hidden_size, op_size, bias=bias, device=device, dtype=dtype
431
+ )
432
+ self.out_proj = LinearClass(
433
+ hidden_size, hidden_size, bias=bias, device=device, dtype=dtype
434
+ )
435
+
436
+ # Initialize attention mechanisms
437
+ attn_kwargs = {
438
+ "causal": causal,
439
+ "softmax_scale": softmax_scale,
440
+ "attention_dropout": config.attn_pdrop,
441
+ }
442
+ self.inner_attn = SelfAttention(**attn_kwargs)
443
+ self.inner_cross_attn = CrossAttention(**attn_kwargs)
444
+
445
+ self.layer_idx = layer_idx
446
+ self.return_residual = return_residual
447
+ self.checkpointing = checkpointing
448
+
449
+ def _forward_self_attn(
450
+ self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
451
+ ) -> torch.FloatTensor:
452
+ qkv = rearrange(
453
+ self.Wqkv(x), "... (three h d) -> ... three h d", three=3, d=self.head_dim
454
+ )
455
+ if self.rotary_dim > 0:
456
+ qkv = self.rotary_emb(qkv)
457
+ attn_func = (
458
+ torch.utils.checkpoint.checkpoint
459
+ if self.checkpointing
460
+ else lambda f, *args, **kwargs: f(*args, **kwargs)
461
+ )
462
+ return attn_func(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
463
+
464
+ def _forward_cross_attn(
465
+ self,
466
+ x: torch.FloatTensor,
467
+ past_key_values: Optional[InferenceParams],
468
+ key_padding_mask: Optional[torch.BoolTensor],
469
+ ) -> torch.FloatTensor:
470
+ qkv = self.Wqkv(x)
471
+ q, kv = (
472
+ qkv[..., : self.n_head * self.head_dim],
473
+ qkv[..., self.n_head * self.head_dim :],
474
+ )
475
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
476
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
477
+
478
+ seqlen_offset = (
479
+ past_key_values.seqlen_offset if past_key_values is not None else 0
480
+ )
481
+ causal = None if seqlen_offset == 0 else False
482
+ if self.rotary_dim > 0:
483
+ q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
484
+
485
+ if past_key_values is not None:
486
+ kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
487
+
488
+ attn_func = (
489
+ torch.utils.checkpoint.checkpoint
490
+ if self.checkpointing
491
+ else lambda fn, *args, **kwargs: fn(*args, **kwargs)
492
+ )
493
+
494
+ return attn_func(
495
+ self.inner_cross_attn,
496
+ q,
497
+ kv,
498
+ key_padding_mask=key_padding_mask,
499
+ causal=causal,
500
+ )
501
+
502
+ def forward(
503
+ self,
504
+ x: torch.FloatTensor,
505
+ past_key_values: Optional[InferenceParams] = None,
506
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
507
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
508
+ attention_mask = attention_mask.bool() if attention_mask is not None else None
509
+ use_cross_attn = self.n_head != self.n_head_kv or past_key_values is not None
510
+ attn_output_function = (
511
+ self._forward_cross_attn if use_cross_attn else self._forward_self_attn
512
+ )
513
+ attn_output = (
514
+ attn_output_function(x, past_key_values, attention_mask)
515
+ if use_cross_attn
516
+ else attn_output_function(x, attention_mask)
517
+ )
518
+ output = self.out_proj(rearrange(attn_output, "... h d -> ... (h d)"))
519
+ return (output, x) if self.return_residual else output
520
+
521
+
522
+ # Parallel block. This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
523
+ class ParallelBlock(nn.Module):
524
+ def __init__(self, config: PretrainedConfig, block_idx: Optional[int] = None):
525
+ super().__init__()
526
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
527
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
528
+ self.block_idx = block_idx
529
+ self.mixer = MHA(config, layer_idx=block_idx)
530
+ self.mlp = MLP(config)
531
+
532
+ def forward(
533
+ self,
534
+ hidden_states: torch.FloatTensor,
535
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
536
+ attention_mask: Optional[torch.BoolTensor] = None,
537
+ ) -> torch.FloatTensor:
538
+ residual = hidden_states
539
+ hidden_states = self.ln(hidden_states)
540
+
541
+ attn_outputs = self.mixer(
542
+ hidden_states,
543
+ past_key_values=past_key_values,
544
+ attention_mask=attention_mask,
545
+ )
546
+ if isinstance(attn_outputs, tuple):
547
+ attn_outputs = attn_outputs[0]
548
+
549
+ attn_outputs = self.resid_dropout(attn_outputs)
550
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
551
+ return attn_outputs + feed_forward_hidden_states + residual
552
+
553
+
554
+ class CausalLMHead(nn.Module):
555
+ """Causal Language Modeling head. Simplified version."""
556
+
557
+ def __init__(self, config):
558
+ super().__init__()
559
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
560
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
561
+
562
+ def forward(self, hidden_states):
563
+ return self.linear(self.ln(hidden_states)).to(torch.float32)
564
+
565
+
566
+ # Improving Language Understanding by Generative Pre-Training
567
+ # (https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)
568
+ class CausalLMLoss(nn.Module):
569
+ def __init__(self, shift_labels: bool = True) -> None:
570
+ super().__init__()
571
+ self.shift_labels = shift_labels
572
+ self.loss_fct = nn.CrossEntropyLoss()
573
+
574
+ def forward(
575
+ self, logits: torch.FloatTensor, labels: torch.LongTensor
576
+ ) -> torch.FloatTensor:
577
+ if self.shift_labels:
578
+ logits, labels = logits[..., :-1, :], labels[..., 1:]
579
+ return self.loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
580
+
581
+
582
+ class PhiPreTrainedModel(PreTrainedModel):
583
+ config_class = PhiConfig
584
+ base_model_prefix = "transformer"
585
+ supports_gradient_checkpointing = False
586
+ _no_split_modules = ["ParallelBlock"]
587
+
588
+ def __init__(self, *inputs, **kwargs) -> None:
589
+ super().__init__(*inputs, **kwargs)
590
+
591
+ def prepare_inputs_for_generation(
592
+ self,
593
+ input_ids: torch.LongTensor = None,
594
+ inputs_embeds: torch.FloatTensor = None,
595
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
596
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
597
+ **kwargs,
598
+ ) -> Dict[str, Any]:
599
+ if input_ids is None and inputs_embeds is None:
600
+ raise ValueError(
601
+ "You have to specify either `input_ids` or `inputs_embeds`."
602
+ )
603
+
604
+ max_batch_size = (
605
+ inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0]
606
+ )
607
+ seqlen_offset = (
608
+ inputs_embeds.shape[1] + input_ids.shape[1] - 2
609
+ if inputs_embeds is not None
610
+ else input_ids.shape[1] - 1
611
+ )
612
+
613
+ args = (
614
+ {"inputs_embeds": inputs_embeds}
615
+ if inputs_embeds is not None
616
+ else {"input_ids": input_ids}
617
+ )
618
+
619
+ if not isinstance(past_key_values, InferenceParams):
620
+ past_key_values = InferenceParams(
621
+ max_seqlen=self.config.n_positions,
622
+ max_batch_size=max_batch_size,
623
+ seqlen_offset=0,
624
+ batch_size_offset=0,
625
+ key_value_memory_dict={},
626
+ lengths_per_sample=None,
627
+ )
628
+ else:
629
+ past_key_values.seqlen_offset = seqlen_offset
630
+ args = {"input_ids": input_ids[:, -1].unsqueeze(-1)}
631
+
632
+ return {
633
+ **args,
634
+ "past_key_values": past_key_values,
635
+ "attention_mask": attention_mask,
636
+ }
637
+
638
+
639
+ class PhiModel(PhiPreTrainedModel):
640
+ _keys_to_ignore_on_load_missing = [""]
641
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
642
+
643
+ def __init__(self, config: PhiConfig) -> None:
644
+ super().__init__(config)
645
+ self.embd = Embedding(config)
646
+ self.h = nn.ModuleList(
647
+ [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
648
+ )
649
+ self.gradient_checkpointing = config.gradient_checkpointing
650
+ self.post_init()
651
+
652
+ def get_input_embeddings(self) -> nn.Embedding:
653
+ return self.embd.wte
654
+
655
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
656
+ self.embd.wte = new_embeddings
657
+
658
+ def forward(
659
+ self,
660
+ input_ids: torch.LongTensor = None,
661
+ inputs_embeds: torch.FloatTensor = None,
662
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
663
+ attention_mask: Optional[torch.BoolTensor] = None,
664
+ ) -> torch.FloatTensor:
665
+ if (input_ids is None) == (inputs_embeds is None):
666
+ raise ValueError("Specify exactly one of `input_ids` or `inputs_embeds`.")
667
+ hidden_states = self.embd(input_ids) if input_ids is not None else inputs_embeds
668
+
669
+ for layer in self.h:
670
+ func = layer.__call__ if self.gradient_checkpointing else layer
671
+ args = (hidden_states, past_key_values, attention_mask)
672
+ hidden_states = (
673
+ torch.utils.checkpoint.checkpoint(func, *args, use_reentrant=True)
674
+ if self.gradient_checkpointing
675
+ else func(*args)
676
+ )
677
+
678
+ return hidden_states
679
+
680
+
681
+ class PhiForCausalLM(PhiPreTrainedModel):
682
+ _keys_to_ignore_on_load_missing, _keys_to_ignore_on_load_unexpected = (
683
+ [""],
684
+ [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"],
685
+ )
686
+
687
+ def __init__(self, config: PhiConfig) -> None:
688
+ super().__init__(config)
689
+ self.transformer = PhiModel(config)
690
+ self.lm_head = CausalLMHead(config)
691
+ self.loss = CausalLMLoss()
692
+ self.post_init()
693
+
694
+ def get_output_embeddings(self) -> nn.Linear:
695
+ return self.lm_head.linear
696
+
697
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
698
+ self.lm_head.linear = new_embeddings
699
+
700
+ def forward(
701
+ self,
702
+ input_ids: torch.LongTensor = None,
703
+ inputs_embeds: torch.FloatTensor = None,
704
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
705
+ attention_mask: Optional[torch.BoolTensor] = None,
706
+ labels: Optional[torch.LongTensor] = None,
707
+ **kwargs,
708
+ ) -> CausalLMOutputWithPast:
709
+ hidden_states = self.transformer(
710
+ input_ids=input_ids,
711
+ inputs_embeds=inputs_embeds,
712
+ past_key_values=past_key_values,
713
+ attention_mask=attention_mask,
714
+ )
715
+ lm_logits = self.lm_head(hidden_states)
716
+ loss = self.loss(lm_logits, labels) if labels is not None else None
717
+
718
+ return CausalLMOutputWithPast(
719
+ loss=loss, logits=lm_logits, past_key_values=past_key_values
720
+ )
moondream.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .vision_encoder import VisionEncoder
3
+ from .text_model import TextModel
4
+ from .configuration_moondream import MoondreamConfig
5
+ from transformers import PreTrainedModel
6
+ import re
7
+
8
+
9
+ class Moondream(PreTrainedModel):
10
+ config_class = MoondreamConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.vision_encoder = VisionEncoder()
15
+ self.text_model = TextModel(config, None)
16
+
17
+ @property
18
+ def device(self):
19
+ return self.text_model.model.device
20
+
21
+ def encode_image(self, image):
22
+ return self.vision_encoder(image)
23
+
24
+ def input_embeds(self, prompt, image_embeds, tokenizer):
25
+ def _tokenize(txt):
26
+ return tokenizer(
27
+ txt, return_tensors="pt", add_special_tokens=False
28
+ ).input_ids.to(self.device)
29
+
30
+ # Add BOS token
31
+ embeds = []
32
+ embeds.append(
33
+ self.text_model.text_emb(
34
+ (torch.tensor([[tokenizer.bos_token_id]], device=self.device))
35
+ )
36
+ )
37
+
38
+ if "<image>" not in prompt:
39
+ embeds.append(self.text_model.text_emb(_tokenize(prompt)))
40
+ else:
41
+ assert prompt.count("<image>") == 1
42
+ before, after = prompt.split("<image>")
43
+ embeds.append(self.text_model.text_emb(_tokenize(f"{before}<image>")))
44
+ embeds.append(image_embeds.to(self.device))
45
+ embeds.append(self.text_model.text_emb(_tokenize(f"</image>{after}")))
46
+
47
+ return torch.cat(embeds, dim=1)
48
+
49
+ def generate(
50
+ self,
51
+ image_embeds,
52
+ prompt,
53
+ tokenizer,
54
+ eos_text="Human:",
55
+ max_new_tokens=128,
56
+ **kwargs,
57
+ ):
58
+ eos_tokens = tokenizer(eos_text, add_special_tokens=False)[0].ids
59
+
60
+ generate_config = {
61
+ "eos_token_id": eos_tokens,
62
+ "bos_token_id": tokenizer.bos_token_id,
63
+ "pad_token_id": tokenizer.eos_token_id,
64
+ "max_new_tokens": max_new_tokens,
65
+ **kwargs,
66
+ }
67
+
68
+ with torch.no_grad():
69
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
70
+ output_ids = self.text_model.model.generate(
71
+ inputs_embeds=inputs_embeds, **generate_config
72
+ )
73
+
74
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
75
+
76
+ def answer_question(
77
+ self,
78
+ image_embeds,
79
+ question,
80
+ tokenizer,
81
+ chat_history="",
82
+ result_queue=None,
83
+ **kwargs,
84
+ ):
85
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
86
+ answer = self.generate(
87
+ image_embeds,
88
+ prompt,
89
+ eos_text="<END>",
90
+ tokenizer=tokenizer,
91
+ max_new_tokens=128,
92
+ **kwargs,
93
+ )[0]
94
+ cleaned_answer = re.sub("<$", "", re.sub("END$", "", answer)).strip()
95
+
96
+ # Use the result_queue to pass the result if it is provided
97
+ if result_queue:
98
+ result_queue.put(cleaned_answer)
99
+ else:
100
+ return cleaned_answer
text_model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import transformers
3
+ from .modeling_phi import PhiForCausalLM
4
+ from .configuration_moondream import PhiConfig
5
+
6
+ transformers.logging.set_verbosity_error()
7
+
8
+
9
+ class TextModel(nn.Module):
10
+ def __init__(self, config, tokenizer) -> None:
11
+ super().__init__()
12
+
13
+ if type(config.phi_config) == dict:
14
+ phi_config = PhiConfig(**config.phi_config)
15
+ else:
16
+ phi_config = config.phi_config
17
+
18
+ self.model = PhiForCausalLM(phi_config)
19
+ self.text_emb = self.model.get_input_embeddings()
20
+ self.tokenizer = tokenizer
vision_encoder.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from PIL import Image
4
+ from einops import rearrange
5
+ from torchvision.transforms.v2 import (
6
+ Compose,
7
+ Resize,
8
+ InterpolationMode,
9
+ ToImage,
10
+ ToDtype,
11
+ Normalize,
12
+ )
13
+ import timm
14
+
15
+
16
+ class VisualHolder(nn.Module):
17
+ def __init__(self, model):
18
+ super().__init__()
19
+ self.visual = model
20
+
21
+ def forward(self, x):
22
+ return self.visual(x)
23
+
24
+
25
+ class ModelHolder(nn.Module):
26
+ def __init__(self, model):
27
+ super().__init__()
28
+ self.model = model
29
+
30
+ def forward(self, x):
31
+ return self.model(x)
32
+
33
+
34
+ class LinearPatchEmbedding(nn.Module):
35
+ def __init__(self, conv):
36
+ super().__init__()
37
+ self.linear = nn.Linear(588, 1152)
38
+ self.linear.weight.data = conv.weight.data.view(1152, -1)
39
+ if conv.bias is not None:
40
+ self.linear.bias.data = conv.bias.data
41
+
42
+ def forward(self, x):
43
+ return self.linear(x)
44
+
45
+
46
+ class MLP(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_features: int,
50
+ hidden_features: int = None,
51
+ out_features: int = None,
52
+ act_layer: nn.Module = nn.GELU,
53
+ ) -> None:
54
+ super().__init__()
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ self.fc1 = nn.Linear(in_features, hidden_features)
58
+ self.act = act_layer()
59
+ self.fc2 = nn.Linear(hidden_features, out_features)
60
+
61
+ torch.nn.init.kaiming_normal_(
62
+ self.fc1.weight, mode="fan_in", nonlinearity="relu"
63
+ )
64
+ torch.nn.init.kaiming_normal_(
65
+ self.fc2.weight, mode="fan_in", nonlinearity="relu"
66
+ )
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ x = self.fc1(x)
70
+ x = self.act(x)
71
+ x = self.fc2(x)
72
+ return x
73
+
74
+
75
+ class VisionProjection(nn.Module):
76
+ def __init__(self):
77
+ super().__init__()
78
+
79
+ image_embedding_dim = 1152
80
+ model_dim = 2048
81
+ hidden_dim = model_dim * 4
82
+
83
+ self.mlp1 = MLP(image_embedding_dim, hidden_dim, model_dim)
84
+ self.mlp2 = MLP(model_dim, hidden_dim, model_dim)
85
+ self.ln = nn.LayerNorm(model_dim)
86
+
87
+ @property
88
+ def device(self):
89
+ return self.mlp1.fc1.weight.device
90
+
91
+ def forward(self, x):
92
+ x = self.mlp1(x)
93
+ x = self.ln(x)
94
+ x = x + self.mlp2(x)
95
+ return x
96
+
97
+
98
+ class VisionTower(nn.Module):
99
+ def __init__(self):
100
+ super().__init__()
101
+
102
+ self.encoder = ModelHolder(
103
+ VisualHolder(timm.create_model("vit_so400m_patch14_siglip_384"))
104
+ )
105
+ self.encoder.model.visual.patch_embed = LinearPatchEmbedding(
106
+ self.encoder.model.visual.patch_embed.proj
107
+ )
108
+ self.encoder.model.visual.attn_pool = nn.Identity()
109
+
110
+ self.projection = VisionProjection()
111
+
112
+ def forward(self, x):
113
+ x = self.encoder(x)
114
+ x = self.projection(x)
115
+ return x
116
+
117
+
118
+ class VisionEncoder(nn.Module):
119
+ def __init__(self) -> None:
120
+ super().__init__()
121
+
122
+ self.model = VisionTower()
123
+ self.preprocess = Compose(
124
+ [
125
+ Resize(size=(378, 378), interpolation=InterpolationMode.BICUBIC),
126
+ ToImage(),
127
+ ToDtype(torch.float32, scale=True),
128
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
129
+ ]
130
+ )
131
+
132
+ @property
133
+ def device(self):
134
+ return self.model.projection.mlp1.fc1.weight.device
135
+
136
+ @property
137
+ def dtype(self):
138
+ return self.model.projection.mlp1.fc1.weight.dtype
139
+
140
+ def __call__(self, image: Image) -> torch.Tensor:
141
+ with torch.no_grad():
142
+ image_vec = (
143
+ self.preprocess(image.convert("RGB"))
144
+ .unsqueeze(0)
145
+ .to(self.device, dtype=self.dtype)
146
+ )
147
+ image_vec = rearrange(
148
+ image_vec, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14
149
+ )
150
+ return self.model(image_vec)