NoteDance commited on
Commit
a17b1d3
1 Parent(s): 04dbedd

Update Gemma.py

Browse files
Files changed (1) hide show
  1. Gemma.py +14 -4
Gemma.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -66,12 +66,17 @@ def apply_rotary_emb(x, freqs_cis):
66
  return x_out
67
 
68
 
69
- class Embedder:
70
  """Embedder module."""
71
  def __init__(self, config: GemmaConfig):
72
  self.vocab_size = config.vocab_size
73
  self.embed_dim = config.hidden_size
74
- self.input_embedding_table = tf.Variable(tf.random.normal((self.vocab_size, self.embed_dim)))
 
 
 
 
 
75
 
76
  def encode(self, x):
77
  x = tf.gather(self.input_embedding_table, x)
@@ -92,7 +97,12 @@ class RMSNorm:
92
  ):
93
  self.eps = eps
94
  self.add_unit_offset = add_unit_offset
95
- self.weight = tf.Variable(tf.random.zeros((dim)))
 
 
 
 
 
96
 
97
  def _norm(self, x):
98
  return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True) + self.eps)
 
1
+ # Copyright 2024 NoteDance
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
66
  return x_out
67
 
68
 
69
+ class Embedder(tf.keras.layers.Layer):
70
  """Embedder module."""
71
  def __init__(self, config: GemmaConfig):
72
  self.vocab_size = config.vocab_size
73
  self.embed_dim = config.hidden_size
74
+ self.input_embedding_table = self.add_weight(
75
+ name='input_embedding_table',
76
+ shape=(self.vocab_size, self.embed_dim),
77
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
78
+ trainable=True
79
+ )
80
 
81
  def encode(self, x):
82
  x = tf.gather(self.input_embedding_table, x)
 
97
  ):
98
  self.eps = eps
99
  self.add_unit_offset = add_unit_offset
100
+ self.weight = self.add_weight(
101
+ name='weight',
102
+ shape=(self.dim,),
103
+ initializer=tf.keras.initializers.Zeros(),
104
+ trainable=True
105
+ )
106
 
107
  def _norm(self, x):
108
  return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True) + self.eps)