OjciecTadeusz's picture
Update main.py
f01378a verified
raw
history blame
2.88 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.01
top_p: float = 1.0
details: bool = True
return_full_text: bool = False
stream: bool = False
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(item: Item):
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=1048,
top_p=top_p,
repetition_penalty=1.0,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
#
stream = client.text_generation(
formatted_prompt,
**generate_kwargs,
# stream=item.stream,
stream=False,
details=item.details,
return_full_text=item.return_full_text
)
# return stream
output = ""
for response in stream:
# Check if response has the attribute 'token'
if hasattr(response, 'tokens'):
print('tokens')
output += response.token.text
else:
output += response # If not, treat it as a string
return [{'msg': output}]
# return output
# def generate(item: Item):
# temperature = float(item.temperature)
# if temperature < 1e-2:
# temperature = 1e-2
# top_p = float(item.top_p)
# generate_kwargs = dict(
# temperature=temperature,
# max_new_tokens=1048,
# top_p=top_p,
# repetition_penalty=1.0,
# do_sample=True,
# seed=42,
# )
# formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
# stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text)
# output = ""
# for response in stream:
# output += response.token.text
# return output
@app.post("/generate/")
async def generate_text(item: Item):
try:
response = generate(item)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# @app.get("/health")
# async def health_check():
# return {
# "status": "healthy",
# "huggingface_client": "initialized",
# "auth_required": True
# }