mlnet-samples / embedding.py
XiaoYun Zhang
update
6abb254
raw
history blame contribute delete
No virus
1.66 kB
import openai
class Embedding:
type: str|None = None
vector_size: int|None = None
def generate_embedding(self, content: str) -> list[float]:
pass
class OpenAITextAda002(Embedding):
type: str = 'text-ada-002'
vector_size: int = 1536
api_key: str = None
def __init__(self, api_key: str):
self.api_key = api_key
def generate_embedding(self, content: str) -> list[float]:
# replace newline with space
content = content.replace('\n', ' ')
# limit to 8192 characters
content = content[:6000]
return openai.Embedding.create(
api_key=self.api_key,
api_type='openai',
input = content,
model="text-embedding-ada-002"
)["data"][0]["embedding"]
class AzureOpenAITextAda002(Embedding):
type: str = 'text-ada-002'
vector_size: int = 1536
api_key: str = None
def __init__(
self,
api_base: str,
model_name: str,
api_key: str):
self.api_key = api_key
self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
def generate_embedding(self, content: str) -> list[float]:
# replace newline with space
content = content.replace('\n', ' ')
# limit to 8192 characters
content = content[:6000]
return openai.Embedding.create(
api_key=self.api_key,
api_type='azure',
api_base=self.api_base,
input = content,
engine=self.model_name,
api_version="2023-07-01-preview"
)["data"][0]["embedding"]