Update Whisper.py
Browse files- Whisper.py +13 -3
Whisper.py
CHANGED
@@ -150,7 +150,7 @@ class AudioEncoder:
|
|
150 |
return x
|
151 |
|
152 |
|
153 |
-
class TextDecoder:
|
154 |
def __init__(
|
155 |
self,
|
156 |
n_vocab: int,
|
@@ -160,8 +160,18 @@ class TextDecoder:
|
|
160 |
n_layer: int,
|
161 |
dtype = tf.float16,
|
162 |
):
|
163 |
-
self.token_embedding =
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
self.blocks = [
|
167 |
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
|
|
150 |
return x
|
151 |
|
152 |
|
153 |
+
class TextDecoder(tf.keras.layers.Layer):
|
154 |
def __init__(
|
155 |
self,
|
156 |
n_vocab: int,
|
|
|
160 |
n_layer: int,
|
161 |
dtype = tf.float16,
|
162 |
):
|
163 |
+
self.token_embedding = self.add_weight(
|
164 |
+
name='token_embedding',
|
165 |
+
shape=[self.n_vocab, self.n_state],
|
166 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), # 设定标准差 stddev
|
167 |
+
trainable=True
|
168 |
+
)
|
169 |
+
self.positional_embedding = self.add_weight(
|
170 |
+
name='positional_embedding',
|
171 |
+
shape=[self.n_ctx, self.n_state],
|
172 |
+
initializer=tf.keras.initializers.Zeros(), # 初始化为全零
|
173 |
+
trainable=True
|
174 |
+
)
|
175 |
|
176 |
self.blocks = [
|
177 |
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|