TRL Model
This is a TRL language model. It has been fine-tuned with reinforcement learning to guide the model outputs according to a value, function, or human feedback. The model can be used for text generation. This project aims to reduce the level of toxicity in the outputs generated by the LAMINI Flan T5 248M language model using Reinforcement Learning with Artificial Intelligence Feedback technique (RLAIF). Reinforcement Learning with Human Feedback (RLHF) is a method to align models with a particular kind of data. RLHF creates a latent reward model using human feedback and finetunes a model using Proximal Policy Optimization. RLAIF on the other hand replaces human feedback with a high-performance AI agent. The model has been fine-tuned on the Social Reasoning Dataset by ProlificAI for 191 steps and 1 epoch using the Proximal Policy Optimisation (PPO) algorithm. The Roberta hate text detection model was used as the Proximal Policy Optimisation (PPO) reward model.
The power of this model lies in its size; it is barely 500 MBs in size and performs well given its size. The intended use of this model should be conversation, text generation, or context-based Q&A. This model might not perform well on tasks like mathematics, sciences, coding, etc. It might hallucinate on such tasks. After quantization, this model could be easily run on edge devices like smartphones and microprocessors.
The training log of the model can be found in this weights and biases page.
Note: This model is a fine-tuned version of LaMini Flan T5 248M, which in turn is a fine-tuned version of the Flan T5 model released by Google. The Flan T5 follows the encoder-decoder architecture, unlike other GPT-like models that are decoder-only.
Usage
To use this model for inference, first install the TRL library:
python -m pip install trl
You can then generate text as follows:
from trl import AutoModelForSeq2SeqLMWithValueHead
from transformers import pipeline, AutoTokenizer
import torch
checkpoint = "ARahul2003/lamini_flan_t5_detoxify_rlaif"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(checkpoint,
device_map='cpu', #or 'auto'/'cuda:0'
torch_dtype=torch.float32)
pipe = pipeline('text2text-generation',
model = base_model,
tokenizer = tokenizer,
max_length = 512,
do_sample=True,
temperature=0.3,
top_p=0.95,
)
prompt = 'Hello! How are you?'
print(pipe(prompt)[0]['generated_text'])
If you want to use the model for training or to obtain the outputs from the value head, load the model as follows:
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
tokenizer = AutoTokenizer.from_pretrained("ARahul2003/lamini_flan_t5_detoxify_rlaif")
model = AutoModelForCausalLMWithValueHead.from_pretrained("ARahul2003/lamini_flan_t5_detoxify_rlaif")
inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
If you want to use the model for inference in a gradio app, consider the following code:
from trl import AutoModelForSeq2SeqLMWithValueHead
from transformers import pipeline, AutoTokenizer
import torch
import gradio as gr
title = "LaMini Flan T5 248M"
checkpoint = "ARahul2003/lamini_flan_t5_detoxify_rlaif"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(checkpoint,
device_map='cpu', #or 'auto'
torch_dtype=torch.float32)
pipe = pipeline('text2text-generation',
model = base_model,
tokenizer = tokenizer,
max_length = 512,
do_sample=True,
temperature=0.3,
top_p=0.95,
)
def chat_with_model(inp_chat, chat_history = None):
prompt = f"{inp_chat}" #f"User: {inp_chat} Bot:"
responses = pipe(prompt)
return responses[0]['generated_text']
examples = [
'Hi!',
'How are you?',
'Please let me know your thoughts on the given place and why you think it deserves to be visited: \n"Barcelona, Spain"'
]
gr.ChatInterface(
fn=chat_with_model,
title=title,
examples=examples
).launch()
Make sure to keep all the tensors on the same device (CPU/GPU).
- Downloads last month
- 4