|
import json |
|
import time |
|
import argparse |
|
from typing import List, Dict |
|
|
|
from vllm import LLM, SamplingParams |
|
from jinja2 import Template |
|
|
|
|
|
|
|
TASK_INSTRUCTION = """ |
|
You are an expert in composing functions. You are given a question and a set of possible functions. |
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose. |
|
If none of the functions can be used, point it out and refuse to answer. |
|
If the given question lacks the parameters required by the function, also point it out. |
|
""".strip() |
|
|
|
FORMAT_INSTRUCTION = """ |
|
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included. |
|
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]' |
|
``` |
|
{ |
|
"tool_calls": [ |
|
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, |
|
... (more tool calls as required) |
|
] |
|
} |
|
``` |
|
""".strip() |
|
|
|
|
|
class XLAMHandler: |
|
def __init__(self, model: str, temperature: float = 0.3, top_p: float = 1, max_tokens: int = 512): |
|
self.llm = LLM(model=model) |
|
self.sampling_params = SamplingParams( |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=max_tokens |
|
) |
|
self.chat_template = self.llm.get_tokenizer().chat_template |
|
|
|
@staticmethod |
|
def apply_chat_template(template, messages): |
|
jinja_template = Template(template) |
|
return jinja_template.render(messages=messages) |
|
|
|
def process_query(self, query: str, tools: Dict, task_instruction: str, format_instruction: str): |
|
|
|
xlam_tools = self.convert_to_xlam_tool(tools) |
|
|
|
|
|
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction) |
|
|
|
messages = [ |
|
{"role": "user", "content": prompt} |
|
] |
|
formatted_prompt = self.apply_chat_template(self.chat_template, messages) |
|
|
|
|
|
start_time = time.time() |
|
outputs = self.llm.generate([formatted_prompt], self.sampling_params) |
|
latency = time.time() - start_time |
|
|
|
|
|
result = outputs[0].outputs[0].text |
|
parsed_result, success, _ = self.parse_response(result) |
|
|
|
|
|
metadata = { |
|
"latency": latency, |
|
"success": success, |
|
} |
|
|
|
return parsed_result, metadata |
|
|
|
def convert_to_xlam_tool(self, tools): |
|
if isinstance(tools, dict): |
|
return { |
|
"name": tools["name"], |
|
"description": tools["description"], |
|
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()} |
|
} |
|
elif isinstance(tools, list): |
|
return [self.convert_to_xlam_tool(tool) for tool in tools] |
|
else: |
|
return tools |
|
|
|
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION): |
|
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n" |
|
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n" |
|
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n" |
|
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n" |
|
return prompt |
|
|
|
def parse_response(self, response): |
|
try: |
|
data = json.loads(response) |
|
tool_calls = data.get('tool_calls', []) if isinstance(data, dict) else data |
|
result = [ |
|
{tool_call['name']: tool_call['arguments']} |
|
for tool_call in tool_calls if isinstance(tool_call, dict) |
|
] |
|
return result, True, [] |
|
except json.JSONDecodeError: |
|
return [], False, ["Failed to parse JSON response"] |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Test XLAM model with vLLM") |
|
parser.add_argument("--model", required=True, help="Path to the model") |
|
parser.add_argument("--temperature", type=float, default=0.3, help="Temperature for sampling") |
|
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for sampling") |
|
parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
handler = XLAMHandler(args.model, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens) |
|
|
|
|
|
weather_api = { |
|
"name": "get_weather", |
|
"description": "Get the current weather for a location", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": { |
|
"type": "string", |
|
"description": "The city and state, e.g. San Francisco, CA" |
|
}, |
|
"unit": { |
|
"type": "string", |
|
"enum": ["celsius", "fahrenheit"], |
|
"description": "The unit of temperature to return" |
|
} |
|
}, |
|
"required": ["location"] |
|
} |
|
} |
|
|
|
|
|
test_queries = [ |
|
"What's the weather like in New York?", |
|
"Tell me the temperature in London in Celsius", |
|
"What's the weather forecast for Tokyo?", |
|
"What is the stock price of CRM?", |
|
"What's the current temperature in Paris in Fahrenheit?" |
|
] |
|
|
|
|
|
for query in test_queries: |
|
print(f"Query: {query}") |
|
result, metadata = handler.process_query(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION) |
|
print(f"Result: {json.dumps(result, indent=2)}") |
|
print(f"Metadata: {json.dumps(metadata, indent=2)}") |
|
print("-" * 50) |
|
|
|
|
|
calculator_api = { |
|
"name": "calculate", |
|
"description": "Perform a mathematical calculation", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"operation": { |
|
"type": "string", |
|
"enum": ["add", "subtract", "multiply", "divide"], |
|
"description": "The mathematical operation to perform" |
|
}, |
|
"x": { |
|
"type": "number", |
|
"description": "The first number" |
|
}, |
|
"y": { |
|
"type": "number", |
|
"description": "The second number" |
|
} |
|
}, |
|
"required": ["operation", "x", "y"] |
|
} |
|
} |
|
|
|
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?" |
|
multi_api_result, multi_api_metadata = handler.process_query( |
|
multi_api_query, |
|
[weather_api, calculator_api], |
|
TASK_INSTRUCTION, |
|
FORMAT_INSTRUCTION |
|
) |
|
|
|
print("Multi-API Query Result:") |
|
print(json.dumps(multi_api_result, indent=2)) |
|
print(f"Metadata: {json.dumps(multi_api_metadata, indent=2)}") |
|
|