llava-fastapi / main.py
aa1223's picture
Upload main.py
d1fced8 verified
raw
history blame
2.06 kB
import torch
from transformers import pipeline, BitsAndBytesConfig
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
from PIL import Image
from io import BytesIO
# Set up device (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure quantization if using GPU
if device == "cuda":
print("GPU found. Using 4-bit quantization.")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
else:
print("GPU not found. Using CPU with default settings.")
quantization_config = None
# Load model pipeline
model_id = "bczhou/tiny-llava-v1-hf"
pipe = pipeline("image-to-text", model=model_id, device=device)
print(f"Using device: {device}")
# Initialize FastAPI application
app = FastAPI()
# Health check endpoint to ensure API is running
@app.get("/")
async def root():
return {"message": "API is running fine."}
# Define Pydantic model for request input
class ImagePromptInput(BaseModel):
image_url: str
prompt: str
# FastAPI route for generating text from an image
@app.post("/generate")
async def generate_text(input_data: ImagePromptInput):
try:
# Download and process the image
response = requests.get(input_data.image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
image = image.resize((750, 500)) # Resize image to fixed dimensions
# Create a full prompt to pass to the model
full_prompt = f"USER: <image>\n{input_data.prompt}\nASSISTANT: "
# Generate response using the model pipeline
outputs = pipe(image, prompt=full_prompt, generate_kwargs={"max_new_tokens": 200})
# Return generated text
generated_text = outputs[0]['generated_text'] #type: ignore
return {"response": generated_text}
except Exception as e:
# Return error if something goes wrong
raise HTTPException(status_code=500, detail=str(e))