Post
1732
Good folks at
@Apple
have developed a novel method called KV Prediction that significantly reduces the "time to first token" (TTFT) for on-device LLM inference.
Some highlights of the paper:
• Uses a small auxiliary transformer model to efficiently predict the KV cache of a larger base model
• Reduces TTFT by up to 4x while retaining 60-80% accuracy on benchmarks
• Achieves Pareto-optimal efficiency-accuracy trade-off compared to baselines
• Demonstrates 15-50% relative accuracy improvements on TriviaQA at equal TTFT FLOP budgets
• Shows up to 30% accuracy gains on HumanEval code completion at fixed TTFT FLOP counts
• Validated on Apple M2 Pro CPU, proving FLOP gains translate to real-world speedups
So, how's it done?
Based on the KV Prediction method described in the paper, here are the key steps for how it's done:
1. Choose a base model and an auxiliary model:
- The base model is a larger, pretrained transformer model that will be used for final generation.
- The auxiliary model is a smaller transformer model used to efficiently process the input prompt.
2. Design the KV predictor:
- Create a set of learned linear projections to map from the auxiliary model's KV cache to the base model's KV cache.
- Define a mapping from auxiliary cache layers to base cache layers.
3. Training process:
- Pass input tokens through the auxiliary model to get its KV cache.
- Use the KV predictor to generate a predicted KV cache for the base model.
- Run the base model using the predicted KV cache and compute losses.
- Backpropagate errors through the frozen base model to update the auxiliary model and KV predictor.
4. Inference process:
- Process the input prompt with the auxiliary model to get its KV cache.
- Use the KV predictor to generate the predicted base model KV cache.
- Run a single token generation step with the base model using the predicted KV cache.
- Continue autoregressive generation with the base model as normal.
Excited to hear your thoughts!
Some highlights of the paper:
• Uses a small auxiliary transformer model to efficiently predict the KV cache of a larger base model
• Reduces TTFT by up to 4x while retaining 60-80% accuracy on benchmarks
• Achieves Pareto-optimal efficiency-accuracy trade-off compared to baselines
• Demonstrates 15-50% relative accuracy improvements on TriviaQA at equal TTFT FLOP budgets
• Shows up to 30% accuracy gains on HumanEval code completion at fixed TTFT FLOP counts
• Validated on Apple M2 Pro CPU, proving FLOP gains translate to real-world speedups
So, how's it done?
Based on the KV Prediction method described in the paper, here are the key steps for how it's done:
1. Choose a base model and an auxiliary model:
- The base model is a larger, pretrained transformer model that will be used for final generation.
- The auxiliary model is a smaller transformer model used to efficiently process the input prompt.
2. Design the KV predictor:
- Create a set of learned linear projections to map from the auxiliary model's KV cache to the base model's KV cache.
- Define a mapping from auxiliary cache layers to base cache layers.
3. Training process:
- Pass input tokens through the auxiliary model to get its KV cache.
- Use the KV predictor to generate a predicted KV cache for the base model.
- Run the base model using the predicted KV cache and compute losses.
- Backpropagate errors through the frozen base model to update the auxiliary model and KV predictor.
4. Inference process:
- Process the input prompt with the auxiliary model to get its KV cache.
- Use the KV predictor to generate the predicted base model KV cache.
- Run a single token generation step with the base model using the predicted KV cache.
- Continue autoregressive generation with the base model as normal.
Excited to hear your thoughts!