youssefabdelmottaleb's picture
Add: SWIN-Transformer-Model-Deployment
9b01bfc
from PIL import Image
import torch
import os
import threading
import time
from dotenv import load_dotenv
from azure.storage.blob import BlobServiceClient
from transformers import SwinForImageClassification, AutoImageProcessor
from models import GarbageClassifier
CURRENT_DIR = os.getcwd()
ROOT_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
load_dotenv()
# Define Azure Blob Storage settings and model path
AZURE_CONNECTION_STRING = os.environ.get("AZURE_CONNECTION_STRING")
AZURE_CONTAINER_NAME = os.environ.get("AZURE_CONTAINER_NAME")
MODEL_PATH = os.environ.get("MODEL_PATH", "youssefabdelmottaleb/Garbage-Classification-SWIN-Transformer")
# Initialize the Swin Transformer model and processor
model = SwinForImageClassification.from_pretrained(MODEL_PATH)
processor = AutoImageProcessor.from_pretrained(MODEL_PATH)
# Function to save image to Azure Blob Storage
def save_image_to_azure(image, result_class):
# Convert PIL image to bytes
from io import BytesIO
image_bytes = BytesIO()
image.save(image_bytes, format='JPEG')
image_bytes = image_bytes.getvalue()
# Create the BlobServiceClient object
blob_service_client = BlobServiceClient.from_connection_string(AZURE_CONNECTION_STRING)
# Create a unique blob name
blob_name = f"predicted_images/{result_class}/{result_class}_{int(time.time())}.jpg"
# Create a blob client using the local file name as the name for the blob
blob_client = blob_service_client.get_blob_client(container=AZURE_CONTAINER_NAME, blob=blob_name)
# Upload the created file
blob_client.upload_blob(image_bytes)
print(f"Image saved to Azure Blob Storage: {blob_name}")
# Function to save image to local path
def save_image_to_local_path(image, result_class):
# Define the directory based on the class
directory = os.path.join('./predicted_images', result_class)
os.makedirs(directory, exist_ok=True)
# Define the image path
image_path = os.path.join(directory, f"{result_class}_{int(time.time())}.jpg")
# Save the image
image.save(image_path)
print(f"Image saved to {image_path}")
# Function to predict using Swin Transformer model
def predict(image_file):
image = Image.open(image_file).convert('RGB') # Ensure image is RGB
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.unsqueeze(0).to(device) for k, v in inputs.items()}
# Perform inference
model.to(device)
model.eval()
with torch.no_grad():
outputs = model(**inputs)
# Post-process the outputs to get the predicted class
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = processor.labels[predicted_class_idx]
# Start a thread to save the image based on the result
thread = threading.Thread(target=save_image_to_azure, args=(image, predicted_class))
thread.start()
return {"class": predicted_class}
if __name__ == "__main__":
image_file = os.path.join(ROOT_DIR, "input/Ecomate_Dataset/metal/ecomate_metal_41.jpg")
result = predict(image_file)
print("Prediction result:", result)