from fastapi import FastAPI, UploadFile, File, HTTPException from PIL import Image import torch import logging from models import GarbageClassifier app = FastAPI() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load model and processor model_dir = "youssefabdelmottaleb/Garbage-Classification-SWIN-Transformer" classifier = GarbageClassifier(model_dir) # Endpoint to receive images and return predictions @app.post("/predict") async def predict_endpoint(file: UploadFile = File(...)): try: image = Image.open(file.file).convert("RGB") inputs = classifier.processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = classifier.model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_class = classifier.labels[predicted_class_idx] return {"class": predicted_class} except Exception as e: logger.error(f"Error processing image: {e}") raise HTTPException(status_code=500, detail="Error processing image") image_path = "C:/Users/youss/Downloads/paperr.jpg" classifier.evaluate_image(image_path) # To run the app: uvicorn app:app --host 0.0.0.0 --port 8000