Chat template ignores system message after a tool response.
The chat template doesn't appear to add the system message if the conversation ends in a tool call. For example:
from transformers import AutoTokenizer
from transformers.utils import get_json_schema
tokenizer = AutoTokenizer.from_pretrained(mistral_models_path)
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22. # A real function should probably actually get the temperature!
def get_current_wind_speed(location: str) -> float:
"""
Get the current wind speed in km/h at a given location.
Args:
location: The location to get the temperature for, in the format "City, Country"
Returns:
The current wind speed at the given location in km/h, as a float.
"""
return 6. # A real function should probably actually get the wind speed!
tools = [get_current_temperature, get_current_wind_speed]
tools = [get_json_schema(tool) for tool in tools]
messages = [
{"role": "system", "content": "You are a bot that responds to weather queries. Your responses should be in the style of a pirate"},
{"role": "user", "content": "What's the weather like in Paris?"},
{"role": "assistant", "tool_calls": [
{
"type": "function",
"function": {
"name": "get_current_temperature",
"arguments": {"location": "Paris, France", "format": "celsius"},
},
"id": "abcdef123"
}]
},
{
"role": "tool",
"name": "get_current_temperature",
"tool_call_id": "abcdef123",
"content": "22.0"
},
]
print(tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=True, tokenize=False))
Results in the following, where you can see the system prompt isn't included.
<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_current_temperature", "description": "Get the current temperature at a location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \"City, Country\""}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit to return the temperature in."}}, "required": ["location", "unit"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \"City, Country\""}}, "required": ["location"]}}}][/AVAILABLE_TOOLS][INST]What's the weather like in Paris?[/INST][TOOL_CALLS][{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "format": "celsius"}, "id": "abcdef123"}]</s>[TOOL_RESULTS]{"content": 22.0, "call_id": "abcdef123"}[/TOOL_RESULTS]
I don't know whether this is the exact solution you would want. But the following works.
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}
{%- set ns = namespace() %}
{%- set ns.index = 0 %}
{%- for message in loop_messages %}
{%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}
{%- if (message["role"] == "user") != (ns.index % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- set ns.index = ns.index + 1 %}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS][" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if loop.last and system_message is defined %}
{{- "[INST]" + system_message + "\n\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST]" + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif (message.tool_calls is defined and message.tool_calls is not none) %}
{{- "[TOOL_CALLS][" }}
{%- for tool_call in message.tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length != 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
{%- endif %}
{{- ', "id": "' + tool_call.id + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- message["content"] + eos_token}}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS]{"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }}
{#- THIS BIT HERE vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv #}
{%- if loop.last and system_message is defined %}
{{- "[INST]" + system_message + "[/INST]" }}
{%- endif %}
{#- THIS BIT HERE ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}