Spaces:
Build error
Build error
import json | |
from typing import List | |
import fastapi | |
import markdown | |
import uvicorn | |
from ctransformers import AutoModelForCausalLM | |
from fastapi import HTTPException | |
from fastapi.responses import HTMLResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from sse_starlette.sse import EventSourceResponse | |
from pydantic import BaseModel, Field | |
from typing_extensions import Literal | |
from dialogue import DialogueTemplate | |
llm = AutoModelForCausalLM.from_pretrained("gsaivinay/airoboros-13B-gpt4-1.3-GGML", | |
model_file="airoboros-13b-gpt4-1.3.ggmlv3.q4_1.bin", | |
model_type="llama") | |
app = fastapi.FastAPI(title="Starchat Beta") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def index(): | |
with open("README.md", "r", encoding="utf-8") as readme_file: | |
md_template_string = readme_file.read() | |
html_content = markdown.markdown(md_template_string) | |
return HTMLResponse(content=html_content, status_code=200) | |
async def chat(prompt = "<|user|> Write an express server with server sent events. <|assistant|>"): | |
tokens = llm.tokenize(prompt) | |
async def server_sent_events(chat_chunks, llm): | |
yield prompt | |
for chat_chunk in llm.generate(chat_chunks): | |
yield llm.detokenize(chat_chunk) | |
yield "" | |
return EventSourceResponse(server_sent_events(tokens, llm)) | |
class ChatCompletionRequestMessage(BaseModel): | |
role: Literal["system", "user", "assistant"] = Field( | |
default="user", description="The role of the message." | |
) | |
content: str = Field(default="", description="The content of the message.") | |
class ChatCompletionRequest(BaseModel): | |
messages: List[ChatCompletionRequestMessage] = Field( | |
default=[], description="A list of messages to generate completions for." | |
) | |
system_message = "Below is a conversation between a human user and a helpful AI coding assistant." | |
async def chat(request: ChatCompletionRequest): | |
kwargs = request.dict() | |
dialogue_template = DialogueTemplate( | |
system=system_message, messages=kwargs['messages'] | |
) | |
prompt = dialogue_template.get_inference_prompt() | |
tokens = llm.tokenize(combined_messages) | |
try: | |
chat_chunks = llm.generate(tokens) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def format_response(chat_chunks: Generator) -> Any: | |
for chat_chunk in chat_chunks: | |
response = { | |
'choices': [ | |
{ | |
'message': { | |
'role': 'system', | |
'content': llm.detokenize(chat_chunk) | |
}, | |
'finish_reason': 'stop' if llm.detokenize(chat_chunk) == "[DONE]" else 'unknown' | |
} | |
] | |
} | |
yield f"data: {json.dumps(response)}\n\n" | |
yield "event: done\ndata: {}\n\n" | |
return EventSourceResponse(format_response(chat_chunks), media_type="text/event-stream") | |
async def chatV0(request: ChatCompletionRequest, response_mode=None): | |
kwargs = request.dict() | |
dialogue_template = DialogueTemplate( | |
system=system_message, messages=kwargs['messages'] | |
) | |
prompt = dialogue_template.get_inference_prompt() | |
tokens = llm.tokenize(prompt) | |
async def server_sent_events(chat_chunks, llm): | |
for token in llm.generate(chat_chunks): | |
yield dict(data=llm.detokenize(token)) | |
yield dict(data="[DONE]") | |
return EventSourceResponse(server_sent_events(tokens, llm)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |