|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Inference-only Gemma model implementation.""" |
|
|
|
import tensorflow as tf |
|
from tensorflow.keras.layers import Dense |
|
from tensorflow.keras import Model |
|
import dataclasses |
|
|
|
|
|
@dataclasses.dataclass |
|
class GemmaConfig: |
|
|
|
vocab_size: int = 256000 |
|
|
|
max_position_embeddings: int = 8192 |
|
|
|
num_hidden_layers: int = 28 |
|
|
|
num_attention_heads: int = 16 |
|
|
|
num_key_value_heads: int = 16 |
|
|
|
hidden_size: int = 3072 |
|
|
|
intermediate_size: int = 24576 |
|
|
|
head_dim: int = 256 |
|
|
|
rms_norm_eps: float = 1e-6 |
|
|
|
|
|
def precompute_freqs_cis(dim: int, |
|
end: int, |
|
theta: float = 10000.0): |
|
"""Precomputes the frequency cis.""" |
|
freqs = 1.0 / (theta**(tf.cast(tf.range(0, dim, 2)[:(dim // 2)], 'float32') / dim)) |
|
t = tf.range(end) |
|
freqs = tf.cast(tf.experimental.numpy.outer(t, freqs), 'float32') |
|
freqs_cis = tf.complex(tf.ones_like(freqs), freqs) |
|
return freqs_cis |
|
|
|
|
|
def apply_rotary_emb(x, freqs_cis): |
|
"""Applies the rotary embedding to the query and key tensors.""" |
|
x_ = tf.complex( |
|
*tf.split(tf.cast(tf.transpose(x, [0, 2, 1, 3]), 'float32'), num_or_size_splits=2, axis=-1), |
|
) |
|
x_ = x_ * tf.cast(freqs_cis, x_.dtype) |
|
x_out = tf.cast(tf.stack(tf.math.real(x_), |
|
tf.math.imag(x_), axis=-1), x.dtype) |
|
x_out = tf.concat(tf.split(x_out, num_or_size_splits=2, axis=-1), axis=-2) |
|
x_out = tf.transpose(tf.reshape(x_out, (x_out.shape[0], x_out.shape[1], x_out.shape[2], |
|
-1)), (0, 2, 1, 3)) |
|
return x_out |
|
|
|
|
|
class Embedder(tf.keras.layers.Layer): |
|
"""Embedder module.""" |
|
def __init__(self, config: GemmaConfig): |
|
self.vocab_size = config.vocab_size |
|
self.embed_dim = config.hidden_size |
|
self.input_embedding_table = self.add_weight( |
|
name='input_embedding_table', |
|
shape=(self.vocab_size, self.embed_dim), |
|
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), |
|
trainable=True |
|
) |
|
|
|
def encode(self, x): |
|
x = tf.gather(self.input_embedding_table, x) |
|
x *= tf.cast(tf.math.sqrt(self.embed_dim), x.dtype) |
|
return x |
|
|
|
def decode(self, x): |
|
return tf.matmul(x, tf.transpose(self.input_embedding_table)) |
|
|
|
|
|
class RMSNorm: |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
eps: float = 1e-6, |
|
add_unit_offset: bool = True, |
|
): |
|
self.eps = eps |
|
self.add_unit_offset = add_unit_offset |
|
self.weight = self.add_weight( |
|
name='weight', |
|
shape=(self.dim,), |
|
initializer=tf.keras.initializers.Zeros(), |
|
trainable=True |
|
) |
|
|
|
def _norm(self, x): |
|
return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True) + self.eps) |
|
|
|
def __call__(self, x): |
|
x = tf.cast(self._norm(tf.cast(x, 'float32')), x.dtype) |
|
if self.add_unit_offset: |
|
output = x * (1 + self.weight) |
|
else: |
|
output = x * self.weight |
|
return output |
|
|
|
|
|
class GemmaMLP: |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
): |
|
self.gate_proj = Dense(intermediate_size) |
|
self.up_proj = Dense(intermediate_size) |
|
self.down_proj = Dense(hidden_size) |
|
|
|
def __call__(self, x): |
|
gate = self.gate_proj(x) |
|
gate = tf.nn.gelu(gate) |
|
up = self.up_proj(x) |
|
fuse = gate * up |
|
outputs = self.down_proj(fuse) |
|
return outputs |
|
|
|
|
|
class GemmaAttention: |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
head_dim: int, |
|
): |
|
self.num_heads = num_heads |
|
self.num_kv_heads = num_kv_heads |
|
|
|
assert self.num_heads % self.num_kv_heads == 0 |
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads |
|
|
|
self.hidden_size = hidden_size |
|
self.head_dim = head_dim |
|
|
|
self.q_size = self.num_heads * self.head_dim |
|
self.kv_size = self.num_kv_heads * self.head_dim |
|
|
|
self.scaling = self.head_dim**-0.5 |
|
|
|
self.qkv_proj = Dense( |
|
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim, |
|
) |
|
self.o_proj = Dense( |
|
self.hidden_size, |
|
) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
freqs_cis, |
|
kv_write_indices, |
|
kv_cache, |
|
mask, |
|
): |
|
hidden_states_shape = hidden_states.shape |
|
assert len(hidden_states_shape) == 3 |
|
|
|
batch_size, input_len, _ = hidden_states_shape |
|
|
|
qkv = self.qkv_proj(hidden_states) |
|
xq, xk, xv = tf.split(qkv, [self.q_size, self.kv_size, self.kv_size], |
|
axis=-1) |
|
|
|
xq = tf.reshape(xq, (batch_size, -1, self.num_heads, self.head_dim)) |
|
xk = tf.reshape(xk, (batch_size, -1, self.num_kv_heads, self.head_dim)) |
|
xv = tf.reshape(xv, (batch_size, -1, self.num_kv_heads, self.head_dim)) |
|
|
|
|
|
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) |
|
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) |
|
|
|
|
|
|
|
k_cache, v_cache = kv_cache |
|
k_cache.assign(tf.tensor_scatter_nd_update(k_cache, kv_write_indices, xk)) |
|
v_cache.assign(tf.tensor_scatter_nd_update(v_cache, kv_write_indices, xv)) |
|
|
|
key = k_cache |
|
value = v_cache |
|
if self.num_kv_heads != self.num_heads: |
|
|
|
batch_size, seq_len, num_heads, head_dim = key.shape |
|
key = tf.reshape(tf.tile(key[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]), |
|
[batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim]) |
|
batch_size, seq_len, num_heads, head_dim = value.shape |
|
value = tf.reshape(tf.tile(value[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]), |
|
[batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim]) |
|
|
|
|
|
q = tf.transpose(xq, (0, 2, 1, 3)) |
|
|
|
k = tf.transpose(key, (0, 2, 1, 3)) |
|
v = tf.transpose(value, (0, 2, 1, 3)) |
|
|
|
|
|
scores = tf.matmul(q, tf.transpose(k, (0, 1, 3, 2))) * self.scaling |
|
scores = scores + mask |
|
scores = tf.cast(tf.nn.softmax(tf.cast(scores, 'float32'), axis=-1), q.dtype) |
|
|
|
|
|
output = tf.matmul(scores, v) |
|
|
|
|
|
output = tf.reshape((tf.transpose(output, (0, 2, 1, 3)), |
|
(batch_size, input_len, -1))) |
|
output = self.o_proj(output) |
|
return output |
|
|
|
|
|
class GemmaDecoderLayer: |
|
|
|
def __init__( |
|
self, |
|
config: GemmaConfig, |
|
): |
|
self.self_attn = GemmaAttention( |
|
hidden_size=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
num_kv_heads=config.num_key_value_heads, |
|
head_dim=config.head_dim, |
|
) |
|
self.mlp = GemmaMLP( |
|
hidden_size=config.hidden_size, |
|
intermediate_size=config.intermediate_size, |
|
) |
|
self.input_layernorm = RMSNorm(config.hidden_size, |
|
eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, |
|
eps=config.rms_norm_eps) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
freqs_cis, |
|
kv_write_indices, |
|
kv_cache, |
|
mask, |
|
): |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
hidden_states = self.self_attn( |
|
hidden_states=hidden_states, |
|
freqs_cis=freqs_cis, |
|
kv_write_indices=kv_write_indices, |
|
kv_cache=kv_cache, |
|
mask=mask, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class Gemma(Model): |
|
|
|
def __init__(self, config: GemmaConfig): |
|
super(Gemma, self).__init__() |
|
self.config = config |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embedder = Embedder() |
|
self.layers = [] |
|
for _ in range(config.num_hidden_layers): |
|
self.layers.append(GemmaDecoderLayer(config)) |
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.output = Dense(config.vocab_size) |
|
|
|
def __call__( |
|
self, |
|
data, |
|
freqs_cis, |
|
kv_write_indices, |
|
kv_caches, |
|
mask |
|
): |
|
hidden_states = self.embedder.encode(data) |
|
for i in range(len(self.layers)): |
|
layer = self.layers[i] |
|
hidden_states = layer( |
|
hidden_states=hidden_states, |
|
freqs_cis=freqs_cis, |
|
kv_write_indices=kv_write_indices, |
|
kv_cache=kv_caches[i], |
|
mask=mask, |
|
) |
|
hidden_states = self.norm(hidden_states) |
|
logits = self.embedder.decode(hidden_states) |
|
return logits |