youssefabdelmottaleb's picture
Add: SWIN-Transformer-Model-Deployment
9b01bfc
from fastapi import FastAPI, UploadFile, File, HTTPException
from PIL import Image
import torch
import logging
from models import GarbageClassifier
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load model and processor
model_dir = "youssefabdelmottaleb/Garbage-Classification-SWIN-Transformer"
classifier = GarbageClassifier(model_dir)
# Endpoint to receive images and return predictions
@app.post("/predict")
async def predict_endpoint(file: UploadFile = File(...)):
try:
image = Image.open(file.file).convert("RGB")
inputs = classifier.processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = classifier.model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = classifier.labels[predicted_class_idx]
return {"class": predicted_class}
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail="Error processing image")
image_path = "C:/Users/youss/Downloads/paperr.jpg"
classifier.evaluate_image(image_path)
# To run the app: uvicorn app:app --host 0.0.0.0 --port 8000