from fastapi import FastAPI, HTTPException from pydantic import BaseModel from huggingface_hub import InferenceClient import uvicorn app = FastAPI() client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") class Item(BaseModel): prompt: str history: list system_prompt: str temperature: float = 0.01 top_p: float = 1.0 details: bool = True return_full_text: bool = False stream: bool = False def format_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt def generate(item: Item): temperature = float(item.temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(item.top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=1048, top_p=top_p, repetition_penalty=1.0, do_sample=True, seed=42, ) formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history) # stream = client.text_generation( formatted_prompt, **generate_kwargs, # stream=item.stream, stream=False, details=item.details, return_full_text=item.return_full_text ) # return stream output = "" for response in stream: # Check if response has the attribute 'token' if hasattr(response, 'tokens'): print('tokens') output += response.token.text else: output += response # If not, treat it as a string return [{'msg': output}] # return output # def generate(item: Item): # temperature = float(item.temperature) # if temperature < 1e-2: # temperature = 1e-2 # top_p = float(item.top_p) # generate_kwargs = dict( # temperature=temperature, # max_new_tokens=1048, # top_p=top_p, # repetition_penalty=1.0, # do_sample=True, # seed=42, # ) # formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history) # stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text) # output = "" # for response in stream: # output += response.token.text # return output @app.post("/generate/") async def generate_text(item: Item): try: response = generate(item) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # @app.get("/health") # async def health_check(): # return { # "status": "healthy", # "huggingface_client": "initialized", # "auth_required": True # }