Spaces:
Runtime error
Runtime error
import torch | |
from transformers import pipeline, BitsAndBytesConfig | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
# Set up device (CPU or GPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Configure quantization if using GPU | |
if device == "cuda": | |
print("GPU found. Using 4-bit quantization.") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
else: | |
print("GPU not found. Using CPU with default settings.") | |
quantization_config = None | |
# Load model pipeline | |
model_id = "bczhou/tiny-llava-v1-hf" | |
pipe = pipeline("image-to-text", model=model_id, device=device) | |
print(f"Using device: {device}") | |
# Initialize FastAPI application | |
app = FastAPI() | |
# Health check endpoint to ensure API is running | |
async def root(): | |
return {"message": "API is running fine."} | |
# Define Pydantic model for request input | |
class ImagePromptInput(BaseModel): | |
image_url: str | |
prompt: str | |
# FastAPI route for generating text from an image | |
async def generate_text(input_data: ImagePromptInput): | |
try: | |
# Download and process the image | |
response = requests.get(input_data.image_url) | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
image = image.resize((750, 500)) # Resize image to fixed dimensions | |
# Create a full prompt to pass to the model | |
full_prompt = f"USER: <image>\n{input_data.prompt}\nASSISTANT: " | |
# Generate response using the model pipeline | |
outputs = pipe(image, prompt=full_prompt, generate_kwargs={"max_new_tokens": 200}) | |
# Return generated text | |
generated_text = outputs[0]['generated_text'] #type: ignore | |
return {"response": generated_text} | |
except Exception as e: | |
# Return error if something goes wrong | |
raise HTTPException(status_code=500, detail=str(e)) | |