File size: 3,395 Bytes
2056905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import asyncio
import logging
import os
import time
from typing import List, Union

from pillow_heif import register_heif_opener

register_heif_opener()

import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from transformers import pipeline

LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
MAX_URLS = int(os.getenv("MAX_URLS", 5))
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 200))
# https://huggingface.co/models?pipeline_tag=image-to-text&sort=likes
MODEL = os.getenv("MODEL", "../models/Salesforce/blip-image-captioning-large")


logging.basicConfig(level=LOG_LEVEL)
logger = logging.getLogger(__name__)


app = FastAPI()


captioner = None  # Placeholder for the captioner pipeline
is_initialized = asyncio.Event()  # Event to track initialization status
lock = asyncio.Lock()


def load_model():
    global captioner
    logger.info("Loading model...")
    # simpler model: "ydshieh/vit-gpt2-coco-en"
    captioner = pipeline(
        "image-to-text",
        model=MODEL,
        max_new_tokens=MAX_NEW_TOKENS,
    )
    logger.info("Done loading model.")
    is_initialized.set()


class Image(BaseModel):
    url: Union[HttpUrl, List[HttpUrl]]  # url can be a string or a list of strings


@app.on_event("startup")
async def startup_event():
    global app
    asyncio.create_task(asyncio.to_thread(load_model))
    # add gradio interface
    iface = gr.Interface(fn=captioner_gradapter, inputs="text", outputs=["text"], allow_flagging="never")
    app = gr.mount_gradio_app(app, iface, path="/gradio")


async def captioner_gradapter(image_url):
    await is_initialized.wait()
    async with lock:
        result = await asyncio.to_thread(captioner, image_url)
        caption = result[0]["generated_text"]
    return caption


@app.get("/")
async def root():
    return {"message": "Hello World"}


# the image url is passed in as a "url" tag in the json body
@app.post("/caption/")
async def create_caption(image: Image):
    if isinstance(image.url, list) and len(image.url) > MAX_URLS:
        logger.debug(
            f"Request with more than {MAX_URLS} URLs received. Refusing the request."
        )

        raise HTTPException(
            status_code=400,
            detail=f"Maximum of {MAX_URLS} URLs can be processed at once",
        )
    async with lock:
        await is_initialized.wait()  # Wait until initialization is completed

        start_time = time.time()
        # get the image url from the json body
        image_url = image.url
        try:
            caption = await asyncio.to_thread(captioner, image_url)
        except Exception as e:
            logger.error("Error during caption generation: %s", str(e))
            raise HTTPException(
                status_code=500,
                detail="An error occurred during caption generation. Please try again later.",
            )
        end_time = time.time()
        duration = end_time - start_time
        logger.debug("Captioning completed. Time taken: %s seconds.", duration)

        return {"caption": caption, "duration": duration}


# add liveness probe
@app.get("/healthz")
async def healthz():
    return {"status": "ok"}


# add readiness probe
@app.get("/readyz")
async def readyz():
    if not is_initialized.is_set():
        raise HTTPException(status_code=503, detail="Initialization in progress")
    return {"status": "ok"}