mrm8488's picture
Update README.md
dc6001b
|
raw
history blame
3.41 kB
metadata
license: wtfpl
datasets:
  - HuggingFaceH4/no_robots
pipeline_tag: text-generation

MAMBA (2.8B) 🐍 fine-tuned on H4/no_robots dataset for chat / instruction

Model Card is still WIP!

Base model info

Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention.

Dataset info

Look Ma, an instruction dataset that wasn't generated by GPTs!

Dataset Description

Dataset Summary

No Robots is a high-quality dataset of 10,000 instructions and demonstrations created by skilled human annotators. This data can be used for supervised fine-tuning (SFT) to make language models follow instructions better. No Robots was modelled after the instruction dataset described in OpenAI's InstructGPT paper, and is comprised mostly of single-turn instructions across the following categories:

Category Count
Generation 4560
Open QA 1240
Brainstorm 1120
Chat 850
Rewrite 660
Summarize 420
Coding 350
Classify 350
Closed QA 260
Extract 190

Usage

pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = "clibrain/mamba-2.8b-chat-no_robots"

eos_token = "<|endoftext|>"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.eos_token = eos_token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template

model = MambaLMHeadModel.from_pretrained(
        model_name, device=device, dtype=torch.float16)

messages = []
prompt = "Tell me 5 sites to visit in Spain"
messages.append(dict(role="user", content=prompt))

input_ids = tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=True
).to(device)

out = model.generate(
    input_ids=input_ids,
    max_length=2000,
    temperature=0.9,
    top_p=0.7,
    eos_token_id=tokenizer.eos_token_id,
)

decoded = tokenizer.batch_decode(out)
assistant_message = (
    decoded[0].split("<|assistant|>\n")[-1].replace(eos, "")
)

print(assistant_message)

Gradio Demo

git clone https://github.com/mrm8488/mamba-chat.git
cd mamba-chat

pip install -r requirements.txt
pip install -q gradio==4.8.0

python app.py \
--model clibrain/mamba-2.8b-chat-no_robots \
--share

Evaluations

Coming soon!

Acknowledgments

Thanks to mamba-chat for heavily inspiring our work