from __future__ import annotations import logging import openai logger = logging.getLogger(__name__) class AI: def __init__(self, model="gpt-4", temperature=0.1): self.temperature = temperature try: openai.Model.retrieve(model) self.model = model except openai.InvalidRequestError: print( f"Model {model} not available for provided API key. Reverting " "to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: " "https://openai.com/waitlist/gpt-4-api" ) self.model = "gpt-3.5-turbo" def start(self, system, user): messages = [ {"role": "system", "content": system}, {"role": "user", "content": user}, ] return self.next(messages) def fsystem(self, msg): return {"role": "system", "content": msg} def fuser(self, msg): return {"role": "user", "content": msg} def fassistant(self, msg): return {"role": "assistant", "content": msg} def next(self, messages: list[dict[str, str]], prompt=None): if prompt: messages += [{"role": "user", "content": prompt}] logger.debug(f"Creating a new chat completion: {messages}") response = openai.ChatCompletion.create( messages=messages, stream=True, model=self.model, temperature=self.temperature, ) chat = [] for chunk in response: delta = chunk["choices"][0]["delta"] msg = delta.get("content", "") print(msg, end="") chat.append(msg) print() messages += [{"role": "assistant", "content": "".join(chat)}] logger.debug(f"Chat completion finished: {messages}") return messages