NeoZ123 commited on
Commit
35afe79
1 Parent(s): 7c96587

Upload 14 files

Browse files
README.md CHANGED
@@ -1,3 +1,91 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - zh
5
+ library_name: transformers
6
+ tags:
7
+ - Long Context
8
+ - chatglm
9
+ - llama
10
+ datasets:
11
+ - THUDM/LongReward-10k
12
+ pipeline_tag: text-generation
13
+ ---
14
+ # LongReward-glm4-9b-SFT
15
+
16
+ <p align="center">
17
+ 🤗 <a href="https://huggingface.co/datasets/THUDM/LongReward-10k" target="_blank">[LongReward Dataset] </a> • 💻 <a href="https://github.com/THUDM/LongReward" target="_blank">[Github Repo]</a> • 📃 <a href="https://arxiv.org/abs/" target="_blank">[LongReward Paper]</a>
18
+ </p>
19
+
20
+ LongReward-glm4-9b-SFT is supervisedly fined-tuned from [glm-4-9b](https://huggingface.co/THUDM/glm-4-9b) using the `sft` split of [LongReward-10k](https://huggingface.co/datasets/THUDM/LongReward-45) dataset, and supports a maximum context window of up to 64K tokens.
21
+
22
+ Environment: Same environment requirement as [glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) (`transforemrs>=4.43.0`).
23
+
24
+ A simple demo for deployment of the model:
25
+ ```python
26
+ import torch
27
+ from transformers import AutoTokenizer, AutoModelForCausalLM
28
+
29
+ model_path = "THUDM/LongReward-glm4-9b-SFT"
30
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
31
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map='auto')
32
+ context = '''
33
+ W. Russell Todd, 94, United States Army general (b. 1928). February 13. Tim Aymar, 59, heavy metal singer (Pharaoh) (b. 1963). Marshall \"Eddie\" Conway, 76, Black Panther Party leader (b. 1946). Roger Bonk, 78, football player (North Dakota Fighting Sioux, Winnipeg Blue Bombers) (b. 1944). Conrad Dobler, 72, football player (St. Louis Cardinals, New Orleans Saints, Buffalo Bills) (b. 1950). Brian DuBois, 55, baseball player (Detroit Tigers) (b. 1967). Robert Geddes, 99, architect, dean of the Princeton University School of Architecture (1965–1982) (b. 1923). Tom Luddy, 79, film producer (Barfly, The Secret Garden), co-founder of the Telluride Film Festival (b. 1943). David Singmaster, 84, mathematician (b. 1938).
34
+ '''
35
+ query = "What was Robert Geddes' profession?"
36
+ prompt = context + '\n\n' + query
37
+ response, _ = model.chat(tokenizer, prompt, temprature=1, max_new_tokens=1024)
38
+ print(response)
39
+ ```
40
+
41
+ You can also deploy the model with [vllm](https://github.com/vllm-project/vllm) for faster inference:
42
+ ```python
43
+ import torch
44
+ from vllm import LLM, SamplingParams
45
+
46
+ model_path = "THUDM/LongReward-glm4-9b-SFT"
47
+ model = LLM(
48
+ model= model_path,
49
+ dtype=torch.bfloat16,
50
+ trust_remote_code=True,
51
+ tensor_parallel_size=1,
52
+ max_model_len=65536,
53
+ gpu_memory_utilization=1,
54
+ )
55
+ tokenizer = model.get_tokenizer()
56
+ context = '''
57
+ W. Russell Todd, 94, United States Army general (b. 1928). February 13. Tim Aymar, 59, heavy metal singer (Pharaoh) (b. 1963). Marshall \"Eddie\" Conway, 76, Black Panther Party leader (b. 1946). Roger Bonk, 78, football player (North Dakota Fighting Sioux, Winnipeg Blue Bombers) (b. 1944). Conrad Dobler, 72, football player (St. Louis Cardinals, New Orleans Saints, Buffalo Bills) (b. 1950). Brian DuBois, 55, baseball player (Detroit Tigers) (b. 1967). Robert Geddes, 99, architect, dean of the Princeton University School of Architecture (1965–1982) (b. 1923). Tom Luddy, 79, film producer (Barfly, The Secret Garden), co-founder of the Telluride Film Festival (b. 1943). David Singmaster, 84, mathematician (b. 1938).
58
+ '''
59
+ query = "What was Robert Geddes' profession?"
60
+ prompt = context + '\n\n' + query
61
+ inputs = tokenizer.build_chat_input(prompt, history=[], role='user')
62
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")]
63
+ generation_params = SamplingParams(
64
+ temperature=0.95,
65
+ top_p=0.7,
66
+ max_tokens=1024,
67
+ stop_token_ids=eos_token_id,
68
+ )
69
+ input_ids = inputs.input_ids[0].tolist()
70
+ outputs = model.generate(sampling_params=generation_params, prompt_token_ids=[input_ids])
71
+ response = tokenizer.decode(outputs[0].outputs[0].token_ids[:-1])
72
+ print(response)
73
+ ```
74
+
75
+
76
+ ## License
77
+ [glm-4-9b License](https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/LICENSE)
78
+
79
+ ## Citation
80
+
81
+ If you find our work useful, please consider citing LongReward:
82
+
83
+ ```
84
+ @article{zhang2024longreward,
85
+ title = {LongReward: Improving Long-context Large Language Models
86
+ with AI Feedback}
87
+ author={Jiajie Zhang and Zhongni Hou and Xin Lv and Shulin Cao and Zhenyu Hou and Yilin Niu and Lei Hou and Lei Hou and Yuxiao Dong and Ling Feng and Juanzi Li},
88
+ journal={arXiv preprint arXiv:},
89
+ year={2024}
90
+ }
91
+ ```
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "THUDM/glm-4-9b-chat",
3
+ "model_type": "chatglm",
4
+ "architectures": [
5
+ "ChatGLMModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_chatglm.ChatGLMConfig",
9
+ "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
10
+ "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
11
+ "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
12
+ "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
13
+ },
14
+ "add_bias_linear": false,
15
+ "add_qkv_bias": true,
16
+ "apply_query_key_layer_scaling": true,
17
+ "apply_residual_connection_post_layernorm": false,
18
+ "attention_dropout": 0.0,
19
+ "attention_softmax_in_fp32": true,
20
+ "attn_implementation": "sdpa",
21
+ "bias_dropout_fusion": true,
22
+ "ffn_hidden_size": 13696,
23
+ "fp32_residual_connection": false,
24
+ "hidden_dropout": 0.0,
25
+ "hidden_size": 4096,
26
+ "kv_channels": 128,
27
+ "layernorm_epsilon": 1e-5,
28
+ "multi_query_attention": true,
29
+ "multi_query_group_num": 2,
30
+ "num_attention_heads": 32,
31
+ "num_hidden_layers": 40,
32
+ "num_layers": 40,
33
+ "rope_ratio": 500,
34
+ "original_rope": true,
35
+ "padded_vocab_size": 151552,
36
+ "post_layer_norm": true,
37
+ "rmsnorm": true,
38
+ "seq_length": 65536,
39
+ "use_cache": true,
40
+ "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.43.0",
42
+ "tie_word_embeddings": false,
43
+ "eos_token_id": [151329, 151336, 151338],
44
+ "pad_token_id": 151329
45
+ }
configuration_chatglm.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ChatGLMConfig(PretrainedConfig):
5
+ model_type = "chatglm"
6
+
7
+ def __init__(
8
+ self,
9
+ num_layers=28,
10
+ padded_vocab_size=65024,
11
+ hidden_size=4096,
12
+ ffn_hidden_size=13696,
13
+ kv_channels=128,
14
+ num_attention_heads=32,
15
+ seq_length=2048,
16
+ hidden_dropout=0.0,
17
+ classifier_dropout=None,
18
+ attention_dropout=0.0,
19
+ layernorm_epsilon=1e-5,
20
+ rmsnorm=True,
21
+ apply_residual_connection_post_layernorm=False,
22
+ post_layer_norm=True,
23
+ add_bias_linear=False,
24
+ add_qkv_bias=False,
25
+ bias_dropout_fusion=True,
26
+ multi_query_attention=False,
27
+ multi_query_group_num=1,
28
+ rope_ratio=1,
29
+ apply_query_key_layer_scaling=True,
30
+ attention_softmax_in_fp32=True,
31
+ fp32_residual_connection=False,
32
+ **kwargs
33
+ ):
34
+ self.num_layers = num_layers
35
+ self.vocab_size = padded_vocab_size
36
+ self.padded_vocab_size = padded_vocab_size
37
+ self.hidden_size = hidden_size
38
+ self.ffn_hidden_size = ffn_hidden_size
39
+ self.kv_channels = kv_channels
40
+ self.num_attention_heads = num_attention_heads
41
+ self.seq_length = seq_length
42
+ self.hidden_dropout = hidden_dropout
43
+ self.classifier_dropout = classifier_dropout
44
+ self.attention_dropout = attention_dropout
45
+ self.layernorm_epsilon = layernorm_epsilon
46
+ self.rmsnorm = rmsnorm
47
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
48
+ self.post_layer_norm = post_layer_norm
49
+ self.add_bias_linear = add_bias_linear
50
+ self.add_qkv_bias = add_qkv_bias
51
+ self.bias_dropout_fusion = bias_dropout_fusion
52
+ self.multi_query_attention = multi_query_attention
53
+ self.multi_query_group_num = multi_query_group_num
54
+ self.rope_ratio = rope_ratio
55
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57
+ self.fp32_residual_connection = fp32_residual_connection
58
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token_id": [
3
+ 151329,
4
+ 151336,
5
+ 151338
6
+ ],
7
+ "pad_token_id": 151329,
8
+ "do_sample": true,
9
+ "temperature": 0.8,
10
+ "max_length": 128000,
11
+ "top_p": 0.8,
12
+ "transformers_version": "4.40.2"
13
+ }
model-00000-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff4c8ea839805cc9f7aeb57c4832525fdc04a8f56bb34dff28f7004987fd6de5
3
+ size 4079226064
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8a0597aa5b9fa44d01c89e31a51f095e144eec1da7cd15ae2392b5a6b4a65f6
3
+ size 4079226136
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f01577fffd10041270665020ac02bd038202cc57a4437ad1e8998c7720b28a4
3
+ size 4079226136
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e925cb1224b1c4cd1e70451be4953d741f3007b62ec80843f6135ae115efb295
3
+ size 4079226136
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d80b9b3c8b19a3605f8991d050a6f1acb4c83d9185fc22dc49339093997c44a6
3
+ size 2483036544
model.safetensors.index.json ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 21043855360
4
+ },
5
+ "weight_map": {
6
+ "transformer.encoder.layers.0.input_layernorm.weight": "model-00000-of-00005.safetensors",
7
+ "transformer.encoder.layers.0.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
8
+ "transformer.encoder.layers.0.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
9
+ "transformer.encoder.layers.0.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
10
+ "transformer.encoder.layers.0.self_attention.dense.weight": "model-00000-of-00005.safetensors",
11
+ "transformer.encoder.layers.0.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
12
+ "transformer.encoder.layers.0.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
13
+ "transformer.encoder.layers.2.input_layernorm.weight": "model-00000-of-00005.safetensors",
14
+ "transformer.encoder.layers.2.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
15
+ "transformer.encoder.layers.2.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
16
+ "transformer.encoder.layers.2.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
17
+ "transformer.encoder.layers.2.self_attention.dense.weight": "model-00000-of-00005.safetensors",
18
+ "transformer.encoder.layers.2.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
19
+ "transformer.encoder.layers.2.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
20
+ "transformer.encoder.layers.10.input_layernorm.weight": "model-00001-of-00005.safetensors",
21
+ "transformer.encoder.layers.10.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
22
+ "transformer.encoder.layers.10.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
23
+ "transformer.encoder.layers.10.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
24
+ "transformer.encoder.layers.10.self_attention.dense.weight": "model-00001-of-00005.safetensors",
25
+ "transformer.encoder.layers.10.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
26
+ "transformer.encoder.layers.10.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
27
+ "transformer.encoder.layers.14.input_layernorm.weight": "model-00001-of-00005.safetensors",
28
+ "transformer.encoder.layers.14.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
29
+ "transformer.encoder.layers.14.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
30
+ "transformer.encoder.layers.14.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
31
+ "transformer.encoder.layers.14.self_attention.dense.weight": "model-00001-of-00005.safetensors",
32
+ "transformer.encoder.layers.14.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
33
+ "transformer.encoder.layers.14.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
34
+ "transformer.encoder.layers.32.input_layernorm.weight": "model-00003-of-00005.safetensors",
35
+ "transformer.encoder.layers.32.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
36
+ "transformer.encoder.layers.32.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
37
+ "transformer.encoder.layers.32.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
38
+ "transformer.encoder.layers.32.self_attention.dense.weight": "model-00003-of-00005.safetensors",
39
+ "transformer.encoder.layers.32.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
40
+ "transformer.encoder.layers.32.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
41
+ "transformer.encoder.layers.27.input_layernorm.weight": "model-00002-of-00005.safetensors",
42
+ "transformer.encoder.layers.27.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
43
+ "transformer.encoder.layers.27.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
44
+ "transformer.encoder.layers.27.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
45
+ "transformer.encoder.layers.27.self_attention.dense.weight": "model-00002-of-00005.safetensors",
46
+ "transformer.encoder.layers.27.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
47
+ "transformer.encoder.layers.27.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
48
+ "transformer.encoder.layers.25.input_layernorm.weight": "model-00002-of-00005.safetensors",
49
+ "transformer.encoder.layers.25.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
50
+ "transformer.encoder.layers.25.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
51
+ "transformer.encoder.layers.25.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
52
+ "transformer.encoder.layers.25.self_attention.dense.weight": "model-00002-of-00005.safetensors",
53
+ "transformer.encoder.layers.25.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
54
+ "transformer.encoder.layers.25.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
55
+ "transformer.encoder.layers.20.input_layernorm.weight": "model-00002-of-00005.safetensors",
56
+ "transformer.encoder.layers.20.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
57
+ "transformer.encoder.layers.20.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
58
+ "transformer.encoder.layers.20.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
59
+ "transformer.encoder.layers.20.self_attention.dense.weight": "model-00002-of-00005.safetensors",
60
+ "transformer.encoder.layers.20.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
61
+ "transformer.encoder.layers.20.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
62
+ "transformer.encoder.layers.7.input_layernorm.weight": "model-00000-of-00005.safetensors",
63
+ "transformer.encoder.layers.7.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
64
+ "transformer.encoder.layers.7.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
65
+ "transformer.encoder.layers.7.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
66
+ "transformer.encoder.layers.7.self_attention.dense.weight": "model-00000-of-00005.safetensors",
67
+ "transformer.encoder.layers.7.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
68
+ "transformer.encoder.layers.7.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
69
+ "transformer.encoder.layers.36.input_layernorm.weight": "model-00003-of-00005.safetensors",
70
+ "transformer.encoder.layers.36.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
71
+ "transformer.encoder.layers.36.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
72
+ "transformer.encoder.layers.36.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
73
+ "transformer.encoder.layers.36.self_attention.dense.weight": "model-00003-of-00005.safetensors",
74
+ "transformer.encoder.layers.36.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
75
+ "transformer.encoder.layers.36.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
76
+ "transformer.encoder.layers.1.input_layernorm.weight": "model-00000-of-00005.safetensors",
77
+ "transformer.encoder.layers.1.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
78
+ "transformer.encoder.layers.1.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
79
+ "transformer.encoder.layers.1.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
80
+ "transformer.encoder.layers.1.self_attention.dense.weight": "model-00000-of-00005.safetensors",
81
+ "transformer.encoder.layers.1.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
82
+ "transformer.encoder.layers.1.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
83
+ "transformer.encoder.layers.38.input_layernorm.weight": "model-00003-of-00005.safetensors",
84
+ "transformer.encoder.layers.38.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
85
+ "transformer.encoder.layers.38.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
86
+ "transformer.encoder.layers.38.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
87
+ "transformer.encoder.layers.38.self_attention.dense.weight": "model-00003-of-00005.safetensors",
88
+ "transformer.encoder.layers.38.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
89
+ "transformer.encoder.layers.38.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
90
+ "transformer.encoder.layers.28.input_layernorm.weight": "model-00002-of-00005.safetensors",
91
+ "transformer.encoder.layers.28.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
92
+ "transformer.encoder.layers.28.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
93
+ "transformer.encoder.layers.28.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
94
+ "transformer.encoder.layers.28.self_attention.dense.weight": "model-00002-of-00005.safetensors",
95
+ "transformer.encoder.layers.28.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
96
+ "transformer.encoder.layers.28.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
97
+ "transformer.encoder.layers.39.input_layernorm.weight": "model-00003-of-00005.safetensors",
98
+ "transformer.encoder.layers.39.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
99
+ "transformer.encoder.layers.39.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
100
+ "transformer.encoder.layers.39.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
101
+ "transformer.encoder.layers.39.self_attention.dense.weight": "model-00003-of-00005.safetensors",
102
+ "transformer.encoder.layers.39.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
103
+ "transformer.encoder.layers.39.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
104
+ "transformer.encoder.layers.24.input_layernorm.weight": "model-00002-of-00005.safetensors",
105
+ "transformer.encoder.layers.24.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
106
+ "transformer.encoder.layers.24.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
107
+ "transformer.encoder.layers.24.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
108
+ "transformer.encoder.layers.24.self_attention.dense.weight": "model-00002-of-00005.safetensors",
109
+ "transformer.encoder.layers.24.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
110
+ "transformer.encoder.layers.24.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
111
+ "transformer.encoder.layers.9.input_layernorm.weight": "model-00000-of-00005.safetensors",
112
+ "transformer.encoder.layers.9.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
113
+ "transformer.encoder.layers.9.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
114
+ "transformer.encoder.layers.9.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
115
+ "transformer.encoder.layers.9.self_attention.dense.weight": "model-00000-of-00005.safetensors",
116
+ "transformer.encoder.layers.9.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
117
+ "transformer.encoder.layers.9.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
118
+ "transformer.encoder.layers.34.input_layernorm.weight": "model-00003-of-00005.safetensors",
119
+ "transformer.encoder.layers.34.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
120
+ "transformer.encoder.layers.34.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
121
+ "transformer.encoder.layers.34.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
122
+ "transformer.encoder.layers.34.self_attention.dense.weight": "model-00003-of-00005.safetensors",
123
+ "transformer.encoder.layers.34.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
124
+ "transformer.encoder.layers.34.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
125
+ "transformer.encoder.layers.26.input_layernorm.weight": "model-00002-of-00005.safetensors",
126
+ "transformer.encoder.layers.26.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
127
+ "transformer.encoder.layers.26.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
128
+ "transformer.encoder.layers.26.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
129
+ "transformer.encoder.layers.26.self_attention.dense.weight": "model-00002-of-00005.safetensors",
130
+ "transformer.encoder.layers.26.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
131
+ "transformer.encoder.layers.26.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
132
+ "transformer.encoder.layers.18.input_layernorm.weight": "model-00001-of-00005.safetensors",
133
+ "transformer.encoder.layers.18.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
134
+ "transformer.encoder.layers.18.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
135
+ "transformer.encoder.layers.18.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
136
+ "transformer.encoder.layers.18.self_attention.dense.weight": "model-00001-of-00005.safetensors",
137
+ "transformer.encoder.layers.18.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
138
+ "transformer.encoder.layers.18.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
139
+ "transformer.encoder.layers.8.input_layernorm.weight": "model-00000-of-00005.safetensors",
140
+ "transformer.encoder.layers.8.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
141
+ "transformer.encoder.layers.8.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
142
+ "transformer.encoder.layers.8.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
143
+ "transformer.encoder.layers.8.self_attention.dense.weight": "model-00000-of-00005.safetensors",
144
+ "transformer.encoder.layers.8.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
145
+ "transformer.encoder.layers.8.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
146
+ "transformer.encoder.layers.37.input_layernorm.weight": "model-00003-of-00005.safetensors",
147
+ "transformer.encoder.layers.37.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
148
+ "transformer.encoder.layers.37.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
149
+ "transformer.encoder.layers.37.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
150
+ "transformer.encoder.layers.37.self_attention.dense.weight": "model-00003-of-00005.safetensors",
151
+ "transformer.encoder.layers.37.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
152
+ "transformer.encoder.layers.37.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
153
+ "transformer.encoder.layers.35.input_layernorm.weight": "model-00003-of-00005.safetensors",
154
+ "transformer.encoder.layers.35.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
155
+ "transformer.encoder.layers.35.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
156
+ "transformer.encoder.layers.35.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
157
+ "transformer.encoder.layers.35.self_attention.dense.weight": "model-00003-of-00005.safetensors",
158
+ "transformer.encoder.layers.35.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
159
+ "transformer.encoder.layers.35.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
160
+ "transformer.encoder.layers.22.input_layernorm.weight": "model-00002-of-00005.safetensors",
161
+ "transformer.encoder.layers.22.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
162
+ "transformer.encoder.layers.22.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
163
+ "transformer.encoder.layers.22.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
164
+ "transformer.encoder.layers.22.self_attention.dense.weight": "model-00002-of-00005.safetensors",
165
+ "transformer.encoder.layers.22.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
166
+ "transformer.encoder.layers.22.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
167
+ "transformer.encoder.layers.4.input_layernorm.weight": "model-00000-of-00005.safetensors",
168
+ "transformer.encoder.layers.4.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
169
+ "transformer.encoder.layers.4.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
170
+ "transformer.encoder.layers.4.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
171
+ "transformer.encoder.layers.4.self_attention.dense.weight": "model-00000-of-00005.safetensors",
172
+ "transformer.encoder.layers.4.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
173
+ "transformer.encoder.layers.4.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
174
+ "transformer.embedding.word_embeddings.weight": "model-00004-of-00005.safetensors",
175
+ "transformer.output_layer.weight": "model-00004-of-00005.safetensors",
176
+ "transformer.encoder.final_layernorm.weight": "model-00004-of-00005.safetensors",
177
+ "transformer.encoder.layers.6.input_layernorm.weight": "model-00000-of-00005.safetensors",
178
+ "transformer.encoder.layers.6.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
179
+ "transformer.encoder.layers.6.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
180
+ "transformer.encoder.layers.6.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
181
+ "transformer.encoder.layers.6.self_attention.dense.weight": "model-00000-of-00005.safetensors",
182
+ "transformer.encoder.layers.6.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
183
+ "transformer.encoder.layers.6.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
184
+ "transformer.encoder.layers.33.input_layernorm.weight": "model-00003-of-00005.safetensors",
185
+ "transformer.encoder.layers.33.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
186
+ "transformer.encoder.layers.33.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
187
+ "transformer.encoder.layers.33.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
188
+ "transformer.encoder.layers.33.self_attention.dense.weight": "model-00003-of-00005.safetensors",
189
+ "transformer.encoder.layers.33.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
190
+ "transformer.encoder.layers.33.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
191
+ "transformer.encoder.layers.31.input_layernorm.weight": "model-00003-of-00005.safetensors",
192
+ "transformer.encoder.layers.31.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
193
+ "transformer.encoder.layers.31.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
194
+ "transformer.encoder.layers.31.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
195
+ "transformer.encoder.layers.31.self_attention.dense.weight": "model-00003-of-00005.safetensors",
196
+ "transformer.encoder.layers.31.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
197
+ "transformer.encoder.layers.31.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
198
+ "transformer.encoder.layers.21.input_layernorm.weight": "model-00002-of-00005.safetensors",
199
+ "transformer.encoder.layers.21.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
200
+ "transformer.encoder.layers.21.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
201
+ "transformer.encoder.layers.21.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
202
+ "transformer.encoder.layers.21.self_attention.dense.weight": "model-00002-of-00005.safetensors",
203
+ "transformer.encoder.layers.21.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
204
+ "transformer.encoder.layers.21.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
205
+ "transformer.encoder.layers.12.input_layernorm.weight": "model-00001-of-00005.safetensors",
206
+ "transformer.encoder.layers.12.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
207
+ "transformer.encoder.layers.12.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
208
+ "transformer.encoder.layers.12.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
209
+ "transformer.encoder.layers.12.self_attention.dense.weight": "model-00001-of-00005.safetensors",
210
+ "transformer.encoder.layers.12.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
211
+ "transformer.encoder.layers.12.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
212
+ "transformer.encoder.layers.16.input_layernorm.weight": "model-00001-of-00005.safetensors",
213
+ "transformer.encoder.layers.16.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
214
+ "transformer.encoder.layers.16.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
215
+ "transformer.encoder.layers.16.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
216
+ "transformer.encoder.layers.16.self_attention.dense.weight": "model-00001-of-00005.safetensors",
217
+ "transformer.encoder.layers.16.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
218
+ "transformer.encoder.layers.16.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
219
+ "transformer.encoder.layers.17.input_layernorm.weight": "model-00001-of-00005.safetensors",
220
+ "transformer.encoder.layers.17.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
221
+ "transformer.encoder.layers.17.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
222
+ "transformer.encoder.layers.17.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
223
+ "transformer.encoder.layers.17.self_attention.dense.weight": "model-00001-of-00005.safetensors",
224
+ "transformer.encoder.layers.17.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
225
+ "transformer.encoder.layers.17.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
226
+ "transformer.encoder.layers.30.input_layernorm.weight": "model-00003-of-00005.safetensors",
227
+ "transformer.encoder.layers.30.post_attention_layernorm.weight": "model-00003-of-00005.safetensors",
228
+ "transformer.encoder.layers.30.self_attention.query_key_value.weight": "model-00003-of-00005.safetensors",
229
+ "transformer.encoder.layers.30.self_attention.query_key_value.bias": "model-00003-of-00005.safetensors",
230
+ "transformer.encoder.layers.30.self_attention.dense.weight": "model-00003-of-00005.safetensors",
231
+ "transformer.encoder.layers.30.mlp.dense_h_to_4h.weight": "model-00003-of-00005.safetensors",
232
+ "transformer.encoder.layers.30.mlp.dense_4h_to_h.weight": "model-00003-of-00005.safetensors",
233
+ "transformer.encoder.layers.5.input_layernorm.weight": "model-00000-of-00005.safetensors",
234
+ "transformer.encoder.layers.5.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
235
+ "transformer.encoder.layers.5.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
236
+ "transformer.encoder.layers.5.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
237
+ "transformer.encoder.layers.5.self_attention.dense.weight": "model-00000-of-00005.safetensors",
238
+ "transformer.encoder.layers.5.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
239
+ "transformer.encoder.layers.5.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
240
+ "transformer.encoder.layers.11.input_layernorm.weight": "model-00001-of-00005.safetensors",
241
+ "transformer.encoder.layers.11.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
242
+ "transformer.encoder.layers.11.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
243
+ "transformer.encoder.layers.11.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
244
+ "transformer.encoder.layers.11.self_attention.dense.weight": "model-00001-of-00005.safetensors",
245
+ "transformer.encoder.layers.11.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
246
+ "transformer.encoder.layers.11.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
247
+ "transformer.encoder.layers.3.input_layernorm.weight": "model-00000-of-00005.safetensors",
248
+ "transformer.encoder.layers.3.post_attention_layernorm.weight": "model-00000-of-00005.safetensors",
249
+ "transformer.encoder.layers.3.self_attention.query_key_value.weight": "model-00000-of-00005.safetensors",
250
+ "transformer.encoder.layers.3.self_attention.query_key_value.bias": "model-00000-of-00005.safetensors",
251
+ "transformer.encoder.layers.3.self_attention.dense.weight": "model-00000-of-00005.safetensors",
252
+ "transformer.encoder.layers.3.mlp.dense_h_to_4h.weight": "model-00000-of-00005.safetensors",
253
+ "transformer.encoder.layers.3.mlp.dense_4h_to_h.weight": "model-00000-of-00005.safetensors",
254
+ "transformer.encoder.layers.29.input_layernorm.weight": "model-00002-of-00005.safetensors",
255
+ "transformer.encoder.layers.29.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
256
+ "transformer.encoder.layers.29.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
257
+ "transformer.encoder.layers.29.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
258
+ "transformer.encoder.layers.29.self_attention.dense.weight": "model-00002-of-00005.safetensors",
259
+ "transformer.encoder.layers.29.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
260
+ "transformer.encoder.layers.29.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
261
+ "transformer.encoder.layers.23.input_layernorm.weight": "model-00002-of-00005.safetensors",
262
+ "transformer.encoder.layers.23.post_attention_layernorm.weight": "model-00002-of-00005.safetensors",
263
+ "transformer.encoder.layers.23.self_attention.query_key_value.weight": "model-00002-of-00005.safetensors",
264
+ "transformer.encoder.layers.23.self_attention.query_key_value.bias": "model-00002-of-00005.safetensors",
265
+ "transformer.encoder.layers.23.self_attention.dense.weight": "model-00002-of-00005.safetensors",
266
+ "transformer.encoder.layers.23.mlp.dense_h_to_4h.weight": "model-00002-of-00005.safetensors",
267
+ "transformer.encoder.layers.23.mlp.dense_4h_to_h.weight": "model-00002-of-00005.safetensors",
268
+ "transformer.encoder.layers.13.input_layernorm.weight": "model-00001-of-00005.safetensors",
269
+ "transformer.encoder.layers.13.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
270
+ "transformer.encoder.layers.13.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
271
+ "transformer.encoder.layers.13.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
272
+ "transformer.encoder.layers.13.self_attention.dense.weight": "model-00001-of-00005.safetensors",
273
+ "transformer.encoder.layers.13.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
274
+ "transformer.encoder.layers.13.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
275
+ "transformer.encoder.layers.15.input_layernorm.weight": "model-00001-of-00005.safetensors",
276
+ "transformer.encoder.layers.15.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
277
+ "transformer.encoder.layers.15.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
278
+ "transformer.encoder.layers.15.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
279
+ "transformer.encoder.layers.15.self_attention.dense.weight": "model-00001-of-00005.safetensors",
280
+ "transformer.encoder.layers.15.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
281
+ "transformer.encoder.layers.15.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors",
282
+ "transformer.encoder.layers.19.input_layernorm.weight": "model-00001-of-00005.safetensors",
283
+ "transformer.encoder.layers.19.post_attention_layernorm.weight": "model-00001-of-00005.safetensors",
284
+ "transformer.encoder.layers.19.self_attention.query_key_value.weight": "model-00001-of-00005.safetensors",
285
+ "transformer.encoder.layers.19.self_attention.query_key_value.bias": "model-00001-of-00005.safetensors",
286
+ "transformer.encoder.layers.19.self_attention.dense.weight": "model-00001-of-00005.safetensors",
287
+ "transformer.encoder.layers.19.mlp.dense_h_to_4h.weight": "model-00001-of-00005.safetensors",
288
+ "transformer.encoder.layers.19.mlp.dense_4h_to_h.weight": "model-00001-of-00005.safetensors"
289
+ }
290
+ }
modeling_chatglm.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch ChatGLM model. """
2
+
3
+ import math
4
+ import sys
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
10
+ from torch.nn.utils import skip_init
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast,
16
+ SequenceClassifierOutputWithPast,
17
+ )
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging, is_torch_npu_available
20
+ from transformers.generation.logits_process import LogitsProcessor
21
+ from transformers.generation.utils import ModelOutput
22
+
23
+ from .configuration_chatglm import ChatGLMConfig
24
+
25
+ try:
26
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+ except:
32
+ pass
33
+
34
+ # flags required to enable jit fusion kernels
35
+
36
+ if sys.platform != 'darwin' and not is_torch_npu_available():
37
+ torch._C._jit_set_profiling_mode(False)
38
+ torch._C._jit_set_profiling_executor(False)
39
+ torch._C._jit_override_can_fuse_on_cpu(True)
40
+ torch._C._jit_override_can_fuse_on_gpu(True)
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
45
+ _CONFIG_FOR_DOC = "ChatGLMConfig"
46
+
47
+
48
+ def default_init(cls, *args, **kwargs):
49
+ return cls(*args, **kwargs)
50
+
51
+
52
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
53
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
54
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
55
+ scores.zero_()
56
+ scores[..., 198] = 5e4
57
+ return scores
58
+
59
+
60
+ def split_tensor_along_last_dim(
61
+ tensor: torch.Tensor,
62
+ num_partitions: int,
63
+ contiguous_split_chunks: bool = False,
64
+ ) -> List[torch.Tensor]:
65
+ """Split a tensor along its last dimension.
66
+ Arguments:
67
+ tensor: input tensor.
68
+ num_partitions: number of partitions to split the tensor
69
+ contiguous_split_chunks: If True, make each chunk contiguous
70
+ in memory.
71
+ Returns:
72
+ A list of Tensors
73
+ """
74
+ # Get the size and dimension.
75
+ last_dim = tensor.dim() - 1
76
+ last_dim_size = tensor.size()[last_dim] // num_partitions
77
+ # Split.
78
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
79
+ # Note: torch.split does not create contiguous tensors by default.
80
+ if contiguous_split_chunks:
81
+ return tuple(chunk.contiguous() for chunk in tensor_list)
82
+
83
+ return tensor_list
84
+
85
+
86
+ class RotaryEmbedding(nn.Module):
87
+ def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
88
+ super().__init__()
89
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
90
+ self.register_buffer("inv_freq", inv_freq)
91
+ self.dim = dim
92
+ self.original_impl = original_impl
93
+ self.rope_ratio = rope_ratio
94
+
95
+ def forward_impl(
96
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
97
+ ):
98
+ """Enhanced Transformer with Rotary Position Embedding.
99
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
100
+ transformers/rope/__init__.py. MIT License:
101
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
102
+ """
103
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
104
+ base = base * self.rope_ratio
105
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
106
+
107
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
108
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
109
+
110
+ # Calculate the product of position index and $\theta_i$
111
+ idx_theta = torch.outer(seq_idx, theta).float()
112
+
113
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
114
+
115
+ # this is to mimic the behaviour of complex32, else we will get different results
116
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
117
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
118
+ return cache
119
+
120
+ def forward(self, max_seq_len, offset=0):
121
+ return self.forward_impl(
122
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
123
+ )
124
+
125
+
126
+ @torch.jit.script
127
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
128
+ # x: [b, np, sq, hn]
129
+ b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
130
+ rot_dim = rope_cache.shape[-2] * 2
131
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
132
+ # truncate to support variable sizes
133
+ rope_cache = rope_cache[:, :sq]
134
+ xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
135
+ rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
136
+ x_out2 = torch.stack(
137
+ [
138
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
139
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
140
+ ],
141
+ -1,
142
+ )
143
+ x_out2 = x_out2.flatten(3)
144
+ return torch.cat((x_out2, x_pass), dim=-1)
145
+
146
+
147
+ class RMSNorm(torch.nn.Module):
148
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
149
+ super().__init__()
150
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
151
+ self.eps = eps
152
+
153
+ def forward(self, hidden_states: torch.Tensor):
154
+ input_dtype = hidden_states.dtype
155
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
156
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
157
+
158
+ return (self.weight * hidden_states).to(input_dtype)
159
+
160
+
161
+ class CoreAttention(torch.nn.Module):
162
+ def __init__(self, config: ChatGLMConfig, layer_number):
163
+ super(CoreAttention, self).__init__()
164
+ self.config = config
165
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
166
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
167
+ if self.apply_query_key_layer_scaling:
168
+ self.attention_softmax_in_fp32 = True
169
+ self.layer_number = max(1, layer_number)
170
+ self.is_causal = True
171
+
172
+ projection_size = config.kv_channels * config.num_attention_heads
173
+
174
+ # Per attention head and per partition values.
175
+ self.hidden_size_per_partition = projection_size
176
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
177
+ self.num_attention_heads_per_partition = config.num_attention_heads
178
+
179
+ coeff = None
180
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
181
+ if self.apply_query_key_layer_scaling:
182
+ coeff = self.layer_number
183
+ self.norm_factor *= coeff
184
+ self.coeff = coeff
185
+
186
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
187
+
188
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
189
+ # [b, np, sq, sk]
190
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
191
+
192
+ # [b, np, sq, hn] -> [b * np, sq, hn]
193
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
194
+ # [b, np, sk, hn] -> [b * np, sk, hn]
195
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
196
+
197
+ # preallocting input tensor: [b * np, sq, sk]
198
+ matmul_input_buffer = torch.empty(
199
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
200
+ device=query_layer.device
201
+ )
202
+
203
+ # Raw attention scores. [b * np, sq, sk]
204
+ matmul_result = torch.baddbmm(
205
+ matmul_input_buffer,
206
+ query_layer, # [b * np, sq, hn]
207
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
208
+ beta=0.0,
209
+ alpha=(1.0 / self.norm_factor),
210
+ )
211
+
212
+ # change view to [b, np, sq, sk]
213
+ attention_scores = matmul_result.view(*output_size)
214
+
215
+ # ===========================
216
+ # Attention probs and dropout
217
+ # ===========================
218
+
219
+ # attention scores and attention mask [b, np, sq, sk]
220
+ if self.attention_softmax_in_fp32:
221
+ attention_scores = attention_scores.float()
222
+ if self.coeff is not None:
223
+ attention_scores = attention_scores * self.coeff
224
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
225
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
226
+ device=attention_scores.device, dtype=torch.bool)
227
+ attention_mask.tril_()
228
+ attention_mask = ~attention_mask
229
+ if attention_mask is not None:
230
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
231
+ attention_probs = F.softmax(attention_scores, dim=-1)
232
+ attention_probs = attention_probs.type_as(value_layer)
233
+
234
+ # This is actually dropping out entire tokens to attend to, which might
235
+ # seem a bit unusual, but is taken from the original Transformer paper.
236
+ attention_probs = self.attention_dropout(attention_probs)
237
+
238
+ # query layer shape: [b * np, sq, hn]
239
+ # value layer shape: [b, np, sk, hn]
240
+ # attention shape: [b, np, sq, sk]
241
+ # context layer shape: [b, np, sq, hn]
242
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
243
+ # change view [b * np, sk, hn]
244
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
245
+ # change view [b * np, sq, sk]
246
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
247
+ # matmul: [b * np, sq, hn]
248
+ context_layer = torch.bmm(attention_probs, value_layer)
249
+ # change view [b, np, sq, hn]
250
+ context_layer = context_layer.view(*output_size)
251
+ # [b, np, sq, hn] --> [b, sq, np, hn]
252
+ context_layer = context_layer.transpose(1, 2).contiguous()
253
+ # [b, sq, np, hn] --> [b, sq, hp]
254
+ splited_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
255
+ context_layer = context_layer.reshape(*splited_context_layer_shape)
256
+
257
+ return context_layer
258
+
259
+
260
+ class SdpaAttention(CoreAttention):
261
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
262
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
263
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
264
+ is_causal=True,
265
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
266
+ else:
267
+ if attention_mask is not None:
268
+ attention_mask = ~attention_mask
269
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
270
+ attention_mask,
271
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
272
+ context_layer = context_layer.transpose(1, 2).contiguous()
273
+ splited_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
274
+ context_layer = context_layer.reshape(*splited_context_layer_shape)
275
+ return context_layer
276
+
277
+
278
+ def _get_unpad_data(attention_mask):
279
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
280
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
281
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
282
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
283
+ return (
284
+ indices,
285
+ cu_seqlens,
286
+ max_seqlen_in_batch,
287
+ )
288
+
289
+
290
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
291
+ class FlashAttention2(CoreAttention):
292
+ def __init__(self, *args, **kwargs):
293
+ super().__init__(*args, **kwargs)
294
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
295
+
296
+ def forward(self, query_states, key_states, value_states, attention_mask):
297
+ query_states = query_states.transpose(1, 2)
298
+ key_states = key_states.transpose(1, 2)
299
+ value_states = value_states.transpose(1, 2)
300
+ batch_size, query_length = query_states.shape[:2]
301
+ if not self._flash_attn_uses_top_left_mask:
302
+ causal = self.is_causal
303
+ else:
304
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
305
+ causal = self.is_causal and query_length != 1
306
+ dropout = self.config.attention_dropout if self.training else 0.0
307
+ # Contains at least one padding token in the sequence
308
+ if attention_mask is not None:
309
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
310
+ query_states, key_states, value_states, attention_mask, query_length
311
+ )
312
+
313
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
314
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
315
+
316
+ attn_output_unpad = flash_attn_varlen_func(
317
+ query_states,
318
+ key_states,
319
+ value_states,
320
+ cu_seqlens_q=cu_seqlens_q,
321
+ cu_seqlens_k=cu_seqlens_k,
322
+ max_seqlen_q=max_seqlen_in_batch_q,
323
+ max_seqlen_k=max_seqlen_in_batch_k,
324
+ dropout_p=dropout,
325
+ softmax_scale=None,
326
+ causal=causal,
327
+ )
328
+
329
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
330
+ else:
331
+ attn_output = flash_attn_func(
332
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
333
+ )
334
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
335
+ return attn_output
336
+
337
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
338
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
339
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
340
+
341
+ key_layer = index_first_axis(
342
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
343
+ )
344
+ value_layer = index_first_axis(
345
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
346
+ )
347
+ if query_length == kv_seq_len:
348
+ query_layer = index_first_axis(
349
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
350
+ indices_k
351
+ )
352
+ cu_seqlens_q = cu_seqlens_k
353
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
354
+ indices_q = indices_k
355
+ elif query_length == 1:
356
+ max_seqlen_in_batch_q = 1
357
+ cu_seqlens_q = torch.arange(
358
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
359
+ ) # There is a memcpy here, that is very bad.
360
+ indices_q = cu_seqlens_q[:-1]
361
+ query_layer = query_layer.squeeze(1)
362
+ else:
363
+ # The -q_len: slice assumes left padding.
364
+ attention_mask = attention_mask[:, -query_length:]
365
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
366
+
367
+ return (
368
+ query_layer,
369
+ key_layer,
370
+ value_layer,
371
+ indices_q,
372
+ (cu_seqlens_q, cu_seqlens_k),
373
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
374
+ )
375
+
376
+
377
+ CORE_ATTENTION_CLASSES = {
378
+ "eager": CoreAttention,
379
+ "sdpa": SdpaAttention,
380
+ "flash_attention_2": FlashAttention2
381
+ }
382
+
383
+
384
+ class SelfAttention(torch.nn.Module):
385
+ """Parallel self-attention layer abstract class.
386
+ Self-attention layer takes input with size [s, b, h]
387
+ and returns output of the same size.
388
+ """
389
+
390
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
391
+ super(SelfAttention, self).__init__()
392
+ self.layer_number = max(1, layer_number)
393
+
394
+ self.projection_size = config.kv_channels * config.num_attention_heads
395
+
396
+ # Per attention head and per partition values.
397
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
398
+ self.num_attention_heads_per_partition = config.num_attention_heads
399
+
400
+ self.multi_query_attention = config.multi_query_attention
401
+ self.qkv_hidden_size = 3 * self.projection_size
402
+ if self.multi_query_attention:
403
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
404
+ self.qkv_hidden_size = (
405
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
406
+ )
407
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
408
+ bias=config.add_bias_linear or config.add_qkv_bias,
409
+ device=device, **_config_to_kwargs(config)
410
+ )
411
+
412
+ self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number)
413
+
414
+ # Output.
415
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
416
+ device=device, **_config_to_kwargs(config)
417
+ )
418
+
419
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
420
+ if self.multi_query_attention:
421
+ num_attention_heads = self.num_multi_query_groups_per_partition
422
+ else:
423
+ num_attention_heads = self.num_attention_heads_per_partition
424
+ return torch.empty(
425
+ inference_max_sequence_len,
426
+ batch_size,
427
+ num_attention_heads,
428
+ self.hidden_size_per_attention_head,
429
+ dtype=dtype,
430
+ device=device,
431
+ )
432
+
433
+ def forward(
434
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
435
+ ):
436
+ # hidden_states: [b, sq, h]
437
+
438
+ # =================================================
439
+ # Pre-allocate memory for key-values for inference.
440
+ # =================================================
441
+ # =====================
442
+ # Query, Key, and Value
443
+ # =====================
444
+
445
+ # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
446
+ mixed_x_layer = self.query_key_value(hidden_states)
447
+
448
+ if self.multi_query_attention:
449
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
450
+ [
451
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
452
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
453
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
454
+ ],
455
+ dim=-1,
456
+ )
457
+ query_layer = query_layer.view(
458
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
459
+ )
460
+ key_layer = key_layer.view(
461
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
462
+ )
463
+ value_layer = value_layer.view(
464
+ value_layer.size()[:-1]
465
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
466
+ )
467
+ else:
468
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
469
+ (self.num_attention_heads_per_partition,
470
+ 3 * self.hidden_size_per_attention_head)
471
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
472
+
473
+ # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
474
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
475
+
476
+ # [b, sq, np, hn] -> [b, np, sq, hn]
477
+ query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
478
+
479
+ # apply relative positional encoding (rotary embedding)
480
+ if rotary_pos_emb is not None:
481
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
482
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
483
+
484
+ # adjust key and value for inference
485
+ if kv_cache is not None:
486
+ cache_k, cache_v = kv_cache
487
+ key_layer = torch.cat((cache_k, key_layer), dim=2)
488
+ value_layer = torch.cat((cache_v, value_layer), dim=2)
489
+ if use_cache:
490
+ if kv_cache is None:
491
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
492
+ dim=1)
493
+ else:
494
+ kv_cache = (key_layer, value_layer)
495
+ else:
496
+ kv_cache = None
497
+
498
+ if self.multi_query_attention:
499
+ key_layer = key_layer.unsqueeze(2)
500
+ key_layer = key_layer.expand(
501
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
502
+ )
503
+ key_layer = key_layer.contiguous().view(
504
+ key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
505
+ )
506
+ value_layer = value_layer.unsqueeze(2)
507
+ value_layer = value_layer.expand(
508
+ -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
509
+ )
510
+ value_layer = value_layer.contiguous().view(
511
+ value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
512
+ )
513
+
514
+ # ==================================
515
+ # core attention computation
516
+ # ==================================
517
+
518
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
519
+
520
+ # =================
521
+ # Output. [sq, b, h]
522
+ # =================
523
+
524
+ output = self.dense(context_layer)
525
+
526
+ return output, kv_cache
527
+
528
+
529
+ def _config_to_kwargs(args):
530
+ common_kwargs = {
531
+ "dtype": args.torch_dtype,
532
+ }
533
+ return common_kwargs
534
+
535
+
536
+ class MLP(torch.nn.Module):
537
+ """MLP.
538
+ MLP will take the input with h hidden state, project it to 4*h
539
+ hidden dimension, perform nonlinear transformation, and project the
540
+ state back into h hidden dimension.
541
+ """
542
+
543
+ def __init__(self, config: ChatGLMConfig, device=None):
544
+ super(MLP, self).__init__()
545
+
546
+ self.add_bias = config.add_bias_linear
547
+
548
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
549
+ self.dense_h_to_4h = nn.Linear(
550
+ config.hidden_size,
551
+ config.ffn_hidden_size * 2,
552
+ bias=self.add_bias,
553
+ device=device,
554
+ **_config_to_kwargs(config)
555
+ )
556
+
557
+ def swiglu(x):
558
+ x = torch.chunk(x, 2, dim=-1)
559
+ return F.silu(x[0]) * x[1]
560
+
561
+ self.activation_func = swiglu
562
+
563
+ # Project back to h.
564
+ self.dense_4h_to_h = nn.Linear(
565
+ config.ffn_hidden_size,
566
+ config.hidden_size,
567
+ bias=self.add_bias,
568
+ device=device,
569
+ **_config_to_kwargs(config)
570
+ )
571
+
572
+ def forward(self, hidden_states):
573
+ # [s, b, 4hp]
574
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
575
+ intermediate_parallel = self.activation_func(intermediate_parallel)
576
+ # [s, b, h]
577
+ output = self.dense_4h_to_h(intermediate_parallel)
578
+ return output
579
+
580
+
581
+ class GLMBlock(torch.nn.Module):
582
+ """A single transformer layer.
583
+ Transformer layer takes input with size [s, b, h] and returns an
584
+ output of the same size.
585
+ """
586
+
587
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
588
+ super(GLMBlock, self).__init__()
589
+ self.layer_number = layer_number
590
+
591
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
592
+
593
+ self.fp32_residual_connection = config.fp32_residual_connection
594
+
595
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
596
+ # Layernorm on the input data.
597
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
598
+ dtype=config.torch_dtype)
599
+
600
+ # Self attention.
601
+ self.self_attention = SelfAttention(config, layer_number, device=device)
602
+ self.hidden_dropout = config.hidden_dropout
603
+
604
+ # Layernorm on the attention output
605
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
606
+ dtype=config.torch_dtype)
607
+
608
+ # MLP
609
+ self.mlp = MLP(config, device=device)
610
+
611
+ def forward(
612
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
613
+ ):
614
+ # hidden_states: [s, b, h]
615
+
616
+ # Layer norm at the beginning of the transformer layer.
617
+ layernorm_output = self.input_layernorm(hidden_states)
618
+ # Self attention.
619
+ attention_output, kv_cache = self.self_attention(
620
+ layernorm_output,
621
+ attention_mask,
622
+ rotary_pos_emb,
623
+ kv_cache=kv_cache,
624
+ use_cache=use_cache
625
+ )
626
+
627
+ # Residual connection.
628
+ if self.apply_residual_connection_post_layernorm:
629
+ residual = layernorm_output
630
+ else:
631
+ residual = hidden_states
632
+
633
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
634
+ layernorm_input = residual + layernorm_input
635
+
636
+ # Layer norm post the self attention.
637
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
638
+
639
+ # MLP.
640
+ mlp_output = self.mlp(layernorm_output)
641
+
642
+ # Second residual connection.
643
+ if self.apply_residual_connection_post_layernorm:
644
+ residual = layernorm_output
645
+ else:
646
+ residual = layernorm_input
647
+
648
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
649
+ output = residual + output
650
+
651
+ return output, kv_cache
652
+
653
+
654
+ class GLMTransformer(torch.nn.Module):
655
+ """Transformer class."""
656
+
657
+ def __init__(self, config: ChatGLMConfig, device=None):
658
+ super(GLMTransformer, self).__init__()
659
+
660
+ self.fp32_residual_connection = config.fp32_residual_connection
661
+ self.post_layer_norm = config.post_layer_norm
662
+
663
+ # Number of layers.
664
+ self.num_layers = config.num_layers
665
+
666
+ # Transformer layers.
667
+ def build_layer(layer_number):
668
+ return GLMBlock(config, layer_number, device=device)
669
+
670
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
671
+
672
+ if self.post_layer_norm:
673
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
674
+ # Final layer norm before output.
675
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
676
+ dtype=config.torch_dtype)
677
+
678
+ self.gradient_checkpointing = False
679
+
680
+ def _get_layer(self, layer_number):
681
+ return self.layers[layer_number]
682
+
683
+ def forward(
684
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
685
+ use_cache: Optional[bool] = True,
686
+ output_hidden_states: Optional[bool] = False,
687
+ ):
688
+ if not kv_caches:
689
+ kv_caches = [None for _ in range(self.num_layers)]
690
+ presents = () if use_cache else None
691
+ if self.gradient_checkpointing and self.training:
692
+ if use_cache:
693
+ logger.warning_once(
694
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
695
+ )
696
+ use_cache = False
697
+
698
+ all_self_attentions = None
699
+ all_hidden_states = () if output_hidden_states else None
700
+ for index in range(self.num_layers):
701
+ if output_hidden_states:
702
+ all_hidden_states = all_hidden_states + (hidden_states,)
703
+
704
+ layer = self._get_layer(index)
705
+ if self.gradient_checkpointing and self.training:
706
+ layer_ret = torch.utils.checkpoint.checkpoint(
707
+ layer,
708
+ hidden_states,
709
+ attention_mask,
710
+ rotary_pos_emb,
711
+ kv_caches[index],
712
+ use_cache,
713
+ use_reentrant=False
714
+ )
715
+ else:
716
+ layer_ret = layer(
717
+ hidden_states,
718
+ attention_mask,
719
+ rotary_pos_emb,
720
+ kv_cache=kv_caches[index],
721
+ use_cache=use_cache
722
+ )
723
+ hidden_states, kv_cache = layer_ret
724
+ if use_cache:
725
+ # token by token decoding, use tuple format
726
+ if kv_caches[0] is not None:
727
+ presents = presents + (kv_cache,)
728
+ # prefilling in decoding, use tensor format to save cuda memory
729
+ else:
730
+ if len(presents) == 0:
731
+ presents = kv_cache
732
+ else:
733
+ presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
734
+
735
+ if output_hidden_states:
736
+ all_hidden_states = all_hidden_states + (hidden_states,)
737
+
738
+ # Final layer norm.
739
+ if self.post_layer_norm:
740
+ hidden_states = self.final_layernorm(hidden_states)
741
+
742
+ return hidden_states, presents, all_hidden_states, all_self_attentions
743
+
744
+
745
+ class ChatGLMPreTrainedModel(PreTrainedModel):
746
+ """
747
+ An abstract class to handle weights initialization and
748
+ a simple interface for downloading and loading pretrained models.
749
+ """
750
+
751
+ is_parallelizable = False
752
+ supports_gradient_checkpointing = True
753
+ config_class = ChatGLMConfig
754
+ base_model_prefix = "transformer"
755
+ _no_split_modules = ["GLMBlock"]
756
+ _supports_flash_attn_2 = True
757
+ _supports_sdpa = True
758
+
759
+ def _init_weights(self, module: nn.Module):
760
+ """Initialize the weights."""
761
+ return
762
+
763
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
764
+ if self.config._attn_implementation == "flash_attention_2":
765
+ if padding_mask is not None and not padding_mask.all():
766
+ return padding_mask
767
+ return None
768
+ batch_size, seq_length = input_ids.shape
769
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
770
+ full_attention_mask.tril_()
771
+ past_length = 0
772
+ if past_key_values:
773
+ past_length = past_key_values[0][0].shape[2]
774
+ if past_length:
775
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
776
+ device=input_ids.device), full_attention_mask), dim=-1)
777
+ if padding_mask is not None:
778
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
779
+ if not past_length and padding_mask is not None:
780
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
781
+ full_attention_mask = (full_attention_mask < 0.5).bool()
782
+ full_attention_mask.unsqueeze_(1)
783
+ return full_attention_mask
784
+
785
+ def get_position_ids(self, input_ids, device):
786
+ batch_size, seq_length = input_ids.shape
787
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
788
+ return position_ids
789
+
790
+ class Embedding(torch.nn.Module):
791
+ """Language model embeddings."""
792
+
793
+ def __init__(self, config: ChatGLMConfig, device=None):
794
+ super(Embedding, self).__init__()
795
+
796
+ self.hidden_size = config.hidden_size
797
+ # Word embeddings (parallel).
798
+ self.word_embeddings = nn.Embedding(
799
+ config.padded_vocab_size,
800
+ self.hidden_size,
801
+ dtype=config.torch_dtype,
802
+ device=device
803
+ )
804
+ self.fp32_residual_connection = config.fp32_residual_connection
805
+
806
+ def forward(self, input_ids):
807
+ # Embeddings.
808
+ words_embeddings = self.word_embeddings(input_ids)
809
+ embeddings = words_embeddings
810
+ # If the input flag for fp32 residual connection is set, convert for float.
811
+ if self.fp32_residual_connection:
812
+ embeddings = embeddings.float()
813
+ return embeddings
814
+
815
+
816
+ class ChatGLMModel(ChatGLMPreTrainedModel):
817
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
818
+ super().__init__(config)
819
+ if empty_init:
820
+ init_method = skip_init
821
+ else:
822
+ init_method = default_init
823
+ init_kwargs = {}
824
+ if device is not None:
825
+ init_kwargs["device"] = device
826
+ self.embedding = init_method(Embedding, config, **init_kwargs)
827
+ self.num_layers = config.num_layers
828
+ self.multi_query_group_num = config.multi_query_group_num
829
+ self.kv_channels = config.kv_channels
830
+
831
+ # Rotary positional embeddings
832
+ self.seq_length = config.seq_length
833
+ rotary_dim = (
834
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
835
+ )
836
+
837
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
838
+ original_impl=config.original_rope,
839
+ device=device, dtype=config.torch_dtype)
840
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
841
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
842
+ dtype=config.torch_dtype, **init_kwargs)
843
+
844
+ def get_input_embeddings(self):
845
+ return self.embedding.word_embeddings
846
+
847
+ def set_input_embeddings(self, value):
848
+ self.embedding.word_embeddings = value
849
+
850
+ def forward(
851
+ self,
852
+ input_ids,
853
+ position_ids: Optional[torch.Tensor] = None,
854
+ attention_mask: Optional[torch.BoolTensor] = None,
855
+ full_attention_mask: Optional[torch.BoolTensor] = None,
856
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
857
+ inputs_embeds: Optional[torch.Tensor] = None,
858
+ use_cache: Optional[bool] = None,
859
+ output_attentions: Optional[bool] = None,
860
+ output_hidden_states: Optional[bool] = None,
861
+ return_dict: Optional[bool] = None,
862
+ ):
863
+ output_hidden_states = (
864
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
865
+ )
866
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
867
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
868
+
869
+ batch_size, seq_length = input_ids.shape
870
+
871
+ if inputs_embeds is None:
872
+ inputs_embeds = self.embedding(input_ids)
873
+
874
+ if full_attention_mask is None:
875
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
876
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
877
+
878
+ # Rotary positional embeddings
879
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
880
+ if position_ids is not None:
881
+ rotary_pos_emb = rotary_pos_emb[position_ids]
882
+ else:
883
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
884
+
885
+ # Run encoder.
886
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
887
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
888
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
889
+ )
890
+ if presents is not None and type(presents) is torch.Tensor:
891
+ presents = presents.split(1, dim=0)
892
+ presents = list(presents)
893
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
894
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
895
+ presents = tuple(presents)
896
+
897
+ if not return_dict:
898
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
899
+
900
+ return BaseModelOutputWithPast(
901
+ last_hidden_state=hidden_states,
902
+ past_key_values=presents,
903
+ hidden_states=all_hidden_states,
904
+ attentions=all_self_attentions,
905
+ )
906
+
907
+
908
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
909
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
910
+ super().__init__(config)
911
+
912
+ self.max_sequence_length = config.max_length
913
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
914
+ self.config = config
915
+
916
+ def _update_model_kwargs_for_generation(
917
+ self,
918
+ outputs: ModelOutput,
919
+ model_kwargs: Dict[str, Any],
920
+ is_encoder_decoder: bool = False,
921
+ ) -> Dict[str, Any]:
922
+ # update past_key_values
923
+ cache_name, cache = self._extract_past_from_model_output(outputs)
924
+ model_kwargs[cache_name] = cache
925
+
926
+ # update attention mask
927
+ if "attention_mask" in model_kwargs:
928
+ attention_mask = model_kwargs["attention_mask"]
929
+ model_kwargs["attention_mask"] = torch.cat(
930
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
931
+ )
932
+
933
+ # update position ids
934
+ if "position_ids" in model_kwargs:
935
+ position_ids = model_kwargs["position_ids"]
936
+ new_position_id = position_ids[..., -1:].clone()
937
+ new_position_id += 1
938
+ model_kwargs["position_ids"] = torch.cat(
939
+ [position_ids, new_position_id], dim=-1
940
+ )
941
+
942
+ model_kwargs["is_first_forward"] = False
943
+ return model_kwargs
944
+
945
+ def prepare_inputs_for_generation(
946
+ self,
947
+ input_ids: torch.LongTensor,
948
+ past_key_values: Optional[torch.Tensor] = None,
949
+ attention_mask: Optional[torch.Tensor] = None,
950
+ position_ids: Optional[torch.Tensor] = None,
951
+ use_cache: Optional[bool] = None,
952
+ is_first_forward: bool = True,
953
+ **kwargs
954
+ ) -> dict:
955
+ # only last token for input_ids if past is not None
956
+ if position_ids is None:
957
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
958
+ if not is_first_forward:
959
+ if past_key_values is not None:
960
+ position_ids = position_ids[..., -1:]
961
+ input_ids = input_ids[:, -1:]
962
+ return {
963
+ "input_ids": input_ids,
964
+ "past_key_values": past_key_values,
965
+ "position_ids": position_ids,
966
+ "attention_mask": attention_mask,
967
+ "return_last_logit": True,
968
+ "use_cache": use_cache
969
+ }
970
+
971
+ def forward(
972
+ self,
973
+ input_ids: Optional[torch.Tensor] = None,
974
+ position_ids: Optional[torch.Tensor] = None,
975
+ attention_mask: Optional[torch.Tensor] = None,
976
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
977
+ inputs_embeds: Optional[torch.Tensor] = None,
978
+ labels: Optional[torch.Tensor] = None,
979
+ use_cache: Optional[bool] = None,
980
+ output_attentions: Optional[bool] = None,
981
+ output_hidden_states: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ return_last_logit: Optional[bool] = False,
984
+ ):
985
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
986
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
987
+
988
+ transformer_outputs = self.transformer(
989
+ input_ids=input_ids,
990
+ position_ids=position_ids,
991
+ attention_mask=attention_mask,
992
+ past_key_values=past_key_values,
993
+ inputs_embeds=inputs_embeds,
994
+ use_cache=use_cache,
995
+ output_hidden_states=output_hidden_states,
996
+ return_dict=return_dict,
997
+ )
998
+
999
+ hidden_states = transformer_outputs[0]
1000
+ if return_last_logit:
1001
+ hidden_states = hidden_states[:, -1:]
1002
+ lm_logits = self.transformer.output_layer(hidden_states)
1003
+
1004
+ loss = None
1005
+ if labels is not None:
1006
+ lm_logits = lm_logits.to(torch.float32)
1007
+
1008
+ # Shift so that tokens < n predict n
1009
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1010
+ shift_labels = labels[..., 1:].contiguous()
1011
+ # Flatten the tokens
1012
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1013
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1014
+
1015
+ lm_logits = lm_logits.to(hidden_states.dtype)
1016
+ loss = loss.to(hidden_states.dtype)
1017
+
1018
+ if not return_dict:
1019
+ output = (lm_logits,) + transformer_outputs[1:]
1020
+ return ((loss,) + output) if loss is not None else output
1021
+
1022
+ return CausalLMOutputWithPast(
1023
+ loss=loss,
1024
+ logits=lm_logits,
1025
+ past_key_values=transformer_outputs.past_key_values,
1026
+ hidden_states=transformer_outputs.hidden_states,
1027
+ attentions=transformer_outputs.attentions,
1028
+ )
1029
+
1030
+ @staticmethod
1031
+ def _reorder_cache(
1032
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1033
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1034
+ """
1035
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1036
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1037
+ beam_idx at every generation step.
1038
+ Output shares the same memory storage as `past`.
1039
+ """
1040
+ return tuple(
1041
+ (
1042
+ layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
1043
+ layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
1044
+ )
1045
+ for layer_past in past
1046
+ )
1047
+
1048
+ @torch.inference_mode()
1049
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1050
+ max_length: int = 65536, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95,
1051
+ **kwargs):
1052
+ if history is None:
1053
+ history = []
1054
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1055
+ "temperature": temperature, **kwargs}
1056
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1057
+ inputs = inputs.to(self.device)
1058
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1059
+ tokenizer.get_command("<|observation|>")]
1060
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1061
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1062
+ response = tokenizer.decode(outputs).strip()
1063
+ history.append({"role": role, "content": query})
1064
+ return response, history
tokenization_chatglm.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ import base64
3
+ import os
4
+ import json
5
+ import tiktoken
6
+ from transformers import PreTrainedTokenizer
7
+ from typing import List, Optional, Union, Dict
8
+ from transformers import PreTrainedTokenizer
9
+ from transformers.utils import logging, PaddingStrategy
10
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
11
+
12
+
13
+ class ChatGLM4Tokenizer(PreTrainedTokenizer):
14
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
15
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
16
+
17
+ def __init__(
18
+ self,
19
+ vocab_file,
20
+ padding_side="left",
21
+ clean_up_tokenization_spaces=False,
22
+ encode_special_tokens=False,
23
+ **kwargs
24
+ ):
25
+ self.name = "GLMTokenizer"
26
+ self.vocab_file = vocab_file
27
+ pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
28
+ self.pat_str = re.compile(pat_str)
29
+ self.encode_special_tokens = encode_special_tokens
30
+
31
+ mergeable_ranks = {}
32
+ with open(vocab_file) as f:
33
+ for line in f:
34
+ token, rank = line.strip().split()
35
+ rank = int(rank)
36
+ token = base64.b64decode(token)
37
+ mergeable_ranks[token] = rank
38
+
39
+ self.mergeable_ranks = mergeable_ranks
40
+ self.special_tokens = ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>",
41
+ "<|user|>", "<|assistant|>", "<|observation|>", "<|begin_of_image|>", "<|end_of_image|>",
42
+ "<|begin_of_video|>", "<|end_of_video|>"]
43
+
44
+ self.special_tokens = {
45
+ token: idx for idx, token in enumerate(self.special_tokens, start=len(mergeable_ranks))
46
+ }
47
+ self.special_token_ids = {idx: token for token, idx in self.special_tokens.items()}
48
+
49
+ self.tokenizer = tiktoken.Encoding(
50
+ name="my_tokenizer",
51
+ pat_str=pat_str,
52
+ mergeable_ranks=mergeable_ranks,
53
+ special_tokens=self.special_tokens
54
+ )
55
+ self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
56
+ self.n_words = len(self.decoder) + len(self.special_tokens)
57
+
58
+ super().__init__(
59
+ padding_side=padding_side,
60
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
61
+ **kwargs
62
+ )
63
+
64
+ def get_command(self, token):
65
+ assert token in self.special_tokens
66
+ return self.special_tokens[token]
67
+
68
+ @property
69
+ def vocab_size(self):
70
+ return self.n_words
71
+
72
+ @property
73
+ def eos_token_id(self):
74
+ return self.get_command("<|endoftext|>")
75
+
76
+ def get_vocab(self):
77
+ """ Returns vocab as a dict """
78
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
79
+ vocab.update(self.added_tokens_encoder)
80
+ return vocab
81
+
82
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
83
+ """
84
+ Converts a sequence of tokens in a single string.
85
+ """
86
+ text = ""
87
+ temp = b""
88
+ for t in tokens:
89
+ if isinstance(t, str):
90
+ if temp:
91
+ text += temp.decode("utf-8", errors="replace")
92
+ temp = b""
93
+ text += t
94
+ elif isinstance(t, bytes):
95
+ temp += t
96
+ else:
97
+ raise TypeError("token should only be of type types or str")
98
+ if temp:
99
+ text += temp.decode("utf-8", errors="replace")
100
+ return text
101
+
102
+ def _tokenize(self, text, **kwargs):
103
+ tokens = []
104
+ if self.encode_special_tokens:
105
+ ids = self.tokenizer.encode(text, allowed_special="all")
106
+ else:
107
+ ids = self.tokenizer.encode(text, disallowed_special=())
108
+ for t in ids:
109
+ tokens.append(self.decoder[t])
110
+ return tokens
111
+
112
+ def _convert_token_to_id(self, token):
113
+ """ Converts a token (str) in an id using the vocab. """
114
+ if token in self.special_tokens:
115
+ return self.special_tokens[token]
116
+ return self.mergeable_ranks[token]
117
+
118
+ def _convert_id_to_token(self, index):
119
+ """Converts an index (integer) in a token (str) using the vocab."""
120
+ if index in self.special_token_ids:
121
+ return self.special_token_ids[index]
122
+ return self.decoder[index]
123
+
124
+ def save_vocabulary(self, save_directory, filename_prefix=None):
125
+ """
126
+ Save the vocabulary and special tokens file to a directory.
127
+
128
+ Args:
129
+ save_directory (`str`):
130
+ The directory in which to save the vocabulary.
131
+ filename_prefix (`str`, *optional*):
132
+ An optional prefix to add to the named of the saved files.
133
+
134
+ Returns:
135
+ `Tuple(str)`: Paths to the files saved.
136
+ """
137
+ if os.path.isdir(save_directory):
138
+ vocab_file = os.path.join(
139
+ save_directory, self.vocab_files_names["vocab_file"]
140
+ )
141
+ else:
142
+ vocab_file = save_directory
143
+
144
+ with open(self.vocab_file, 'rb') as fin:
145
+ proto_str = fin.read()
146
+
147
+ with open(vocab_file, "wb") as writer:
148
+ writer.write(proto_str)
149
+
150
+ return (vocab_file,)
151
+
152
+ def get_prefix_tokens(self):
153
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("<sop>")]
154
+ return prefix_tokens
155
+
156
+ def build_single_message(self, role, metadata, message):
157
+ assert role in ["system", "user", "assistant", "observation"], role
158
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
159
+ message_tokens = self.tokenizer.encode(message, disallowed_special=())
160
+ tokens = role_tokens + message_tokens
161
+ return tokens
162
+
163
+ def build_chat_input(self, query, history=None, role="user"):
164
+ if history is None:
165
+ history = []
166
+ input_ids = []
167
+ for item in history:
168
+ content = item["content"]
169
+ if item["role"] == "system" and "tools" in item:
170
+ for function in item["tools"]:
171
+ content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
172
+ content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
173
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
174
+ input_ids.extend(self.build_single_message(role, "", query))
175
+ input_ids.extend([self.get_command("<|assistant|>")])
176
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
177
+
178
+ def build_inputs_with_special_tokens(
179
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
180
+ ) -> List[int]:
181
+ """
182
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
183
+ adding special tokens. A BERT sequence has the following format:
184
+
185
+ - single sequence: `[CLS] X [SEP]`
186
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
187
+
188
+ Args:
189
+ token_ids_0 (`List[int]`):
190
+ List of IDs to which the special tokens will be added.
191
+ token_ids_1 (`List[int]`, *optional*):
192
+ Optional second list of IDs for sequence pairs.
193
+
194
+ Returns:
195
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
196
+ """
197
+ prefix_tokens = self.get_prefix_tokens()
198
+ token_ids_0 = prefix_tokens + token_ids_0
199
+ if token_ids_1 is not None:
200
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
201
+ return token_ids_0
202
+
203
+ def _pad(
204
+ self,
205
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
206
+ max_length: Optional[int] = None,
207
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
208
+ pad_to_multiple_of: Optional[int] = None,
209
+ return_attention_mask: Optional[bool] = None,
210
+ ) -> dict:
211
+ """
212
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
213
+
214
+ Args:
215
+ encoded_inputs:
216
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
217
+ max_length: maximum length of the returned list and optionally padding length (see below).
218
+ Will truncate by taking into account the special tokens.
219
+ padding_strategy: PaddingStrategy to use for padding.
220
+
221
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
222
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
223
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
224
+ The tokenizer padding sides are defined in self.padding_side:
225
+
226
+ - 'left': pads on the left of the sequences
227
+ - 'right': pads on the right of the sequences
228
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
229
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
230
+ `>= 7.5` (Volta).
231
+ return_attention_mask:
232
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
233
+ """
234
+ # Load from model defaults
235
+ assert self.padding_side == "left"
236
+
237
+ required_input = encoded_inputs[self.model_input_names[0]]
238
+ seq_length = len(required_input)
239
+
240
+ if padding_strategy == PaddingStrategy.LONGEST:
241
+ max_length = len(required_input)
242
+
243
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
244
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
245
+
246
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
247
+
248
+ # Initialize attention mask if not present.
249
+ if "attention_mask" not in encoded_inputs:
250
+ encoded_inputs["attention_mask"] = [1] * seq_length
251
+
252
+ if "position_ids" not in encoded_inputs:
253
+ encoded_inputs["position_ids"] = list(range(seq_length))
254
+
255
+ if needs_to_be_padded:
256
+ difference = max_length - len(required_input)
257
+
258
+ if "attention_mask" in encoded_inputs:
259
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
260
+ if "position_ids" in encoded_inputs:
261
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
262
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
263
+
264
+ return encoded_inputs
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a493598071550244b2ee7f26118f3edec2150b9dfa967929a99052ac83fe716
3
+ size 2623634
tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "151329": {
4
+ "content": "<|endoftext|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ }
11
+ },
12
+ "auto_map": {
13
+ "AutoTokenizer": [
14
+ "tokenization_chatglm.ChatGLM4Tokenizer",
15
+ null
16
+ ]
17
+ },
18
+ "chat_template": "{% for message in messages %}{% if loop.first %}[gMASK]<sop><|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
19
+ "clean_up_tokenization_spaces": false,
20
+ "do_lower_case": false,
21
+ "eos_token": "<|endoftext|>",
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "padding_side": "left",
24
+ "remove_space": false,
25
+ "tokenizer_class": "ChatGLM4Tokenizer"
26
+ }