fix(internlm): Prevent errors by padding the dimensions of wrap tokens.

#2
by yun - opened
Files changed (1) hide show
  1. modeling_internlm_xcomposer2.py +9 -9
modeling_internlm_xcomposer2.py CHANGED
@@ -258,15 +258,15 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
258
  wrap_target = wrap_target[:, :self.max_length].to(self.device)
259
  wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
260
 
261
- wrap_embeds_list.append(wrap_embeds)
262
- wrap_atts_list.append(wrap_atts)
263
- wrap_target_list.append(wrap_target)
264
- wrap_im_mask_list.append(wrap_im_mask)
265
-
266
- wrap_embeds = torch.cat(wrap_embeds_list)
267
- wrap_atts = torch.cat(wrap_atts_list)
268
- wrap_target = torch.cat(wrap_target_list)
269
- wrap_im_mask = torch.cat(wrap_im_mask_list)
270
  return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
271
 
272
  def mask_human_targets(self, input_ids, pure=False):
 
258
  wrap_target = wrap_target[:, :self.max_length].to(self.device)
259
  wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
260
 
261
+ wrap_embeds_list.append(wrap_embeds.squeeze(0))
262
+ wrap_atts_list.append(wrap_atts.squeeze(0))
263
+ wrap_target_list.append(wrap_target.squeeze(0))
264
+ wrap_im_mask_list.append(wrap_im_mask.squeeze(0))
265
+
266
+ wrap_embeds = torch.nn.utils.rnn.pad_sequence(wrap_embeds_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
267
+ wrap_atts = torch.nn.utils.rnn.pad_sequence(wrap_atts_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
268
+ wrap_target = torch.nn.utils.rnn.pad_sequence(wrap_target_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
269
+ wrap_im_mask = torch.nn.utils.rnn.pad_sequence(wrap_im_mask_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
270
  return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
271
 
272
  def mask_human_targets(self, input_ids, pure=False):