metadata
license: apache-2.0
datasets:
- CultriX/llama70B-dpo-dataset
language:
- en
base_model:
- NousResearch/Hermes-3-Llama-3.1-8B
pipeline_tag: text-generation
tags:
- dpo
- Llama3
- general
library_name: transformers
Model Card for Llama3-8B-DPO
License: Apache-2.0
Datasets: CultriX/llama70B-dpo-dataset
Language: English
Base Model: NousResearch/Hermes-3-Llama-3.1-8B
Pipeline Tag: Text-Generation
Tags: DPO, Llama3, General
Library: Transformers
Performance
Model Name | AGIEval | TruthfulQA | BigBench |
---|---|---|---|
Hermes-3-Llama-3.1-8B | 41.51 | 58.61 | 43.08 |
Llama3-8B-DPO | 41.87 | 71.38 | 44.5 |
Training Script
# Install required libraries
!pip install --upgrade pip
!pip install git+https://github.com/huggingface/transformers
!pip install git+https://github.com/huggingface/peft.git
!pip install git+https://github.com/huggingface/trl.git
!pip install --upgrade wandb accelerate datasets
import os
import gc
import torch
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig
from huggingface_hub import notebook_login
# Log in to Hugging Face and WandB
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
notebook_login()
else:
notebook_login(token=hf_token)
wb_token = os.getenv('WANDB_API_KEY')
if not wb_token:
wandb.login()
else:
wandb.login(key=wb_token)
# Set model names
model_name = "NousResearch/Hermes-3-Llama-3.1-8B"
base_model_name = model_name
fine_tuned_model_name = "OrpoLlama-3-8B"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False
# Apply LoRA for fine-tuning
peft_config = LoraConfig(
r=8, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.gradient_checkpointing_enable()
# Load and format dataset
dataset = load_dataset("CultriX/llama70B-dpo-dataset")["train"]
def chatml_format(example):
system = example.get("system", "")
question = example.get("question", "")
chosen = example.get("chosen", "")
rejected = example.get("rejected", "")
prompt = ""
if system:
prompt += f"<|im_start|>system\n{system}<|im_end|>\n"
prompt += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
return {
"prompt": prompt,
"chosen": f"{chosen}<|im_end|>\n",
"rejected": f"{rejected}<|im_end|>\n",
}
dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)
# Fine-tune the model using DPO Trainer
training_args = DPOConfig(
output_dir="model-output",
logging_steps=50,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=1e-4,
lr_scheduler_type="cosine",
num_train_epochs=4,
save_strategy="no",
optim="adamw_torch",
warmup_ratio=0.03,
bf16=True,
report_to="wandb",
beta=0.1,
max_prompt_length=2048,
max_length=4096,
disable_dropout=False,
force_use_ref_model=True,
)
trainer = DPOTrainer(
model=model,
ref_model=AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16),
args=training_args,
tokenizer=tokenizer,
train_dataset=dataset,
)
trainer.train()
# Save fine-tuned model
trainer.model.save_pretrained("final_ckpt")
tokenizer.save_pretrained("final_ckpt")
# Test the fine-tuned model
from transformers import pipeline
fine_tuned_model = AutoModelForCausalLM.from_pretrained("final_ckpt", torch_dtype=torch.bfloat16)
text_gen_pipeline = pipeline(
"text-generation",
model=fine_tuned_model,
tokenizer=tokenizer,
max_length=4096,
)
messages = [
{
"role": "system",
"content": "You are a helpful assistant chatbot that provides concise answers.",
},
{
"role": "user",
"content": "What are GPUs and why would I use them for machine learning tasks?",
},
]
prompt = "".join(f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" for msg in messages)
sequences = text_gen_pipeline(prompt, do_sample=True, temperature=0.7, top_p=0.9)
print(sequences[0]["generated_text"])