Papers
arxiv:2404.03592

ReFT: Representation Finetuning for Language Models

Published on Apr 4
· Submitted by akhaliq on Apr 5
#1 Paper of the day
Authors:
,
,

Abstract

Parameter-efficient fine-tuning (PEFT) methods seek to adapt large models via updates to a small number of weights. However, much prior interpretability work has shown that representations encode rich semantic information, suggesting that editing representations might be a more powerful alternative. Here, we pursue this hypothesis by developing a family of Representation Finetuning (ReFT) methods. ReFT methods operate on a frozen base model and learn task-specific interventions on hidden representations. We define a strong instance of the ReFT family, Low-rank Linear Subspace ReFT (LoReFT). LoReFT is a drop-in replacement for existing PEFTs and learns interventions that are 10x-50x more parameter-efficient than prior state-of-the-art PEFTs. We showcase LoReFT on eight commonsense reasoning tasks, four arithmetic reasoning tasks, Alpaca-Eval v1.0, and GLUE. In all these evaluations, LoReFT delivers the best balance of efficiency and performance, and almost always outperforms state-of-the-art PEFTs. We release a generic ReFT training library publicly at https://github.com/stanfordnlp/pyreft.

Community

As a developer, the key takeaway for me is: 7% greater accuracy, 27x fewer parameters, 18 minutes to train a 7b that can compete with GPT3.5. That's a nice step up in performance and efficiency.

But I'm totally confused as to how this actually works. Here's what I think I understand.

  1. During finetuning, we "simply" modify the inference output via a type of mask/filter.
  2. This mask is a type of contextualised embedding.
  3. During inference we simply pass the original output through our new learned mask

Am I right in my (abstract) understanding?

If so, there should be zero parameters modified... Hence my confusion

·
Paper author
edited Apr 7

Thanks for your comments!

PEFTs update model subcomponents (e.g., layer weight diff), or new components (e.g., Adaptors), or some embeddings (e.g., Prefix embeddings).

So what ReFT does is to train interventions that intervene on the representations in the following steps:

  1. collect representations using hooks (as callback functions).
  2. learn a transformation function f, that applies to those representations.
  3. put them back into the computation graph.

The learnable parameters are in the function f. We provide one way to parameterize f in the paper, which we call it LoReFT. But, you can design your own transformation function.

This paper looks really promising, but I am having a hard time understanding what "representation" means. I didn't find a definition in the paper either.

What comes closest to a definition is at the beginning of chapter 3, where representation seems to be a synonym for embedding (input tokens x_1,...,x_n are translated to representations h_1,...,h_n).

This is in line with MichaelBarryUK's comment. The author (zhengxuanzenwu) replies with saying prepresentations are

  1. model subcomponents or
  2. new components or
  3. Prefix-tune.

Unfortunately, I don't understand 1, 2, and 3 either.

·
Paper author
edited Apr 7

Hey! Thanks for the question.

Yes, intervening representations are h_1,...,h_n. These are residual streams or block output at each layer at each token position.

I also updated my previous answer trying to be clearer. Let me know if this makes sense tho.

@zhengxuanzenwu This is a very interesting paper! Great work! In the results section, it looks like a lot of methods outperform LoReFT on the Arithmetic reasoning task. Do you have a hypothesis for why this is? My initial thought is that maybe it's not as simple a task to capture mathematical representations in the output of attention layers?
EDIT: In the paper, you say that the length of the generations might have something to do with this. Did you run any tests to see if the effectiveness of the method reduces as the length of generation increases?

·

@shamikbose89 Thanks for your interest! Yes, LoReFT underperforms for arithmetic reasoning tasks, especially for GSM8K. In short, we don't know how to fix it yet. But here are a couple of hypotheses:

  • Hyperparameter selection is not optimal. Although we tried hyperparameter tuning, our grid search is still pretty limited. We also haven't tried layerwise intervention weights sharing, etc.
  • Intervening on decoding steps might help. Currently, we only intervene on the prompt. It is surprising that this is sufficient for the other two tasks with LLaMA models. For math reasonings which require CoT generations, intervening on decoding steps might help with long-form reasonings.
  • More complex parameterization of the intervention. LoReFT is just one way of defining the intervention function. Coming up with more complex interventions could help.

Offline, we also tried to train and test on GSM8K only (the GSM8K dataset is also cleaner without GPT4 generated CoTs). LoReFT with Llama-2 still underperforms compared to LoRA + Llama-2 7B slightly (approximately 32% vs. 35%). However, LoReFT definitely has much fewer trainable parameters. See LoftQ paper for Llama-2 performance on GSM8K.

Re EDIT: Here, the generation length is shorter, since the golden label from GSM8K is shorter than GPT4 generated CoTs. Yea, could be interesting to look into this, since 32% vs. 35% is a smaller gap.

p.s.. If you want to improve the math reasoning ability of ReFTs, I would also recommend to use GSM8K setup in the LoftQ paper. It is just much cleaner than the LLM-Adapter setup. We were doing this for the sake of benchmarking only.

@zhengxuanzenwu I am struggling to understand one aspect. I get the LoReFT equation but after that I dont understand why you would apply it selectively to various positions. Why not apply it to all positions and make it a layer-based operation?

Also have you done any tests to compare how this impacts latency?

Awesome work, and thanks for sharing!!

·
Paper author

@derek-thomas thanks for your interests! We tried hyperparameter-tuning on whether we share LoReFT weights across all positions for the prompt tokens. It seems like once we go up to certain threshold (and this is task dependent), performance does not increase. Intuitively, this may suggest that, editing every residual streams in the same way is not ideal (I.e., each position store information differently). On the other hand, if we don’t share weights across positions, the parameter count for LoReFT is going to be higher.

We did preliminary latency analysis in Appendix H.

One thing to note that is, the fact that we only intervene on the prompt token (i.e., only intervene on the KV cache) makes ReFT paradigm efficient. This is different from adaptors, where all decoding step has overhead. This is also different from LoRA weight merging before serving, since ReFT allows dynamic task-based interventions in batch.

Sign up or log in to comment

Models citing this paper 1

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2404.03592 in a dataset README.md to link it from this page.

Spaces citing this paper 4

Collections including this paper 27