Spaces:
Sleeping
Sleeping
import json | |
import multiprocessing | |
from re import compile, Match, Pattern | |
from threading import Lock | |
from functools import partial | |
from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict | |
from typing_extensions import TypedDict, Literal | |
import anyio | |
from anyio.streams.memory import MemoryObjectSendStream | |
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool | |
from fastapi import Depends, FastAPI, APIRouter, Request, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from fastapi.routing import APIRoute | |
from pydantic import BaseModel, Field | |
from pydantic_settings import BaseSettings | |
from sse_starlette.sse import EventSourceResponse | |
from llama2_wrapper.model import LLAMA2_WRAPPER | |
from llama2_wrapper.types import ( | |
Completion, | |
CompletionChunk, | |
ChatCompletion, | |
ChatCompletionChunk, | |
) | |
class Settings(BaseSettings): | |
model_path: str = Field( | |
default="", | |
description="The path to the model to use for generating completions.", | |
) | |
backend_type: str = Field( | |
default="llama.cpp", | |
description="Backend for llama2, options: llama.cpp, gptq, transformers", | |
) | |
max_tokens: int = Field(default=4000, ge=1, description="Maximum context size.") | |
load_in_8bit: bool = Field( | |
default=False, | |
description="`Whether to use bitsandbytes to run model in 8 bit mode (only for transformers models).", | |
) | |
verbose: bool = Field( | |
default=False, | |
description="Whether to print verbose output to stderr.", | |
) | |
host: str = Field(default="localhost", description="API address") | |
port: int = Field(default=8000, description="API port") | |
interrupt_requests: bool = Field( | |
default=True, | |
description="Whether to interrupt requests when a new request is received.", | |
) | |
class ErrorResponse(TypedDict): | |
"""OpenAI style error response""" | |
message: str | |
type: str | |
param: Optional[str] | |
code: Optional[str] | |
class ErrorResponseFormatters: | |
"""Collection of formatters for error responses. | |
Args: | |
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): | |
Request body | |
match (Match[str]): Match object from regex pattern | |
Returns: | |
Tuple[int, ErrorResponse]: Status code and error response | |
""" | |
def context_length_exceeded( | |
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
match, # type: Match[str] # type: ignore | |
) -> Tuple[int, ErrorResponse]: | |
"""Formatter for context length exceeded error""" | |
context_window = int(match.group(2)) | |
prompt_tokens = int(match.group(1)) | |
completion_tokens = request.max_new_tokens | |
if hasattr(request, "messages"): | |
# Chat completion | |
message = ( | |
"This model's maximum context length is {} tokens. " | |
"However, you requested {} tokens " | |
"({} in the messages, {} in the completion). " | |
"Please reduce the length of the messages or completion." | |
) | |
else: | |
# Text completion | |
message = ( | |
"This model's maximum context length is {} tokens, " | |
"however you requested {} tokens " | |
"({} in your prompt; {} for the completion). " | |
"Please reduce your prompt; or completion length." | |
) | |
return 400, ErrorResponse( | |
message=message.format( | |
context_window, | |
completion_tokens + prompt_tokens, | |
prompt_tokens, | |
completion_tokens, | |
), | |
type="invalid_request_error", | |
param="messages", | |
code="context_length_exceeded", | |
) | |
def model_not_found( | |
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
match, # type: Match[str] # type: ignore | |
) -> Tuple[int, ErrorResponse]: | |
"""Formatter for model_not_found error""" | |
model_path = str(match.group(1)) | |
message = f"The model `{model_path}` does not exist" | |
return 400, ErrorResponse( | |
message=message, | |
type="invalid_request_error", | |
param=None, | |
code="model_not_found", | |
) | |
class RouteErrorHandler(APIRoute): | |
"""Custom APIRoute that handles application errors and exceptions""" | |
# key: regex pattern for original error message from llama_cpp | |
# value: formatter function | |
pattern_and_formatters: Dict[ | |
"Pattern", | |
Callable[ | |
[ | |
Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
"Match[str]", | |
], | |
Tuple[int, ErrorResponse], | |
], | |
] = { | |
compile( | |
r"Requested tokens \((\d+)\) exceed context window of (\d+)" | |
): ErrorResponseFormatters.context_length_exceeded, | |
compile( | |
r"Model path does not exist: (.+)" | |
): ErrorResponseFormatters.model_not_found, | |
} | |
def error_message_wrapper( | |
self, | |
error: Exception, | |
body: Optional[ | |
Union[ | |
"CreateChatCompletionRequest", | |
"CreateCompletionRequest", | |
] | |
] = None, | |
) -> Tuple[int, ErrorResponse]: | |
"""Wraps error message in OpenAI style error response""" | |
if body is not None and isinstance( | |
body, | |
( | |
CreateCompletionRequest, | |
CreateChatCompletionRequest, | |
), | |
): | |
# When text completion or chat completion | |
for pattern, callback in self.pattern_and_formatters.items(): | |
match = pattern.search(str(error)) | |
if match is not None: | |
return callback(body, match) | |
# Wrap other errors as internal server error | |
return 500, ErrorResponse( | |
message=str(error), | |
type="internal_server_error", | |
param=None, | |
code=None, | |
) | |
def get_route_handler( | |
self, | |
) -> Callable[[Request], Coroutine[None, None, Response]]: | |
"""Defines custom route handler that catches exceptions and formats | |
in OpenAI style error response""" | |
original_route_handler = super().get_route_handler() | |
async def custom_route_handler(request: Request) -> Response: | |
try: | |
return await original_route_handler(request) | |
except Exception as exc: | |
json_body = await request.json() | |
try: | |
if "messages" in json_body: | |
# Chat completion | |
body: Optional[ | |
Union[ | |
CreateChatCompletionRequest, | |
CreateCompletionRequest, | |
] | |
] = CreateChatCompletionRequest(**json_body) | |
elif "prompt" in json_body: | |
# Text completion | |
body = CreateCompletionRequest(**json_body) | |
# else: | |
# # Embedding | |
# body = CreateEmbeddingRequest(**json_body) | |
except Exception: | |
# Invalid request body | |
body = None | |
# Get proper error message from the exception | |
( | |
status_code, | |
error_message, | |
) = self.error_message_wrapper(error=exc, body=body) | |
return JSONResponse( | |
{"error": error_message}, | |
status_code=status_code, | |
) | |
return custom_route_handler | |
router = APIRouter(route_class=RouteErrorHandler) | |
settings: Optional[Settings] = None | |
llama2: Optional[LLAMA2_WRAPPER] = None | |
def create_app(settings: Optional[Settings] = None): | |
if settings is None: | |
settings = Settings() | |
app = FastAPI( | |
title="llama2-wrapper Fast API", | |
version="0.0.1", | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.include_router(router) | |
global llama2 | |
llama2 = LLAMA2_WRAPPER( | |
model_path=settings.model_path, | |
backend_type=settings.backend_type, | |
max_tokens=settings.max_tokens, | |
load_in_8bit=settings.load_in_8bit, | |
verbose=settings.load_in_8bit, | |
) | |
def set_settings(_settings: Settings): | |
global settings | |
settings = _settings | |
set_settings(settings) | |
return app | |
llama_outer_lock = Lock() | |
llama_inner_lock = Lock() | |
def get_llama(): | |
# NOTE: This double lock allows the currently streaming llama model to | |
# check if any other requests are pending in the same thread and cancel | |
# the stream if so. | |
llama_outer_lock.acquire() | |
release_outer_lock = True | |
try: | |
llama_inner_lock.acquire() | |
try: | |
llama_outer_lock.release() | |
release_outer_lock = False | |
yield llama2 | |
finally: | |
llama_inner_lock.release() | |
finally: | |
if release_outer_lock: | |
llama_outer_lock.release() | |
def get_settings(): | |
yield settings | |
async def get_event_publisher( | |
request: Request, | |
inner_send_chan: MemoryObjectSendStream, | |
iterator: Iterator, | |
): | |
async with inner_send_chan: | |
try: | |
async for chunk in iterate_in_threadpool(iterator): | |
await inner_send_chan.send(dict(data=json.dumps(chunk))) | |
if await request.is_disconnected(): | |
raise anyio.get_cancelled_exc_class()() | |
if settings.interrupt_requests and llama_outer_lock.locked(): | |
await inner_send_chan.send(dict(data="[DONE]")) | |
raise anyio.get_cancelled_exc_class()() | |
await inner_send_chan.send(dict(data="[DONE]")) | |
except anyio.get_cancelled_exc_class() as e: | |
print("disconnected") | |
with anyio.move_on_after(1, shield=True): | |
print(f"Disconnected from client (via refresh/close) {request.client}") | |
raise e | |
stream_field = Field( | |
default=False, | |
description="Whether to stream the results as they are generated. Useful for chatbots.", | |
) | |
max_new_tokens_field = Field( | |
default=1000, ge=1, description="The maximum number of tokens to generate." | |
) | |
temperature_field = Field( | |
default=0.9, | |
ge=0.0, | |
le=2.0, | |
description="The temperature to use for sampling.", | |
) | |
top_p_field = Field( | |
default=1.0, | |
ge=0.0, | |
le=1.0, | |
description="The top-p value to use for sampling.", | |
) | |
top_k_field = Field( | |
default=40, | |
ge=0, | |
description="The top-k value to use for sampling.", | |
) | |
repetition_penalty_field = Field( | |
default=1.0, | |
ge=0.0, | |
description="The penalty to apply to repeated tokens.", | |
) | |
# stop_field = Field( | |
# default=None, | |
# description="A list of tokens at which to stop generation. If None, no stop tokens are used.", | |
# ) | |
class CreateCompletionRequest(BaseModel): | |
prompt: Union[str, List[str]] = Field( | |
default="", description="The prompt to generate text from." | |
) | |
stream: bool = stream_field | |
max_new_tokens: int = max_new_tokens_field | |
temperature: float = temperature_field | |
top_p: float = top_p_field | |
top_k: int = top_k_field | |
repetition_penalty: float = repetition_penalty_field | |
# stop: Optional[Union[str, List[str]]] = stop_field | |
model_config = { | |
"json_schema_extra": { | |
"examples": [ | |
{ | |
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", | |
# "stop": ["\n", "###"], | |
} | |
] | |
} | |
} | |
async def create_completion( | |
request: Request, | |
body: CreateCompletionRequest, | |
llama2: LLAMA2_WRAPPER = Depends(get_llama), | |
) -> Completion: | |
if isinstance(body.prompt, list): | |
assert len(body.prompt) <= 1 | |
body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" | |
kwargs = body.model_dump() | |
iterator_or_completion: Union[ | |
Completion, Iterator[CompletionChunk] | |
] = await run_in_threadpool(llama2.completion, **kwargs) | |
if isinstance(iterator_or_completion, Iterator): | |
first_response = await run_in_threadpool(next, iterator_or_completion) | |
# If no exception was raised from first_response, we can assume that | |
# the iterator is valid and we can use it to stream the response. | |
def iterator() -> Iterator[CompletionChunk]: | |
yield first_response | |
yield from iterator_or_completion | |
send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
return EventSourceResponse( | |
recv_chan, | |
data_sender_callable=partial( # type: ignore | |
get_event_publisher, | |
request=request, | |
inner_send_chan=send_chan, | |
iterator=iterator(), | |
), | |
) | |
else: | |
return iterator_or_completion | |
class ChatCompletionRequestMessage(BaseModel): | |
role: Literal["system", "user", "assistant"] = Field( | |
default="user", description="The role of the message." | |
) | |
content: str = Field(default="", description="The content of the message.") | |
class CreateChatCompletionRequest(BaseModel): | |
messages: List[ChatCompletionRequestMessage] = Field( | |
default=[], description="A list of messages to generate completions for." | |
) | |
stream: bool = stream_field | |
max_new_tokens: int = max_new_tokens_field | |
temperature: float = temperature_field | |
top_p: float = top_p_field | |
top_k: int = top_k_field | |
repetition_penalty: float = repetition_penalty_field | |
# stop: Optional[List[str]] = stop_field | |
model_config = { | |
"json_schema_extra": { | |
"examples": [ | |
{ | |
"messages": [ | |
ChatCompletionRequestMessage( | |
role="system", content="You are a helpful assistant." | |
).model_dump(), | |
ChatCompletionRequestMessage( | |
role="user", content="What is the capital of France?" | |
).model_dump(), | |
] | |
} | |
] | |
} | |
} | |
async def create_chat_completion( | |
request: Request, | |
body: CreateChatCompletionRequest, | |
llama2: LLAMA2_WRAPPER = Depends(get_llama), | |
settings: Settings = Depends(get_settings), | |
) -> ChatCompletion: | |
kwargs = body.model_dump() | |
iterator_or_completion: Union[ | |
ChatCompletion, Iterator[ChatCompletionChunk] | |
] = await run_in_threadpool(llama2.chat_completion, **kwargs) | |
if isinstance(iterator_or_completion, Iterator): | |
first_response = await run_in_threadpool(next, iterator_or_completion) | |
# If no exception was raised from first_response, we can assume that | |
# the iterator is valid and we can use it to stream the response. | |
def iterator() -> Iterator[ChatCompletionChunk]: | |
yield first_response | |
yield from iterator_or_completion | |
send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
return EventSourceResponse( | |
recv_chan, | |
data_sender_callable=partial( # type: ignore | |
get_event_publisher, | |
request=request, | |
inner_send_chan=send_chan, | |
iterator=iterator(), | |
), | |
) | |
else: | |
return iterator_or_completion | |
class ModelData(TypedDict): | |
id: str | |
object: Literal["model"] | |
owned_by: str | |
permissions: List[str] | |
class ModelList(TypedDict): | |
object: Literal["list"] | |
data: List[ModelData] | |
async def get_models( | |
settings: Settings = Depends(get_settings), | |
) -> ModelList: | |
assert llama2 is not None | |
return { | |
"object": "list", | |
"data": [ | |
{ | |
"id": settings.backend_type + " default model" | |
if settings.model_path == "" | |
else settings.model_path, | |
"object": "model", | |
"owned_by": "me", | |
"permissions": [], | |
} | |
], | |
} | |