Safetensors
gLM2
custom_code
andrecornman commited on
Commit
5b795bf
1 Parent(s): 03cfbbb

Upload gLM2ForMaskedLM

Browse files
Files changed (4) hide show
  1. config.json +20 -0
  2. configuration_glm2.py +37 -0
  3. model.safetensors +3 -0
  4. modeling_glm2.py +565 -0
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "gLM2ForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_glm2.gLM2Config",
7
+ "AutoModel": "modeling_glm2.gLM2Model",
8
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
9
+ },
10
+ "depth": 30,
11
+ "dim": 640,
12
+ "ffn_dim_multiplier": null,
13
+ "heads": 10,
14
+ "model_type": "gLM2",
15
+ "norm_eps": 1e-05,
16
+ "swiglu_multiple_of": 256,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.36.0",
19
+ "vocab_size": 37
20
+ }
configuration_glm2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gLM2 model configuration"""
2
+
3
+ from typing import Optional
4
+ from transformers import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class gLM2Config(PretrainedConfig):
11
+ model_type = "gLM2"
12
+
13
+ def __init__(
14
+ self,
15
+ dim: int = 640,
16
+ depth: int = 30,
17
+ heads: int = 10,
18
+ vocab_size: int = 37,
19
+ swiglu_multiple_of: int = 256,
20
+ ffn_dim_multiplier: Optional[float] = None,
21
+ norm_eps: float = 1e-5,
22
+ **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.dim = dim
26
+ self.depth = depth
27
+ self.heads = heads
28
+ self.vocab_size = vocab_size
29
+ self.swiglu_multiple_of = swiglu_multiple_of
30
+ self.ffn_dim_multiplier = ffn_dim_multiplier
31
+ self.norm_eps = norm_eps
32
+
33
+ self.auto_map = {
34
+ "AutoConfig": "configuration_glm2.gLM2Config",
35
+ "AutoModel": "modeling_glm2.gLM2Model",
36
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
37
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:047facd39472afe082c985058fa042f43a7b10d4ffb2ab51b9a3e6c63e9f3834
3
+ size 609855088
modeling_glm2.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch gLM2 model.
2
+
3
+ Requires flash attention.
4
+ Some modules adapted from:
5
+ https://github.com/meta-llama/llama/blob/main/llama/model.py
6
+ """
7
+ import math
8
+ import torch
9
+ from einops import rearrange
10
+ from typing import Optional, Tuple, Union
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutput,
15
+ MaskedLMOutput,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ try:
21
+ from flash_attn.ops.activations import swiglu
22
+ from flash_attn.layers.rotary import apply_rotary_emb_func
23
+ from flash_attn import (
24
+ flash_attn_kvpacked_func,
25
+ flash_attn_varlen_kvpacked_func,
26
+ )
27
+ from flash_attn.bert_padding import pad_input, unpad_input
28
+ from flash_attn.ops.triton.layer_norm import RMSNorm
29
+ except ImportError:
30
+ raise ImportError(
31
+ "gLM2 requires flash attention: `pip install flash-attn --no-build-isolation`")
32
+
33
+ from .configuration_glm2 import gLM2Config
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class RotaryEmbedding(torch.nn.Module):
40
+ """
41
+ Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
42
+ Changed to only support passing in q or k individually, so that we can use varlen rotary.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ dim: int,
48
+ base=10000.0,
49
+ interleaved=False,
50
+ scale_base=None,
51
+ pos_idx_in_fp32=True,
52
+ device=None,
53
+ ):
54
+ super().__init__()
55
+ self.dim = dim
56
+ self.base = float(base)
57
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
58
+ # Generate and save the inverse frequency buffer (non trainable)
59
+ inv_freq = self._compute_inv_freq(device)
60
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
61
+ self.interleaved = interleaved
62
+ self.scale_base = scale_base
63
+ scale = (
64
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
65
+ / (1.4 * dim)
66
+ if scale_base is not None
67
+ else None
68
+ )
69
+ self.register_buffer("scale", scale, persistent=False)
70
+
71
+ self._seq_len_cached = 0
72
+ self._cos_cached = None
73
+ self._sin_cached = None
74
+ self._cos_k_cached = None
75
+ self._sin_k_cached = None
76
+
77
+ def _compute_inv_freq(self, device=None):
78
+ return 1.0 / (
79
+ self.base
80
+ ** (
81
+ torch.arange(0, self.dim, 2, device=device,
82
+ dtype=torch.float32)
83
+ / self.dim
84
+ )
85
+ )
86
+
87
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
88
+ # Reset the tables if the sequence length has changed,
89
+ # if we're on a new device (possibly due to tracing for instance),
90
+ # or if we're switching from inference mode to training
91
+ if (
92
+ seqlen > self._seq_len_cached
93
+ or self._cos_cached is None
94
+ or self._cos_cached.device != device
95
+ or self._cos_cached.dtype != dtype
96
+ or (self.training and self._cos_cached.is_inference())
97
+ ):
98
+ self._seq_len_cached = seqlen
99
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
100
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
101
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
102
+ if self.pos_idx_in_fp32:
103
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
104
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
105
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
106
+ # cos & sin output to change significantly.
107
+ # We want to recompute self.inv_freq if it was not loaded in fp32
108
+ if self.inv_freq.dtype != torch.float32:
109
+ inv_freq = self._compute_inv_freq(device=device)
110
+ else:
111
+ inv_freq = self.inv_freq
112
+ else:
113
+ t = torch.arange(seqlen, device=device,
114
+ dtype=self.inv_freq.dtype)
115
+ inv_freq = self.inv_freq
116
+ # Don't do einsum, it converts fp32 to fp16 under AMP
117
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
118
+ freqs = torch.outer(t, inv_freq)
119
+ if self.scale is None:
120
+ self._cos_cached = torch.cos(freqs).to(dtype)
121
+ self._sin_cached = torch.sin(freqs).to(dtype)
122
+ else:
123
+ power = (
124
+ torch.arange(
125
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
126
+ )
127
+ - seqlen // 2
128
+ ) / self.scale_base
129
+ scale = self.scale.to(device=power.device) ** rearrange(
130
+ power, "s -> s 1"
131
+ )
132
+ # We want the multiplication by scale to happen in fp32
133
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
134
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
135
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
136
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
137
+
138
+ def forward(
139
+ self,
140
+ q: torch.Tensor,
141
+ k: torch.Tensor,
142
+ seqlen_offset: Union[int, torch.Tensor] = 0,
143
+ cu_seqlens: Optional[torch.Tensor] = None,
144
+ max_seqlen: Optional[int] = None,
145
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
146
+ """
147
+ q: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
148
+ shape (total_seqlen, nheads, headdim).
149
+ k: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
150
+ shape (total_seqlen, nheads, headdim).
151
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
152
+ Most commonly used in inference when we have KV cache.
153
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
154
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
155
+ Apply rotary embedding *inplace* to qkv and / or kv.
156
+ """
157
+ if cu_seqlens is not None:
158
+ assert max_seqlen is not None
159
+ seqlen = q.shape[1] if max_seqlen is None else max_seqlen
160
+ if max_seqlen is not None:
161
+ self._update_cos_sin_cache(
162
+ max_seqlen, device=q.device, dtype=q.dtype)
163
+ elif isinstance(seqlen_offset, int):
164
+ self._update_cos_sin_cache(
165
+ seqlen + seqlen_offset, device=q.device, dtype=q.dtype
166
+ )
167
+ q = apply_rotary_emb_func(
168
+ q,
169
+ self._cos_cached,
170
+ self._sin_cached,
171
+ interleaved=self.interleaved,
172
+ inplace=True,
173
+ seqlen_offsets=seqlen_offset,
174
+ cu_seqlens=cu_seqlens,
175
+ max_seqlen=max_seqlen,
176
+ )
177
+ if self.scale is None:
178
+ k = apply_rotary_emb_func(
179
+ k,
180
+ self._cos_cached,
181
+ self._sin_cached,
182
+ interleaved=self.interleaved,
183
+ inplace=True,
184
+ seqlen_offsets=seqlen_offset,
185
+ cu_seqlens=cu_seqlens,
186
+ max_seqlen=max_seqlen,
187
+ )
188
+ else:
189
+ k = apply_rotary_emb_func(
190
+ k,
191
+ self._cos_k_cached,
192
+ self._sin_k_cached,
193
+ interleaved=self.interleaved,
194
+ inplace=True,
195
+ seqlen_offsets=seqlen_offset,
196
+ cu_seqlens=cu_seqlens,
197
+ max_seqlen=max_seqlen,
198
+ )
199
+ return q, k
200
+
201
+
202
+ # @torch.jit.script
203
+ # def rmsnorm_func(hidden_states, weight, variance_epsilon):
204
+ # """Apply the root mean square normalization."""
205
+ # input_dtype = hidden_states.dtype
206
+ # hidden_states = hidden_states.to(torch.float32)
207
+ # variance = hidden_states.pow(2).mean(-1, keepdim=True)
208
+ # hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
209
+ # return (weight * hidden_states).to(input_dtype)
210
+
211
+
212
+ # class RMSNorm(nn.Module):
213
+ # """Root mean square normalization."""
214
+
215
+ # def __init__(self, dim, eps=1e-6):
216
+ # super().__init__()
217
+ # self.weight = nn.Parameter(torch.ones(dim))
218
+ # self.register_buffer(
219
+ # "variance_epsilon",
220
+ # torch.tensor(eps),
221
+ # persistent=False,
222
+ # )
223
+
224
+ # def forward(self, hidden_states):
225
+ # return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
226
+
227
+
228
+ class Attention(nn.Module):
229
+ """Multi-head attention module."""
230
+
231
+ def __init__(self, config: gLM2Config):
232
+ super().__init__()
233
+ self.n_heads = config.heads
234
+ self.head_dim = config.dim // config.heads
235
+
236
+ self.wqkv = nn.Linear(config.dim, self.n_heads *
237
+ self.head_dim * 3, bias=False)
238
+ self.wo = nn.Linear(config.heads * self.head_dim,
239
+ config.dim, bias=False)
240
+
241
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
242
+
243
+ def _forward_varlen(
244
+ self,
245
+ x: torch.Tensor,
246
+ cu_seqlens: Optional[torch.Tensor] = None,
247
+ max_seq_len: Optional[torch.Tensor] = None,
248
+ ) -> torch.Tensor:
249
+ total_seqlen, h_size = x.shape
250
+ qkv = self.wqkv(x)
251
+ q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
252
+
253
+ q = q.view(total_seqlen, self.n_heads, self.head_dim)
254
+ k = k.view(total_seqlen, self.n_heads, self.head_dim)
255
+ v = v.view(total_seqlen, self.n_heads, self.head_dim)
256
+
257
+ q, k = self.rotary_emb(
258
+ q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len)
259
+
260
+ # (seqlen, 2, n_heads, head_dim)
261
+ kv = torch.stack([k, v], 1)
262
+
263
+ # (seqlen, n_heads, head_dim)
264
+ output = flash_attn_varlen_kvpacked_func(
265
+ q,
266
+ kv,
267
+ cu_seqlens_q=cu_seqlens,
268
+ cu_seqlens_k=cu_seqlens,
269
+ max_seqlen_q=max_seq_len,
270
+ max_seqlen_k=max_seq_len,
271
+ dropout_p=0.0,
272
+ causal=False,
273
+ )
274
+ output = output.view(total_seqlen, h_size)
275
+ return self.wo(output)
276
+
277
+ def forward(
278
+ self,
279
+ x: torch.Tensor,
280
+ cu_seqlens: Optional[torch.Tensor] = None,
281
+ max_seq_len: Optional[torch.Tensor] = None,
282
+ ) -> torch.Tensor:
283
+ if cu_seqlens is not None:
284
+ assert max_seq_len is not None
285
+ return self._forward_varlen(x, cu_seqlens, max_seq_len)
286
+
287
+ bsz, seqlen, h_size = x.shape
288
+ qkv = self.wqkv(x)
289
+ q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
290
+ q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
291
+ k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
292
+ v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
293
+
294
+ q, k = self.rotary_emb(q, k)
295
+ # (bs, seqlen, 2, n_heads, head_dim)
296
+ kv = torch.stack([k, v], 2)
297
+
298
+ output = flash_attn_kvpacked_func(
299
+ q,
300
+ kv,
301
+ dropout_p=0.0,
302
+ causal=False,
303
+ )
304
+ output = output.view(bsz, seqlen, h_size)
305
+ return self.wo(output)
306
+
307
+
308
+ class FeedForward(nn.Module):
309
+ def __init__(
310
+ self,
311
+ dim: int,
312
+ hidden_dim: int,
313
+ multiple_of: int,
314
+ ffn_dim_multiplier: Optional[float],
315
+ ):
316
+ """
317
+ SwiGLU FeedForward module.
318
+
319
+ Args:
320
+ dim (int): Input dimension.
321
+ hidden_dim (int): Hidden dimension of the feedforward layer.
322
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
323
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
324
+ """
325
+ super().__init__()
326
+ hidden_dim = int(2 * hidden_dim / 3)
327
+ # custom dim factor multiplier
328
+ if ffn_dim_multiplier is not None:
329
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
330
+ hidden_dim = multiple_of * \
331
+ ((hidden_dim + multiple_of - 1) // multiple_of)
332
+
333
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
334
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
335
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
336
+
337
+ def forward(self, x):
338
+ return self.w2(swiglu(self.w1(x), self.w3(x)))
339
+
340
+
341
+ class TransformerBlock(nn.Module):
342
+ def __init__(self, config: gLM2Config):
343
+ super().__init__()
344
+ self.n_heads = config.heads
345
+ self.dim = config.dim
346
+ self.head_dim = config.dim // config.heads
347
+ self.attention = Attention(config)
348
+ self.feed_forward = FeedForward(
349
+ dim=config.dim,
350
+ hidden_dim=4 * config.dim,
351
+ multiple_of=config.swiglu_multiple_of,
352
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
353
+ )
354
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
355
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
356
+
357
+ def forward(
358
+ self,
359
+ x: torch.Tensor,
360
+ cu_seqlens: Optional[torch.Tensor] = None,
361
+ max_seq_len: Optional[torch.Tensor] = None,
362
+ ) -> torch.Tensor:
363
+ r = self.attention(
364
+ self.attention_norm(x), cu_seqlens, max_seq_len
365
+ )
366
+ h = x + r
367
+ r = self.feed_forward(self.ffn_norm(h))
368
+ out = h + r
369
+ return out
370
+
371
+
372
+ class TransformerLayers(nn.Module):
373
+ def __init__(self, config: gLM2Config):
374
+ super().__init__()
375
+ self.config = config
376
+ self.layers = torch.nn.ModuleList(
377
+ [TransformerBlock(config=config) for _ in range(config.depth)]
378
+ )
379
+ self.apply(self._init_weights)
380
+ # Apply special scaled init to the residual projections, per GPT-2 paper.
381
+ # Weight w2 is output of FeedForward. Weight wo is output of Attention.
382
+ for pn, p in self.named_parameters():
383
+ if pn.endswith('w2.weight') or pn.endswith('wo.weight'):
384
+ torch.nn.init.normal_(
385
+ p, mean=0.0, std=0.02/math.sqrt(2 * self.config.depth))
386
+
387
+ def _init_weights(self, module):
388
+ if isinstance(module, nn.Linear):
389
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
390
+ if module.bias is not None:
391
+ torch.nn.init.zeros_(module.bias)
392
+
393
+ def forward(
394
+ self,
395
+ x: torch.FloatTensor,
396
+ attention_mask: Optional[torch.BoolTensor] = None,
397
+ return_all_hiddens: bool = False,
398
+ ):
399
+ if x.shape[-1] != self.config.dim:
400
+ raise ValueError(
401
+ f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
402
+ )
403
+ batch_size, seq_len = x.shape[:2]
404
+ should_unpad = attention_mask is not None and not attention_mask.all()
405
+ if should_unpad:
406
+ x, indices, cu_seqlens, max_seq_len_in_batch = unpad_input(
407
+ x, attention_mask
408
+ )
409
+ else:
410
+ indices, cu_seqlens, max_seq_len_in_batch = None, None, None
411
+ hiddens = []
412
+ for layer in self.layers:
413
+ x = layer(x, cu_seqlens, max_seq_len_in_batch)
414
+ if return_all_hiddens:
415
+ hiddens.append(x)
416
+
417
+ if should_unpad:
418
+ x = pad_input(x, indices, batch_size, seq_len)
419
+ if return_all_hiddens:
420
+ hiddens = [pad_input(h, indices, batch_size, seq_len)
421
+ for h in hiddens]
422
+
423
+ if return_all_hiddens:
424
+ return x, hiddens
425
+ return x
426
+
427
+
428
+ class gLM2PreTrainedModel(PreTrainedModel):
429
+ """
430
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
431
+ models.
432
+ """
433
+ config_class = gLM2Config
434
+ base_model_prefix = "glm2"
435
+ supports_gradient_checkpointing = False
436
+
437
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
438
+ def _init_weights(module, initializer_range=0.02):
439
+ if isinstance(module, nn.Linear):
440
+ nn.init.normal_(module.weight, std=initializer_range)
441
+ if module.bias is not None:
442
+ nn.init.zeros_(module.bias)
443
+ elif isinstance(module, nn.Embedding):
444
+ nn.init.normal_(module.weight, std=initializer_range)
445
+ if module.padding_idx is not None:
446
+ nn.init.zeros_(module.weight[module.padding_idx])
447
+
448
+
449
+ class gLM2Model(gLM2PreTrainedModel):
450
+ """gLM2 Model."""
451
+
452
+ def __init__(self, config: gLM2Config):
453
+ super().__init__(config)
454
+ self.config = config
455
+
456
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
457
+ self._init_weights(self.tok_embeddings)
458
+ self.encoder = TransformerLayers(config)
459
+
460
+ def _init_weights(self, module):
461
+ if isinstance(module, nn.Linear):
462
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
463
+ if module.bias is not None:
464
+ torch.nn.init.zeros_(module.bias)
465
+ elif isinstance(module, nn.Embedding):
466
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
467
+
468
+ def forward(
469
+ self,
470
+ input_ids: torch.Tensor,
471
+ attention_mask: Optional[torch.Tensor] = None,
472
+ output_hidden_states: Optional[bool] = None,
473
+ return_dict: Optional[bool] = None,
474
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
475
+ output_hidden_states = (
476
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
477
+ )
478
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
479
+
480
+ h = self.tok_embeddings(input_ids)
481
+ if output_hidden_states:
482
+ sequence_output, all_hidden_states = self.encoder(
483
+ h, attention_mask, return_all_hiddens=True)
484
+ else:
485
+ sequence_output = self.encoder(h, attention_mask)
486
+ all_hidden_states = None
487
+
488
+ if not return_dict:
489
+ return (sequence_output, all_hidden_states)
490
+
491
+ return BaseModelOutput(
492
+ last_hidden_state=sequence_output,
493
+ hidden_states=all_hidden_states,
494
+
495
+ )
496
+
497
+
498
+ class gLM2ForMaskedLM(gLM2PreTrainedModel):
499
+
500
+ def __init__(self, config: gLM2Config):
501
+ super().__init__(config)
502
+
503
+ self.glm2 = gLM2Model(config)
504
+ self.lm_head = gLM2LMHead(config)
505
+ self._init_weights(self.lm_head)
506
+
507
+ def _init_weights(self, module):
508
+ if isinstance(module, nn.Linear):
509
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
510
+ if module.bias is not None:
511
+ torch.nn.init.zeros_(module.bias)
512
+ elif isinstance(module, nn.Embedding):
513
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
514
+
515
+ def forward(
516
+ self,
517
+ input_ids: torch.Tensor,
518
+ attention_mask: Optional[torch.Tensor] = None,
519
+ labels: Optional[torch.LongTensor] = None,
520
+ output_hidden_states: Optional[bool] = None,
521
+ return_dict: Optional[bool] = None,
522
+ ) -> Union[Tuple, MaskedLMOutput]:
523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
524
+
525
+ outputs = self.glm2(
526
+ input_ids,
527
+ attention_mask=attention_mask,
528
+ output_hidden_states=output_hidden_states,
529
+ return_dict=return_dict,
530
+ )
531
+ sequence_output = outputs[0]
532
+ prediction_scores = self.lm_head(sequence_output)
533
+
534
+ masked_lm_loss = None
535
+ if labels is not None:
536
+ loss_fct = CrossEntropyLoss()
537
+
538
+ labels = labels.to(prediction_scores.device)
539
+ masked_lm_loss = loss_fct(
540
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
541
+
542
+ if not return_dict:
543
+ output = (prediction_scores,) + outputs[2:]
544
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
545
+
546
+ return MaskedLMOutput(
547
+ loss=masked_lm_loss,
548
+ logits=prediction_scores,
549
+ hidden_states=outputs.hidden_states,
550
+ attentions=outputs.attentions,
551
+ )
552
+
553
+
554
+ class gLM2LMHead(nn.Module):
555
+ """gLM2 head for masked language modeling."""
556
+
557
+ def __init__(self, config):
558
+ super().__init__()
559
+
560
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
561
+ self.proj_output = nn.Linear(
562
+ config.dim, config.vocab_size, bias=False)
563
+
564
+ def forward(self, features):
565
+ return self.proj_output(self.norm(features))