|
import json |
|
import os.path |
|
|
|
from pathlib import Path |
|
from typing import List, Optional, Union |
|
|
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage |
|
|
|
from gpt_engineer.core.ai import AI |
|
from gpt_engineer.core.token_usage import TokenUsageLog |
|
|
|
|
|
Message = Union[AIMessage, HumanMessage, SystemMessage] |
|
|
|
|
|
class CachingAI(AI): |
|
def __init__(self, *args, **kwargs): |
|
self.temperature = 0.1 |
|
self.azure_endpoint = "" |
|
self.streaming = False |
|
try: |
|
self.model_name = "gpt-4-1106-preview" |
|
self.llm = self._create_chat_model() |
|
except: |
|
self.model_name = "cached_response_model" |
|
self.llm = None |
|
self.streaming = False |
|
self.token_usage_log = TokenUsageLog("gpt-4-1106-preview") |
|
self.cache_file = Path(__file__).parent / "ai_cache.json" |
|
|
|
def next( |
|
self, |
|
messages: List[Message], |
|
prompt: Optional[str] = None, |
|
*, |
|
step_name: str, |
|
) -> List[Message]: |
|
""" |
|
Advances the conversation by sending message history |
|
to LLM and updating with the response. |
|
|
|
Parameters |
|
---------- |
|
messages : List[Message] |
|
The list of messages in the conversation. |
|
prompt : Optional[str], optional |
|
The prompt to use, by default None. |
|
step_name : str |
|
The name of the step. |
|
|
|
Returns |
|
------- |
|
List[Message] |
|
The updated list of messages in the conversation. |
|
""" |
|
""" |
|
Advances the conversation by sending message history |
|
to LLM and updating with the response. |
|
""" |
|
if prompt: |
|
messages.append(HumanMessage(content=prompt)) |
|
|
|
|
|
if os.path.isfile(self.cache_file): |
|
with open(self.cache_file, "r") as cache_file: |
|
cache = json.load(cache_file) |
|
else: |
|
cache = dict() |
|
|
|
messages_key = self.serialize_messages(messages) |
|
if messages_key not in cache: |
|
callbacks = [] |
|
print("calling backoff inference") |
|
response = self.backoff_inference(messages, callbacks) |
|
self.token_usage_log.update_log( |
|
messages=messages, answer=response.content, step_name=step_name |
|
) |
|
print("called backoff inference") |
|
print("cost in usd:", self.token_usage_log.usage_cost()) |
|
|
|
messages.append(response) |
|
cache[messages_key] = self.serialize_messages(messages) |
|
with open(self.cache_file, "w") as cache_file: |
|
json.dump(cache, cache_file) |
|
cache_file.write("\n") |
|
|
|
return self.deserialize_messages(cache[messages_key]) |
|
|