Spaces:
Runtime error
Runtime error
from PIL import Image | |
import torch | |
import os | |
import threading | |
import time | |
from dotenv import load_dotenv | |
from azure.storage.blob import BlobServiceClient | |
from transformers import SwinForImageClassification, AutoImageProcessor | |
from models import GarbageClassifier | |
CURRENT_DIR = os.getcwd() | |
ROOT_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "..")) | |
load_dotenv() | |
# Define Azure Blob Storage settings and model path | |
AZURE_CONNECTION_STRING = os.environ.get("AZURE_CONNECTION_STRING") | |
AZURE_CONTAINER_NAME = os.environ.get("AZURE_CONTAINER_NAME") | |
MODEL_PATH = os.environ.get("MODEL_PATH", "youssefabdelmottaleb/Garbage-Classification-SWIN-Transformer") | |
# Initialize the Swin Transformer model and processor | |
model = SwinForImageClassification.from_pretrained(MODEL_PATH) | |
processor = AutoImageProcessor.from_pretrained(MODEL_PATH) | |
# Function to save image to Azure Blob Storage | |
def save_image_to_azure(image, result_class): | |
# Convert PIL image to bytes | |
from io import BytesIO | |
image_bytes = BytesIO() | |
image.save(image_bytes, format='JPEG') | |
image_bytes = image_bytes.getvalue() | |
# Create the BlobServiceClient object | |
blob_service_client = BlobServiceClient.from_connection_string(AZURE_CONNECTION_STRING) | |
# Create a unique blob name | |
blob_name = f"predicted_images/{result_class}/{result_class}_{int(time.time())}.jpg" | |
# Create a blob client using the local file name as the name for the blob | |
blob_client = blob_service_client.get_blob_client(container=AZURE_CONTAINER_NAME, blob=blob_name) | |
# Upload the created file | |
blob_client.upload_blob(image_bytes) | |
print(f"Image saved to Azure Blob Storage: {blob_name}") | |
# Function to save image to local path | |
def save_image_to_local_path(image, result_class): | |
# Define the directory based on the class | |
directory = os.path.join('./predicted_images', result_class) | |
os.makedirs(directory, exist_ok=True) | |
# Define the image path | |
image_path = os.path.join(directory, f"{result_class}_{int(time.time())}.jpg") | |
# Save the image | |
image.save(image_path) | |
print(f"Image saved to {image_path}") | |
# Function to predict using Swin Transformer model | |
def predict(image_file): | |
image = Image.open(image_file).convert('RGB') # Ensure image is RGB | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt") | |
inputs = {k: v.unsqueeze(0).to(device) for k, v in inputs.items()} | |
# Perform inference | |
model.to(device) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Post-process the outputs to get the predicted class | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = processor.labels[predicted_class_idx] | |
# Start a thread to save the image based on the result | |
thread = threading.Thread(target=save_image_to_azure, args=(image, predicted_class)) | |
thread.start() | |
return {"class": predicted_class} | |
if __name__ == "__main__": | |
image_file = os.path.join(ROOT_DIR, "input/Ecomate_Dataset/metal/ecomate_metal_41.jpg") | |
result = predict(image_file) | |
print("Prediction result:", result) | |