Add empty_init option
Browse files- modeling_chatglm.py +37 -14
modeling_chatglm.py
CHANGED
@@ -346,10 +346,18 @@ def attention_fn(
|
|
346 |
return outputs
|
347 |
|
348 |
|
|
|
|
|
|
|
|
|
349 |
class SelfAttention(torch.nn.Module):
|
350 |
def __init__(self, hidden_size, num_attention_heads,
|
351 |
layer_id, hidden_size_per_attention_head=None, bias=True,
|
352 |
-
params_dtype=torch.float, position_encoding_2d=True):
|
|
|
|
|
|
|
|
|
353 |
super(SelfAttention, self).__init__()
|
354 |
|
355 |
self.layer_id = layer_id
|
@@ -377,7 +385,7 @@ class SelfAttention(torch.nn.Module):
|
|
377 |
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
378 |
|
379 |
# Strided linear layer.
|
380 |
-
self.query_key_value =
|
381 |
torch.nn.Linear,
|
382 |
hidden_size,
|
383 |
3 * self.inner_hidden_size,
|
@@ -385,7 +393,7 @@ class SelfAttention(torch.nn.Module):
|
|
385 |
dtype=params_dtype,
|
386 |
)
|
387 |
|
388 |
-
self.dense =
|
389 |
torch.nn.Linear,
|
390 |
self.inner_hidden_size,
|
391 |
hidden_size,
|
@@ -498,8 +506,12 @@ class GEGLU(torch.nn.Module):
|
|
498 |
|
499 |
class GLU(torch.nn.Module):
|
500 |
def __init__(self, hidden_size, inner_hidden_size=None,
|
501 |
-
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
|
502 |
super(GLU, self).__init__()
|
|
|
|
|
|
|
|
|
503 |
self.layer_id = layer_id
|
504 |
self.activation_func = activation_func
|
505 |
|
@@ -508,7 +520,7 @@ class GLU(torch.nn.Module):
|
|
508 |
if inner_hidden_size is None:
|
509 |
inner_hidden_size = 4 * hidden_size
|
510 |
self.inner_hidden_size = inner_hidden_size
|
511 |
-
self.dense_h_to_4h =
|
512 |
torch.nn.Linear,
|
513 |
self.hidden_size,
|
514 |
self.inner_hidden_size,
|
@@ -516,7 +528,7 @@ class GLU(torch.nn.Module):
|
|
516 |
dtype=params_dtype,
|
517 |
)
|
518 |
# Project back to h.
|
519 |
-
self.dense_4h_to_h =
|
520 |
torch.nn.Linear,
|
521 |
self.inner_hidden_size,
|
522 |
self.hidden_size,
|
@@ -552,7 +564,8 @@ class GLMBlock(torch.nn.Module):
|
|
552 |
use_bias=True,
|
553 |
params_dtype=torch.float,
|
554 |
num_layers=28,
|
555 |
-
position_encoding_2d=True
|
|
|
556 |
):
|
557 |
super(GLMBlock, self).__init__()
|
558 |
# Set output layer initialization if not provided.
|
@@ -572,7 +585,8 @@ class GLMBlock(torch.nn.Module):
|
|
572 |
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
573 |
bias=use_bias,
|
574 |
params_dtype=params_dtype,
|
575 |
-
position_encoding_2d=self.position_encoding_2d
|
|
|
576 |
)
|
577 |
|
578 |
# Layernorm on the input data.
|
@@ -587,6 +601,7 @@ class GLMBlock(torch.nn.Module):
|
|
587 |
bias=use_bias,
|
588 |
layer_id=layer_id,
|
589 |
params_dtype=params_dtype,
|
|
|
590 |
)
|
591 |
|
592 |
def forward(
|
@@ -781,9 +796,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
781 |
`encoder_hidden_states` is then expected as an input to the forward pass.
|
782 |
"""
|
783 |
|
784 |
-
def __init__(self, config: ChatGLMConfig):
|
785 |
super().__init__(config)
|
786 |
-
|
|
|
|
|
|
|
787 |
# recording parameters
|
788 |
self.max_sequence_length = config.max_sequence_length
|
789 |
self.hidden_size = config.hidden_size
|
@@ -798,7 +816,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
798 |
self.pre_seq_len = config.pre_seq_len
|
799 |
self.prefix_projection = config.prefix_projection
|
800 |
|
801 |
-
self.word_embeddings =
|
802 |
torch.nn.Embedding,
|
803 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
804 |
dtype=self.params_dtype
|
@@ -817,6 +835,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
817 |
use_bias=True,
|
818 |
params_dtype=self.params_dtype,
|
819 |
position_encoding_2d=self.position_encoding_2d,
|
|
|
820 |
)
|
821 |
|
822 |
self.layers = torch.nn.ModuleList(
|
@@ -1004,8 +1023,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
1004 |
|
1005 |
|
1006 |
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
1007 |
-
def __init__(self, config: ChatGLMConfig):
|
1008 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
# self.hidden_size = config.hidden_size
|
1011 |
# self.params_dtype = torch.half
|
@@ -1014,9 +1037,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1014 |
|
1015 |
self.position_encoding_2d = config.position_encoding_2d
|
1016 |
|
1017 |
-
self.transformer = ChatGLMModel(config)
|
1018 |
|
1019 |
-
self.lm_head =
|
1020 |
nn.Linear,
|
1021 |
config.hidden_size,
|
1022 |
config.vocab_size,
|
|
|
346 |
return outputs
|
347 |
|
348 |
|
349 |
+
def default_init(cls, *args, **kwargs):
|
350 |
+
return cls(*args, **kwargs)
|
351 |
+
|
352 |
+
|
353 |
class SelfAttention(torch.nn.Module):
|
354 |
def __init__(self, hidden_size, num_attention_heads,
|
355 |
layer_id, hidden_size_per_attention_head=None, bias=True,
|
356 |
+
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
357 |
+
if empty_init:
|
358 |
+
init_method = skip_init
|
359 |
+
else:
|
360 |
+
init_method = default_init
|
361 |
super(SelfAttention, self).__init__()
|
362 |
|
363 |
self.layer_id = layer_id
|
|
|
385 |
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
|
386 |
|
387 |
# Strided linear layer.
|
388 |
+
self.query_key_value = init_method(
|
389 |
torch.nn.Linear,
|
390 |
hidden_size,
|
391 |
3 * self.inner_hidden_size,
|
|
|
393 |
dtype=params_dtype,
|
394 |
)
|
395 |
|
396 |
+
self.dense = init_method(
|
397 |
torch.nn.Linear,
|
398 |
self.inner_hidden_size,
|
399 |
hidden_size,
|
|
|
506 |
|
507 |
class GLU(torch.nn.Module):
|
508 |
def __init__(self, hidden_size, inner_hidden_size=None,
|
509 |
+
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
510 |
super(GLU, self).__init__()
|
511 |
+
if empty_init:
|
512 |
+
init_method = skip_init
|
513 |
+
else:
|
514 |
+
init_method = default_init
|
515 |
self.layer_id = layer_id
|
516 |
self.activation_func = activation_func
|
517 |
|
|
|
520 |
if inner_hidden_size is None:
|
521 |
inner_hidden_size = 4 * hidden_size
|
522 |
self.inner_hidden_size = inner_hidden_size
|
523 |
+
self.dense_h_to_4h = init_method(
|
524 |
torch.nn.Linear,
|
525 |
self.hidden_size,
|
526 |
self.inner_hidden_size,
|
|
|
528 |
dtype=params_dtype,
|
529 |
)
|
530 |
# Project back to h.
|
531 |
+
self.dense_4h_to_h = init_method(
|
532 |
torch.nn.Linear,
|
533 |
self.inner_hidden_size,
|
534 |
self.hidden_size,
|
|
|
564 |
use_bias=True,
|
565 |
params_dtype=torch.float,
|
566 |
num_layers=28,
|
567 |
+
position_encoding_2d=True,
|
568 |
+
empty_init=True
|
569 |
):
|
570 |
super(GLMBlock, self).__init__()
|
571 |
# Set output layer initialization if not provided.
|
|
|
585 |
hidden_size_per_attention_head=hidden_size_per_attention_head,
|
586 |
bias=use_bias,
|
587 |
params_dtype=params_dtype,
|
588 |
+
position_encoding_2d=self.position_encoding_2d,
|
589 |
+
empty_init=empty_init
|
590 |
)
|
591 |
|
592 |
# Layernorm on the input data.
|
|
|
601 |
bias=use_bias,
|
602 |
layer_id=layer_id,
|
603 |
params_dtype=params_dtype,
|
604 |
+
empty_init=empty_init
|
605 |
)
|
606 |
|
607 |
def forward(
|
|
|
796 |
`encoder_hidden_states` is then expected as an input to the forward pass.
|
797 |
"""
|
798 |
|
799 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
800 |
super().__init__(config)
|
801 |
+
if empty_init:
|
802 |
+
init_method = skip_init
|
803 |
+
else:
|
804 |
+
init_method = default_init
|
805 |
# recording parameters
|
806 |
self.max_sequence_length = config.max_sequence_length
|
807 |
self.hidden_size = config.hidden_size
|
|
|
816 |
self.pre_seq_len = config.pre_seq_len
|
817 |
self.prefix_projection = config.prefix_projection
|
818 |
|
819 |
+
self.word_embeddings = init_method(
|
820 |
torch.nn.Embedding,
|
821 |
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
822 |
dtype=self.params_dtype
|
|
|
835 |
use_bias=True,
|
836 |
params_dtype=self.params_dtype,
|
837 |
position_encoding_2d=self.position_encoding_2d,
|
838 |
+
empty_init=empty_init
|
839 |
)
|
840 |
|
841 |
self.layers = torch.nn.ModuleList(
|
|
|
1023 |
|
1024 |
|
1025 |
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
1026 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True):
|
1027 |
super().__init__(config)
|
1028 |
+
if empty_init:
|
1029 |
+
init_method = skip_init
|
1030 |
+
else:
|
1031 |
+
init_method = default_init
|
1032 |
|
1033 |
# self.hidden_size = config.hidden_size
|
1034 |
# self.params_dtype = torch.half
|
|
|
1037 |
|
1038 |
self.position_encoding_2d = config.position_encoding_2d
|
1039 |
|
1040 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
1041 |
|
1042 |
+
self.lm_head = init_method(
|
1043 |
nn.Linear,
|
1044 |
config.hidden_size,
|
1045 |
config.vocab_size,
|