face-detector / main.py
osrokas's picture
Synced repo using 'sync_with_huggingface' Github Action
794e3ef verified
# Import Fast API
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.templating import Jinja2Templates
import base64
# Import bytes
from io import BytesIO
import os
# Import logging
import logging
# Import utilities
from src.utils.utils import IMAGE_FORMATS
# Import machine learning
from src.predict import predict
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# Initialazing FastAPI application
app = FastAPI()
# Initialazing templates
templates = Jinja2Templates(directory="templates")
# Initialazing logger
logger = logging.getLogger(__name__)
logger.info(f"Loading YOLO model...")
# Download YOLO model from Hugging Face Hub
model_path = hf_hub_download(
repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt"
)
# Load YOLO model
model = YOLO(model_path)
# Index route
@app.get("/")
async def root(request: Request):
# Render index.html
return templates.TemplateResponse("index.html", {"request": request})
# Upload images decorator
@app.post("/predict-img")
def predict_image(request: Request, file: UploadFile = File(...)):
try:
# Try to read the file
contents = file.file.read()
# Open file and write contents
with open(file.filename, "wb") as f:
f.write(contents)
# Get image filename
image = file.filename
# Check if image format is valid
if not image.endswith(IMAGE_FORMATS):
# If not, raise an error
raise ValueError("Invalid image format")
except Exception as e:
# If there is an error, return the error
return {f"{e}"}
finally:
file.file.close()
# Getting image path
image = file.filename
# Predicting
results = predict(model, image)
# TODO
# extract extension from image and use it to save the image
# Convert image to bytes
img_bytes = BytesIO()
results.save(img_bytes, "JPEG")
img_bytes.seek(0)
img_bytes = base64.b64encode(img_bytes.getvalue()).decode()
try:
os.remove(image)
except Exception as e:
logging.error(f"Error deleting image: {e}")
return templates.TemplateResponse(
"index.html", {"request": request, "img": img_bytes}
)