fix(internlm): Prevent errors by padding the dimensions of wrap tokens.
Browse filesThe text_input in a batch can contain texts of various lengths.
In this case, the wrap_tokens will be of different lengths and torch.cat will get an error because the dim is not correct.
I added padding to resolve the issue below.
I would appreciate it if you could review this PR.
Error examples
```
ret_val = func(*args, **kwargs)
File "/tmp/ray/session_2024-02-05_14-30-33_881744_3780/runtime_resources/pip/40ae4806a7327971d7c077068c4b0a3019a14611/virtualenv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1801, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py310/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 98, in forward
output = self._forward_module.training_step(*inputs, **kwargs)
File "/tmp/ray/session_2024-02-05_14-30-33_881744_3780/runtime_resources/py_modules_files/_ray_pkg_33670290aabc83b3/ml/model/application/vlm/place_vlm/llava/system_stage2_internlm.py", line 56, in training_step
outputs = self.model(samples=batch)
File "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/internlm2-7b/modeling_internlm_xcomposer2.py", line 337, in forward
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
File "/root/.cache/huggingface/modules/transformers_modules/internlm2-7b/modeling_internlm_xcomposer2.py", line 266, in interleav_wrap
wrap_embeds = torch.cat(wrap_embeds_list)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 1230 but got size 3546 for tensor number 1 in the list.
```
@@ -258,15 +258,15 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
258 |
wrap_target = wrap_target[:, :self.max_length].to(self.device)
|
259 |
wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
|
260 |
|
261 |
-
wrap_embeds_list.append(wrap_embeds)
|
262 |
-
wrap_atts_list.append(wrap_atts)
|
263 |
-
wrap_target_list.append(wrap_target)
|
264 |
-
wrap_im_mask_list.append(wrap_im_mask)
|
265 |
-
|
266 |
-
wrap_embeds = torch.
|
267 |
-
wrap_atts = torch.
|
268 |
-
wrap_target = torch.
|
269 |
-
wrap_im_mask = torch.
|
270 |
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
271 |
|
272 |
def mask_human_targets(self, input_ids, pure=False):
|
|
|
258 |
wrap_target = wrap_target[:, :self.max_length].to(self.device)
|
259 |
wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
|
260 |
|
261 |
+
wrap_embeds_list.append(wrap_embeds.squeeze(0))
|
262 |
+
wrap_atts_list.append(wrap_atts.squeeze(0))
|
263 |
+
wrap_target_list.append(wrap_target.squeeze(0))
|
264 |
+
wrap_im_mask_list.append(wrap_im_mask.squeeze(0))
|
265 |
+
|
266 |
+
wrap_embeds = torch.nn.utils.rnn.pad_sequence(wrap_embeds_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
267 |
+
wrap_atts = torch.nn.utils.rnn.pad_sequence(wrap_atts_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
268 |
+
wrap_target = torch.nn.utils.rnn.pad_sequence(wrap_target_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
269 |
+
wrap_im_mask = torch.nn.utils.rnn.pad_sequence(wrap_im_mask_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
270 |
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
271 |
|
272 |
def mask_human_targets(self, input_ids, pure=False):
|