App / main.py
Yjhhh's picture
Create main.py
bd132d0 verified
import json
import os
import uuid
from typing import AsyncGenerator, NoReturn
import google.generativeai as genai
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket
load_dotenv()
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel("gemini-pro")
app = FastAPI()
PROMPT = """
You are a helpful assistant, skilled in explaining complex concepts in simple terms.
{message}
""" # noqa: E501
IMAGE_PROMPT = """
Generate an image based on the following description:
{description}
""" # noqa: E501
async def get_ai_response(message: str) -> AsyncGenerator[str, None]:
"""
Gemini Response
"""
response = await model.generate_content_async(
PROMPT.format(message=message), stream=True
)
msg_id = str(uuid.uuid4())
all_text = ""
async for chunk in response:
if chunk.candidates:
for part in chunk.candidates[0].content.parts:
all_text += part.text
yield json.dumps({"id": msg_id, "text": all_text})
async def get_ai_image(description: str) -> str:
"""
Gemini Image Generation
"""
response = await model.generate_image_async(
IMAGE_PROMPT.format(description=description)
)
if response.images:
# Assuming we take the first generated image
return json.dumps({"image_url": response.images[0].url})
return json.dumps({"error": "No image generated"})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> NoReturn:
"""
Websocket for AI responses
"""
await websocket.accept()
while True:
message = await websocket.receive_text()
async for text in get_ai_response(message):
await websocket.send_text(text)
@app.post("/generate-image/")
async def generate_image_endpoint(description: str):
"""
Endpoint for AI image generation
"""
image_url = await get_ai_image(description)
return json.loads(image_url)
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860
)