HF transformers integration
#28
by
ybelkada
- opened
- config.json +23 -24
- generation_config.json +6 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +205 -206
- modeling_chatglm.py +88 -27
config.json
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
{
|
2 |
"_name_or_path": "THUDM/chatglm3-6b",
|
3 |
-
"
|
4 |
"architectures": [
|
5 |
-
"
|
6 |
],
|
|
|
|
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
9 |
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
@@ -11,32 +13,29 @@
|
|
11 |
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
12 |
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
13 |
},
|
14 |
-
"
|
15 |
-
"
|
16 |
-
"
|
17 |
-
"apply_residual_connection_post_layernorm": false,
|
18 |
-
"attention_dropout": 0.0,
|
19 |
-
"attention_softmax_in_fp32": true,
|
20 |
-
"bias_dropout_fusion": true,
|
21 |
-
"ffn_hidden_size": 13696,
|
22 |
-
"fp32_residual_connection": false,
|
23 |
-
"hidden_dropout": 0.0,
|
24 |
"hidden_size": 4096,
|
|
|
|
|
25 |
"kv_channels": 128,
|
26 |
-
"
|
|
|
|
|
27 |
"multi_query_attention": true,
|
28 |
"multi_query_group_num": 2,
|
29 |
"num_attention_heads": 32,
|
30 |
-
"
|
|
|
31 |
"original_rope": true,
|
32 |
-
"
|
33 |
-
"
|
34 |
-
"
|
35 |
-
"
|
36 |
-
"use_cache": true,
|
37 |
-
"torch_dtype": "float16",
|
38 |
-
"transformers_version": "4.30.2",
|
39 |
"tie_word_embeddings": false,
|
40 |
-
"
|
41 |
-
"
|
42 |
-
|
|
|
|
|
|
1 |
{
|
2 |
"_name_or_path": "THUDM/chatglm3-6b",
|
3 |
+
"apply_query_key_layer_scaling": true,
|
4 |
"architectures": [
|
5 |
+
"ChatGlmForCausalLM"
|
6 |
],
|
7 |
+
"attention_bias": false,
|
8 |
+
"attention_dropout": 0.0,
|
9 |
"auto_map": {
|
10 |
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
11 |
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
|
|
13 |
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
14 |
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
15 |
},
|
16 |
+
"bos_token_id": 1,
|
17 |
+
"eos_token_id": 2,
|
18 |
+
"hidden_act": "silu",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"hidden_size": 4096,
|
20 |
+
"initializer_range": 0.02,
|
21 |
+
"intermediate_size": 13696,
|
22 |
"kv_channels": 128,
|
23 |
+
"max_position_embeddings": 2048,
|
24 |
+
"mlp_bias": false,
|
25 |
+
"model_type": "chatglm",
|
26 |
"multi_query_attention": true,
|
27 |
"multi_query_group_num": 2,
|
28 |
"num_attention_heads": 32,
|
29 |
+
"num_hidden_layers": 28,
|
30 |
+
"num_key_value_heads": 32,
|
31 |
"original_rope": true,
|
32 |
+
"partial_rotary_factor": 0.5,
|
33 |
+
"rms_norm_eps": 1e-05,
|
34 |
+
"rope_scaling": null,
|
35 |
+
"rope_theta": 10000.0,
|
|
|
|
|
|
|
36 |
"tie_word_embeddings": false,
|
37 |
+
"torch_dtype": "float16",
|
38 |
+
"transformers_version": "4.37.0.dev0",
|
39 |
+
"use_cache": true,
|
40 |
+
"vocab_size": 65024
|
41 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"transformers_version": "4.37.0.dev0"
|
6 |
+
}
|
model-00001-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3684a0fd31c8a00e061d1242bcf0faadb34a9d0c70fb64d6ab40c703337e1cbe
|
3 |
+
size 4907609888
|
model-00002-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39afb798e4edce6865d67d65a32aa6fd9b47f545937aed82f16837146bc6bc59
|
3 |
+
size 4895070096
|
model-00003-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ffa10cbc5159b8962f44210eb4ff6d180bfe867ae6f0c88615b0e4f58a0b0158
|
3 |
+
size 2684511912
|
model.safetensors.index.json
CHANGED
@@ -1,207 +1,206 @@
|
|
1 |
{
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
}
|
|
|
1 |
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 12487168000
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"lm_head.weight": "model-00003-of-00003.safetensors",
|
7 |
+
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
8 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
9 |
+
"model.layers.0.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
10 |
+
"model.layers.0.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
11 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
12 |
+
"model.layers.0.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
13 |
+
"model.layers.0.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
14 |
+
"model.layers.0.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
15 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
16 |
+
"model.layers.1.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
17 |
+
"model.layers.1.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
18 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
19 |
+
"model.layers.1.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
20 |
+
"model.layers.1.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
21 |
+
"model.layers.1.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
22 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
23 |
+
"model.layers.10.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
24 |
+
"model.layers.10.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
25 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
26 |
+
"model.layers.10.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
27 |
+
"model.layers.10.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
28 |
+
"model.layers.10.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
29 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
30 |
+
"model.layers.11.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
31 |
+
"model.layers.11.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
32 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
33 |
+
"model.layers.11.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
34 |
+
"model.layers.11.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
35 |
+
"model.layers.11.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
36 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
37 |
+
"model.layers.12.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
38 |
+
"model.layers.12.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
39 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
40 |
+
"model.layers.12.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
41 |
+
"model.layers.12.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
42 |
+
"model.layers.12.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
43 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
44 |
+
"model.layers.13.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
45 |
+
"model.layers.13.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
46 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
47 |
+
"model.layers.13.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
48 |
+
"model.layers.13.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
49 |
+
"model.layers.13.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
50 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
51 |
+
"model.layers.14.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
52 |
+
"model.layers.14.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
53 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
54 |
+
"model.layers.14.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
55 |
+
"model.layers.14.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
56 |
+
"model.layers.14.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
57 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
58 |
+
"model.layers.15.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
59 |
+
"model.layers.15.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
60 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
61 |
+
"model.layers.15.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
62 |
+
"model.layers.15.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
63 |
+
"model.layers.15.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
64 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
65 |
+
"model.layers.16.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
66 |
+
"model.layers.16.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
67 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
68 |
+
"model.layers.16.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
69 |
+
"model.layers.16.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
70 |
+
"model.layers.16.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
71 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
72 |
+
"model.layers.17.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
73 |
+
"model.layers.17.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
74 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
75 |
+
"model.layers.17.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
76 |
+
"model.layers.17.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
77 |
+
"model.layers.17.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
78 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
79 |
+
"model.layers.18.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
80 |
+
"model.layers.18.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
81 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
82 |
+
"model.layers.18.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
83 |
+
"model.layers.18.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
84 |
+
"model.layers.18.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
85 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
86 |
+
"model.layers.19.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
87 |
+
"model.layers.19.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
88 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
89 |
+
"model.layers.19.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
90 |
+
"model.layers.19.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
91 |
+
"model.layers.19.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
92 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
93 |
+
"model.layers.2.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
94 |
+
"model.layers.2.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
95 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
96 |
+
"model.layers.2.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
97 |
+
"model.layers.2.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
98 |
+
"model.layers.2.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
99 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
100 |
+
"model.layers.20.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
101 |
+
"model.layers.20.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
102 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
103 |
+
"model.layers.20.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
104 |
+
"model.layers.20.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
105 |
+
"model.layers.20.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
106 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
107 |
+
"model.layers.21.mlp.dense_4h_to_h.weight": "model-00002-of-00003.safetensors",
|
108 |
+
"model.layers.21.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
109 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
110 |
+
"model.layers.21.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
111 |
+
"model.layers.21.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
112 |
+
"model.layers.21.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
113 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
114 |
+
"model.layers.22.mlp.dense_4h_to_h.weight": "model-00003-of-00003.safetensors",
|
115 |
+
"model.layers.22.mlp.dense_h_to_4h.weight": "model-00002-of-00003.safetensors",
|
116 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
117 |
+
"model.layers.22.self_attention.dense.weight": "model-00002-of-00003.safetensors",
|
118 |
+
"model.layers.22.self_attention.query_key_value.bias": "model-00002-of-00003.safetensors",
|
119 |
+
"model.layers.22.self_attention.query_key_value.weight": "model-00002-of-00003.safetensors",
|
120 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
121 |
+
"model.layers.23.mlp.dense_4h_to_h.weight": "model-00003-of-00003.safetensors",
|
122 |
+
"model.layers.23.mlp.dense_h_to_4h.weight": "model-00003-of-00003.safetensors",
|
123 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
124 |
+
"model.layers.23.self_attention.dense.weight": "model-00003-of-00003.safetensors",
|
125 |
+
"model.layers.23.self_attention.query_key_value.bias": "model-00003-of-00003.safetensors",
|
126 |
+
"model.layers.23.self_attention.query_key_value.weight": "model-00003-of-00003.safetensors",
|
127 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
128 |
+
"model.layers.24.mlp.dense_4h_to_h.weight": "model-00003-of-00003.safetensors",
|
129 |
+
"model.layers.24.mlp.dense_h_to_4h.weight": "model-00003-of-00003.safetensors",
|
130 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
131 |
+
"model.layers.24.self_attention.dense.weight": "model-00003-of-00003.safetensors",
|
132 |
+
"model.layers.24.self_attention.query_key_value.bias": "model-00003-of-00003.safetensors",
|
133 |
+
"model.layers.24.self_attention.query_key_value.weight": "model-00003-of-00003.safetensors",
|
134 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
135 |
+
"model.layers.25.mlp.dense_4h_to_h.weight": "model-00003-of-00003.safetensors",
|
136 |
+
"model.layers.25.mlp.dense_h_to_4h.weight": "model-00003-of-00003.safetensors",
|
137 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
138 |
+
"model.layers.25.self_attention.dense.weight": "model-00003-of-00003.safetensors",
|
139 |
+
"model.layers.25.self_attention.query_key_value.bias": "model-00003-of-00003.safetensors",
|
140 |
+
"model.layers.25.self_attention.query_key_value.weight": "model-00003-of-00003.safetensors",
|
141 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
142 |
+
"model.layers.26.mlp.dense_4h_to_h.weight": "model-00003-of-00003.safetensors",
|
143 |
+
"model.layers.26.mlp.dense_h_to_4h.weight": "model-00003-of-00003.safetensors",
|
144 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
145 |
+
"model.layers.26.self_attention.dense.weight": "model-00003-of-00003.safetensors",
|
146 |
+
"model.layers.26.self_attention.query_key_value.bias": "model-00003-of-00003.safetensors",
|
147 |
+
"model.layers.26.self_attention.query_key_value.weight": "model-00003-of-00003.safetensors",
|
148 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
149 |
+
"model.layers.27.mlp.dense_4h_to_h.weight": "model-00003-of-00003.safetensors",
|
150 |
+
"model.layers.27.mlp.dense_h_to_4h.weight": "model-00003-of-00003.safetensors",
|
151 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
152 |
+
"model.layers.27.self_attention.dense.weight": "model-00003-of-00003.safetensors",
|
153 |
+
"model.layers.27.self_attention.query_key_value.bias": "model-00003-of-00003.safetensors",
|
154 |
+
"model.layers.27.self_attention.query_key_value.weight": "model-00003-of-00003.safetensors",
|
155 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
156 |
+
"model.layers.3.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
157 |
+
"model.layers.3.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
158 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
159 |
+
"model.layers.3.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
160 |
+
"model.layers.3.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
161 |
+
"model.layers.3.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
162 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
163 |
+
"model.layers.4.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
164 |
+
"model.layers.4.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
165 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
166 |
+
"model.layers.4.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
167 |
+
"model.layers.4.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
168 |
+
"model.layers.4.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
169 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
170 |
+
"model.layers.5.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
171 |
+
"model.layers.5.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
172 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
173 |
+
"model.layers.5.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
174 |
+
"model.layers.5.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
175 |
+
"model.layers.5.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
176 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
177 |
+
"model.layers.6.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
178 |
+
"model.layers.6.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
179 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
180 |
+
"model.layers.6.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
181 |
+
"model.layers.6.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
182 |
+
"model.layers.6.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
183 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
184 |
+
"model.layers.7.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
185 |
+
"model.layers.7.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
186 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
187 |
+
"model.layers.7.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
188 |
+
"model.layers.7.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
189 |
+
"model.layers.7.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
190 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
191 |
+
"model.layers.8.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
192 |
+
"model.layers.8.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
193 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
194 |
+
"model.layers.8.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
195 |
+
"model.layers.8.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
196 |
+
"model.layers.8.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
197 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
198 |
+
"model.layers.9.mlp.dense_4h_to_h.weight": "model-00001-of-00003.safetensors",
|
199 |
+
"model.layers.9.mlp.dense_h_to_4h.weight": "model-00001-of-00003.safetensors",
|
200 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
201 |
+
"model.layers.9.self_attention.dense.weight": "model-00001-of-00003.safetensors",
|
202 |
+
"model.layers.9.self_attention.query_key_value.bias": "model-00001-of-00003.safetensors",
|
203 |
+
"model.layers.9.self_attention.query_key_value.weight": "model-00001-of-00003.safetensors",
|
204 |
+
"model.norm.weight": "model-00003-of-00003.safetensors"
|
205 |
+
}
|
206 |
+
}
|
|
modeling_chatglm.py
CHANGED
@@ -223,8 +223,7 @@ class CoreAttention(torch.nn.Module):
|
|
223 |
if pytorch_major_version >= 2:
|
224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
-
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
227 |
-
is_causal=True)
|
228 |
else:
|
229 |
if attention_mask is not None:
|
230 |
attention_mask = ~attention_mask
|
@@ -312,7 +311,6 @@ class CoreAttention(torch.nn.Module):
|
|
312 |
|
313 |
class SelfAttention(torch.nn.Module):
|
314 |
"""Parallel self-attention layer abstract class.
|
315 |
-
|
316 |
Self-attention layer takes input with size [s, b, h]
|
317 |
and returns output of the same size.
|
318 |
"""
|
@@ -448,7 +446,6 @@ class SelfAttention(torch.nn.Module):
|
|
448 |
|
449 |
return output, kv_cache
|
450 |
|
451 |
-
|
452 |
def _config_to_kwargs(args):
|
453 |
common_kwargs = {
|
454 |
"dtype": args.torch_dtype,
|
@@ -504,7 +501,6 @@ class MLP(torch.nn.Module):
|
|
504 |
|
505 |
class GLMBlock(torch.nn.Module):
|
506 |
"""A single transformer layer.
|
507 |
-
|
508 |
Transformer layer takes input with size [s, b, h] and returns an
|
509 |
output of the same size.
|
510 |
"""
|
@@ -597,7 +593,7 @@ class GLMTransformer(torch.nn.Module):
|
|
597 |
if self.post_layer_norm:
|
598 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
599 |
# Final layer norm before output.
|
600 |
-
self.
|
601 |
dtype=config.torch_dtype)
|
602 |
|
603 |
self.gradient_checkpointing = False
|
@@ -653,7 +649,7 @@ class GLMTransformer(torch.nn.Module):
|
|
653 |
|
654 |
# Final layer norm.
|
655 |
if self.post_layer_norm:
|
656 |
-
hidden_states = self.
|
657 |
|
658 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
659 |
|
@@ -740,7 +736,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
740 |
init_kwargs = {}
|
741 |
if device is not None:
|
742 |
init_kwargs["device"] = device
|
743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
self.num_layers = config.num_layers
|
745 |
self.multi_query_group_num = config.multi_query_group_num
|
746 |
self.kv_channels = config.kv_channels
|
@@ -753,9 +756,21 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
753 |
|
754 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
755 |
dtype=config.torch_dtype)
|
756 |
-
|
757 |
-
|
758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
759 |
self.pre_seq_len = config.pre_seq_len
|
760 |
self.prefix_projection = config.prefix_projection
|
761 |
if self.pre_seq_len is not None:
|
@@ -765,6 +780,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
765 |
self.prefix_encoder = PrefixEncoder(config)
|
766 |
self.dropout = torch.nn.Dropout(0.1)
|
767 |
|
|
|
|
|
768 |
def get_input_embeddings(self):
|
769 |
return self.embedding.word_embeddings
|
770 |
|
@@ -804,7 +821,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
804 |
batch_size, seq_length = input_ids.shape
|
805 |
|
806 |
if inputs_embeds is None:
|
807 |
-
inputs_embeds = self.
|
808 |
|
809 |
if self.pre_seq_len is not None:
|
810 |
if past_key_values is None:
|
@@ -827,10 +844,54 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
827 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
828 |
|
829 |
# Run encoder.
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
|
835 |
if not return_dict:
|
836 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
@@ -844,7 +905,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
844 |
|
845 |
def quantize(self, weight_bit_width: int):
|
846 |
from .quantization import quantize
|
847 |
-
quantize(self
|
848 |
return self
|
849 |
|
850 |
|
@@ -853,7 +914,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
853 |
super().__init__(config)
|
854 |
|
855 |
self.max_sequence_length = config.max_length
|
856 |
-
self.
|
|
|
857 |
self.config = config
|
858 |
self.quantized = False
|
859 |
|
@@ -934,7 +996,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
934 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
935 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
936 |
|
937 |
-
transformer_outputs = self.
|
938 |
input_ids=input_ids,
|
939 |
position_ids=position_ids,
|
940 |
attention_mask=attention_mask,
|
@@ -948,8 +1010,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
948 |
hidden_states = transformer_outputs[0]
|
949 |
if return_last_logit:
|
950 |
hidden_states = hidden_states[-1:]
|
951 |
-
lm_logits = self.
|
952 |
-
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
953 |
|
954 |
loss = None
|
955 |
if labels is not None:
|
@@ -1062,8 +1123,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1062 |
inputs = inputs.to(self.device)
|
1063 |
if past_key_values is not None:
|
1064 |
past_length = past_key_values[0][0].shape[0]
|
1065 |
-
if self.
|
1066 |
-
past_length -= self.
|
1067 |
inputs.position_ids += past_length
|
1068 |
attention_mask = inputs.attention_mask
|
1069 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
@@ -1205,7 +1266,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1205 |
|
1206 |
self.config.quantization_bit = bits
|
1207 |
|
1208 |
-
self.
|
1209 |
**kwargs)
|
1210 |
return self
|
1211 |
|
@@ -1215,7 +1276,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1215 |
super().__init__(config)
|
1216 |
|
1217 |
self.num_labels = config.num_labels
|
1218 |
-
self.
|
1219 |
|
1220 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1221 |
if config.classifier_dropout is not None:
|
@@ -1242,7 +1303,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1242 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1243 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1244 |
|
1245 |
-
transformer_outputs = self.
|
1246 |
input_ids=input_ids,
|
1247 |
position_ids=position_ids,
|
1248 |
attention_mask=attention_mask,
|
@@ -1293,4 +1354,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1293 |
past_key_values=transformer_outputs.past_key_values,
|
1294 |
hidden_states=transformer_outputs.hidden_states,
|
1295 |
attentions=transformer_outputs.attentions,
|
1296 |
-
)
|
|
|
223 |
if pytorch_major_version >= 2:
|
224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)
|
|
|
227 |
else:
|
228 |
if attention_mask is not None:
|
229 |
attention_mask = ~attention_mask
|
|
|
311 |
|
312 |
class SelfAttention(torch.nn.Module):
|
313 |
"""Parallel self-attention layer abstract class.
|
|
|
314 |
Self-attention layer takes input with size [s, b, h]
|
315 |
and returns output of the same size.
|
316 |
"""
|
|
|
446 |
|
447 |
return output, kv_cache
|
448 |
|
|
|
449 |
def _config_to_kwargs(args):
|
450 |
common_kwargs = {
|
451 |
"dtype": args.torch_dtype,
|
|
|
501 |
|
502 |
class GLMBlock(torch.nn.Module):
|
503 |
"""A single transformer layer.
|
|
|
504 |
Transformer layer takes input with size [s, b, h] and returns an
|
505 |
output of the same size.
|
506 |
"""
|
|
|
593 |
if self.post_layer_norm:
|
594 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
595 |
# Final layer norm before output.
|
596 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
597 |
dtype=config.torch_dtype)
|
598 |
|
599 |
self.gradient_checkpointing = False
|
|
|
649 |
|
650 |
# Final layer norm.
|
651 |
if self.post_layer_norm:
|
652 |
+
hidden_states = self.norm(hidden_states)
|
653 |
|
654 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
655 |
|
|
|
736 |
init_kwargs = {}
|
737 |
if device is not None:
|
738 |
init_kwargs["device"] = device
|
739 |
+
|
740 |
+
self.embed_tokens = nn.Embedding(
|
741 |
+
config.padded_vocab_size,
|
742 |
+
config.hidden_size,
|
743 |
+
dtype=config.torch_dtype,
|
744 |
+
device=device
|
745 |
+
)
|
746 |
+
|
747 |
self.num_layers = config.num_layers
|
748 |
self.multi_query_group_num = config.multi_query_group_num
|
749 |
self.kv_channels = config.kv_channels
|
|
|
756 |
|
757 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
758 |
dtype=config.torch_dtype)
|
759 |
+
|
760 |
+
# Transformer layers.
|
761 |
+
def build_layer(layer_number):
|
762 |
+
return GLMBlock(config, layer_number, device=device)
|
763 |
+
|
764 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
765 |
+
self.num_layers = config.num_layers
|
766 |
+
self.post_layer_norm = config.post_layer_norm
|
767 |
+
|
768 |
+
if self.post_layer_norm:
|
769 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
770 |
+
# Final layer norm before output.
|
771 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
772 |
+
dtype=config.torch_dtype)
|
773 |
+
|
774 |
self.pre_seq_len = config.pre_seq_len
|
775 |
self.prefix_projection = config.prefix_projection
|
776 |
if self.pre_seq_len is not None:
|
|
|
780 |
self.prefix_encoder = PrefixEncoder(config)
|
781 |
self.dropout = torch.nn.Dropout(0.1)
|
782 |
|
783 |
+
self.gradient_checkpointing = False
|
784 |
+
|
785 |
def get_input_embeddings(self):
|
786 |
return self.embedding.word_embeddings
|
787 |
|
|
|
821 |
batch_size, seq_length = input_ids.shape
|
822 |
|
823 |
if inputs_embeds is None:
|
824 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
825 |
|
826 |
if self.pre_seq_len is not None:
|
827 |
if past_key_values is None:
|
|
|
844 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
845 |
|
846 |
# Run encoder.
|
847 |
+
if not past_key_values:
|
848 |
+
past_key_values = [None for _ in range(self.num_layers)]
|
849 |
+
presents = () if use_cache else None
|
850 |
+
if self.gradient_checkpointing and self.training:
|
851 |
+
if use_cache:
|
852 |
+
logger.warning_once(
|
853 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
854 |
+
)
|
855 |
+
use_cache = False
|
856 |
+
|
857 |
+
all_self_attentions = None
|
858 |
+
all_hidden_states = () if output_hidden_states else None
|
859 |
+
|
860 |
+
hidden_states = inputs_embeds
|
861 |
+
# To comply with former chat-glm format that expects (seqlen, bs, hd)
|
862 |
+
hidden_states = hidden_states.permute(1, 0, 2)
|
863 |
+
|
864 |
+
for index, layer in enumerate(self.layers):
|
865 |
+
if output_hidden_states:
|
866 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
867 |
+
|
868 |
+
if self.gradient_checkpointing and self.training:
|
869 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
870 |
+
layer,
|
871 |
+
hidden_states,
|
872 |
+
full_attention_mask,
|
873 |
+
rotary_pos_emb,
|
874 |
+
past_key_values[index],
|
875 |
+
use_cache
|
876 |
+
)
|
877 |
+
else:
|
878 |
+
layer_ret = layer(
|
879 |
+
hidden_states,
|
880 |
+
full_attention_mask,
|
881 |
+
rotary_pos_emb,
|
882 |
+
kv_cache=past_key_values[index],
|
883 |
+
use_cache=use_cache
|
884 |
+
)
|
885 |
+
hidden_states, kv_cache = layer_ret
|
886 |
+
if use_cache:
|
887 |
+
presents = presents + (kv_cache,)
|
888 |
+
|
889 |
+
if output_hidden_states:
|
890 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
891 |
+
|
892 |
+
# Final layer norm.
|
893 |
+
if self.post_layer_norm:
|
894 |
+
hidden_states = self.norm(hidden_states)
|
895 |
|
896 |
if not return_dict:
|
897 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
905 |
|
906 |
def quantize(self, weight_bit_width: int):
|
907 |
from .quantization import quantize
|
908 |
+
quantize(self, weight_bit_width)
|
909 |
return self
|
910 |
|
911 |
|
|
|
914 |
super().__init__(config)
|
915 |
|
916 |
self.max_sequence_length = config.max_length
|
917 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
918 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
919 |
self.config = config
|
920 |
self.quantized = False
|
921 |
|
|
|
996 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
997 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
998 |
|
999 |
+
transformer_outputs = self.model(
|
1000 |
input_ids=input_ids,
|
1001 |
position_ids=position_ids,
|
1002 |
attention_mask=attention_mask,
|
|
|
1010 |
hidden_states = transformer_outputs[0]
|
1011 |
if return_last_logit:
|
1012 |
hidden_states = hidden_states[-1:]
|
1013 |
+
lm_logits = self.lm_head(hidden_states)
|
|
|
1014 |
|
1015 |
loss = None
|
1016 |
if labels is not None:
|
|
|
1123 |
inputs = inputs.to(self.device)
|
1124 |
if past_key_values is not None:
|
1125 |
past_length = past_key_values[0][0].shape[0]
|
1126 |
+
if self.model.pre_seq_len is not None:
|
1127 |
+
past_length -= self.model.pre_seq_len
|
1128 |
inputs.position_ids += past_length
|
1129 |
attention_mask = inputs.attention_mask
|
1130 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
|
1266 |
|
1267 |
self.config.quantization_bit = bits
|
1268 |
|
1269 |
+
self.model = quantize(self.model, bits, empty_init=empty_init, device=device,
|
1270 |
**kwargs)
|
1271 |
return self
|
1272 |
|
|
|
1276 |
super().__init__(config)
|
1277 |
|
1278 |
self.num_labels = config.num_labels
|
1279 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1280 |
|
1281 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1282 |
if config.classifier_dropout is not None:
|
|
|
1303 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1304 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1305 |
|
1306 |
+
transformer_outputs = self.model(
|
1307 |
input_ids=input_ids,
|
1308 |
position_ids=position_ids,
|
1309 |
attention_mask=attention_mask,
|
|
|
1354 |
past_key_values=transformer_outputs.past_key_values,
|
1355 |
hidden_states=transformer_outputs.hidden_states,
|
1356 |
attentions=transformer_outputs.attentions,
|
1357 |
+
)
|