qianyuchen commited on
Commit
80ef065
1 Parent(s): b996e5b

Update modeling_minicpmv.py

Browse files

修改minicpmv的get vision embedding使其适应zero3训练

Files changed (1) hide show
  1. modeling_minicpmv.py +20 -10
modeling_minicpmv.py CHANGED
@@ -4,11 +4,12 @@ import json
4
  import timm
5
  import torch
6
  import torchvision
 
7
  from PIL import Image
8
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9
  from torchvision import transforms
10
  from transformers import LlamaTokenizer
11
-
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
14
  from .resampler import Resampler
@@ -74,15 +75,24 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
74
 
75
  def get_vision_embedding(self, pixel_values):
76
  res = []
77
- dtype = self.vpm.pos_embed.data.dtype
78
- for pixel_value in pixel_values:
79
  H, W = pixel_value.shape[-2:]
80
- tgt_size = (
81
- math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0]))
82
  vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
83
- if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0:
84
- vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:]
85
- res.append(self.resampler(vision_embedding, tgt_size))
 
 
 
 
 
 
 
 
 
 
86
  return torch.vstack(res)
87
 
88
  def get_vllm_embedding(self, data):
@@ -93,8 +103,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
93
  if len(pixel_values) > 0:
94
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
95
  elif self.training:
96
- dtype = self.vpm.pos_embed.data.dtype
97
- device = self.vpm.pos_embed.data.device
98
  dummy_image = torch.zeros(
99
  (1, 3, 224, 224), device=device, dtype=dtype
100
  )
 
4
  import timm
5
  import torch
6
  import torchvision
7
+ import deepspeed
8
  from PIL import Image
9
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
10
  from torchvision import transforms
11
  from transformers import LlamaTokenizer
12
+ from transformers.integrations import is_deepspeed_zero3_enabled
13
  from .configuration_minicpm import MiniCPMVConfig
14
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
15
  from .resampler import Resampler
 
75
 
76
  def get_vision_embedding(self, pixel_values):
77
  res = []
78
+ dtype = self.llm.lm_head.weight.dtype
79
+ def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
80
  H, W = pixel_value.shape[-2:]
81
+ target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
 
82
  vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
83
+ if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
84
+ vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
85
+ return resampler(vision_embedding, target_size)
86
+
87
+ if is_deepspeed_zero3_enabled():
88
+ with deepspeed.zero.GatheredParameters(self.vpm.pos_embed, modifier_rank=0):
89
+ for pixel_value in pixel_values:
90
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
91
+ res.append(result)
92
+ else:
93
+ for pixel_value in pixel_values:
94
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
95
+ res.append(result)
96
  return torch.vstack(res)
97
 
98
  def get_vllm_embedding(self, data):
 
103
  if len(pixel_values) > 0:
104
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
105
  elif self.training:
106
+ dtype = self.llm.lm_head.weight.dtype
107
+ device = self.llm.lm_head.weight.device
108
  dummy_image = torch.zeros(
109
  (1, 3, 224, 224), device=device, dtype=dtype
110
  )