|
""" |
|
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 |
|
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 |
|
""" |
|
|
|
import functools |
|
from typing import NamedTuple |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.lax as lax |
|
import jax.numpy as jnp |
|
from einops import rearrange |
|
|
|
""" |
|
Computing ffn blockwise without materializing the large hidden tensor, training |
|
4x longer sequences than the memory-efficient transformer. |
|
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 |
|
""" |
|
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True): |
|
|
|
|
|
|
|
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) |
|
def scan_ffn(remat_ffn, carry, hidden_states): |
|
outputs = remat_ffn(hidden_states, deterministic=deterministic) |
|
return carry, outputs |
|
scan_axis = inputs.ndim - 2 |
|
_, res = nn.scan( |
|
scan_ffn, |
|
variable_broadcast="params", |
|
split_rngs={"params": False, "dropout": True}, |
|
in_axes=scan_axis, |
|
out_axes=scan_axis, |
|
)(remat_ffn, None, inputs) |
|
res = rearrange(res, 'b c n d -> b (c n) d') |
|
return res |
|
|
|
|
|
""" |
|
Compute attention blockwise without materializing the full attention matrix, |
|
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; |
|
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA |
|
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370 |
|
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x |
|
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. |
|
""" |
|
def blockwise_attn( |
|
query, key, value, |
|
bias=None, |
|
deterministic=True, |
|
dropout_rng=None, |
|
attn_pdrop=0.0, |
|
causal=True, |
|
query_chunk_size=2048, |
|
key_chunk_size=2048, |
|
dtype=jnp.float32, |
|
policy=jax.checkpoint_policies.nothing_saveable(), |
|
precision=None, |
|
float32_logits=True, |
|
prevent_cse=True, |
|
): |
|
|
|
|
|
|
|
|
|
query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
|
if float32_logits: |
|
query = query.astype(jnp.float32) |
|
key = key.astype(jnp.float32) |
|
|
|
batch, q_len, num_heads, dim_per_head = query.shape |
|
batch, kv_len, num_heads, dim_per_head = key.shape |
|
batch, kv_len, num_heads, dim_per_head = value.shape |
|
|
|
num_q = q_len // query_chunk_size |
|
num_kv = kv_len // key_chunk_size |
|
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) |
|
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) |
|
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) |
|
|
|
query = jnp.moveaxis(query, 1, 0) |
|
key = jnp.moveaxis(key, 1, 0) |
|
value = jnp.moveaxis(value, 1, 0) |
|
|
|
if bias is not None: |
|
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): |
|
assert bias_dim == 1 or bias_dim == broadcast_dim |
|
if not deterministic and attn_pdrop > 0.0: |
|
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) |
|
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) |
|
else: |
|
attn_dropout = None |
|
|
|
_chunk_bias_fn = functools.partial( |
|
_chunk_attention_bias, |
|
query_chunk_size, key_chunk_size, bias, deterministic, |
|
attn_dropout, attn_pdrop, causal, dtype) |
|
|
|
def scan_attention(args): |
|
query_chunk, query_chunk_idx = args |
|
|
|
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) |
|
def scan_kv_block(carry, args): |
|
key_chunk, value_chunk, key_chunk_idx = args |
|
(numerator, denominator, prev_max_score) = carry |
|
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) |
|
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) |
|
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) |
|
attn_weights = attn_weights + bias_chunk |
|
|
|
max_score = jnp.max(attn_weights, axis=-1, keepdims=True) |
|
max_score = jnp.maximum(prev_max_score, max_score) |
|
max_score = jax.lax.stop_gradient(max_score) |
|
exp_weights = jnp.exp(attn_weights - max_score) |
|
exp_values = jnp.einsum( |
|
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision |
|
) |
|
correction = jnp.exp(prev_max_score - max_score) |
|
numerator = numerator * correction + exp_values |
|
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) |
|
return Carry(numerator, denominator, max_score), None |
|
|
|
def skip_upper_half(carry, args): |
|
key_chunk, value_chunk, key_chunk_idx = args |
|
skip_block = jnp.array(False) |
|
if causal: |
|
skip_block = query_chunk_idx < key_chunk_idx |
|
return jax.lax.cond( |
|
skip_block, |
|
lambda carry, args: (carry, None), |
|
scan_kv_block, |
|
carry, |
|
args, |
|
) |
|
|
|
init_carry = Carry( |
|
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), |
|
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), |
|
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), |
|
) |
|
(numerator, denominator, max_score), _ = lax.scan( |
|
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) |
|
) |
|
outputs = (numerator / denominator).astype(dtype) |
|
return outputs |
|
|
|
_, res = lax.scan( |
|
lambda _, x: ((), scan_attention(x)), |
|
(), xs=(query, jnp.arange(0, num_q)) |
|
) |
|
res = rearrange(res, 'n b c h d -> b (n c) h d') |
|
return res |
|
|
|
|
|
class Carry(NamedTuple): |
|
numerator: jax.Array |
|
denominator: jax.Array |
|
max_so_far: jax.Array |
|
|
|
|
|
def _chunk_attention_bias(query_chunk_size, key_chunk_size, |
|
bias, deterministic, attn_dropout, attn_pdrop, causal, |
|
dtype, query_chunk_idx, key_chunk_idx): |
|
query_offset = query_chunk_idx * query_chunk_size |
|
key_offset = key_chunk_idx * key_chunk_size |
|
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) |
|
if bias is not None: |
|
chunk_bias = lax.dynamic_slice( |
|
bias, |
|
start_indices=(0, 0, query_offset, key_offset), |
|
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), |
|
) |
|
|
|
if causal: |
|
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) |
|
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) |
|
offset = query_offset - key_offset |
|
query_idx += offset |
|
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min |
|
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) |
|
|
|
if not deterministic and attn_pdrop > 0.0: |
|
attn_dropout_slice = lax.dynamic_slice( |
|
attn_dropout, |
|
start_indices=(0, 0, query_offset, key_offset), |
|
slice_sizes=( |
|
*attn_dropout.shape[:2], |
|
min(attn_dropout.shape[-2], query_chunk_size), |
|
min(attn_dropout.shape[-1], key_chunk_size), |
|
), |
|
) |
|
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min |
|
return chunk_bias.astype(dtype) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
def reference_attn(query, key, value, causal, dtype): |
|
query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
|
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) |
|
if causal: |
|
mask_value = jnp.finfo(logits.dtype).min |
|
_, q_seq_len, _, _ = query.shape |
|
_, kv_seq_len, _, _ = key.shape |
|
mask_shape = (q_seq_len, kv_seq_len) |
|
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) |
|
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) |
|
causal_mask = (row_ids < col_ids)[None, None, :, :] |
|
logits = logits + jnp.where(causal_mask, mask_value, 0.0) |
|
weights = jax.nn.softmax(logits, axis=-1) |
|
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) |
|
return out |
|
|
|
|
|
shape = (1, 32, 8, 64) |
|
query = jax.random.normal(jax.random.PRNGKey(0), shape) |
|
key = jax.random.normal(jax.random.PRNGKey(1), shape) |
|
value = jax.random.normal(jax.random.PRNGKey(2), shape) |
|
|
|
causal = True |
|
chunk_size = 4 |
|
policy = jax.checkpoint_policies.nothing_saveable() |
|
|
|
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) |
|
reference = reference_attn(query, key, value, causal, 'float32') |
|
|
|
assert jnp.allclose(reference, blockwise, atol=1e-6) |
|
|