Abstract
Language models only really need to use an exponential fraction of their neurons for individual inferences. As proof, we present FastBERT, a BERT variant that uses 0.3\% of its neurons during inference while performing on par with similar BERT models. FastBERT selectively engages just 12 out of 4095 neurons for each layer inference. This is achieved by replacing feedforward networks with fast feedforward networks (FFFs). While no truly efficient implementation currently exists to unlock the full acceleration potential of conditional neural execution, we provide high-level CPU code achieving 78x speedup over the optimized baseline feedforward implementation, and a PyTorch implementation delivering 40x speedup over the equivalent batched feedforward inference. We publish our training code, benchmarking setup, and model weights.
Community
Fantastic work.
If my understanding is correct, instead of "brute forcing" everything all at once, this computes a tree, where only the relevant branches are computed. Hence the efficiency.
How does this conditional "sparse" computation affect cross-domain generation?
Fantastic work.
If my understanding is correct, instead of "brute forcing" everything all at once, this computes a tree, where only the relevant branches are computed. Hence the efficiency.
How does this conditional "sparse" computation affect cross-domain generation?
Exactly! In short, the idea is that "whether some neurons end up being used at all depends on the result of some other neurons". The optimal way to employ conditionality is a tree structure. And, since the activation status ("activated" vs "not activated") of each neuron offers a natural dichotomy, we use binary trees.
Regarding your question: This work focuses on the BERT architecture, so only downstream NLU(-like) tasks are considered. Of course, the ultimate potential (i.e. "big money") is in generation, but sadly, we lack the resources to train FFF-powered generative models from scratch.
Naturally, we'd be willing to extend out support to anyone who wishes to train generative models with FFFs instead of the traditional feedforward layers.
If I had the money I'd be throwing it at you, whilst sailing around the mediterranean in my luxury yacht ๐ I hope you get a research grant to continue this, it's very exciting. Not sure NVIDIA will be too happy though
If I had the money I'd be throwing it at you, whilst sailing around the mediterranean in my luxury yacht ๐ I hope you get a research grant to continue this, it's very exciting. Not sure NVIDIA will be too happy though
Haha, thanks.
NVIDIA actually stands to gain a lot from this. As we explain in Section 3.2 of the paper, CMM is completely compatible with the CUDA single-instruction-multiple-threads (SIMT) approach to computation. This requires no adjustments on the hardware front (except perhaps for the caching strategies at L0/L1).
In other words, NVIDIA could be selling the same amount of silicon with much greater inference potential without any (urgent) need for innovation on the manufacturing front.
It sounds great! If I want to use it to accelerate the inference of a BERT classification model, what should I do? Thanks!
Nice work.
Should be possible to "just" fine-tune an existing model for this inference mode.
At least i'l try it.
Nice work.
Should be possible to "just" fine-tune an existing model for this inference mode.
At least i'l try it.
No, it wouldnโt. The model needs to learn to put the parts which should be selected together, together.
learn to put the parts which should be selected together, together.
Exactly as I said
Perhaps I'm missing something, but I'm getting slower inference times from this over standard bert-base-uncased
EDIT: Please see https://huggingface.co/papers/2311.10770#655e4dd5c6b7c3ab7ef6df09
This is bittersweet LOL. I've been obsesssed with FFF networks since the original paper drop, and now have a better implementation to work with. I was perplexed how nobody talking about FFF -- things shift so quickly. Obviously kudos for optimizing arguably the most expensive component in a transformer. I never attempted to go to 0.3% neurons. That was shocking. The level of sparsity in relation to performance, is very human-like to me.
It sounds great! If I want to use it to accelerate the inference of a BERT classification model, what should I do? Thanks!
You can just finetune this BERT model for your classification task. Then, depending on your target inference device, you will need to choose an implementation of CMM that will deliver a speedup.
Nice work.
Should be possible to "just" fine-tune an existing model for this inference mode.
At least i'l try it.
I'm afraid that as @someone13574 says, you'll likely need to train the whole model from scratch.
Perhaps I'm missing something, but I'm getting slower inference times from this over standard bert-base-uncased
Indeed, you seemed to have missed the Reproducibility subsection of the introduction and Section 3.3 of the paper. The code provided here serves for demonstration that only 12 neurons are needed to perform inference, per feedforward layer. This can be verified by looking into training/cramming/architectures/fff.py
and checking out the forward
function.
The performance speedups were measured using the code and setup in the separate benchmark_cpu
, benchmark_pytorch
, and benchmark_cuda
directories. Some of these implementations can be loaded into PyTorch as extensions and used, but all but the CPU ones will underperform the native fused dense matrix multiplication implementations for the traditional feedforward networks.
To explain the measurements that you are getting: when running the masking code that proves the use of only a small fraction of the neurons, you are actually computing all 4095 neurons (even though you don't need them all) and comparing the inference time to the BERT model that has 3072 neurons. One would expect 1.33
slowdown rather than speedup and you're getting 1.61
, which is close enough for a single trial without warmup + CPU tokenization.
While we do not provide an efficient PyTorch or TensorFlow
implementation of CMM, the fact that only 12 neurons are
used in the inference of UltraFastBERT can be verified
simply by masking out the output of all but the chosen
neurons, and we give the code for this.```
I see, apologies. I should have read more carefully. Was just very eager to test out the models ๐
Will update my comment.
You can just finetune this BERT model for your classification task. Then, depending on your target inference device, you will need to choose an implementation of CMM that will deliver a speedup.
@pbelcak I tried to directly use AutoModelForSequenceClassification.from_pretrained('pbelcak/UltraFastBERT-1x11-long'), but got an error:
TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
* (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
Maybe I should manually add a classification layer behind it?
This is an automated message from the Librarian Bot. I found the following papers similar to this paper.
The following papers were recommended by the Semantic Scholar API
- ReLU Strikes Back: Exploiting Activation Sparsity in Large Language Models (2023)
- Towards End-to-end 4-Bit Inference on Generative Large Language Models (2023)
- Approximating Two-Layer Feedforward Networks for Efficient Transformers (2023)
- Sparse Fine-tuning for Inference Acceleration of Large Language Models (2023)
- Efficient LLM Inference on CPUs (2023)
Please give a thumbs up to this comment if you found it helpful!
If you want recommendations for any Paper on Hugging Face checkout this Space
I can't see while quickly skimming the paper, do they just essentially compute attention as normal but utilize the FFN sparsely? If so, not sure how that would be an exponential speedup as Attention+LayerNorm+Skip should be quite a bit of the computation
I can't see while quickly skimming the paper, do they just essentially compute attention as normal but utilize the FFN sparsely? If so, not sure how that would be an exponential speedup as Attention+LayerNorm+Skip should be quite a bit of the computation
tl;dr -- Chat GPT4 thinks you're right -- but there may be room for more improvement with Flash Attention, and Flash Decoding...
Edit:
@timothelaborie
's comment below is also a great idea.
Long version:
I uploaded the paper to Chat GPT4 and asked it... it said:
In the paper "Exponentially Faster Language Modeling," the authors focus on enhancing the efficiency of the feedforward networks within the BERT architecture, but they do not modify the attention mechanism. [...] The paper explicitly states that they leave the attention layers untouched and focus solely on the intermediate layers hosting the feedforward networks. This indicates that the attention computation in UltraFastBERT remains the same as in traditional BERT models.
I asked it if it would still provide an improvement to LLMs, and it said:
While the full attention computation in [LLMs] remains a significant computational task, the advancements in the efficiency of feedforward networks contribute to a realistic expectation of reduced overall computational demands for LLMs. This represents a meaningful stride towards more efficient and sustainable AI models.
I also asked it questions about the quadratic complexity -- and the scaling... I hypothesized that with FFFs you could run a 70B parameter model as if it was a 2.1B parameter model, if the context was limited to a small size (but a larger context would require more memory and compute power). -- It responded with like a whole page of observations about which parts were correct and incorrect. Notably:
FFFs primarily optimize the feedforward layers of a transformer model, like BERT or its variants, reducing the computational load during inference. While this could theoretically allow a larger model to operate with reduced computational costs, it doesn't necessarily equate to running a 70B model with the cost of a 2B model. The savings are significant, but they may not scale linearly or to such an extent, especially considering other components of the model like the attention mechanism.
The hypothesis suggests a trade-off between context size and computational resources (memory and processing speed). This is accurate in the sense that more context requires more memory and increases processing time due to the attention mechanism. However, the reduction in computational costs from FFFs may not scale to the extent of running a 70B model with the [computational] cost of a 2B model.
The processing speed for transformers is affected by the quadratic complexity of the self-attention mechanism. As the context size increases, the number of calculations required for the attention mechanism grows quadratically. This means that even with FFFs, a larger context size would still significantly increase the computational load due to the attention calculations, impacting processing speed.
Very-very long version:
(Semi-)Finally, I asked it to spitball some numbers for how many floating point operations (FLOPs) are required for inferencing on a 70B parameter model (assuming 2K context size, and 20 tokens in, and 20 tokens out, with 16 self attention heads) -- it came back with: 3.15 trillion FLOPs.
I asked it to break that down between attention and FF calculations, it said this:
Attention Operations | Feedforward Network (Neuron Activations) |
---|---|
3.145 trillion FLOPs | 2.8 billion FLOPs |
I asked it to hypothesize what the new FLOPs would be like with FFF, and it said (assuming a straight 99.7% reduction):
Attention Operations | Feedforward Network (Neuron Activations) |
---|---|
3.145 trillion FLOPs | 8.4 million FLOPs |
I also uploaded the paper for Flash Attention, and linked to pytorch's article on Flash Decoding. It claims that all three could be potentially applied to the same model as none of them overlap in what they interact with. -- I specifically asked if the beam search optimization from flash decoding would impact FFFs and it said no. (Is it wrong?)
Since these two don't reduce the overall FLOPs, it's hard to estimate the actual theoretical improvement here... it did a kind of hand wavy thing and shrugged out a answer that said total inferencing speed would be between 2x and 8x faster, and cautioned me to take that number with a grain of salt (indicating that it's confidence in the answer isn't any higher than mine is, or yours should be either...)
For giggles, I asked it what it's confidence was in these numbers, and it said between 20% and 30% for the Flash Attention and Decoding stuff (citing silly things like simplified calculations and a complete lack of empirical data). -- And 50% - 60% for the FLOPs stuff in the table above.
Btw, sorry if this was terribly off topic -- it was just a fun learning experience... (for the parts that were actually informational and not just guesses) -- Also, good call out bob12345. :)
One interesting thing that could be tried with this is putting the attention weights and KV cache on the GPU, while having the FFF weights on the CPU's RAM. The idea is that since the FFF runs very quickly on the CPU, there is no point in wasting the very limited VRAM of current GPUs on it. This setup would make it possible to run very large models on a reasonably priced PC with 128 GB RAM and a 3090.
I have an idea for how this method could be used without having to pre-train from scratch:
- Take base Llama 2 70B, give it 10k+ diverse documents, record every input and output vector from its FFNs
- For each FFN in the original model, train a slightly larger FFF from scratch using the recorded vectors (this causes the FFFs to approximate the original FFNs)
- Replace the FFNs in the original model with the new FFFs and train the new model on something like open-orca
This should be much cheaper than pre-training from scratch. Not sure how much quality will be lost, but might be good enough for testing purposes.
@timothelaborie
this could work; this kind of per-layer distillation with MSE loss does indeed sometimes work if there's enough data for distillation and when sufficiently large dropout p
was used during training. That being said, in my experience, it will never be as good as training from scratch. The best I ever got when trying to distill BERT-base FF layers into FFFs was a relative per-dimension error of about 20% -- not terrible, but also not great.
@pbelcak , @timothelaborie -- Assuming you could do this to a larger LLM and get that 20% relative per-dimension error... could you then finetune that model on it's original training data (as if it was always an FFF network) to reduce that error ratio?
Well, one could try, but there seems to be an inherent limit on how well FFFs can approximate functions produced by FFs of the same size. Ultimately, they might be equivalent in representational power (or richer), but the functions they might have to approximate may have more complex surface than the one that can be constructed by the tree organization of neurons.
Could the idea behind the FFF also be used to speed up the attention mechanism?
For example, you could create a model where each layer has 64 attention heads separated into 16 groups of size 4, and then use a decision tree to select the group of attention heads that gets used. This would also reduce the size of the KV cache.
One potential issue is that the queries produced by the tokens will be compared with keys created by many different groups of attention heads, which could cause stability issues. If that ends up being a problem then perhaps the attention mechanism would need to be tweaked a bit first.
Yep, some variants of this idea could indeed perhaps work well. I was looking into https://arxiv.org/pdf/2311.10642.pdf?ref=nural.cc ALR and ALSR as options for attention replacement, but these do not seem to address the elephant in the room -- the growing context size.
Does FFF work on GPU? Or is this only applicable to CPU?
https://arxiv.org/pdf/2312.07987.pdf
This paper shows that MOE can be applied to the value and output matrices. Might be worth checking if the attention MOE performs better using a decision tree instead, like in the FFF
Models citing this paper 1
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper