from fastapi import FastAPI, HTTPException from pydantic import BaseModel from huggingface_hub import InferenceClient import uvicorn from typing import List, Optional app = FastAPI() client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") class ChatMessage(BaseModel): role: str content: str class GenerationRequest(BaseModel): prompt: str message: Optional[str] = None system_message: Optional[str] = None history: Optional[List[ChatMessage]] = None temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str: prompt = "" # Add system message if provided if system_message: prompt += f"[INST] {system_message} [/INST]" # Add conversation history if history: for msg in history: if msg.role == "user": prompt += f"[INST] {msg.content} [/INST]" else: prompt += f" {msg.content}" # Add the current message prompt += f"[INST] {message} [/INST]" return prompt @app.post("/generate/") async def generate_text(request: GenerationRequest): try: # Use either prompt or message message = request.prompt if request.prompt else request.message if not message: raise HTTPException(status_code=400, detail="Either 'prompt' or 'message' must be provided") # Format the prompt with history and system message if provided formatted_prompt = format_prompt( message=message, history=request.history, system_message=request.system_message ) # Generate response params = { "temperature": max(request.temperature, 0.01), # Ensure temperature isn't too low "max_new_tokens": 1048, "top_p": request.top_p, "repetition_penalty": 1.0, "do_sample": True, "seed": 42 } # Generate the response - handling the response as a single string response = client.text_generation( formatted_prompt, **params ) # The response is now directly a string return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)