from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
import gradio as gr | |
# Your existing Gradio interface code here | |
def image_generator(prompt): | |
# Your image generation logic here | |
pass | |
interface = gr.Interface(fn=image_generator, inputs="text", outputs="image") | |
# Wrap the Gradio interface with FastAPI | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Add your domain(s) here | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Mount the Gradio app | |
app = gr.mount_gradio_app(app, interface, path="/") | |
# If you're not using Gradio, you can just use FastAPI directly: | |
# @app.post("/predict") | |
# async def predict(data: dict): | |
# # Your prediction logic here | |
# pass |