detroitnatif
commited on
Commit
•
8b7e5cc
1
Parent(s):
2efa5aa
added pipeline
Browse files- rewardModel.py +113 -0
rewardModel.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from datasets import load_dataset
|
4 |
+
from transformers import pipeline, AutoTokenizer
|
5 |
+
|
6 |
+
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
7 |
+
from trl.core import LengthSampler
|
8 |
+
|
9 |
+
|
10 |
+
dataset = load_dataset("imdb")
|
11 |
+
# print(dataset['train']['text'][:5])
|
12 |
+
# print(dataset['train']['label'][:5])x
|
13 |
+
|
14 |
+
def build_dataset(config, dataset_name='imdb', input_min = 2, input_max=8):
|
15 |
+
|
16 |
+
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
18 |
+
tokenizer.pad_token = tokenizer.eos_token
|
19 |
+
|
20 |
+
|
21 |
+
ds = load_dataset(dataset_name, split="train")
|
22 |
+
ds = ds.rename_columns({"text": "review"})
|
23 |
+
ds = ds.filter(lambda x: len(x['review']) > 200, batched=False)
|
24 |
+
|
25 |
+
input_size = LengthSampler(input_min, input_max)
|
26 |
+
|
27 |
+
def tokenize(sample):
|
28 |
+
sample['input_ids'] = tokenizer.encode(sample['review'])[: input_size()]
|
29 |
+
sample['query'] = tokenizer.decode(sample["input_ids"])
|
30 |
+
return sample
|
31 |
+
|
32 |
+
ds = ds.map(tokenize, batched=False)
|
33 |
+
ds.set_format(type="torch")
|
34 |
+
return ds
|
35 |
+
|
36 |
+
|
37 |
+
def collator(data):
|
38 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == '__main__': #START
|
43 |
+
print("running")
|
44 |
+
config = PPOConfig(
|
45 |
+
model_name="lvwerra/gpt2-imdb",
|
46 |
+
learning_rate=1.41e-5,
|
47 |
+
log_with="wandb"
|
48 |
+
)
|
49 |
+
|
50 |
+
import wandb
|
51 |
+
wandb.init()
|
52 |
+
|
53 |
+
dataset = build_dataset(config)
|
54 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
|
55 |
+
|
56 |
+
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) # this is the frozen model which we'll compare to
|
57 |
+
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
59 |
+
tokenizer.pad_token = tokenizer.eos_token
|
60 |
+
|
61 |
+
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset, data_collator=collator)
|
62 |
+
device = ppo_trainer.accelerator.device
|
63 |
+
if ppo_trainer.accelerator.num_processes == 1:
|
64 |
+
device = 0 if torch.cuda.is_available() else "cpu"
|
65 |
+
|
66 |
+
sentiment_pipe = pipeline("sentiment-analysis", model='lvwerra/distilbert-imdb', device=device)
|
67 |
+
|
68 |
+
sent_kwargs = {'return_all_scores': True, "function_to_apply": "none", "batch_size": 16}
|
69 |
+
text = "this was really bad!!"
|
70 |
+
print(sentiment_pipe(text, **sent_kwargs))
|
71 |
+
|
72 |
+
text = "this was really good!!"
|
73 |
+
print(sentiment_pipe(text, **sent_kwargs))
|
74 |
+
|
75 |
+
output_min_length = 4
|
76 |
+
output_max_length = 16
|
77 |
+
output_length_sampler = LengthSampler(output_min_length, output_max_length)
|
78 |
+
|
79 |
+
response_generation_kwargs = {
|
80 |
+
"min_length": -1,
|
81 |
+
"top_k": 0.0,
|
82 |
+
"top_p": 1.0,
|
83 |
+
"do_sample": True,
|
84 |
+
"pad_token_id": tokenizer.eos_token_id,
|
85 |
+
|
86 |
+
}
|
87 |
+
|
88 |
+
for epoch, batch in enumerate(ppo_trainer.dataloader):
|
89 |
+
|
90 |
+
query_tensors = batch["input_ids"]
|
91 |
+
|
92 |
+
response_tensors = []
|
93 |
+
|
94 |
+
for query in query_tensors:
|
95 |
+
gen_len = output_length_sampler()
|
96 |
+
response_generation_kwargs["max_new_tokens"] = gen_len
|
97 |
+
response = ppo_trainer.generate(query, **response_generation_kwargs)
|
98 |
+
response_tensors.append(response.squeeze()[-gen_len:]) # only take the piece generated, remove acutal prompt
|
99 |
+
batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
100 |
+
|
101 |
+
texts = [q + r for q, r in zip(batch['query'], batch['response'])]
|
102 |
+
|
103 |
+
pipe_outputs = sentiment_pipe(texts, **sent_kwargs) # this gives a dictionary, we want the positive, which is our rewards
|
104 |
+
|
105 |
+
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
106 |
+
|
107 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
108 |
+
|
109 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
110 |
+
|
111 |
+
# model.save_pretrained('gpt-imbd-pos-v2', push_to_hub=False)
|
112 |
+
|
113 |
+
# tokenizer.save_pretrained('gpt-imbd-pos-v2', push_to_hub=False)
|