Spaces:
Sleeping
Sleeping
File size: 2,876 Bytes
e5928ae 2efb72f e5928ae 9b6975c 2efb72f 9b6975c e5928ae 9b6975c e5928ae 44b50c2 e5928ae 44b50c2 2efb72f e5928ae d53e153 8d5e1fd d53e153 a97e214 bd72865 a97e214 7d56800 4e976d5 a97e214 f01378a a97e214 ce5d4ab b4bd245 cc8c305 a97e214 defc45e 2efb72f e5928ae 2efb72f e5928ae 2efb72f e5928ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.01
top_p: float = 1.0
details: bool = True
return_full_text: bool = False
stream: bool = False
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(item: Item):
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=1048,
top_p=top_p,
repetition_penalty=1.0,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
#
stream = client.text_generation(
formatted_prompt,
**generate_kwargs,
# stream=item.stream,
stream=False,
details=item.details,
return_full_text=item.return_full_text
)
# return stream
output = ""
for response in stream:
# Check if response has the attribute 'token'
if hasattr(response, 'tokens'):
print('tokens')
output += response.token.text
else:
output += response # If not, treat it as a string
return [{'msg': output}]
# return output
# def generate(item: Item):
# temperature = float(item.temperature)
# if temperature < 1e-2:
# temperature = 1e-2
# top_p = float(item.top_p)
# generate_kwargs = dict(
# temperature=temperature,
# max_new_tokens=1048,
# top_p=top_p,
# repetition_penalty=1.0,
# do_sample=True,
# seed=42,
# )
# formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
# stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text)
# output = ""
# for response in stream:
# output += response.token.text
# return output
@app.post("/generate/")
async def generate_text(item: Item):
try:
response = generate(item)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# @app.get("/health")
# async def health_check():
# return {
# "status": "healthy",
# "huggingface_client": "initialized",
# "auth_required": True
# }
|