fix(internlm): Prevent errors by padding the dimensions of wrap tokens.
#2
by
yun
- opened
modeling_internlm_xcomposer2.py
CHANGED
@@ -258,15 +258,15 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
258 |
wrap_target = wrap_target[:, :self.max_length].to(self.device)
|
259 |
wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
|
260 |
|
261 |
-
wrap_embeds_list.append(wrap_embeds)
|
262 |
-
wrap_atts_list.append(wrap_atts)
|
263 |
-
wrap_target_list.append(wrap_target)
|
264 |
-
wrap_im_mask_list.append(wrap_im_mask)
|
265 |
-
|
266 |
-
wrap_embeds = torch.
|
267 |
-
wrap_atts = torch.
|
268 |
-
wrap_target = torch.
|
269 |
-
wrap_im_mask = torch.
|
270 |
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
271 |
|
272 |
def mask_human_targets(self, input_ids, pure=False):
|
|
|
258 |
wrap_target = wrap_target[:, :self.max_length].to(self.device)
|
259 |
wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
|
260 |
|
261 |
+
wrap_embeds_list.append(wrap_embeds.squeeze(0))
|
262 |
+
wrap_atts_list.append(wrap_atts.squeeze(0))
|
263 |
+
wrap_target_list.append(wrap_target.squeeze(0))
|
264 |
+
wrap_im_mask_list.append(wrap_im_mask.squeeze(0))
|
265 |
+
|
266 |
+
wrap_embeds = torch.nn.utils.rnn.pad_sequence(wrap_embeds_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
267 |
+
wrap_atts = torch.nn.utils.rnn.pad_sequence(wrap_atts_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
268 |
+
wrap_target = torch.nn.utils.rnn.pad_sequence(wrap_target_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
269 |
+
wrap_im_mask = torch.nn.utils.rnn.pad_sequence(wrap_im_mask_list, batch_first=True, padding_value=self.tokenizer._pad_token_type_id)
|
270 |
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
271 |
|
272 |
def mask_human_targets(self, input_ids, pure=False):
|