yuzaa qianyuchen commited on
Commit
e978c4c
1 Parent(s): 0830407

Update modeling_minicpmv.py (#56)

Browse files

- Update modeling_minicpmv.py (da88bdc057fcaf87792be979f5f695fe12350716)


Co-authored-by: qianyu chen <[email protected]>

Files changed (1) hide show
  1. modeling_minicpmv.py +8 -8
modeling_minicpmv.py CHANGED
@@ -42,13 +42,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
42
 
43
  return model
44
 
45
- def init_resampler(self, embed_dim, vision_dim):
46
  return Resampler(
47
  num_queries=self.config.query_num,
48
  embed_dim=embed_dim,
49
  num_heads=embed_dim // 128,
50
  kv_dim=vision_dim,
51
- adaptive=True
52
  )
53
 
54
  def init_transform(self):
@@ -60,17 +60,17 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
60
  ),
61
  ]
62
  )
63
-
64
  def get_input_embeddings(self):
65
  return self.llm.get_input_embeddings()
66
 
67
  def set_input_embeddings(self, value):
68
  self.llm.embed_tokens = value
69
-
70
  def get_vllm_embedding(self, data):
71
  if 'vision_hidden_states' not in data:
72
- dtype = self.vpm.embeddings.position_embedding.weight.dtype
73
- device = self.vpm.embeddings.position_embedding.weight.device
74
  tgt_sizes = data['tgt_sizes']
75
  pixel_values_list = data['pixel_values']
76
  vision_hidden_states = []
@@ -107,6 +107,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
107
  single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
108
  single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
109
  single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
 
110
  vision_embedding.append(single_vision_embedding)
111
  vision_embedding = torch.vstack(vision_embedding)
112
 
@@ -152,14 +153,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
152
  image_indices = torch.stack(
153
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
154
  ).to(vllm_embedding.device)
155
-
156
  cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
157
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
158
  elif self.training:
159
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
160
 
161
  return vllm_embedding, vision_hidden_states
162
-
163
  def forward(self, data, **kwargs):
164
  vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
165
  position_ids = data["position_ids"]
 
42
 
43
  return model
44
 
45
+ def init_resampler(self, embed_dim, vision_dim,):
46
  return Resampler(
47
  num_queries=self.config.query_num,
48
  embed_dim=embed_dim,
49
  num_heads=embed_dim // 128,
50
  kv_dim=vision_dim,
51
+ adaptive=True,
52
  )
53
 
54
  def init_transform(self):
 
60
  ),
61
  ]
62
  )
63
+
64
  def get_input_embeddings(self):
65
  return self.llm.get_input_embeddings()
66
 
67
  def set_input_embeddings(self, value):
68
  self.llm.embed_tokens = value
69
+
70
  def get_vllm_embedding(self, data):
71
  if 'vision_hidden_states' not in data:
72
+ dtype = self.llm.model.embed_tokens.weight.dtype
73
+ device = self.llm.model.embed_tokens.weight.device
74
  tgt_sizes = data['tgt_sizes']
75
  pixel_values_list = data['pixel_values']
76
  vision_hidden_states = []
 
107
  single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
108
  single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
109
  single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
110
+
111
  vision_embedding.append(single_vision_embedding)
112
  vision_embedding = torch.vstack(vision_embedding)
113
 
 
153
  image_indices = torch.stack(
154
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
155
  ).to(vllm_embedding.device)
 
156
  cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
157
  cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
158
  elif self.training:
159
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
160
 
161
  return vllm_embedding, vision_hidden_states
162
+
163
  def forward(self, data, **kwargs):
164
  vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
165
  position_ids = data["position_ids"]