""" |
An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 |
Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 |
""" |
import functools |
from typing import NamedTuple |
import flax.linen as nn |
import jax |
import jax.lax as lax |
import jax.numpy as jnp |
from einops import rearrange |
""" |
Computing ffn blockwise without materializing the large hidden tensor, training |
4x longer sequences than the memory-efficient transformer. |
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 |
""" |
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True): |
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) |
def scan_ffn(remat_ffn, carry, hidden_states): |
outputs = remat_ffn(hidden_states, deterministic=deterministic) |
return carry, outputs |
scan_axis = inputs.ndim - 2 |
_, res = nn.scan( |
scan_ffn, |
variable_broadcast="params", |
split_rngs={"params": False, "dropout": True}, |
in_axes=scan_axis, |
out_axes=scan_axis, |
)(remat_ffn, None, inputs) |
res = rearrange(res, 'b c n d -> b (c n) d') |
return res |
""" |
Compute attention blockwise without materializing the full attention matrix, |
initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; |
flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA |
efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370 |
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x |
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. |
""" |
def blockwise_attn( |
query, key, value, |
bias=None, |
deterministic=True, |
dropout_rng=None, |
attn_pdrop=0.0, |
causal=True, |
query_chunk_size=2048, |
key_chunk_size=2048, |
dtype=jnp.float32, |
policy=jax.checkpoint_policies.nothing_saveable(), |
precision=None, |
float32_logits=True, |
prevent_cse=True, |
): |
query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
if float32_logits: |
query = query.astype(jnp.float32) |
key = key.astype(jnp.float32) |
batch, q_len, num_heads, dim_per_head = query.shape |
batch, kv_len, num_heads, dim_per_head = key.shape |
batch, kv_len, num_heads, dim_per_head = value.shape |
num_q = q_len // query_chunk_size |
num_kv = kv_len // key_chunk_size |
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) |
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) |
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) |
query = jnp.moveaxis(query, 1, 0) |
key = jnp.moveaxis(key, 1, 0) |
value = jnp.moveaxis(value, 1, 0) |
if bias is not None: |
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): |
assert bias_dim == 1 or bias_dim == broadcast_dim |
if not deterministic and attn_pdrop > 0.0: |
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) |
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) |
else: |
attn_dropout = None |
_chunk_bias_fn = functools.partial( |
_chunk_attention_bias, |
query_chunk_size, key_chunk_size, bias, deterministic, |
attn_dropout, attn_pdrop, causal, dtype) |
def scan_attention(args): |
query_chunk, query_chunk_idx = args |
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) |
def scan_kv_block(carry, args): |
key_chunk, value_chunk, key_chunk_idx = args |
(numerator, denominator, prev_max_score) = carry |
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) |
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) |
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) |
attn_weights = attn_weights + bias_chunk |
max_score = jnp.max(attn_weights, axis=-1, keepdims=True) |
max_score = jnp.maximum(prev_max_score, max_score) |
max_score = jax.lax.stop_gradient(max_score) |
exp_weights = jnp.exp(attn_weights - max_score) |
exp_values = jnp.einsum( |
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision |
) |
correction = jnp.exp(prev_max_score - max_score) |
numerator = numerator * correction + exp_values |
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) |
return Carry(numerator, denominator, max_score), None |
def skip_upper_half(carry, args): |
key_chunk, value_chunk, key_chunk_idx = args |
skip_block = jnp.array(False) |
if causal: |
skip_block = query_chunk_idx < key_chunk_idx |
return jax.lax.cond( |
skip_block, |
lambda carry, args: (carry, None), |
scan_kv_block, |
carry, |
args, |
) |
init_carry = Carry( |
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), |
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), |
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), |
) |
(numerator, denominator, max_score), _ = lax.scan( |
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) |
) |
outputs = (numerator / denominator).astype(dtype) |
return outputs |
_, res = lax.scan( |
lambda _, x: ((), scan_attention(x)), |
(), xs=(query, jnp.arange(0, num_q)) |
) |
res = rearrange(res, 'n b c h d -> b (n c) h d') |
return res |
class Carry(NamedTuple): |
numerator: jax.Array |
denominator: jax.Array |
max_so_far: jax.Array |
def _chunk_attention_bias(query_chunk_size, key_chunk_size, |
bias, deterministic, attn_dropout, attn_pdrop, causal, |
dtype, query_chunk_idx, key_chunk_idx): |
query_offset = query_chunk_idx * query_chunk_size |
key_offset = key_chunk_idx * key_chunk_size |
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) |
if bias is not None: |
chunk_bias = lax.dynamic_slice( |
bias, |
start_indices=(0, 0, query_offset, key_offset), |
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), |
) |
if causal: |
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) |
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) |
offset = query_offset - key_offset |
query_idx += offset |
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min |
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) |
if not deterministic and attn_pdrop > 0.0: |
attn_dropout_slice = lax.dynamic_slice( |
attn_dropout, |
start_indices=(0, 0, query_offset, key_offset), |
slice_sizes=( |
*attn_dropout.shape[:2], |
min(attn_dropout.shape[-2], query_chunk_size), |
min(attn_dropout.shape[-1], key_chunk_size), |
), |
) |
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min |
return chunk_bias.astype(dtype) |
if __name__ == '__main__': |
def reference_attn(query, key, value, causal, dtype): |
query = query / jnp.sqrt(query.shape[-1]).astype(dtype) |
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) |
if causal: |
mask_value = jnp.finfo(logits.dtype).min |
_, q_seq_len, _, _ = query.shape |
_, kv_seq_len, _, _ = key.shape |
mask_shape = (q_seq_len, kv_seq_len) |
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) |
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) |
causal_mask = (row_ids < col_ids)[None, None, :, :] |
logits = logits + jnp.where(causal_mask, mask_value, 0.0) |
weights = jax.nn.softmax(logits, axis=-1) |
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) |
return out |
shape = (1, 32, 8, 64) |
query = jax.random.normal(jax.random.PRNGKey(0), shape) |
key = jax.random.normal(jax.random.PRNGKey(1), shape) |
value = jax.random.normal(jax.random.PRNGKey(2), shape) |
causal = True |
chunk_size = 4 |
policy = jax.checkpoint_policies.nothing_saveable() |
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) |
reference = reference_attn(query, key, value, causal, 'float32') |
assert jnp.allclose(reference, blockwise, atol=1e-6) |