Join the conversation

Join the community of Machine Learners and AI enthusiasts.

Sign Up
beomiย 
posted an update Apr 18
Post
12233
๐Ÿš€ **InfiniTransformer, Gemma/Llama3 based Implementation!** ๐ŸŒŒ

> Update @ 2024.04.19: It now supports Llama-3!

> Note: this implementation is unofficial

This implementation is designed to handle virtually infinite context lengths.

Here's the github repo: https://github.com/Beomi/InfiniTransformer

๐Ÿ“„ **Read the original Paper:** https://arxiv.org/abs/2404.07143

## **Focus on Infini-Attention**

- **2 Types of Implementation available:** Attention-layer only implementation / Model & Train-wise implementation
- **Fixed(segment dependent) Memory Usage:** Enables training on larger models and longer sequences without the memory overhead typical of standard Transformer implementations.
- **Infinite Context Capability:** Train with unprecedented sequence lengthsโ€”imagine handling up to 1 million sequence lengths on standard hardware!
- You could train Gemma-2B with 1M sequence length with 2K segmentation size with single H100 GPU.

## **Try InfiniTransformer**

1. **Clone the repository:**
bash git clone https://github.com/Beomi/InfiniTransformer
2. **Install necessary tools:**
bash pip install -r requirements.txt pip install -e git+https://github.com/huggingface/transformers.git@b109257f4f#egg=transformers
3. **Dive Deep into Custom Training:**
- Train with extensive sequence lengths using scripts such as ./train.gemma.infini.noclm.1Mseq.sh.

for more detailed info, please visit Repo: https://github.com/Beomi/InfiniTransformer

Look forward to see your feedbacks! ๐Ÿ˜Š

ps. Training loss plot is here ๐Ÿ˜‰

Nice job! What are your findings so far? Can you reasonably handle the lengths that they claim?

ยท

I'm testing on it, with 32K(claimed at paper) and 1M seq len.
IMG_6332.jpeg
I'm training those models with minipile dataset and for now, it seems minimal continual training let model to adapt 'memory' could be sufficient.(less than <1B tokens)

Train is not finished yet, but after the loss converges then I could test haystack test or inference tests. it won't be take long :)

In this post