sauc-abadal-lloret
commited on
Commit
•
e636989
1
Parent(s):
91e9307
Update README.md
Browse files
README.md
CHANGED
@@ -7,4 +7,119 @@ language:
|
|
7 |
base_model:
|
8 |
- EleutherAI/gpt-j-6b
|
9 |
- CarperAI/openai_summarize_tldr_sft
|
10 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
base_model:
|
8 |
- EleutherAI/gpt-j-6b
|
9 |
- CarperAI/openai_summarize_tldr_sft
|
10 |
+
---
|
11 |
+
# ALT-RM model (reward model-based feedback)
|
12 |
+
Fine-tuned **GPT-J (6B)** model on the **TL;DR Summarization** dataset to be better aligned with humans' preferences on summaries, i.e., accounting for axes such as accuracy, coverage, and coherence, following the alignment approach introduced in the [ALT paper](https://www.arxiv.org/abs/2407.16970). This corresponds to the official model checkpoint and the code can be found in [here](https://github.com/sauc-abadal/ALT/tree/main).
|
13 |
+
|
14 |
+
# Model description
|
15 |
+
The alignment process departs from a [SFT checkpoint](https://huggingface.co/CarperAI/openai_summarize_tldr_sft) released by CarperAI and trained using their [trlx](https://github.com/CarperAI/trlx/tree/main/examples/summarize_rlhf) library.
|
16 |
+
|
17 |
+
In a nutshell, the ALT method consists on providing textual feedback to on-policy sampled generations to learn the conditional probability distribution of a generation given both the prompt and the feedback. This logic is implemented in a three-stage decoupled pipeline, namely *sampling*, *feedback*, and *training*, where training is based on a language modelling objective by preppending the feedback tokens before the prompt.
|
18 |
+
In this way, the model learns to discriminate between different generations associated with various feedback types: it learns from both positive and negative examples that encompass the entire feedback spectrum, overcoming one of the main limitations of supervised fine-tuning, which typically learns only from positive demonstrations.
|
19 |
+
|
20 |
+
For extensive coverage on the ALT method, please refer to the paper.
|
21 |
+
|
22 |
+
In particular, the **ALT-RM** checkpoint collects the feedback by leveraging a [Reward Model](https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint) to score the generations, and then maps reward quantiles computed for several generations under the same prompt to pre-defined textual feedbacks. For the summarization task on the TL;DR dataset, the mapping from quantiles to feedback employed was:
|
23 |
+
```python
|
24 |
+
{'QUANTILE 0': 'Excellent.',
|
25 |
+
'QUANTILE 1': 'Good.',
|
26 |
+
'QUANTILE 2': 'Mediocre.',
|
27 |
+
'QUANTILE 3': 'Bad.',
|
28 |
+
'QUANTILE 4': 'Horrible.'}
|
29 |
+
```
|
30 |
+
Thus, at inference time, the expected aligned behavior can be attained by conditioning the input with the *'Excellent.'* feedback.
|
31 |
+
|
32 |
+
**Related Models:** [ALT-Quark]().
|
33 |
+
|
34 |
+
# Intended uses & limitations
|
35 |
+
This model originates from a research project focused on alignment and is intended primarily for research purposes. Commercial use as an off-the-shelf model is discouraged, as it was not designed with such applications in mind. The model is tailored specifically for the summarization task, having been trained on the TL;DR dataset, though some out-of-distribution generalization may be possible for related datasets.
|
36 |
+
|
37 |
+
# How to use
|
38 |
+
|
39 |
+
You should format the input by preppending the feedback as follows: `Excellent. input: {prompt}`
|
40 |
+
```python
|
41 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
42 |
+
|
43 |
+
checkpoint_path = "sauc-abadal-lloret/gpt-j-6b-ALT-RM-tldr"
|
44 |
+
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
46 |
+
tokenizer.pad_token = tokenizer.eos_token
|
47 |
+
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
prompt = "Excellent. input: SUBREDDIT: r/relationship_advice\nTITLE: I'm [18M] going to a party where an old middle \
|
52 |
+
school crush [17F] is also going.\nPOST: Story time! Back in the summer after 8th grade, I hung out with my group of \
|
53 |
+
friends everyday for the whole summer. There was this girl in the group and I really liked her. Like I had the biggest \
|
54 |
+
and dumbest crush on her. I was only 13 so I didn't know shit, but I was thinking she's perfect for me, I gotta marry \
|
55 |
+
her and all this dumb stuff. The puppy love was so strong I wanted to be a part of her life and I wanted her to be a \
|
56 |
+
part of my life. I never had the courage to ask her out, and we went to different high schools. Eventually we stopped \
|
57 |
+
talking but during high school I never really liked anyone else. Every other girl felt dull compared to her. I still \
|
58 |
+
get nostalgic thinking about her and what would've been different if I had the balls to ask her out. Anyway I'm going \
|
59 |
+
to a party this Friday and I heard she's coming. I honestly don't know what to do to so this goes great and eventually \
|
60 |
+
ends up in a relationship.\nTL;DR:"
|
61 |
+
|
62 |
+
inputs = tokenizer([prompt], padding=True, truncation=True, return_tensors="pt")
|
63 |
+
input_seq_len = inputs["input_ids"].shape[1]
|
64 |
+
|
65 |
+
generation_config = GenerationConfig(
|
66 |
+
max_length = 2048,
|
67 |
+
max_new_tokens = 64,
|
68 |
+
do_sample = False,
|
69 |
+
num_beams = 1,
|
70 |
+
bad_words_ids = None,
|
71 |
+
num_return_sequences = 1,
|
72 |
+
return_dict_in_generate = True,
|
73 |
+
pad_token_id = tokenizer.pad_token_id,
|
74 |
+
)
|
75 |
+
|
76 |
+
outputs = model.generate(**inputs, generation_config=generation_config)
|
77 |
+
generated_input_ids = outputs["sequences"][:, input_seq_len:]
|
78 |
+
generated_text = tokenizer.batch_decode(
|
79 |
+
generated_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
80 |
+
)
|
81 |
+
generated_text
|
82 |
+
```
|
83 |
+
|
84 |
+
```
|
85 |
+
[" I have a huge crush on a girl who I never asked out and we went to different high schools. I'm going to a party this Friday and I heard she's coming. I honestly don't know what to do to so this goes great and eventually ends up in a relationship."]
|
86 |
+
```
|
87 |
+
|
88 |
+
## Training data
|
89 |
+
The model was trained on the TL;DR summarization dataset introduced in the Stiennon et al.'s, ["Learning to Summarize from human feedback"](https://arxiv.org/abs/2009.01325) paper. We employed the dataset version from CarperAI, which can be found in the HuggingFace Hub in [here](CarperAI/openai_summarize_tldr).
|
90 |
+
|
91 |
+
## Training procedure
|
92 |
+
The exact training procedure and hyper-parameters configuration can be found in our paper.
|
93 |
+
|
94 |
+
## Variable and metrics
|
95 |
+
As an evaluation metric, we compute GPT-4 win-rates over PPO on a 1k random subset of the test set. We use the prompt provided in the DPO paper and we ask GPT-4 to compare generations between ALT-RM and Quark and PPO. Furthermore, we report the following metrics computed on the whole test set: average reward model score, perplexity measured by the SFT reference policy as a proxy for fluency, and average length of the generations. In addition, we conduct an out-of-domain evaluation and compute GPT-4 win-rates on 100 articles from the test split of the CNN/DailyMail dataset.
|
96 |
+
|
97 |
+
| **Model** | **TL;DR** (In-domain) | **CNN/DailyMail** (Out-of-domain) |
|
98 |
+
|:---------------:|:---------------------:|:----------------------------------:|
|
99 |
+
| Quark vs PPO | 0.36 | 0.40 |
|
100 |
+
| ALT-RM vs PPO | 0.50 | 0.48 |
|
101 |
+
|
102 |
+
*Win-rates with GPT-4. TL;DR on 1000 randomly chosen test prompts and CNN/daily mail on 100 randomly chosen test prompts.*
|
103 |
+
|
104 |
+
| **Model** | **RM** | **PPL** | **Avg. len** | **# Train** |
|
105 |
+
|:---------------:|:---------------------:|:----------------------------------:|:----------------------------------:|:----------------------------------:|
|
106 |
+
| SFT | 2.89 | 1.96 | 31.25 | - |
|
107 |
+
| Refrences | 2.89 | 11.84 | 32.60 | - |
|
108 |
+
| PPO | 3.38 | 2.29 | 67.52 | 116k |
|
109 |
+
| Quark | 3.52 | 1.82 | 49.42 | 19k |
|
110 |
+
| ALT-RM | 3.58 | 2.20 | 46.14 | 19k |
|
111 |
+
|
112 |
+
*TL;DR metrics on the whole test set, including avg. reward model score, perplexity, avg. generations’ length, and number of training prompts.*
|
113 |
+
|
114 |
+
## BibTeX entry and citation info
|
115 |
+
```
|
116 |
+
@misc{lloret2024aligninglanguagemodelstextual,
|
117 |
+
title={Towards Aligning Language Models with Textual Feedback},
|
118 |
+
author={Saüc Abadal Lloret and Shehzaad Dhuliawala and Keerthiram Murugesan and Mrinmaya Sachan},
|
119 |
+
year={2024},
|
120 |
+
eprint={2407.16970},
|
121 |
+
archivePrefix={arXiv},
|
122 |
+
primaryClass={cs.CL},
|
123 |
+
url={https://arxiv.org/abs/2407.16970},
|
124 |
+
}
|
125 |
+
```
|