zxdu20 commited on
Commit
9333486
1 Parent(s): 6466cdc

Add empty_init option

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +37 -14
modeling_chatglm.py CHANGED
@@ -346,10 +346,18 @@ def attention_fn(
346
  return outputs
347
 
348
 
 
 
 
 
349
  class SelfAttention(torch.nn.Module):
350
  def __init__(self, hidden_size, num_attention_heads,
351
  layer_id, hidden_size_per_attention_head=None, bias=True,
352
- params_dtype=torch.float, position_encoding_2d=True):
 
 
 
 
353
  super(SelfAttention, self).__init__()
354
 
355
  self.layer_id = layer_id
@@ -377,7 +385,7 @@ class SelfAttention(torch.nn.Module):
377
  self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
378
 
379
  # Strided linear layer.
380
- self.query_key_value = skip_init(
381
  torch.nn.Linear,
382
  hidden_size,
383
  3 * self.inner_hidden_size,
@@ -385,7 +393,7 @@ class SelfAttention(torch.nn.Module):
385
  dtype=params_dtype,
386
  )
387
 
388
- self.dense = skip_init(
389
  torch.nn.Linear,
390
  self.inner_hidden_size,
391
  hidden_size,
@@ -498,8 +506,12 @@ class GEGLU(torch.nn.Module):
498
 
499
  class GLU(torch.nn.Module):
500
  def __init__(self, hidden_size, inner_hidden_size=None,
501
- layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
502
  super(GLU, self).__init__()
 
 
 
 
503
  self.layer_id = layer_id
504
  self.activation_func = activation_func
505
 
@@ -508,7 +520,7 @@ class GLU(torch.nn.Module):
508
  if inner_hidden_size is None:
509
  inner_hidden_size = 4 * hidden_size
510
  self.inner_hidden_size = inner_hidden_size
511
- self.dense_h_to_4h = skip_init(
512
  torch.nn.Linear,
513
  self.hidden_size,
514
  self.inner_hidden_size,
@@ -516,7 +528,7 @@ class GLU(torch.nn.Module):
516
  dtype=params_dtype,
517
  )
518
  # Project back to h.
519
- self.dense_4h_to_h = skip_init(
520
  torch.nn.Linear,
521
  self.inner_hidden_size,
522
  self.hidden_size,
@@ -552,7 +564,8 @@ class GLMBlock(torch.nn.Module):
552
  use_bias=True,
553
  params_dtype=torch.float,
554
  num_layers=28,
555
- position_encoding_2d=True
 
556
  ):
557
  super(GLMBlock, self).__init__()
558
  # Set output layer initialization if not provided.
@@ -572,7 +585,8 @@ class GLMBlock(torch.nn.Module):
572
  hidden_size_per_attention_head=hidden_size_per_attention_head,
573
  bias=use_bias,
574
  params_dtype=params_dtype,
575
- position_encoding_2d=self.position_encoding_2d
 
576
  )
577
 
578
  # Layernorm on the input data.
@@ -587,6 +601,7 @@ class GLMBlock(torch.nn.Module):
587
  bias=use_bias,
588
  layer_id=layer_id,
589
  params_dtype=params_dtype,
 
590
  )
591
 
592
  def forward(
@@ -781,9 +796,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
781
  `encoder_hidden_states` is then expected as an input to the forward pass.
782
  """
783
 
784
- def __init__(self, config: ChatGLMConfig):
785
  super().__init__(config)
786
-
 
 
 
787
  # recording parameters
788
  self.max_sequence_length = config.max_sequence_length
789
  self.hidden_size = config.hidden_size
@@ -798,7 +816,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
798
  self.pre_seq_len = config.pre_seq_len
799
  self.prefix_projection = config.prefix_projection
800
 
801
- self.word_embeddings = skip_init(
802
  torch.nn.Embedding,
803
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
804
  dtype=self.params_dtype
@@ -817,6 +835,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
817
  use_bias=True,
818
  params_dtype=self.params_dtype,
819
  position_encoding_2d=self.position_encoding_2d,
 
820
  )
821
 
822
  self.layers = torch.nn.ModuleList(
@@ -1004,8 +1023,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
1004
 
1005
 
1006
  class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1007
- def __init__(self, config: ChatGLMConfig):
1008
  super().__init__(config)
 
 
 
 
1009
 
1010
  # self.hidden_size = config.hidden_size
1011
  # self.params_dtype = torch.half
@@ -1014,9 +1037,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1014
 
1015
  self.position_encoding_2d = config.position_encoding_2d
1016
 
1017
- self.transformer = ChatGLMModel(config)
1018
 
1019
- self.lm_head = skip_init(
1020
  nn.Linear,
1021
  config.hidden_size,
1022
  config.vocab_size,
 
346
  return outputs
347
 
348
 
349
+ def default_init(cls, *args, **kwargs):
350
+ return cls(*args, **kwargs)
351
+
352
+
353
  class SelfAttention(torch.nn.Module):
354
  def __init__(self, hidden_size, num_attention_heads,
355
  layer_id, hidden_size_per_attention_head=None, bias=True,
356
+ params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
357
+ if empty_init:
358
+ init_method = skip_init
359
+ else:
360
+ init_method = default_init
361
  super(SelfAttention, self).__init__()
362
 
363
  self.layer_id = layer_id
 
385
  self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
386
 
387
  # Strided linear layer.
388
+ self.query_key_value = init_method(
389
  torch.nn.Linear,
390
  hidden_size,
391
  3 * self.inner_hidden_size,
 
393
  dtype=params_dtype,
394
  )
395
 
396
+ self.dense = init_method(
397
  torch.nn.Linear,
398
  self.inner_hidden_size,
399
  hidden_size,
 
506
 
507
  class GLU(torch.nn.Module):
508
  def __init__(self, hidden_size, inner_hidden_size=None,
509
+ layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
510
  super(GLU, self).__init__()
511
+ if empty_init:
512
+ init_method = skip_init
513
+ else:
514
+ init_method = default_init
515
  self.layer_id = layer_id
516
  self.activation_func = activation_func
517
 
 
520
  if inner_hidden_size is None:
521
  inner_hidden_size = 4 * hidden_size
522
  self.inner_hidden_size = inner_hidden_size
523
+ self.dense_h_to_4h = init_method(
524
  torch.nn.Linear,
525
  self.hidden_size,
526
  self.inner_hidden_size,
 
528
  dtype=params_dtype,
529
  )
530
  # Project back to h.
531
+ self.dense_4h_to_h = init_method(
532
  torch.nn.Linear,
533
  self.inner_hidden_size,
534
  self.hidden_size,
 
564
  use_bias=True,
565
  params_dtype=torch.float,
566
  num_layers=28,
567
+ position_encoding_2d=True,
568
+ empty_init=True
569
  ):
570
  super(GLMBlock, self).__init__()
571
  # Set output layer initialization if not provided.
 
585
  hidden_size_per_attention_head=hidden_size_per_attention_head,
586
  bias=use_bias,
587
  params_dtype=params_dtype,
588
+ position_encoding_2d=self.position_encoding_2d,
589
+ empty_init=empty_init
590
  )
591
 
592
  # Layernorm on the input data.
 
601
  bias=use_bias,
602
  layer_id=layer_id,
603
  params_dtype=params_dtype,
604
+ empty_init=empty_init
605
  )
606
 
607
  def forward(
 
796
  `encoder_hidden_states` is then expected as an input to the forward pass.
797
  """
798
 
799
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
800
  super().__init__(config)
801
+ if empty_init:
802
+ init_method = skip_init
803
+ else:
804
+ init_method = default_init
805
  # recording parameters
806
  self.max_sequence_length = config.max_sequence_length
807
  self.hidden_size = config.hidden_size
 
816
  self.pre_seq_len = config.pre_seq_len
817
  self.prefix_projection = config.prefix_projection
818
 
819
+ self.word_embeddings = init_method(
820
  torch.nn.Embedding,
821
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
822
  dtype=self.params_dtype
 
835
  use_bias=True,
836
  params_dtype=self.params_dtype,
837
  position_encoding_2d=self.position_encoding_2d,
838
+ empty_init=empty_init
839
  )
840
 
841
  self.layers = torch.nn.ModuleList(
 
1023
 
1024
 
1025
  class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1026
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
1027
  super().__init__(config)
1028
+ if empty_init:
1029
+ init_method = skip_init
1030
+ else:
1031
+ init_method = default_init
1032
 
1033
  # self.hidden_size = config.hidden_size
1034
  # self.params_dtype = torch.half
 
1037
 
1038
  self.position_encoding_2d = config.position_encoding_2d
1039
 
1040
+ self.transformer = ChatGLMModel(config, empty_init=empty_init)
1041
 
1042
+ self.lm_head = init_method(
1043
  nn.Linear,
1044
  config.hidden_size,
1045
  config.vocab_size,