Safetensors
English
gptj
sauc-abadal-lloret's picture
Update README.md
cd943c7 verified
metadata
license: mit
datasets:
  - CarperAI/openai_summarize_tldr
language:
  - en
base_model:
  - EleutherAI/gpt-j-6b
  - CarperAI/openai_summarize_tldr_sft

ALT-RM model (reward model-based feedback)

Fine-tuned GPT-J (6B) model on the TL;DR Summarization dataset to be better aligned with humans' preferences on summaries, i.e., accounting for axes such as accuracy, coverage, and coherence, following the alignment approach introduced in the ALT paper. This corresponds to the official model checkpoint and the code can be found in here.

Model description

The alignment process departs from a SFT checkpoint released by CarperAI and trained using their trlx library.

In a nutshell, the ALT method consists on providing textual feedback to on-policy sampled generations to learn the conditional probability distribution of a generation given both the prompt and the feedback. This logic is implemented in a three-stage decoupled pipeline, namely sampling, feedback, and training, where training is based on a language modelling objective by preppending the feedback tokens before the prompt. In this way, the model learns to discriminate between different generations associated with various feedback types: it learns from both positive and negative examples that encompass the entire feedback spectrum, overcoming one of the main limitations of supervised fine-tuning, which typically learns only from positive demonstrations.

For extensive coverage on the ALT method, please refer to the paper.

In particular, the ALT-RM checkpoint collects the feedback by leveraging a Reward Model to score the generations, and then maps reward quantiles computed for several generations under the same prompt to pre-defined textual feedbacks. For the summarization task on the TL;DR dataset, the mapping from quantiles to feedback employed was:

{'QUANTILE 0': 'Excellent.',
 'QUANTILE 1': 'Good.',
 'QUANTILE 2': 'Mediocre.',
 'QUANTILE 3': 'Bad.',
 'QUANTILE 4': 'Horrible.'}

Thus, at inference time, the expected aligned behavior can be attained by conditioning the input with the Excellent. feedback.

Related Models: ALT-Quark.

Intended uses & limitations

This model originates from a research project focused on alignment and is intended primarily for research purposes. Commercial use as an off-the-shelf model is discouraged, as it was not designed with such applications in mind. The model is tailored specifically for the summarization task, having been trained on the TL;DR dataset, though some out-of-distribution generalization may be possible for related datasets.

How to use

You should format the input by preppending the feedback as follows: Excellent. input: {prompt}

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

checkpoint_path = "sauc-abadal-lloret/gpt-j-6b-ALT-RM-tldr"

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) 
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
model.eval()

prompt = "Excellent. input: SUBREDDIT: r/relationship_advice\nTITLE: I'm [18M] going to a party where an old middle \
school crush [17F] is also going.\nPOST: Story time! Back in the summer after 8th grade, I hung out with my group of \
friends everyday for the whole summer. There was this girl in the group and I really liked her. Like I had the biggest \
and dumbest crush on her. I was only 13 so I didn't know shit, but I was thinking she's perfect for me, I gotta marry \
her and all this dumb stuff. The puppy love was so strong I wanted to be a part of her life and I wanted her to be a \
part of my life. I never had the courage to ask her out, and we went to different high schools. Eventually we stopped \
talking but during high school I never really liked anyone else. Every other girl felt dull compared to her. I still \
get nostalgic thinking about her and what would've been different if I had the balls to ask her out. Anyway I'm going \
to a party this Friday and I heard she's coming. I honestly don't know what to do to so this goes great and eventually \
ends up in a relationship.\nTL;DR:"

inputs = tokenizer([prompt], padding=True, truncation=True, return_tensors="pt")
input_seq_len = inputs["input_ids"].shape[1]

generation_config = GenerationConfig(
    max_length = 2048,
    max_new_tokens = 64,
    do_sample = False,
    num_beams = 1,
    bad_words_ids = None,
    num_return_sequences = 1,   
    return_dict_in_generate = True,
    pad_token_id = tokenizer.pad_token_id,
)

outputs = model.generate(**inputs, generation_config=generation_config)
generated_input_ids = outputs["sequences"][:, input_seq_len:]
generated_text = tokenizer.batch_decode(
    generated_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
generated_text
[" I have a huge crush on a girl who I never asked out and we went to different high schools. I'm going to a party this Friday and I heard she's coming. I honestly don't know what to do to so this goes great and eventually ends up in a relationship."]

Training data

The model was trained on the TL;DR summarization dataset introduced in the Stiennon et al.'s, "Learning to Summarize from human feedback" paper. We employed the dataset version from CarperAI, which can be found in the HuggingFace Hub in here.

Training procedure

The exact training procedure and hyper-parameters configuration can be found in our paper.

Variable and metrics

As an evaluation metric, we compute GPT-4 win-rates over PPO on a 1k random subset of the test set. We use the prompt provided in the DPO paper and we ask GPT-4 to compare generations between ALT-RM and Quark and PPO. Furthermore, we report the following metrics computed on the whole test set: average reward model score, perplexity measured by the SFT reference policy as a proxy for fluency, and average length of the generations. In addition, we conduct an out-of-domain evaluation and compute GPT-4 win-rates on 100 articles from the test split of the CNN/DailyMail dataset.

Model TL;DR (In-domain) CNN/DailyMail (Out-of-domain)
Quark vs PPO 0.36 0.40
ALT-RM vs PPO 0.50 0.48

Win-rates with GPT-4. TL;DR on 1000 randomly chosen test prompts and CNN/daily mail on 100 randomly chosen test prompts.

Model RM PPL Avg. len # Train
SFT 2.89 1.96 31.25 -
Refrences 2.89 11.84 32.60 -
PPO 3.38 2.29 67.52 116k
Quark 3.52 1.82 49.42 19k
ALT-RM 3.58 2.20 46.14 19k

TL;DR metrics on the whole test set, including avg. reward model score, perplexity, avg. generations’ length, and number of training prompts.

BibTeX entry and citation info

@misc{lloret2024aligninglanguagemodelstextual,
      title={Towards Aligning Language Models with Textual Feedback}, 
      author={Saüc Abadal Lloret and Shehzaad Dhuliawala and Keerthiram Murugesan and Mrinmaya Sachan},
      year={2024},
      eprint={2407.16970},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2407.16970}, 
}