File size: 3,625 Bytes
41d1bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from typing import List, Union, Optional, Literal
import dataclasses
from tenacity import (
retry,
stop_after_attempt, # type: ignore
wait_random_exponential, # type: ignore
)
import openai
MessageRole = Literal["system", "user", "assistant"]
@dataclasses.dataclass()
class Message():
role: MessageRole
content: str
def message_to_str(message: Message) -> str:
return f"{message.role}: {message.content}"
def messages_to_str(messages: List[Message]) -> str:
return "\n".join([message_to_str(message) for message in messages])
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gpt_completion(
model: str,
prompt: str,
max_tokens: int = 1024,
stop_strs: Optional[List[str]] = None,
temperature: float = 0.0,
num_comps=1,
) -> Union[List[str], str]:
response = openai.Completion.create(
model=model,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=stop_strs,
n=num_comps,
)
if num_comps == 1:
return response.choices[0].text # type: ignore
return [choice.text for choice in response.choices] # type: ignore
@retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6))
def gpt_chat(
model: str,
messages: List,
max_tokens: int = 1024,
temperature: float = 0.0,
num_comps=1,
) -> Union[List[str], str]:
try:
response = openai.ChatCompletion.create(
model=model,
messages=[dataclasses.asdict(message) for message in messages],
max_tokens=max_tokens,
temperature=temperature,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
n=num_comps,
)
if num_comps == 1:
return response.choices[0].message.content # type: ignore
return [choice.message.content for choice in response.choices] # type: ignore
except Exception as e:
print(f"An error occurred while calling OpenAI: {e}")
raise
class ModelBase():
def __init__(self, name: str):
self.name = name
self.is_chat = False
def __repr__(self) -> str:
return f'{self.name}'
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
raise NotImplementedError
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]:
raise NotImplementedError
class GPTChat(ModelBase):
def __init__(self, model_name: str):
self.name = model_name
self.is_chat = True
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
return gpt_chat(self.name, messages, max_tokens, temperature, num_comps)
class GPT4(GPTChat):
def __init__(self):
super().__init__("gpt-4")
class GPT35(GPTChat):
def __init__(self):
super().__init__("gpt-3.5-turbo")
class GPTDavinci(ModelBase):
def __init__(self, model_name: str):
self.name = model_name
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]:
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps) |