yun commited on
Commit
de2ef9d
1 Parent(s): d7ab428

fix(internlm): Prevent errors by padding the dimensions of wrap tokens.

Browse files

The text_input in a batch can contain texts of various lengths.
In this case, the wrap_tokens will be of different lengths and torch.cat will get an error because the dim is not correct.
I added padding to resolve the issue below.
I would appreciate it if you could review this PR.
Error examples
```
ret_val = func(*args, **kwargs)
File "/tmp/ray/session_2024-02-05_14-30-33_881744_3780/runtime_resources/pip/40ae4806a7327971d7c077068c4b0a3019a14611/virtualenv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1801, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py310/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 98, in forward
output = self._forward_module.training_step(*inputs, **kwargs)
File "/tmp/ray/session_2024-02-05_14-30-33_881744_3780/runtime_resources/py_modules_files/_ray_pkg_33670290aabc83b3/ml/model/application/vlm/place_vlm/llava/system_stage2_internlm.py", line 56, in training_step
outputs = self.model(samples=batch)
File "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/internlm2-7b/modeling_internlm_xcomposer2.py", line 337, in forward
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
File "/root/.cache/huggingface/modules/transformers_modules/internlm2-7b/modeling_internlm_xcomposer2.py", line 266, in interleav_wrap
wrap_embeds = torch.cat(wrap_embeds_list)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 1230 but got size 3546 for tensor number 1 in the list.
```

Files changed (1) hide show
  1. modeling_internlm_xcomposer2.py +9 -9
modeling_internlm_xcomposer2.py CHANGED
@@ -258,15 +258,15 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
258
  wrap_target = wrap_target[:, :self.max_length].to(self.device)
259
  wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
260
 
261
- wrap_embeds_list.append(wrap_embeds)
262
- wrap_atts_list.append(wrap_atts)
263
- wrap_target_list.append(wrap_target)
264
- wrap_im_mask_list.append(wrap_im_mask)
265
-
266
- wrap_embeds = torch.cat(wrap_embeds_list)
267
- wrap_atts = torch.cat(wrap_atts_list)
268
- wrap_target = torch.cat(wrap_target_list)
269
- wrap_im_mask = torch.cat(wrap_im_mask_list)
270
  return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
271
 
272
  def mask_human_targets(self, input_ids, pure=False):
 
258
  wrap_target = wrap_target[:, :self.max_length].to(self.device)
259
  wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
260
 
261
+ wrap_embeds_list.append(wrap_embeds.squeeze(0))
262
+ wrap_atts_list.append(wrap_atts.squeeze(0))
263
+ wrap_target_list.append(wrap_target.squeeze(0))
264
+ wrap_im_mask_list.append(wrap_im_mask.squeeze(0))
265
+
266
+ wrap_embeds = torch.nn.utils.rnn.pad_sequence(wrap_embeds_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
267
+ wrap_atts = torch.nn.utils.rnn.pad_sequence(wrap_atts_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
268
+ wrap_target = torch.nn.utils.rnn.pad_sequence(wrap_target_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
269
+ wrap_im_mask = torch.nn.utils.rnn.pad_sequence(wrap_im_mask_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
270
  return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
271
 
272
  def mask_human_targets(self, input_ids, pure=False):