|
from typing import List, Union, Optional, Literal |
|
import dataclasses |
|
|
|
from tenacity import ( |
|
retry, |
|
stop_after_attempt, |
|
wait_random_exponential, |
|
) |
|
import openai |
|
|
|
MessageRole = Literal["system", "user", "assistant"] |
|
|
|
|
|
@dataclasses.dataclass() |
|
class Message(): |
|
role: MessageRole |
|
content: str |
|
|
|
|
|
def message_to_str(message: Message) -> str: |
|
return f"{message.role}: {message.content}" |
|
|
|
|
|
def messages_to_str(messages: List[Message]) -> str: |
|
return "\n".join([message_to_str(message) for message in messages]) |
|
|
|
|
|
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) |
|
def gpt_completion( |
|
model: str, |
|
prompt: str, |
|
max_tokens: int = 1024, |
|
stop_strs: Optional[List[str]] = None, |
|
temperature: float = 0.0, |
|
num_comps=1, |
|
) -> Union[List[str], str]: |
|
response = openai.Completion.create( |
|
model=model, |
|
prompt=prompt, |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
top_p=1, |
|
frequency_penalty=0.0, |
|
presence_penalty=0.0, |
|
stop=stop_strs, |
|
n=num_comps, |
|
) |
|
if num_comps == 1: |
|
return response.choices[0].text |
|
|
|
return [choice.text for choice in response.choices] |
|
|
|
|
|
@retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6)) |
|
def gpt_chat( |
|
model: str, |
|
messages: List, |
|
max_tokens: int = 1024, |
|
temperature: float = 0.0, |
|
num_comps=1, |
|
) -> Union[List[str], str]: |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model=model, |
|
messages=[dataclasses.asdict(message) for message in messages], |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=1, |
|
frequency_penalty=0.0, |
|
presence_penalty=0.0, |
|
n=num_comps, |
|
) |
|
if num_comps == 1: |
|
return response.choices[0].message.content |
|
return [choice.message.content for choice in response.choices] |
|
|
|
except Exception as e: |
|
print(f"An error occurred while calling OpenAI: {e}") |
|
raise |
|
|
|
class ModelBase(): |
|
def __init__(self, name: str): |
|
self.name = name |
|
self.is_chat = False |
|
|
|
def __repr__(self) -> str: |
|
return f'{self.name}' |
|
|
|
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: |
|
raise NotImplementedError |
|
|
|
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: |
|
raise NotImplementedError |
|
|
|
|
|
class GPTChat(ModelBase): |
|
def __init__(self, model_name: str): |
|
self.name = model_name |
|
self.is_chat = True |
|
|
|
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: |
|
return gpt_chat(self.name, messages, max_tokens, temperature, num_comps) |
|
|
|
|
|
class GPT4(GPTChat): |
|
def __init__(self): |
|
super().__init__("gpt-4") |
|
|
|
|
|
class GPT35(GPTChat): |
|
def __init__(self): |
|
super().__init__("gpt-3.5-turbo") |
|
|
|
|
|
class GPTDavinci(ModelBase): |
|
def __init__(self, model_name: str): |
|
self.name = model_name |
|
|
|
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]: |
|
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps) |