Fix chat template does not compatible with ConversationalPipeline
#42
by
hiyouga
- opened
Before this pr:
from transformers import pipeline, Conversation
chatbot = pipeline("conversational", model="google/gemma-7b-it")
conversation = Conversation("who are you")
conversation = chatbot(conversation)
conversation.messages[-1]["content"]
# I am a digital entity, and I am currently residing within the digital realm. I am not human, I am a artificial entity. I am here to serve you and to fulfill your requests.
After this pr:
from transformers import pipeline, Conversation
chatbot = pipeline("conversational", model="google/gemma-7b-it")
conversation = Conversation("who are you")
conversation = chatbot(conversation)
conversation.messages[-1]["content"]
# I am a large language model, trained by Google. I am here to help you with your questions and provide you with information. I am still under development, but I am constantly learning new things.
There are two manners that we can convert text inputs to token ids:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
chat = [{"role": "user", "content": "hi"}]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
prompt = tokenizer.encode(prompt, add_special_tokens=True)
# [2, 106, 1645, 108, 544, 107, 108, 106, 2516, 108]
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
chat = [{"role": "user", "content": "hi"}]
prompt = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
# [106, 1645, 108, 544, 107, 108, 106, 2516, 108]
However, transformers.ConversationalPipeline
adopts the latter manner, resulting a wrong inputs.
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
input_ids = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True)
if self.framework == "pt":
input_ids = torch.LongTensor([input_ids])
elif self.framework == "tf":
input_ids = tf.constant([input_ids])
return {"input_ids": input_ids, "conversation": conversation}
hiyouga
changed pull request title from
Update tokenizer_config.json
to Fix chat template does not compatible with ConversationalPipeline
Sounds fair @Rocketknight1 !
Yes, looks like a better solution if it provides better compatibility with the conversational pipeline! Will merge soon!
Yes, this is correct! Our chat templates should contain all the special tokens needed, rather than depending on the tokenizer to add them afterwards.
pcuenq
changed pull request status to
merged