|
import json |
|
import os |
|
import uuid |
|
from typing import AsyncGenerator, NoReturn |
|
|
|
import google.generativeai as genai |
|
import uvicorn |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI, WebSocket |
|
|
|
load_dotenv() |
|
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
model = genai.GenerativeModel("gemini-pro") |
|
|
|
app = FastAPI() |
|
|
|
PROMPT = """ |
|
You are a helpful assistant, skilled in explaining complex concepts in simple terms. |
|
|
|
{message} |
|
""" |
|
|
|
IMAGE_PROMPT = """ |
|
Generate an image based on the following description: |
|
|
|
{description} |
|
""" |
|
|
|
async def get_ai_response(message: str) -> AsyncGenerator[str, None]: |
|
""" |
|
Gemini Response |
|
""" |
|
response = await model.generate_content_async( |
|
PROMPT.format(message=message), stream=True |
|
) |
|
|
|
msg_id = str(uuid.uuid4()) |
|
all_text = "" |
|
async for chunk in response: |
|
if chunk.candidates: |
|
for part in chunk.candidates[0].content.parts: |
|
all_text += part.text |
|
yield json.dumps({"id": msg_id, "text": all_text}) |
|
|
|
async def get_ai_image(description: str) -> str: |
|
""" |
|
Gemini Image Generation |
|
""" |
|
response = await model.generate_image_async( |
|
IMAGE_PROMPT.format(description=description) |
|
) |
|
|
|
if response.images: |
|
|
|
return json.dumps({"image_url": response.images[0].url}) |
|
return json.dumps({"error": "No image generated"}) |
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket) -> NoReturn: |
|
""" |
|
Websocket for AI responses |
|
""" |
|
await websocket.accept() |
|
while True: |
|
message = await websocket.receive_text() |
|
async for text in get_ai_response(message): |
|
await websocket.send_text(text) |
|
|
|
@app.post("/generate-image/") |
|
async def generate_image_endpoint(description: str): |
|
""" |
|
Endpoint for AI image generation |
|
""" |
|
image_url = await get_ai_image(description) |
|
return json.loads(image_url) |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run( |
|
app, |
|
host="0.0.0.0", |
|
port=7860 |
|
) |
|
|