ff_li
目录调整
f67d239
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Optional, Union
from agent.log import logger
from agent.utils.utils import print_traceback
class FnCallNotImplError(NotImplementedError):
pass
class BaseChatModel(ABC):
def __init__(self):
self._support_fn_call: Optional[bool] = None
# It is okay to use the same code to handle the output
# regardless of whether stream is True or False, as follows:
# ```py
# for chunk in chat_model.chat(..., stream=True/False):
# response += chunk
# yield response
# ```
def chat(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
stream: bool = False,
) -> Union[str, Iterator[str]]:
if messages is None:
assert isinstance(prompt, str)
messages = [{'role': 'user', 'content': prompt}]
else:
assert prompt is None, 'Do not pass prompt and messages at the same time.'
logger.debug(messages)
if stream:
return self._chat_stream(messages, stop=stop)
else:
return self._chat_no_stream(messages, stop=stop)
def support_function_calling(self) -> bool:
if self._support_fn_call is None:
functions = [{
'name': 'get_current_weather',
'description': 'Get the current weather in a given location.',
'parameters': {
'type': 'object',
'properties': {
'location': {
'type':
'string',
'description':
'The city and state, e.g. San Francisco, CA',
},
'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit'],
},
},
'required': ['location'],
},
}]
messages = [{
'role': 'user',
'content': 'What is the weather like in Boston?'
}]
self._support_fn_call = False
try:
response = self.chat_with_functions(messages=messages,
functions=functions)
if response.get('function_call', None):
logger.info('Support of function calling is detected.')
self._support_fn_call = True
except FnCallNotImplError:
pass
except Exception: # TODO: more specific
print_traceback()
return self._support_fn_call
def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None) -> Dict:
raise FnCallNotImplError
@abstractmethod
def _chat_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> Iterator[str]:
raise NotImplementedError
@abstractmethod
def _chat_no_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> str:
raise NotImplementedError