Add tool calling support to chat template
Test script to confirm equivalence:
from transformers import AutoTokenizer
from pathlib import Path
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, ToolMessage
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from mistral_common.protocol.instruct.request import ChatCompletionRequest
hf_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", revision="pr/68")
hf_tool = {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
}
hf_tool = {"type": "function", "function": hf_tool}
test_chat = [{"role": "user", "content": "What's the weather like today in Paris"}]
tool_call = {"name": "get_current_weather", "arguments": {"location": "Paris, France"}}
test_chat.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call, "id": "abcdef123"}]})
test_chat.append({"role": "tool", "name": "get_current_temperature", "tool_call_id": "abcdef123", "content": "22.0"})
hf_text =hf_tokenizer.apply_chat_template(test_chat, tokenize=False, tools=[hf_tool])
hf_tokens = hf_tokenizer.apply_chat_template(test_chat, tokenize=True, tools=[hf_tool])
mistral_models_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
mistral_tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3")
mistral_tool = Tool(
function=Function(
name="get_current_weather",
description="Get the current weather",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
)
)
mistral_query = ChatCompletionRequest(
tools=[mistral_tool],
messages=[
UserMessage(content="What's the weather like today in Paris"),
AssistantMessage(tool_calls=[ToolCall(type="function", function=FunctionCall(
name="get_current_weather", arguments={"location": "Paris, France"}), id="abcdef123"
)]),
ToolMessage(content="22.0", tool_call_id="abcdef123")
],
model="test",
)
encodeds = mistral_tokenizer.encode_chat_completion(mistral_query).text
mistral_text = encodeds.replace("▁", " ")
mistral_tokens = mistral_tokenizer.encode_chat_completion(mistral_query).tokens
print(hf_text == mistral_text)
print(hf_tokens == mistral_tokens)
The changes to chat template cause issue when used with openai openapi specced tools.
This line describing tool_calls
which is a part of ChatCompletionRequestAssistantMessage
is the cause of the issue. For assistant responses, tool_calls
field is existent and set to None
, hence is filtered out when we do| selectattr("tool_calls", "undefined")
.
So the resulting loop_messages
would not have assistant
responses. The 2nd instance of user
message would be indexed 1 (instead of 2) and hence throws the error After the optional system message, conversation roles must alternate user/assistant/user/assistant/...
.
Steps to reproduce:
from transformers import AutoTokenizer
from kserve.protocol.rest.openai.types.openapi import ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestAssistantMessage
messages = [
ChatCompletionRequestSystemMessage(
content="You are a pirate chatbot who always responds in pirate speak!",
role="system",
),
ChatCompletionRequestUserMessage(
content="Hi, who are you?",
role="user"
),
ChatCompletionRequestAssistantMessage(
content="I am an AI model created by MistralAI",
role="assistant"
),
ChatCompletionRequestUserMessage(
content="Tell me about robots",
role="user"
),
]
messages_as_list = [
{
"content":"You are a pirate chatbot who always responds in pirate speak!",
"role":"system",
},
{
"content":"Hi, who are you?",
"role":"user"
},
{
"content":"I am an AI model created by MistralAI",
"role":"assistant"
},
{
"content":"Tell me about robots",
"role":"user"
}
]
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.3')
templated_list = tokenizer.apply_chat_template(messages_as_list,tokenize=False)
print(templated_list)
templated_messages = tokenizer.apply_chat_template(messages,tokenize=False)
Response:
<s>[INST] Hi, who are you?[/INST] I am an AI model created by MistralAI</s>[INST] You are a pirate chatbot who always responds in pirate speak!
Tell me about robots[/INST]
{
"name": "TemplateError",
"message": "After the optional system message, conversation roles must alternate user/assistant/user/assistant/...",
"stack": "---------------------------------------------------------------------------
TemplateError Traceback (most recent call last)
Cell In[3], line 45
43 templated_list = tokenizer.apply_chat_template(messages_as_list,tokenize=False)
44 print(templated_list)
---> 45 templated_messages = tokenizer.apply_chat_template(messages,tokenize=False)
File ~/venvs/kserve/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1833, in PreTrainedTokenizerBase.apply_chat_template(self, conversation, tools, documents, chat_template, add_generation_prompt, tokenize, padding, truncation, max_length, return_tensors, return_dict, return_assistant_tokens_mask, tokenizer_kwargs, **kwargs)
1831 all_generation_indices.append(generation_indices)
1832 else:
-> 1833 rendered_chat = compiled_template.render(
1834 messages=chat,
1835 tools=tool_schemas,
1836 documents=documents,
1837 add_generation_prompt=add_generation_prompt,
1838 **template_kwargs,
1839 )
1840 rendered.append(rendered_chat)
1842 if not is_batched:
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:1304, in Template.render(self, *args, **kwargs)
1302 return self.environment.concat(self.root_render_func(ctx)) # type: ignore
1303 except Exception:
-> 1304 self.environment.handle_exception()
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:939, in Environment.handle_exception(self, source)
934 \"\"\"Exception handling helper. This is used internally to either raise
935 rewritten exceptions or return a rendered traceback for the template.
936 \"\"\"
937 from .debug import rewrite_traceback_stack
--> 939 raise rewrite_traceback_stack(source=source)
File <template>:14, in top-level template code()
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/sandbox.py:394, in SandboxedEnvironment.call(_SandboxedEnvironment__self, _SandboxedEnvironment__context, _SandboxedEnvironment__obj, *args, **kwargs)
392 if not __self.is_safe_callable(__obj):
393 raise SecurityError(f\"{__obj!r} is not safely callable\")
--> 394 return __context.call(__obj, *args, **kwargs)
File ~/venvs/kserve/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1914, in PreTrainedTokenizerBase._compile_jinja_template.<locals>.raise_exception(message)
1913 def raise_exception(message):
-> 1914 raise TemplateError(message)
TemplateError: After the optional system message, conversation roles must alternate user/assistant/user/assistant/..."
}
@imdatta0
I see, let me see if I can update the template to handle cases where tool_calls
is present but null.
@imdatta0 , can you try running this snippet to update your local template and then rerunning your code to check it works okay?
import json
tokenizer.chat_template = json.loads('"{%- if messages[0][\\"role\\"] == \\"system\\" %}\\n {%- set system_message = messages[0][\\"content\\"] %}\\n {%- set loop_messages = messages[1:] %}\\n{%- else %}\\n {%- set loop_messages = messages %}\\n{%- endif %}\\n{%- if not tools is defined %}\\n {%- set tools = none %}\\n{%- endif %}\\n{%- set user_messages = loop_messages | selectattr(\\"role\\", \\"equalto\\", \\"user\\") | list %}\\n\\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\\n{%- set ns = namespace() %}\\n{%- set ns.index = 0 %}\\n{%- for message in loop_messages %}\\n {%- if not (message.role == \\"tool\\" or message.role == \\"tool_results\\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\\n {%- if (message[\\"role\\"] == \\"user\\") != (ns.index % 2 == 0) %}\\n {{- raise_exception(\\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\\") }}\\n {%- endif %}\\n {%- set ns.index = ns.index + 1 %}\\n {%- endif %}\\n{%- endfor %}\\n\\n{{- bos_token }}\\n{%- for message in loop_messages %}\\n {%- if message[\\"role\\"] == \\"user\\" %}\\n {%- if tools is not none and (message == user_messages[-1]) %}\\n {{- \\"[AVAILABLE_TOOLS] [\\" }}\\n {%- for tool in tools %}\\n {%- set tool = tool.function %}\\n {{- \'{\\"type\\": \\"function\\", \\"function\\": {\' }}\\n {%- for key, val in tool.items() if key != \\"return\\" %}\\n {%- if val is string %}\\n {{- \'\\"\' + key + \'\\": \\"\' + val + \'\\"\' }}\\n {%- else %}\\n {{- \'\\"\' + key + \'\\": \' + val|tojson }}\\n {%- endif %}\\n {%- if not loop.last %}\\n {{- \\", \\" }}\\n {%- endif %}\\n {%- endfor %}\\n {{- \\"}}\\" }}\\n {%- if not loop.last %}\\n {{- \\", \\" }}\\n {%- else %}\\n {{- \\"]\\" }}\\n {%- endif %}\\n {%- endfor %}\\n {{- \\"[/AVAILABLE_TOOLS]\\" }}\\n {%- endif %}\\n {%- if loop.last and system_message is defined %}\\n {{- \\"[INST] \\" + system_message + \\"\\\\n\\\\n\\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n {%- else %}\\n {{- \\"[INST] \\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n {%- endif %}\\n {%- elif message[\\"role\\"] == \\"tool_calls\\" or message.tool_calls is defined %}\\n {%- if message.tool_calls is defined %}\\n {%- set tool_calls = message.tool_calls %}\\n {%- else %}\\n {%- set tool_calls = message.content %}\\n {%- endif %}\\n {{- \\"[TOOL_CALLS] [\\" }}\\n {%- for tool_call in tool_calls %}\\n {%- set out = tool_call.function|tojson %}\\n {{- out[:-1] }}\\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\\n {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n {%- endif %}\\n {{- \', \\"id\\": \\"\' + tool_call.id + \'\\"}\' }}\\n {%- if not loop.last %}\\n {{- \\", \\" }}\\n {%- else %}\\n {{- \\"]\\" + eos_token }}\\n {%- endif %}\\n {%- endfor %}\\n {%- elif message[\\"role\\"] == \\"assistant\\" %}\\n {{- \\" \\" + message[\\"content\\"] + eos_token}}\\n {%- elif message[\\"role\\"] == \\"tool_results\\" or message[\\"role\\"] == \\"tool\\" %}\\n {%- if message.content is defined and message.content.content is defined %}\\n {%- set content = message.content.content %}\\n {%- else %}\\n {%- set content = message.content %}\\n {%- endif %}\\n {{- \'[TOOL_RESULTS] {\\"content\\": \' + content|string + \\", \\" }}\\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\\n {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n {%- endif %}\\n {{- \'\\"call_id\\": \\"\' + message.tool_call_id + \'\\"}[/TOOL_RESULTS]\' }}\\n {%- else %}\\n {{- raise_exception(\\"Only user and assistant roles are supported, with the exception of an initial optional system message!\\") }}\\n {%- endif %}\\n{%- endfor %}\\n"')
Hey guys - can you clarify what the expected result is for parallel tool calls for a request like this?
{
"model": "mistralai/Mistral-7B-Instruct-v0.3",
"messages": [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role": "user",
"content": "Can you tell me what the weather will be in Dallas and Orlando in fahrenheit?"
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": [
"celsius",
"fahrenheit"
]
}
}
}
}
}
]
}
I would expect something like[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]
, but I keep getting this using this chat template:
[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]
[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]
Can you confirm if this is the expected result, or if the model is not intended to support parallel tool calls?
Hi @Rocketknight1 , thanks for the swift response. Yeah this new change seems to fix the alternating message issue. But it throws a new error
{
"name": "TypeError",
"message": "'NoneType' object is not iterable",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 3
1 # templated_list = tokenizer.apply_chat_template(messages_as_list,tokenize=False)
2 # print(templated_list)
----> 3 templated_messages = tokenizer.apply_chat_template(messages,tokenize=False)
File ~/venvs/kserve/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1833, in PreTrainedTokenizerBase.apply_chat_template(self, conversation, tools, documents, chat_template, add_generation_prompt, tokenize, padding, truncation, max_length, return_tensors, return_dict, return_assistant_tokens_mask, tokenizer_kwargs, **kwargs)
1831 all_generation_indices.append(generation_indices)
1832 else:
-> 1833 rendered_chat = compiled_template.render(
1834 messages=chat,
1835 tools=tool_schemas,
1836 documents=documents,
1837 add_generation_prompt=add_generation_prompt,
1838 **template_kwargs,
1839 )
1840 rendered.append(rendered_chat)
1842 if not is_batched:
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:1304, in Template.render(self, *args, **kwargs)
1302 return self.environment.concat(self.root_render_func(ctx)) # type: ignore
1303 except Exception:
-> 1304 self.environment.handle_exception()
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:939, in Environment.handle_exception(self, source)
934 \"\"\"Exception handling helper. This is used internally to either raise
935 rewritten exceptions or return a rendered traceback for the template.
936 \"\"\"
937 from .debug import rewrite_traceback_stack
--> 939 raise rewrite_traceback_stack(source=source)
File <template>:71, in top-level template code()
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/runtime.py:422, in LoopContext.__init__(self, iterable, undefined, recurse, depth0)
413 \"\"\"
414 :param iterable: Iterable to wrap.
415 :param undefined: :class:`Undefined` class to use for next and
(...)
419 :param depth0: Incremented when looping recursively.
420 \"\"\"
421 self._iterable = iterable
--> 422 self._iterator = self._to_iterator(iterable)
423 self._undefined = undefined
424 self._recurse = recurse
File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/runtime.py:430, in LoopContext._to_iterator(iterable)
428 @staticmethod
429 def _to_iterator(iterable: t.Iterable[V]) -> t.Iterator[V]:
--> 430 return iter(iterable)
TypeError: 'NoneType' object is not iterable"
}
I tried to debug it
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length != 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
{%- endif %}
{{- ', "id": "' + tool_call.id + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
This seems to be causing the issue. Here message is of the assistant when this breaks. content='I am an AI model created by MistralAI' refusal=None role='assistant' name=None tool_calls=None function_call=None
So updating the first line check toelif message["role"] == "tool_calls" or message.tool_calls is defined and message.tool_calls is not none
seems to work for me.
Would be great if we can include the above changes too :)
Hi @imdatta0 , good spot! I'll include that as well.
Hey
@Rocketknight1
,
Yeah I was checking in on the repo frequently and noticed your PR. I tried the 7B changes and they do work for me.
I really appreciate your help with this so far.