gsaivinay commited on
Commit
9f97c26
1 Parent(s): abe31cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -2
app.py CHANGED
@@ -1,3 +1,109 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/gsaivinay/airoboros-13B-gpt4-1.3-GGML").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+ import fastapi
4
+ import markdown
5
+ import uvicorn
6
+ from ctransformers import AutoModelForCausalLM
7
+ from fastapi import HTTPException
8
+ from fastapi.responses import HTMLResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from sse_starlette.sse import EventSourceResponse
11
+ from pydantic import BaseModel, Field
12
+ from typing_extensions import Literal
13
+ from dialogue import DialogueTemplate
14
 
15
+ llm = AutoModelForCausalLM.from_pretrained("gsaivinay/airoboros-13B-gpt4-1.3-GGML",
16
+ model_file="airoboros-13b-gpt4-1.3.ggmlv3.q4_1.bin",
17
+ model_type="llama")
18
+
19
+ app = fastapi.FastAPI(title="Starchat Beta")
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ @app.get("/")
29
+ async def index():
30
+ with open("README.md", "r", encoding="utf-8") as readme_file:
31
+ md_template_string = readme_file.read()
32
+ html_content = markdown.markdown(md_template_string)
33
+ return HTMLResponse(content=html_content, status_code=200)
34
+
35
+
36
+ @app.get("/stream")
37
+ async def chat(prompt = "<|user|> Write an express server with server sent events. <|assistant|>"):
38
+ tokens = llm.tokenize(prompt)
39
+ async def server_sent_events(chat_chunks, llm):
40
+ yield prompt
41
+ for chat_chunk in llm.generate(chat_chunks):
42
+ yield llm.detokenize(chat_chunk)
43
+ yield ""
44
+
45
+ return EventSourceResponse(server_sent_events(tokens, llm))
46
+
47
+
48
+ class ChatCompletionRequestMessage(BaseModel):
49
+ role: Literal["system", "user", "assistant"] = Field(
50
+ default="user", description="The role of the message."
51
+ )
52
+ content: str = Field(default="", description="The content of the message.")
53
+
54
+ class ChatCompletionRequest(BaseModel):
55
+ messages: List[ChatCompletionRequestMessage] = Field(
56
+ default=[], description="A list of messages to generate completions for."
57
+ )
58
+
59
+ system_message = "Below is a conversation between a human user and a helpful AI coding assistant."
60
+
61
+ @app.post("/v1/chat/completions")
62
+ async def chat(request: ChatCompletionRequest):
63
+ kwargs = request.dict()
64
+ dialogue_template = DialogueTemplate(
65
+ system=system_message, messages=kwargs['messages']
66
+ )
67
+ prompt = dialogue_template.get_inference_prompt()
68
+ tokens = llm.tokenize(combined_messages)
69
+
70
+ try:
71
+ chat_chunks = llm.generate(tokens)
72
+ except Exception as e:
73
+ raise HTTPException(status_code=500, detail=str(e))
74
+
75
+ async def format_response(chat_chunks: Generator) -> Any:
76
+ for chat_chunk in chat_chunks:
77
+ response = {
78
+ 'choices': [
79
+ {
80
+ 'message': {
81
+ 'role': 'system',
82
+ 'content': llm.detokenize(chat_chunk)
83
+ },
84
+ 'finish_reason': 'stop' if llm.detokenize(chat_chunk) == "[DONE]" else 'unknown'
85
+ }
86
+ ]
87
+ }
88
+ yield f"data: {json.dumps(response)}\n\n"
89
+ yield "event: done\ndata: {}\n\n"
90
+
91
+ return EventSourceResponse(format_response(chat_chunks), media_type="text/event-stream")
92
+
93
+ @app.post("/v0/chat/completions")
94
+ async def chatV0(request: ChatCompletionRequest, response_mode=None):
95
+ kwargs = request.dict()
96
+ dialogue_template = DialogueTemplate(
97
+ system=system_message, messages=kwargs['messages']
98
+ )
99
+ prompt = dialogue_template.get_inference_prompt()
100
+ tokens = llm.tokenize(prompt)
101
+ async def server_sent_events(chat_chunks, llm):
102
+ for token in llm.generate(chat_chunks):
103
+ yield dict(data=llm.detokenize(token))
104
+ yield dict(data="[DONE]")
105
+
106
+ return EventSourceResponse(server_sent_events(tokens, llm))
107
+
108
+ if __name__ == "__main__":
109
+ uvicorn.run(app, host="0.0.0.0", port=8000)