license: cc-by-nc-4.0
base_model: Qwen/Qwen2-7B-Instruct
model-index:
- name: Squid
results: []
tags:
- RAG
- on-device language model
- Retrieval Augmented Generation
inference: false
space: false
spaces: false
language:
- en
Squid: Long Context as a New Modality for on-device RAG
- Nexa Model Hub - ArXiv
Overview
Squid is a novel approach to accelerate language model inference by treating long context as a new modality, similar to image, audio, and video modalities in vision-language models. This innovative method incorporates a language encoder model to encode context information into embeddings, applying multimodal model concepts to enhance the efficiency of language model inference。 Below are model highlights:
- 🧠 Context as a distinct modality
- 🗜️ Language encoder for context compression
- 🔗 Multimodal techniques applied to language processing
- ⚡ Optimized for energy efficiency and on-device use
- 📜 Specialized for long context understanding
Model Architecture
Squid employs a decoder-decoder framework with two main components:
- A smaller decoder (0.5B parameters) for transforming information from extensive contexts
- A larger decoder (7B parameters) for comprehending and generating responses to current queries
- The architecture also includes a projector to align embeddings between the text encoder and the main decoder.
Running the Model
Method 1
download this repository and run the following commands:
git lfs install
git clone https://huggingface.co/NexaAIDev/Squid
python inference_example.py
Method 2
Install nexaai-squid
package
pip install nexaai-squid
Then run the following commands:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
from squid.configuration_squid import SquidConfig
from squid.modeling_squid import SquidForCausalLM
def inference_instruct(mycontext, question, device="cuda:0"):
import time
MEMORY_SIZE = 32
start_time = time.time()
generated_token_ids = []
prompt = f" <context>{question}"
text_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<context>")]
input_ids = (
torch.tensor(
text_chunks[0] + [-1] * MEMORY_SIZE + text_chunks[1], dtype=torch.long
)
.unsqueeze(0)
.to(device)
)
context_tokenized = tokenizer(
mycontext + "".join([f"[memory_{i}]" for i in range(MEMORY_SIZE)]),
return_tensors="pt",
)
context_tokenized = {k: v.to(device) for k, v in context_tokenized.items()}
context_token_count = (context_tokenized["input_ids"]).shape[1] - MEMORY_SIZE
for i in range(context_token_count):
next_token = (
model(
input_ids,
context_input_ids=context_tokenized["input_ids"],
context_attention_mask=context_tokenized["attention_mask"],
)
.logits[:, -1]
.argmax(-1)
)
if next_token.item() == 151643:
break
generated_token_ids.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
result = tokenizer.decode(generated_token_ids)
print(f"Time taken: {time.time() - start_time}")
return result
if __name__ == "__main__":
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
AutoConfig.register("squid", SquidConfig)
AutoModelForCausalLM.register(SquidConfig, SquidForCausalLM)
tokenizer = AutoTokenizer.from_pretrained('NexaAIDev/Squid')
model = AutoModelForCausalLM.from_pretrained('NexaAIDev/Squid', trust_remote_code=True, torch_dtype=torch.bfloat16, device_map=device_name)
# Run inference example
mycontext = "Nexa AI is a Cupertino-based company founded in May 2023 that researches and develops models and tools for on-device AI applications. The company is founded by Alex and Zack. The company is known for its Octopus-series models, which rival large-scale language models in capabilities such as function-calling, multimodality, and action-planning, while remaining efficient and compact for edge device deployment. Nexa AI's mission is to advance on-device AI in collaboration with the global developer community. To this end, the company has created an on-device model hub for users to find, share, and collaborate on open-source AI models optimized for edge devices, as well as an SDK for developers to run and deploy AI models locally"
question = "Who founded Nexa AI?"
result = inference_instruct(mycontext, question, device=device_name)
print("Result:", result)
Training Process
Squid's training involves three stages:
- Restoration Training: Reconstructing original context from compressed embeddings
- Continual Training: Generating context continuations from partial compressed contexts
- Instruction Fine-tuning: Generating responses to queries given compressed contexts
This multi-stage approach progressively enhances the model's ability to handle long contexts and generate appropriate responses.
Citation
If you use Squid in your research, please cite our paper:
@article{chen2024squidlongcontextnew,
title={Squid: Long Context as a New Modality for Energy-Efficient On-Device Language Models},
author={Wei Chen and Zhiyuan Li and Shuo Xin and Yihao Wang},
year={2024},
eprint={2408.15518},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2408.15518},
}
Contact
For questions or feedback, please contact us