|
--- |
|
library_name: peft |
|
base_model: mistralai/Mistral-7B-Instruct-v0.2 |
|
license: apache-2.0 |
|
language: |
|
- en |
|
datasets: |
|
- chtmp223/suri |
|
--- |
|
|
|
# Suri-SFT |
|
Suri-SFT is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using supervised fine-tuning with LoRA. Please check [our paper](https://arxiv.org/abs/2406.19371) for more details on the method. |
|
|
|
## π Model Details |
|
|
|
### Model Description |
|
|
|
- **Language(s) (NLP):** English |
|
- **License:** Apache-2.0 |
|
- **Finetuned from model:** [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) |
|
|
|
### Model Sources |
|
|
|
- **Repository:** [Github repository](https://github.com/chtmp223/suri) -- contains code to reconstruct books3 subset. |
|
- **Paper:** [Link](https://arxiv.org/abs/2406.19371) |
|
- **Demo:** [Website](https://chtmp223.github.io/suri) |
|
|
|
## β οΈ Getting Started |
|
|
|
Use the code in [this repository](https://github.com/chtmp223/suri) for training and inference. |
|
|
|
|
|
## π» Training Details |
|
|
|
### Training Data |
|
|
|
[chtmp223/suri](https://huggingface.co/datasets/chtmp223/suri) |
|
|
|
### Training Procedure |
|
|
|
| **Configurations** | **Values** | |
|
|----------------------------------|--------------| |
|
| Hardware (Training and Inference)| 4xA100s | |
|
| Tracking | wandb | |
|
| lora_r | 16 | |
|
| lora_alpha | 16 | |
|
| lora_dropout | 0.05 | |
|
| gradient_accumulation_steps | 1 | |
|
| gradient_checkpointing | True | |
|
| learning_rate | 5.0e-5 | |
|
| lr_scheduler_type | cosine | |
|
| max_length | 15024 | |
|
| max_completion_length | 15000 | |
|
| max_prompt_length | 5000 | |
|
| num_train_epochs | 2 | |
|
| optim | adamw_torch | |
|
| per_device_train_batch_size | 1 | |
|
|
|
|
|
#### Software |
|
|
|
Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl). |
|
|
|
## π€ Inference |
|
``` |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel, PeftConfig |
|
from datasets import load_dataset |
|
import torch |
|
os.environ["TOKENIZERS_PARALLELISM"] = "False" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.cuda.empty_cache() |
|
|
|
model_name = "chtmp223/suri-sft" |
|
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2" |
|
config = PeftConfig.from_pretrained(model_name) |
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device) |
|
model = PeftModel.from_pretrained(base_model, model_name).to(device) |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
prompt = [ |
|
{ |
|
"role": "user", |
|
"content": user_prompt, |
|
} |
|
] |
|
input_context = tokenizer.apply_chat_template( |
|
prompt, add_generation_prompt=True, tokenize=False |
|
) |
|
input_ids = tokenizer.encode( |
|
input_context, return_tensors="pt", add_special_tokens=False |
|
).to(model.device) |
|
output = model.generate( |
|
input_ids, max_length=10000, do_sample=True, use_cache=True |
|
).cpu() |
|
|
|
print(tokenizer.decode(output[0])) |
|
``` |
|
|
|
|
|
## π Citation |
|
|
|
``` |
|
@misc{pham2024surimulticonstraintinstructionfollowing, |
|
title={Suri: Multi-constraint Instruction Following for Long-form Text Generation}, |
|
author={Chau Minh Pham and Simeng Sun and Mohit Iyyer}, |
|
year={2024}, |
|
eprint={2406.19371}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CL}, |
|
url={https://arxiv.org/abs/2406.19371}, |
|
} |
|
``` |
|
|
|
### βοΈ Framework versions |
|
|
|
- PEFT 0.11.1 |