Update modeling_chatglm.py
Browse files- modeling_chatglm.py +90 -29
modeling_chatglm.py
CHANGED
@@ -157,7 +157,7 @@ class RotaryEmbedding(nn.Module):
|
|
157 |
)
|
158 |
|
159 |
|
160 |
-
@torch.jit.script
|
161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
# x: [sq, b, np, hn]
|
163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
@@ -223,8 +223,7 @@ class CoreAttention(torch.nn.Module):
|
|
223 |
if pytorch_major_version >= 2:
|
224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
-
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
227 |
-
is_causal=True)
|
228 |
else:
|
229 |
if attention_mask is not None:
|
230 |
attention_mask = ~attention_mask
|
@@ -237,7 +236,7 @@ class CoreAttention(torch.nn.Module):
|
|
237 |
# Raw attention scores
|
238 |
|
239 |
# [b, np, sq, sk]
|
240 |
-
output_size = (query_layer.size(
|
241 |
|
242 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
243 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
@@ -312,7 +311,6 @@ class CoreAttention(torch.nn.Module):
|
|
312 |
|
313 |
class SelfAttention(torch.nn.Module):
|
314 |
"""Parallel self-attention layer abstract class.
|
315 |
-
|
316 |
Self-attention layer takes input with size [s, b, h]
|
317 |
and returns output of the same size.
|
318 |
"""
|
@@ -448,7 +446,6 @@ class SelfAttention(torch.nn.Module):
|
|
448 |
|
449 |
return output, kv_cache
|
450 |
|
451 |
-
|
452 |
def _config_to_kwargs(args):
|
453 |
common_kwargs = {
|
454 |
"dtype": args.torch_dtype,
|
@@ -504,7 +501,6 @@ class MLP(torch.nn.Module):
|
|
504 |
|
505 |
class GLMBlock(torch.nn.Module):
|
506 |
"""A single transformer layer.
|
507 |
-
|
508 |
Transformer layer takes input with size [s, b, h] and returns an
|
509 |
output of the same size.
|
510 |
"""
|
@@ -597,7 +593,7 @@ class GLMTransformer(torch.nn.Module):
|
|
597 |
if self.post_layer_norm:
|
598 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
599 |
# Final layer norm before output.
|
600 |
-
self.
|
601 |
dtype=config.torch_dtype)
|
602 |
|
603 |
self.gradient_checkpointing = False
|
@@ -653,7 +649,7 @@ class GLMTransformer(torch.nn.Module):
|
|
653 |
|
654 |
# Final layer norm.
|
655 |
if self.post_layer_norm:
|
656 |
-
hidden_states = self.
|
657 |
|
658 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
659 |
|
@@ -740,7 +736,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
740 |
init_kwargs = {}
|
741 |
if device is not None:
|
742 |
init_kwargs["device"] = device
|
743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
self.num_layers = config.num_layers
|
745 |
self.multi_query_group_num = config.multi_query_group_num
|
746 |
self.kv_channels = config.kv_channels
|
@@ -753,9 +756,21 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
753 |
|
754 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
755 |
dtype=config.torch_dtype)
|
756 |
-
|
757 |
-
|
758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
759 |
self.pre_seq_len = config.pre_seq_len
|
760 |
self.prefix_projection = config.prefix_projection
|
761 |
if self.pre_seq_len is not None:
|
@@ -765,6 +780,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
765 |
self.prefix_encoder = PrefixEncoder(config)
|
766 |
self.dropout = torch.nn.Dropout(0.1)
|
767 |
|
|
|
|
|
768 |
def get_input_embeddings(self):
|
769 |
return self.embedding.word_embeddings
|
770 |
|
@@ -804,7 +821,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
804 |
batch_size, seq_length = input_ids.shape
|
805 |
|
806 |
if inputs_embeds is None:
|
807 |
-
inputs_embeds = self.
|
808 |
|
809 |
if self.pre_seq_len is not None:
|
810 |
if past_key_values is None:
|
@@ -827,10 +844,54 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
827 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
828 |
|
829 |
# Run encoder.
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
|
835 |
if not return_dict:
|
836 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
@@ -844,7 +905,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
844 |
|
845 |
def quantize(self, weight_bit_width: int):
|
846 |
from .quantization import quantize
|
847 |
-
quantize(self
|
848 |
return self
|
849 |
|
850 |
|
@@ -853,7 +914,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
853 |
super().__init__(config)
|
854 |
|
855 |
self.max_sequence_length = config.max_length
|
856 |
-
self.
|
|
|
857 |
self.config = config
|
858 |
self.quantized = False
|
859 |
|
@@ -934,7 +996,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
934 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
935 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
936 |
|
937 |
-
transformer_outputs = self.
|
938 |
input_ids=input_ids,
|
939 |
position_ids=position_ids,
|
940 |
attention_mask=attention_mask,
|
@@ -948,8 +1010,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
948 |
hidden_states = transformer_outputs[0]
|
949 |
if return_last_logit:
|
950 |
hidden_states = hidden_states[-1:]
|
951 |
-
lm_logits = self.
|
952 |
-
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
953 |
|
954 |
loss = None
|
955 |
if labels is not None:
|
@@ -1062,8 +1123,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1062 |
inputs = inputs.to(self.device)
|
1063 |
if past_key_values is not None:
|
1064 |
past_length = past_key_values[0][0].shape[0]
|
1065 |
-
if self.
|
1066 |
-
past_length -= self.
|
1067 |
inputs.position_ids += past_length
|
1068 |
attention_mask = inputs.attention_mask
|
1069 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
@@ -1205,7 +1266,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1205 |
|
1206 |
self.config.quantization_bit = bits
|
1207 |
|
1208 |
-
self.
|
1209 |
**kwargs)
|
1210 |
return self
|
1211 |
|
@@ -1215,7 +1276,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1215 |
super().__init__(config)
|
1216 |
|
1217 |
self.num_labels = config.num_labels
|
1218 |
-
self.
|
1219 |
|
1220 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1221 |
if config.classifier_dropout is not None:
|
@@ -1242,7 +1303,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1242 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1243 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1244 |
|
1245 |
-
transformer_outputs = self.
|
1246 |
input_ids=input_ids,
|
1247 |
position_ids=position_ids,
|
1248 |
attention_mask=attention_mask,
|
@@ -1293,4 +1354,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1293 |
past_key_values=transformer_outputs.past_key_values,
|
1294 |
hidden_states=transformer_outputs.hidden_states,
|
1295 |
attentions=transformer_outputs.attentions,
|
1296 |
-
)
|
|
|
157 |
)
|
158 |
|
159 |
|
160 |
+
# @torch.jit.script
|
161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
# x: [sq, b, np, hn]
|
163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
|
|
223 |
if pytorch_major_version >= 2:
|
224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)
|
|
|
227 |
else:
|
228 |
if attention_mask is not None:
|
229 |
attention_mask = ~attention_mask
|
|
|
236 |
# Raw attention scores
|
237 |
|
238 |
# [b, np, sq, sk]
|
239 |
+
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(0))
|
240 |
|
241 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
242 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
|
311 |
|
312 |
class SelfAttention(torch.nn.Module):
|
313 |
"""Parallel self-attention layer abstract class.
|
|
|
314 |
Self-attention layer takes input with size [s, b, h]
|
315 |
and returns output of the same size.
|
316 |
"""
|
|
|
446 |
|
447 |
return output, kv_cache
|
448 |
|
|
|
449 |
def _config_to_kwargs(args):
|
450 |
common_kwargs = {
|
451 |
"dtype": args.torch_dtype,
|
|
|
501 |
|
502 |
class GLMBlock(torch.nn.Module):
|
503 |
"""A single transformer layer.
|
|
|
504 |
Transformer layer takes input with size [s, b, h] and returns an
|
505 |
output of the same size.
|
506 |
"""
|
|
|
593 |
if self.post_layer_norm:
|
594 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
595 |
# Final layer norm before output.
|
596 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
597 |
dtype=config.torch_dtype)
|
598 |
|
599 |
self.gradient_checkpointing = False
|
|
|
649 |
|
650 |
# Final layer norm.
|
651 |
if self.post_layer_norm:
|
652 |
+
hidden_states = self.norm(hidden_states)
|
653 |
|
654 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
655 |
|
|
|
736 |
init_kwargs = {}
|
737 |
if device is not None:
|
738 |
init_kwargs["device"] = device
|
739 |
+
|
740 |
+
self.embed_tokens = nn.Embedding(
|
741 |
+
config.padded_vocab_size,
|
742 |
+
config.hidden_size,
|
743 |
+
dtype=config.torch_dtype,
|
744 |
+
device=device
|
745 |
+
)
|
746 |
+
|
747 |
self.num_layers = config.num_layers
|
748 |
self.multi_query_group_num = config.multi_query_group_num
|
749 |
self.kv_channels = config.kv_channels
|
|
|
756 |
|
757 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
758 |
dtype=config.torch_dtype)
|
759 |
+
|
760 |
+
# Transformer layers.
|
761 |
+
def build_layer(layer_number):
|
762 |
+
return GLMBlock(config, layer_number, device=device)
|
763 |
+
|
764 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
765 |
+
self.num_layers = config.num_layers
|
766 |
+
self.post_layer_norm = config.post_layer_norm
|
767 |
+
|
768 |
+
if self.post_layer_norm:
|
769 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
770 |
+
# Final layer norm before output.
|
771 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
772 |
+
dtype=config.torch_dtype)
|
773 |
+
|
774 |
self.pre_seq_len = config.pre_seq_len
|
775 |
self.prefix_projection = config.prefix_projection
|
776 |
if self.pre_seq_len is not None:
|
|
|
780 |
self.prefix_encoder = PrefixEncoder(config)
|
781 |
self.dropout = torch.nn.Dropout(0.1)
|
782 |
|
783 |
+
self.gradient_checkpointing = False
|
784 |
+
|
785 |
def get_input_embeddings(self):
|
786 |
return self.embedding.word_embeddings
|
787 |
|
|
|
821 |
batch_size, seq_length = input_ids.shape
|
822 |
|
823 |
if inputs_embeds is None:
|
824 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
825 |
|
826 |
if self.pre_seq_len is not None:
|
827 |
if past_key_values is None:
|
|
|
844 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
845 |
|
846 |
# Run encoder.
|
847 |
+
if not past_key_values:
|
848 |
+
past_key_values = [None for _ in range(self.num_layers)]
|
849 |
+
presents = () if use_cache else None
|
850 |
+
if self.gradient_checkpointing and self.training:
|
851 |
+
if use_cache:
|
852 |
+
logger.warning_once(
|
853 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
854 |
+
)
|
855 |
+
use_cache = False
|
856 |
+
|
857 |
+
all_self_attentions = None
|
858 |
+
all_hidden_states = () if output_hidden_states else None
|
859 |
+
|
860 |
+
hidden_states = inputs_embeds
|
861 |
+
# To comply with former chat-glm format that expects (seqlen, bs, hd)
|
862 |
+
hidden_states = hidden_states.permute(1, 0, 2)
|
863 |
+
|
864 |
+
for index, layer in enumerate(self.layers):
|
865 |
+
if output_hidden_states:
|
866 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
867 |
+
|
868 |
+
if self.gradient_checkpointing and self.training:
|
869 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
870 |
+
layer,
|
871 |
+
hidden_states,
|
872 |
+
full_attention_mask,
|
873 |
+
rotary_pos_emb,
|
874 |
+
past_key_values[index],
|
875 |
+
use_cache
|
876 |
+
)
|
877 |
+
else:
|
878 |
+
layer_ret = layer(
|
879 |
+
hidden_states,
|
880 |
+
full_attention_mask,
|
881 |
+
rotary_pos_emb,
|
882 |
+
kv_cache=past_key_values[index],
|
883 |
+
use_cache=use_cache
|
884 |
+
)
|
885 |
+
hidden_states, kv_cache = layer_ret
|
886 |
+
if use_cache:
|
887 |
+
presents = presents + (kv_cache,)
|
888 |
+
|
889 |
+
if output_hidden_states:
|
890 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
891 |
+
|
892 |
+
# Final layer norm.
|
893 |
+
if self.post_layer_norm:
|
894 |
+
hidden_states = self.norm(hidden_states)
|
895 |
|
896 |
if not return_dict:
|
897 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
905 |
|
906 |
def quantize(self, weight_bit_width: int):
|
907 |
from .quantization import quantize
|
908 |
+
quantize(self, weight_bit_width)
|
909 |
return self
|
910 |
|
911 |
|
|
|
914 |
super().__init__(config)
|
915 |
|
916 |
self.max_sequence_length = config.max_length
|
917 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
918 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
919 |
self.config = config
|
920 |
self.quantized = False
|
921 |
|
|
|
996 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
997 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
998 |
|
999 |
+
transformer_outputs = self.model(
|
1000 |
input_ids=input_ids,
|
1001 |
position_ids=position_ids,
|
1002 |
attention_mask=attention_mask,
|
|
|
1010 |
hidden_states = transformer_outputs[0]
|
1011 |
if return_last_logit:
|
1012 |
hidden_states = hidden_states[-1:]
|
1013 |
+
lm_logits = self.lm_head(hidden_states)
|
|
|
1014 |
|
1015 |
loss = None
|
1016 |
if labels is not None:
|
|
|
1123 |
inputs = inputs.to(self.device)
|
1124 |
if past_key_values is not None:
|
1125 |
past_length = past_key_values[0][0].shape[0]
|
1126 |
+
if self.model.pre_seq_len is not None:
|
1127 |
+
past_length -= self.model.pre_seq_len
|
1128 |
inputs.position_ids += past_length
|
1129 |
attention_mask = inputs.attention_mask
|
1130 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
|
1266 |
|
1267 |
self.config.quantization_bit = bits
|
1268 |
|
1269 |
+
self.model = quantize(self.model, bits, empty_init=empty_init, device=device,
|
1270 |
**kwargs)
|
1271 |
return self
|
1272 |
|
|
|
1276 |
super().__init__(config)
|
1277 |
|
1278 |
self.num_labels = config.num_labels
|
1279 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1280 |
|
1281 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1282 |
if config.classifier_dropout is not None:
|
|
|
1303 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1304 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1305 |
|
1306 |
+
transformer_outputs = self.model(
|
1307 |
input_ids=input_ids,
|
1308 |
position_ids=position_ids,
|
1309 |
attention_mask=attention_mask,
|
|
|
1354 |
past_key_values=transformer_outputs.past_key_values,
|
1355 |
hidden_states=transformer_outputs.hidden_states,
|
1356 |
attentions=transformer_outputs.attentions,
|
1357 |
+
)
|