zxdu20 commited on
Commit
3485994
1 Parent(s): 9333486
Files changed (2) hide show
  1. modeling_chatglm.py +21 -11
  2. tokenization_chatglm.py +11 -14
modeling_chatglm.py CHANGED
@@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
689
 
690
  return attention_mask
691
 
692
- def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
693
  batch_size, seq_length = input_ids.shape
 
 
694
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
695
  if self.position_encoding_2d:
696
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
@@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
704
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
705
  else:
706
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
707
- if not gmask:
708
- for i, context_length in enumerate(context_lengths):
709
  position_ids[context_length:] = mask_positions[i]
710
 
711
  return position_ids
@@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
939
 
940
  if position_ids is None:
941
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
942
- mask_token = gMASK if gMASK in input_ids else MASK
943
- use_gmask = True if gMASK in input_ids else False
 
 
 
 
 
 
944
 
945
- mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
946
  position_ids = self.get_position_ids(
947
  input_ids,
948
  mask_positions=mask_positions,
949
  device=input_ids.device,
950
- gmask=use_gmask
951
  )
952
 
953
  if self.pre_seq_len is not None and attention_mask is not None:
@@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1106
  ) -> dict:
1107
  batch_size, seq_length = input_ids.shape
1108
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
1109
- mask_token = gMASK if gMASK in input_ids else MASK
1110
- use_gmask = True if gMASK in input_ids else False
1111
  seqs = input_ids.tolist()
1112
- mask_positions = [seq.index(mask_token) for seq in seqs]
 
 
 
 
 
1113
 
1114
  # only last token for input_ids if past is not None
1115
  if past is not None or past_key_values is not None:
@@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1152
  input_ids,
1153
  device=input_ids.device,
1154
  mask_positions=mask_positions,
1155
- gmask=use_gmask
1156
  )
1157
 
1158
  return {
 
689
 
690
  return attention_mask
691
 
692
+ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
693
  batch_size, seq_length = input_ids.shape
694
+ if use_gmasks is None:
695
+ use_gmasks = [False] * batch_size
696
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
697
  if self.position_encoding_2d:
698
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
 
706
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
707
  else:
708
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
709
+ for i, context_length in enumerate(context_lengths):
710
+ if not use_gmasks[i]:
711
  position_ids[context_length:] = mask_positions[i]
712
 
713
  return position_ids
 
941
 
942
  if position_ids is None:
943
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
944
+ seqs = input_ids.tolist()
945
+
946
+ mask_positions, use_gmasks = [], []
947
+ for seq in seqs:
948
+ mask_token = gMASK if gMASK in seq else MASK
949
+ use_gmask = mask_token == gMASK
950
+ mask_positions.append(seq.index(mask_token))
951
+ use_gmasks.append(use_gmask)
952
 
 
953
  position_ids = self.get_position_ids(
954
  input_ids,
955
  mask_positions=mask_positions,
956
  device=input_ids.device,
957
+ use_gmasks=use_gmasks
958
  )
959
 
960
  if self.pre_seq_len is not None and attention_mask is not None:
 
1113
  ) -> dict:
1114
  batch_size, seq_length = input_ids.shape
1115
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
 
 
1116
  seqs = input_ids.tolist()
1117
+ mask_positions, use_gmasks = [], []
1118
+ for seq in seqs:
1119
+ mask_token = gMASK if gMASK in seq else MASK
1120
+ use_gmask = mask_token == gMASK
1121
+ mask_positions.append(seq.index(mask_token))
1122
+ use_gmasks.append(use_gmask)
1123
 
1124
  # only last token for input_ids if past is not None
1125
  if past is not None or past_key_values is not None:
 
1162
  input_ids,
1163
  device=input_ids.device,
1164
  mask_positions=mask_positions,
1165
+ use_gmasks=use_gmasks
1166
  )
1167
 
1168
  return {
tokenization_chatglm.py CHANGED
@@ -176,6 +176,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
176
  mask_token='[MASK]',
177
  gmask_token='[gMASK]',
178
  padding_side="left",
 
 
179
  num_image_tokens=20000,
180
  **kwargs
181
  ) -> None:
@@ -188,6 +190,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
188
  end_token=end_token,
189
  mask_token=mask_token,
190
  gmask_token=gmask_token,
 
 
191
  num_image_tokens=num_image_tokens,
192
  **kwargs
193
  )
@@ -322,22 +326,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
322
  Returns:
323
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
324
  """
325
- mask_ids = self.sp_tokenizer[self.mask_token]
326
- gmask_ids = self.sp_tokenizer[self.gmask_token]
327
  eos_id = self.sp_tokenizer[self.eos_token]
328
- if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
329
- token_ids_0 += [gmask_ids]
330
-
331
- if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
332
- token_ids_0 += [self.sp_tokenizer[self.end_token]]
333
-
334
- token_ids_0 += [self.sp_tokenizer[self.bos_token]]
335
-
336
  if token_ids_1 is not None:
337
- if not token_ids_1 or token_ids_1[-1] != eos_id:
338
- token_ids_1 += [eos_id]
339
- token_ids_0 += token_ids_1
340
-
341
  return token_ids_0
342
 
343
  def _pad(
@@ -402,6 +395,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
402
  encoded_inputs["attention_mask"] = attention_mask
403
 
404
  if "position_ids" not in encoded_inputs:
 
 
 
 
405
  position_ids = np.arange(seq_length, dtype=np.int64)
406
  mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
407
  if mask_token in required_input:
 
176
  mask_token='[MASK]',
177
  gmask_token='[gMASK]',
178
  padding_side="left",
179
+ pad_token="<pad>",
180
+ unk_token="<unk>",
181
  num_image_tokens=20000,
182
  **kwargs
183
  ) -> None:
 
190
  end_token=end_token,
191
  mask_token=mask_token,
192
  gmask_token=gmask_token,
193
+ pad_token=pad_token,
194
+ unk_token=unk_token,
195
  num_image_tokens=num_image_tokens,
196
  **kwargs
197
  )
 
326
  Returns:
327
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
328
  """
329
+ gmask_id = self.sp_tokenizer[self.gmask_token]
 
330
  eos_id = self.sp_tokenizer[self.eos_token]
331
+ token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
 
 
 
 
 
 
 
332
  if token_ids_1 is not None:
333
+ token_ids_0 = token_ids_0 + token_ids_1 + [eos_id]
 
 
 
334
  return token_ids_0
335
 
336
  def _pad(
 
395
  encoded_inputs["attention_mask"] = attention_mask
396
 
397
  if "position_ids" not in encoded_inputs:
398
+ if bos_token_id in required_input:
399
+ context_length = required_input.index(bos_token_id)
400
+ else:
401
+ context_length = seq_length
402
  position_ids = np.arange(seq_length, dtype=np.int64)
403
  mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
404
  if mask_token in required_input: