from PIL import Image import torch import os import threading import time from dotenv import load_dotenv from azure.storage.blob import BlobServiceClient from transformers import SwinForImageClassification, AutoImageProcessor from models import GarbageClassifier CURRENT_DIR = os.getcwd() ROOT_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "..")) load_dotenv() # Define Azure Blob Storage settings and model path AZURE_CONNECTION_STRING = os.environ.get("AZURE_CONNECTION_STRING") AZURE_CONTAINER_NAME = os.environ.get("AZURE_CONTAINER_NAME") MODEL_PATH = os.environ.get("MODEL_PATH", "youssefabdelmottaleb/Garbage-Classification-SWIN-Transformer") # Initialize the Swin Transformer model and processor model = SwinForImageClassification.from_pretrained(MODEL_PATH) processor = AutoImageProcessor.from_pretrained(MODEL_PATH) # Function to save image to Azure Blob Storage def save_image_to_azure(image, result_class): # Convert PIL image to bytes from io import BytesIO image_bytes = BytesIO() image.save(image_bytes, format='JPEG') image_bytes = image_bytes.getvalue() # Create the BlobServiceClient object blob_service_client = BlobServiceClient.from_connection_string(AZURE_CONNECTION_STRING) # Create a unique blob name blob_name = f"predicted_images/{result_class}/{result_class}_{int(time.time())}.jpg" # Create a blob client using the local file name as the name for the blob blob_client = blob_service_client.get_blob_client(container=AZURE_CONTAINER_NAME, blob=blob_name) # Upload the created file blob_client.upload_blob(image_bytes) print(f"Image saved to Azure Blob Storage: {blob_name}") # Function to save image to local path def save_image_to_local_path(image, result_class): # Define the directory based on the class directory = os.path.join('./predicted_images', result_class) os.makedirs(directory, exist_ok=True) # Define the image path image_path = os.path.join(directory, f"{result_class}_{int(time.time())}.jpg") # Save the image image.save(image_path) print(f"Image saved to {image_path}") # Function to predict using Swin Transformer model def predict(image_file): image = Image.open(image_file).convert('RGB') # Ensure image is RGB device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Preprocess the image inputs = processor(images=image, return_tensors="pt") inputs = {k: v.unsqueeze(0).to(device) for k, v in inputs.items()} # Perform inference model.to(device) model.eval() with torch.no_grad(): outputs = model(**inputs) # Post-process the outputs to get the predicted class logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_class = processor.labels[predicted_class_idx] # Start a thread to save the image based on the result thread = threading.Thread(target=save_image_to_azure, args=(image, predicted_class)) thread.start() return {"class": predicted_class} if __name__ == "__main__": image_file = os.path.join(ROOT_DIR, "input/Ecomate_Dataset/metal/ecomate_metal_41.jpg") result = predict(image_file) print("Prediction result:", result)