Llama-3-openhermes-reft
Llama-3-openhermes-reft is a fine-tuned version of meta-llama/Meta-Llama-3-8B on a 10K subset of teknium/OpenHermes-2.5 dataset using Representation Fine-Tuning (ReFT). The model has been trained for 1 epoch on 1x A100 using PyReFT library.
What is ReFT?
ReFT methods are drop-in replacements for weight-based PEFTs. Parameter-efficient finetuning (PEFT) methods propose a efficient and cheaper alternative to full fine-tuning by updating a small fraction of weights, while using less memory and finishing training faster. Current state-of-art PEFTs like LoRA and DoRA modify weights of model but not the representations. Representation Finetuning (ReFT) operates on a frozen base model and learn task-specific interventions on hidden representations.
PyReFT
PyReFT, a Python library made for training and sharing ReFTs.
This library is built on top of pyvene, a library for performing and training activation interventions on arbitrary PyTorch models.
- Codebase: PyReFT
- PyPI release: Link
- Any pretrained LM available on HuggingFace is supported through pyreft for finetuning with ReFT methods, and finetuned models can be easily uploaded to HuggingFace.
Inference
import torch, transformers, pyreft
device = "cuda"
model_name_or_path = "meta-llama/Meta-Llama-3-8B"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)
reft_model = pyreft.ReftModel.load(
"Syed-Hasan-8503/Llama-3-openhermes-reft", model, from_huggingface_hub=True
)
reft_model.set_device("cuda")
instruction = "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?"
prompt_no_input_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)
base_unit_location = prompt["input_ids"].shape[-1] - 1 # last position
_, reft_response = reft_model.generate(
prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
- Downloads last month
- 4