CoI_Agent / LLM.py
jianghuyihei's picture
first commit
863d8a3
raw
history blame
6.46 kB
from openai import AzureOpenAI, OpenAI,AsyncAzureOpenAI,AsyncOpenAI
from abc import abstractmethod
import os
import httpx
import base64
import logging
import asyncio
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_fixed,
)
def get_content_between_a_b(start_tag, end_tag, text):
extracted_text = ""
start_index = text.find(start_tag)
while start_index != -1:
end_index = text.find(end_tag, start_index + len(start_tag))
if end_index != -1:
extracted_text += text[start_index + len(start_tag) : end_index] + " "
start_index = text.find(start_tag, end_index + len(end_tag))
else:
break
return extracted_text.strip()
def before_retry_fn(retry_state):
if retry_state.attempt_number > 1:
logging.info(f"Retrying API call. Attempt #{retry_state.attempt_number}, f{retry_state}")
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_openai_url(img_pth):
end = img_pth.split(".")[-1]
if end == "jpg":
end = "jpeg"
base64_image = encode_image(img_pth)
return f"data:image/{end};base64,{base64_image}"
class base_llm:
def __init__(self) -> None:
pass
@abstractmethod
def response(self,messages,**kwargs):
pass
def get_imgs(self,prompt, save_path="saves/dalle3.jpg"):
pass
class openai_llm(base_llm):
def __init__(self,model = "gpt4o-0513") -> None:
super().__init__()
self.model = model
if "AZURE_OPENAI_ENDPOINT" not in os.environ or os.environ["AZURE_OPENAI_ENDPOINT"] == "":
raise ValueError("AZURE_OPENAI_ENDPOINT is not set")
if "AZURE_OPENAI_KEY" not in os.environ or os.environ["AZURE_OPENAI_KEY"] == "":
raise ValueError("AZURE_OPENAI_KEY is not set")
api_version = os.environ.get("AZURE_OPENAI_API_VERSION",None)
if api_version == "":
api_version = None
self.client = AzureOpenAI(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_KEY"],
api_version= api_version
)
self.async_client = AsyncAzureOpenAI(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_KEY"],
api_version= api_version
)
def cal_cosine_similarity(self, vec1, vec2):
if isinstance(vec1, list):
vec1 = np.array(vec1)
if isinstance(vec2, list):
vec2 = np.array(vec2)
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
@retry(wait=wait_fixed(10), stop=stop_after_attempt(10), before=before_retry_fn)
def response(self,messages,**kwargs):
try:
response = self.client.chat.completions.create(
model=kwargs.get("model", self.model),
messages=messages,
n = kwargs.get("n", 1),
temperature= kwargs.get("temperature", 0.7),
max_tokens=kwargs.get("max_tokens", 4000),
timeout=kwargs.get("timeout", 180)
)
except Exception as e:
model = kwargs.get("model", self.model)
print(f"get {model} response failed: {e}")
print(e)
logging.info(e)
return
return response.choices[0].message.content
@retry(wait=wait_fixed(10), stop=stop_after_attempt(10), before=before_retry_fn)
def get_embbeding(self,text):
if os.environ.get("EMBEDDING_API_ENDPOINT"):
client = AzureOpenAI(
azure_endpoint=os.environ.get("EMBEDDING_API_ENDPOINT",None),
api_key=os.environ.get("EMBEDDING_API_KEY",None),
api_version= os.environ.get("AZURE_OPENAI_API_VERSION",None),
azure_deployment="embedding-3-large"
)
else:
client = self.client
try:
embbeding = client.embeddings.create(
model=os.environ.get("EMBEDDING_MODEL","text-embedding-3-large"),
input=text,
timeout= 180
)
return embbeding.data[0].embedding
except Exception as e:
print(f"get embbeding failed: {e}")
print(e)
logging.info(e)
return
async def get_embbeding_async(self,text):
if os.environ.get("EMBEDDING_API_ENDPOINT",None):
async_client = AsyncAzureOpenAI(
azure_endpoint=os.environ.get("EMBEDDING_API_ENDPOINT",None),
api_key=os.environ.get("EMBEDDING_API_KEY",None),
api_version= os.environ.get("AZURE_OPENAI_API_VERSION",None),
azure_deployment="embedding-3-large"
)
else:
async_client = self.async_client
try:
embbeding = await async_client.embeddings.create(
model=os.environ.get("EMBEDDING_MODEL","text-embedding-3-large"),
input=text,
timeout= 180
)
return embbeding.data[0].embedding
except Exception as e:
await asyncio.sleep(0.1)
print(f"get embbeding failed: {e}")
print(e)
logging.info(e)
return
@retry(wait=wait_fixed(10), stop=stop_after_attempt(10), before=before_retry_fn)
async def response_async(self,messages,**kwargs):
try:
response = await self.async_client.chat.completions.create(
model=kwargs.get("model", self.model),
messages=messages,
n = kwargs.get("n", 1),
temperature= kwargs.get("temperature", 0.7),
max_tokens=kwargs.get("max_tokens", 4000),
timeout=kwargs.get("timeout", 180)
)
except Exception as e:
await asyncio.sleep(0.1)
model = kwargs.get("model", self.model)
print(f"get {model} response failed: {e}")
print(e)
logging.info(e)
return
return response.choices[0].message.content
if __name__ == "__main__":
llm = gemini_llm(api_key="")
prompt = """
"""
messages = [{"role":"user","content":prompt}]
response = asyncio.run(llm.response_async(messages))
print(response)