Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from typing import Dict, Iterator, List, Optional, Union | |
from agent.log import logger | |
from agent.utils.utils import print_traceback | |
class FnCallNotImplError(NotImplementedError): | |
pass | |
class BaseChatModel(ABC): | |
def __init__(self): | |
self._support_fn_call: Optional[bool] = None | |
# It is okay to use the same code to handle the output | |
# regardless of whether stream is True or False, as follows: | |
# ```py | |
# for chunk in chat_model.chat(..., stream=True/False): | |
# response += chunk | |
# yield response | |
# ``` | |
def chat( | |
self, | |
prompt: Optional[str] = None, | |
messages: Optional[List[Dict]] = None, | |
stop: Optional[List[str]] = None, | |
stream: bool = False, | |
) -> Union[str, Iterator[str]]: | |
if messages is None: | |
assert isinstance(prompt, str) | |
messages = [{'role': 'user', 'content': prompt}] | |
else: | |
assert prompt is None, 'Do not pass prompt and messages at the same time.' | |
logger.debug(messages) | |
if stream: | |
return self._chat_stream(messages, stop=stop) | |
else: | |
return self._chat_no_stream(messages, stop=stop) | |
def support_function_calling(self) -> bool: | |
if self._support_fn_call is None: | |
functions = [{ | |
'name': 'get_current_weather', | |
'description': 'Get the current weather in a given location.', | |
'parameters': { | |
'type': 'object', | |
'properties': { | |
'location': { | |
'type': | |
'string', | |
'description': | |
'The city and state, e.g. San Francisco, CA', | |
}, | |
'unit': { | |
'type': 'string', | |
'enum': ['celsius', 'fahrenheit'], | |
}, | |
}, | |
'required': ['location'], | |
}, | |
}] | |
messages = [{ | |
'role': 'user', | |
'content': 'What is the weather like in Boston?' | |
}] | |
self._support_fn_call = False | |
try: | |
response = self.chat_with_functions(messages=messages, | |
functions=functions) | |
if response.get('function_call', None): | |
logger.info('Support of function calling is detected.') | |
self._support_fn_call = True | |
except FnCallNotImplError: | |
pass | |
except Exception: # TODO: more specific | |
print_traceback() | |
return self._support_fn_call | |
def chat_with_functions(self, | |
messages: List[Dict], | |
functions: Optional[List[Dict]] = None) -> Dict: | |
raise FnCallNotImplError | |
def _chat_stream( | |
self, | |
messages: List[Dict], | |
stop: Optional[List[str]] = None, | |
) -> Iterator[str]: | |
raise NotImplementedError | |
def _chat_no_stream( | |
self, | |
messages: List[Dict], | |
stop: Optional[List[str]] = None, | |
) -> str: | |
raise NotImplementedError | |