import argparse import os import random import io from PIL import Image import numpy as np import torch import torch.backends.cudnn as cudnn from typing import List from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry from minigpt4.conversation.conversation import Chat, CONV_VISION from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi.responses import RedirectResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from PIL import Image import io import uvicorn # imports modules for registration from minigpt4.datasets.builders import * from minigpt4.models import * from minigpt4.processors import * from minigpt4.runners import * from minigpt4.tasks import * def parse_args(): parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4_eval.yaml', help="path to configuration file.") parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) args = parser.parse_args() return args def setup_seeds(config): seed = config.run_cfg.seed + get_rank() random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) cudnn.benchmark = False cudnn.deterministic = True # ======================================== # Model Initialization # ======================================== SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue. You can duplicate and use it with a paid private GPU. Duplicate Space Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io). ''' print('Initializing Chat') cfg = Config(parse_args()) model_config = cfg.model_cfg model_cls = registry.get_model_class(model_config.arch) model = model_cls.from_config(model_config).to('cuda:0') vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) chat = Chat(model, vis_processor) print('Initialization Finished') # ======================================== # Gradio Setting # ======================================== app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Replace "*" with your frontend domain allow_credentials=True, allow_methods=["GET", "POST"], allow_headers=["*"], ) class Item(BaseModel): gr_img: UploadFile = File(..., description="Image file") text_input: str = None chat_state = CONV_VISION.copy() img_list = [] chatbot = [] @app.get("/") async def root(): return RedirectResponse(url="/docs") @app.post("/upload_img/") async def upload_img( file: UploadFile = File(...), ): pil_image = Image.open(io.BytesIO(await file.read())) chat.upload_img(pil_image, chat_state, img_list) return {"message": "image uploaded successfully."} @app.post("/process/") async def process_item(prompts: List[str] = Form(...)): if not img_list: # Check if img_list is empty or None raise HTTPException(status_code=400, detail="No images uploaded.") global chatbot responses = [] for prompt in prompts: # Process each prompt individually chat.ask(prompt, chat_state) chatbot.append([prompt, None]) llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=1, max_length=2000)[0] chatbot[-1][1] = llm_message responses.append({ "prompt": prompt, "response": llm_message }) return responses @app.post("/reset/") async def reset( ): global chat_state, img_list, chatbot # Use global keyword to reassign img_list = [] if chat_state is not None: chat_state.messages = [] if img_list is not None: img_list = [] if chatbot is not None: chatbot = [] if __name__ == "__main__": # Run the FastAPI app with Uvicorn uvicorn.run("main:app", host="0.0.0.0", port=7860)