qianyuchen
commited on
Commit
•
80ef065
1
Parent(s):
b996e5b
Update modeling_minicpmv.py
Browse files修改minicpmv的get vision embedding使其适应zero3训练
- modeling_minicpmv.py +20 -10
modeling_minicpmv.py
CHANGED
@@ -4,11 +4,12 @@ import json
|
|
4 |
import timm
|
5 |
import torch
|
6 |
import torchvision
|
|
|
7 |
from PIL import Image
|
8 |
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
9 |
from torchvision import transforms
|
10 |
from transformers import LlamaTokenizer
|
11 |
-
|
12 |
from .configuration_minicpm import MiniCPMVConfig
|
13 |
from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
|
14 |
from .resampler import Resampler
|
@@ -74,15 +75,24 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
74 |
|
75 |
def get_vision_embedding(self, pixel_values):
|
76 |
res = []
|
77 |
-
dtype = self.
|
78 |
-
|
79 |
H, W = pixel_value.shape[-2:]
|
80 |
-
|
81 |
-
math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0]))
|
82 |
vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
|
83 |
-
if hasattr(
|
84 |
-
vision_embedding = vision_embedding[:,
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
return torch.vstack(res)
|
87 |
|
88 |
def get_vllm_embedding(self, data):
|
@@ -93,8 +103,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
93 |
if len(pixel_values) > 0:
|
94 |
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
95 |
elif self.training:
|
96 |
-
dtype = self.
|
97 |
-
device = self.
|
98 |
dummy_image = torch.zeros(
|
99 |
(1, 3, 224, 224), device=device, dtype=dtype
|
100 |
)
|
|
|
4 |
import timm
|
5 |
import torch
|
6 |
import torchvision
|
7 |
+
import deepspeed
|
8 |
from PIL import Image
|
9 |
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
10 |
from torchvision import transforms
|
11 |
from transformers import LlamaTokenizer
|
12 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
13 |
from .configuration_minicpm import MiniCPMVConfig
|
14 |
from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
|
15 |
from .resampler import Resampler
|
|
|
75 |
|
76 |
def get_vision_embedding(self, pixel_values):
|
77 |
res = []
|
78 |
+
dtype = self.llm.lm_head.weight.dtype
|
79 |
+
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
|
80 |
H, W = pixel_value.shape[-2:]
|
81 |
+
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
|
|
|
82 |
vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
|
83 |
+
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
|
84 |
+
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
|
85 |
+
return resampler(vision_embedding, target_size)
|
86 |
+
|
87 |
+
if is_deepspeed_zero3_enabled():
|
88 |
+
with deepspeed.zero.GatheredParameters(self.vpm.pos_embed, modifier_rank=0):
|
89 |
+
for pixel_value in pixel_values:
|
90 |
+
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
91 |
+
res.append(result)
|
92 |
+
else:
|
93 |
+
for pixel_value in pixel_values:
|
94 |
+
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
95 |
+
res.append(result)
|
96 |
return torch.vstack(res)
|
97 |
|
98 |
def get_vllm_embedding(self, data):
|
|
|
103 |
if len(pixel_values) > 0:
|
104 |
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
105 |
elif self.training:
|
106 |
+
dtype = self.llm.lm_head.weight.dtype
|
107 |
+
device = self.llm.lm_head.weight.device
|
108 |
dummy_image = torch.zeros(
|
109 |
(1, 3, 224, 224), device=device, dtype=dtype
|
110 |
)
|