|
import argparse |
|
import markdown2 |
|
import sys |
|
import uvicorn |
|
|
|
from pathlib import Path |
|
from typing import Union, Optional |
|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel, Field |
|
from fastapi.responses import HTMLResponse |
|
from tclogger import logger, OSEnver |
|
|
|
from transforms.embed import JinaAIEmbedder |
|
from configs.constants import AVAILABLE_MODELS |
|
|
|
info_path = Path(__file__).parents[1] / "configs" / "info.json" |
|
ENVER = OSEnver(info_path) |
|
|
|
|
|
class EmbeddingApp: |
|
def __init__(self): |
|
self.app = FastAPI( |
|
docs_url="/", |
|
title=ENVER["app_name"], |
|
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
|
version=ENVER["version"], |
|
) |
|
self.embedder = JinaAIEmbedder() |
|
self.setup_routes() |
|
|
|
def get_available_models(self): |
|
return AVAILABLE_MODELS |
|
|
|
def get_readme(self): |
|
readme_path = Path(__file__).parents[1] / "README.md" |
|
with open(readme_path, "r", encoding="utf-8") as rf: |
|
readme_str = rf.read() |
|
readme_html = markdown2.markdown( |
|
readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"] |
|
) |
|
return readme_html |
|
|
|
class CalcEmbeddingPostItem(BaseModel): |
|
text: Union[str, list[str]] = Field( |
|
default=None, |
|
summary="Input text(s) to embed", |
|
) |
|
model: Optional[str] = Field( |
|
default=AVAILABLE_MODELS[0], |
|
summary="Embedding model name", |
|
) |
|
|
|
def calc_embedding(self, item: CalcEmbeddingPostItem): |
|
logger.note(f"> Encoding text: [{item.text}]", end=" ") |
|
if item.model != self.embedder.model: |
|
self.embedder.switch_model(item.model) |
|
embeddings = self.embedder.encode(item.text).tolist() |
|
logger.success(f"[{len(embeddings[0])}]") |
|
if len(embeddings) == 1: |
|
return embeddings[0] |
|
else: |
|
return embeddings |
|
|
|
def setup_routes(self): |
|
self.app.get( |
|
"/models", |
|
summary="Get available models", |
|
)(self.get_available_models) |
|
|
|
self.app.post( |
|
"/embedding", |
|
summary="Calculate embedding for input text(s)", |
|
)(self.calc_embedding) |
|
|
|
self.app.get( |
|
"/readme", |
|
summary="README of Embed API", |
|
response_class=HTMLResponse, |
|
include_in_schema=False, |
|
)(self.get_readme) |
|
|
|
|
|
class ArgParser(argparse.ArgumentParser): |
|
def __init__(self, *args, **kwargs): |
|
super(ArgParser, self).__init__(*args, **kwargs) |
|
|
|
self.add_argument( |
|
"-s", |
|
"--server", |
|
type=str, |
|
default=ENVER["server"], |
|
help=f"Server IP ({ENVER['server']}) for Embed API", |
|
) |
|
self.add_argument( |
|
"-p", |
|
"--port", |
|
type=int, |
|
default=ENVER["port"], |
|
help=f"Server Port ({ENVER['port']}) for Embed API", |
|
) |
|
|
|
self.args = self.parse_args(sys.argv[1:]) |
|
|
|
|
|
app = EmbeddingApp().app |
|
|
|
if __name__ == "__main__": |
|
args = ArgParser().args |
|
uvicorn.run("__main__:app", host=args.server, port=args.port) |
|
|
|
|
|
|