Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
Abstract
Transformer-based language models spread FLOPs uniformly across input sequences. In this work we demonstrate that transformers can instead learn to dynamically allocate FLOPs (or compute) to specific positions in a sequence, optimising the allocation along the sequence for different layers across the model depth. Our method enforces a total compute budget by capping the number of tokens (k) that can participate in the self-attention and MLP computations at a given layer. The tokens to be processed are determined by the network using a top-k routing mechanism. Since k is defined a priori, this simple procedure uses a static computation graph with known tensor sizes, unlike other conditional computation techniques. Nevertheless, since the identities of the k tokens are fluid, this method can expend FLOPs non-uniformly across the time and model depth dimensions. Thus, compute expenditure is entirely predictable in sum total, but dynamic and context-sensitive at the token-level. Not only do models trained in this way learn to dynamically allocate compute, they do so efficiently. These models match baseline performance for equivalent FLOPS and wall-clock times to train, but require a fraction of the FLOPs per forward pass, and can be upwards of 50\% faster to step during post-training sampling.
Community
MoD TL;DR;
- Reduces the number of tokens (k) that can participate in the self-attention and MLP computations at a given layer.
- Adds a top-k Router that learns which tokens should be processed at each layer
Results:
- MoD matches baseline performance with 66% faster training up to 1B scale
- Can be combined with MoEs into MoDE
- Inference improvements might be limited to batch_size=1
Do you expect MoDs to scale up to >=13B parameters?
Am I correct that all these Mo* approaches are not for the GPU poor (MoE on a single GPU doesn't give much benefit, MoD on a 13B parameter model would strongly reduce latency, but you still have a hard time executing a 13B parameter model on a single smaller GPU)?
Awesome work! Here is a blog that was also nice to skim/read: https://huggingface.co/blog/joey00072/mixture-of-depth-is-vibe
Awesome work! Here is a blog that was also nice to skim/read: https://huggingface.co/blog/joey00072/mixture-of-depth-is-vibe
Here were some interesting insights from the blog:
- With mod, you can't do batch inference since each token, since each token in a different batch can get routed around the block. You can if you use a mask but at that point, it's the same as having inference on a normal model with the overhead of a router.
- Putting the whole seq will not have a lot of speed up, the same problem above some token will go through blocks some will not, and at inference time we don't want fixed capacity routing.
- Existing speedup techniques of speculative decoding will not work or not be useful as they are in normal models.
Unlocking Efficiency in Transformers: The Mixture-of-Depths Approach
Links π:
π Subscribe: https://www.youtube.com/@Arxflix
π Twitter: https://x.com/arxflix
π LMNT (Partner): https://lmnt.com/
Models citing this paper 2
Datasets citing this paper 0
No dataset linking this paper