PEFT
Safetensors
English
suri-sft / README.md
chtmp223's picture
Update README.md
e172057 verified
---
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.2
license: apache-2.0
language:
- en
datasets:
- chtmp223/suri
---
# Suri-SFT
Suri-SFT is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using supervised fine-tuning with LoRA. Please check [our paper](https://arxiv.org/abs/2406.19371) for more details on the method.
## πŸ“’ Model Details
### Model Description
- **Language(s) (NLP):** English
- **License:** Apache-2.0
- **Finetuned from model:** [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
### Model Sources
- **Repository:** [Github repository](https://github.com/chtmp223/suri) -- contains code to reconstruct books3 subset.
- **Paper:** [Link](https://arxiv.org/abs/2406.19371)
- **Demo:** [Website](https://chtmp223.github.io/suri)
## ⚠️ Getting Started
Use the code in [this repository](https://github.com/chtmp223/suri) for training and inference.
## πŸ’» Training Details
### Training Data
[chtmp223/suri](https://huggingface.co/datasets/chtmp223/suri)
### Training Procedure
| **Configurations** | **Values** |
|----------------------------------|--------------|
| Hardware (Training and Inference)| 4xA100s |
| Tracking | wandb |
| lora_r | 16 |
| lora_alpha | 16 |
| lora_dropout | 0.05 |
| gradient_accumulation_steps | 1 |
| gradient_checkpointing | True |
| learning_rate | 5.0e-5 |
| lr_scheduler_type | cosine |
| max_length | 15024 |
| max_completion_length | 15000 |
| max_prompt_length | 5000 |
| num_train_epochs | 2 |
| optim | adamw_torch |
| per_device_train_batch_size | 1 |
#### Software
Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).
## πŸ€— Inference
```
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "False"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
model_name = "chtmp223/suri-sft"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model, model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
prompt = [
{
"role": "user",
"content": user_prompt,
}
]
input_context = tokenizer.apply_chat_template(
prompt, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer.encode(
input_context, return_tensors="pt", add_special_tokens=False
).to(model.device)
output = model.generate(
input_ids, max_length=10000, do_sample=True, use_cache=True
).cpu()
print(tokenizer.decode(output[0]))
```
## πŸ“œ Citation
```
@misc{pham2024surimulticonstraintinstructionfollowing,
title={Suri: Multi-constraint Instruction Following for Long-form Text Generation},
author={Chau Minh Pham and Simeng Sun and Mohit Iyyer},
year={2024},
eprint={2406.19371},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2406.19371},
}
```
### βš™οΈ Framework versions
- PEFT 0.11.1