Edit model card

This is an attempt to replicate the RLHF pipeline

Base Model

We used bloomz-7b1-mt because of its less-restricted license and multilingual ability.

Supervised Fintune

For SFT we used a combination of multiple datasets including:

Reward Model

For RM we used the code of reward-modeling repo and datasets from

Reinforcement Learning

For RL we used the code of trlx with slight modification.

Instead of building value network upon the policy network with a single linear layer, we add another hydra head upon the reference network's frozen bottom layers as value network.

Example

We used Vicuna v1.1 template for model training

from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "keyfan/bloomz-rlhf"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).cuda()

template = ("A chat between a curious human and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the human's questions. "
            "USER: {}\nASSISTANT:")
question = template.format("Who was the president of the United States in 1955?")
inputs = tokenizer.encode(question, return_tensors="pt").cuda()
outputs = model.generate(inputs, do_sample=True, top_p=0.8, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

Evalutions

Result on the Chinese BELLE eval set

others rewrite classification generation summarization extract open qa brainstorming closed qa macro ave macro ave w/o others
0.619 0.873 0.706 0.934 0.755 0.619 0.527 0.908 0.615 0.728 0.742
  • We found in GPT-4 evaluation the order in which the responses were presented has unneglectable affect on the final score even with the very-well designed Vicuna prompt. So we removed the score on the Vicuna eval set.
Downloads last month
15
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train keyfan/bloomz-rlhf