Fix gmask
Browse files- modeling_chatglm.py +21 -11
- tokenization_chatglm.py +11 -14
modeling_chatglm.py
CHANGED
@@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
689 |
|
690 |
return attention_mask
|
691 |
|
692 |
-
def get_position_ids(self, input_ids, mask_positions, device,
|
693 |
batch_size, seq_length = input_ids.shape
|
|
|
|
|
694 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
695 |
if self.position_encoding_2d:
|
696 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
@@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
704 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
705 |
else:
|
706 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
707 |
-
|
708 |
-
|
709 |
position_ids[context_length:] = mask_positions[i]
|
710 |
|
711 |
return position_ids
|
@@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
939 |
|
940 |
if position_ids is None:
|
941 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
942 |
-
|
943 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
|
945 |
-
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
946 |
position_ids = self.get_position_ids(
|
947 |
input_ids,
|
948 |
mask_positions=mask_positions,
|
949 |
device=input_ids.device,
|
950 |
-
|
951 |
)
|
952 |
|
953 |
if self.pre_seq_len is not None and attention_mask is not None:
|
@@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1106 |
) -> dict:
|
1107 |
batch_size, seq_length = input_ids.shape
|
1108 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
1109 |
-
mask_token = gMASK if gMASK in input_ids else MASK
|
1110 |
-
use_gmask = True if gMASK in input_ids else False
|
1111 |
seqs = input_ids.tolist()
|
1112 |
-
mask_positions = [
|
|
|
|
|
|
|
|
|
|
|
1113 |
|
1114 |
# only last token for input_ids if past is not None
|
1115 |
if past is not None or past_key_values is not None:
|
@@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1152 |
input_ids,
|
1153 |
device=input_ids.device,
|
1154 |
mask_positions=mask_positions,
|
1155 |
-
|
1156 |
)
|
1157 |
|
1158 |
return {
|
|
|
689 |
|
690 |
return attention_mask
|
691 |
|
692 |
+
def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
|
693 |
batch_size, seq_length = input_ids.shape
|
694 |
+
if use_gmasks is None:
|
695 |
+
use_gmasks = [False] * batch_size
|
696 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
697 |
if self.position_encoding_2d:
|
698 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
|
706 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
707 |
else:
|
708 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
709 |
+
for i, context_length in enumerate(context_lengths):
|
710 |
+
if not use_gmasks[i]:
|
711 |
position_ids[context_length:] = mask_positions[i]
|
712 |
|
713 |
return position_ids
|
|
|
941 |
|
942 |
if position_ids is None:
|
943 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
944 |
+
seqs = input_ids.tolist()
|
945 |
+
|
946 |
+
mask_positions, use_gmasks = [], []
|
947 |
+
for seq in seqs:
|
948 |
+
mask_token = gMASK if gMASK in seq else MASK
|
949 |
+
use_gmask = mask_token == gMASK
|
950 |
+
mask_positions.append(seq.index(mask_token))
|
951 |
+
use_gmasks.append(use_gmask)
|
952 |
|
|
|
953 |
position_ids = self.get_position_ids(
|
954 |
input_ids,
|
955 |
mask_positions=mask_positions,
|
956 |
device=input_ids.device,
|
957 |
+
use_gmasks=use_gmasks
|
958 |
)
|
959 |
|
960 |
if self.pre_seq_len is not None and attention_mask is not None:
|
|
|
1113 |
) -> dict:
|
1114 |
batch_size, seq_length = input_ids.shape
|
1115 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
|
|
|
|
1116 |
seqs = input_ids.tolist()
|
1117 |
+
mask_positions, use_gmasks = [], []
|
1118 |
+
for seq in seqs:
|
1119 |
+
mask_token = gMASK if gMASK in seq else MASK
|
1120 |
+
use_gmask = mask_token == gMASK
|
1121 |
+
mask_positions.append(seq.index(mask_token))
|
1122 |
+
use_gmasks.append(use_gmask)
|
1123 |
|
1124 |
# only last token for input_ids if past is not None
|
1125 |
if past is not None or past_key_values is not None:
|
|
|
1162 |
input_ids,
|
1163 |
device=input_ids.device,
|
1164 |
mask_positions=mask_positions,
|
1165 |
+
use_gmasks=use_gmasks
|
1166 |
)
|
1167 |
|
1168 |
return {
|
tokenization_chatglm.py
CHANGED
@@ -176,6 +176,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
176 |
mask_token='[MASK]',
|
177 |
gmask_token='[gMASK]',
|
178 |
padding_side="left",
|
|
|
|
|
179 |
num_image_tokens=20000,
|
180 |
**kwargs
|
181 |
) -> None:
|
@@ -188,6 +190,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
188 |
end_token=end_token,
|
189 |
mask_token=mask_token,
|
190 |
gmask_token=gmask_token,
|
|
|
|
|
191 |
num_image_tokens=num_image_tokens,
|
192 |
**kwargs
|
193 |
)
|
@@ -322,22 +326,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
322 |
Returns:
|
323 |
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
324 |
"""
|
325 |
-
|
326 |
-
gmask_ids = self.sp_tokenizer[self.gmask_token]
|
327 |
eos_id = self.sp_tokenizer[self.eos_token]
|
328 |
-
|
329 |
-
token_ids_0 += [gmask_ids]
|
330 |
-
|
331 |
-
if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
|
332 |
-
token_ids_0 += [self.sp_tokenizer[self.end_token]]
|
333 |
-
|
334 |
-
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
335 |
-
|
336 |
if token_ids_1 is not None:
|
337 |
-
|
338 |
-
token_ids_1 += [eos_id]
|
339 |
-
token_ids_0 += token_ids_1
|
340 |
-
|
341 |
return token_ids_0
|
342 |
|
343 |
def _pad(
|
@@ -402,6 +395,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
402 |
encoded_inputs["attention_mask"] = attention_mask
|
403 |
|
404 |
if "position_ids" not in encoded_inputs:
|
|
|
|
|
|
|
|
|
405 |
position_ids = np.arange(seq_length, dtype=np.int64)
|
406 |
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
407 |
if mask_token in required_input:
|
|
|
176 |
mask_token='[MASK]',
|
177 |
gmask_token='[gMASK]',
|
178 |
padding_side="left",
|
179 |
+
pad_token="<pad>",
|
180 |
+
unk_token="<unk>",
|
181 |
num_image_tokens=20000,
|
182 |
**kwargs
|
183 |
) -> None:
|
|
|
190 |
end_token=end_token,
|
191 |
mask_token=mask_token,
|
192 |
gmask_token=gmask_token,
|
193 |
+
pad_token=pad_token,
|
194 |
+
unk_token=unk_token,
|
195 |
num_image_tokens=num_image_tokens,
|
196 |
**kwargs
|
197 |
)
|
|
|
326 |
Returns:
|
327 |
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
328 |
"""
|
329 |
+
gmask_id = self.sp_tokenizer[self.gmask_token]
|
|
|
330 |
eos_id = self.sp_tokenizer[self.eos_token]
|
331 |
+
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
if token_ids_1 is not None:
|
333 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [eos_id]
|
|
|
|
|
|
|
334 |
return token_ids_0
|
335 |
|
336 |
def _pad(
|
|
|
395 |
encoded_inputs["attention_mask"] = attention_mask
|
396 |
|
397 |
if "position_ids" not in encoded_inputs:
|
398 |
+
if bos_token_id in required_input:
|
399 |
+
context_length = required_input.index(bos_token_id)
|
400 |
+
else:
|
401 |
+
context_length = seq_length
|
402 |
position_ids = np.arange(seq_length, dtype=np.int64)
|
403 |
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
404 |
if mask_token in required_input:
|