add face detection model
Browse files- main.py +54 -4
- models/face_classifier.py +1 -1
- procedures.py +61 -0
- requirements.txt +0 -0
- utils/handlers.py +46 -0
main.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1 |
-
from typing import Union
|
2 |
import dotenv
|
3 |
import traceback
|
4 |
import json
|
5 |
import io
|
6 |
import os
|
7 |
import base64
|
8 |
-
from fastapi import FastAPI, File, HTTPException, UploadFile, Response
|
|
|
9 |
import models.face_classifier as classifier
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
11 |
from PIL import Image
|
12 |
from rembg import remove
|
|
|
|
|
|
|
13 |
from utils.helpers import combine_images, image_to_base64, calculate_mask_area, process_image
|
14 |
|
15 |
|
@@ -31,7 +36,7 @@ app.add_middleware(
|
|
31 |
|
32 |
model = classifier.FaceSegmentationModel()
|
33 |
|
34 |
-
|
35 |
|
36 |
@app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"])
|
37 |
async def predict_image(file: UploadFile = File(...)):
|
@@ -91,4 +96,49 @@ async def predict_image(file: UploadFile = File(...)):
|
|
91 |
except Exception as e:
|
92 |
# Mendapatkan stack trace
|
93 |
error_traceback = traceback.format_exc()
|
94 |
-
raise HTTPException(status_code=500, detail=f"An error occurred: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
import dotenv
|
3 |
import traceback
|
4 |
import json
|
5 |
import io
|
6 |
import os
|
7 |
import base64
|
8 |
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Response
|
9 |
+
from ultralytics import YOLO
|
10 |
import models.face_classifier as classifier
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
from PIL import Image
|
14 |
from rembg import remove
|
15 |
+
import procedures
|
16 |
+
from utils import handlers
|
17 |
+
from supervision import Detections
|
18 |
from utils.helpers import combine_images, image_to_base64, calculate_mask_area, process_image
|
19 |
|
20 |
|
|
|
36 |
|
37 |
model = classifier.FaceSegmentationModel()
|
38 |
|
39 |
+
yolo_model_path = hf_hub_download(repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt")
|
40 |
|
41 |
@app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"])
|
42 |
async def predict_image(file: UploadFile = File(...)):
|
|
|
96 |
except Exception as e:
|
97 |
# Mendapatkan stack trace
|
98 |
error_traceback = traceback.format_exc()
|
99 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
100 |
+
|
101 |
+
@app.post("/detect-face/", summary="Detect face from image", tags=["Classify"])
|
102 |
+
async def detect_face(file: Optional[UploadFile] = File(None), base64_string: Optional[str] = Form(None)):
|
103 |
+
try:
|
104 |
+
if file is None and base64_string is None:
|
105 |
+
raise HTTPException(status_code=400, detail="No input data provided")
|
106 |
+
base64_handler = handlers.Base64Handler()
|
107 |
+
image_handler = handlers.ImageFileHandler(successor=base64_handler)
|
108 |
+
input_data: Union[UploadFile, str, None] = file if file is not None else base64_string
|
109 |
+
print(input_data)
|
110 |
+
pil_image = await image_handler.handle(input_data)
|
111 |
+
if pil_image is None:
|
112 |
+
raise HTTPException(status_code=400, detail="Unsupported file type")
|
113 |
+
|
114 |
+
# Load the YOLO model
|
115 |
+
|
116 |
+
model = YOLO(yolo_model_path)
|
117 |
+
|
118 |
+
# Inference using the pil image
|
119 |
+
output = model(pil_image)
|
120 |
+
results = Detections.from_ultralytics(output[0])
|
121 |
+
|
122 |
+
if len(results) == 0:
|
123 |
+
raise HTTPException(status_code=404, detail="No face detected")
|
124 |
+
|
125 |
+
# Get the first bounding box
|
126 |
+
first_bbox = results[0].xyxy[0].tolist()
|
127 |
+
|
128 |
+
# Crop the image using the bounding box
|
129 |
+
x_min, y_min, x_max, y_max = map(int, first_bbox)
|
130 |
+
cropped_image = pil_image.crop((x_min, y_min, x_max, y_max))
|
131 |
+
|
132 |
+
# Convert cropped image to Base64
|
133 |
+
buffered = io.BytesIO()
|
134 |
+
cropped_image.save(buffered, format="JPEG")
|
135 |
+
cropped_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
136 |
+
|
137 |
+
return {"bounding_box": first_bbox, "cropped_image": cropped_image_base64}
|
138 |
+
|
139 |
+
except HTTPException as e:
|
140 |
+
error_traceback = traceback.format_exc()
|
141 |
+
raise e
|
142 |
+
except Exception as e:
|
143 |
+
error_traceback = traceback.format_exc()
|
144 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
models/face_classifier.py
CHANGED
@@ -12,7 +12,7 @@ def warning_with_traceback(message, category, filename, lineno, file=None, line=
|
|
12 |
traceback.print_stack(file=log)
|
13 |
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
14 |
|
15 |
-
warnings.showwarning = warning_with_traceback
|
16 |
|
17 |
|
18 |
class FaceSegmentationModel:
|
|
|
12 |
traceback.print_stack(file=log)
|
13 |
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
14 |
|
15 |
+
# warnings.showwarning = warning_with_traceback
|
16 |
|
17 |
|
18 |
class FaceSegmentationModel:
|
procedures.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# procedures.py
|
2 |
+
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import traceback
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from typing import Optional, Union
|
8 |
+
from utils import handlers
|
9 |
+
from fastapi import File, Form, HTTPException, UploadFile, Response
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from ultralytics import YOLO
|
12 |
+
from supervision import Detections
|
13 |
+
|
14 |
+
async def detect_face(
|
15 |
+
model_path: str,
|
16 |
+
file: Optional[UploadFile],
|
17 |
+
base64_string: Optional[str],
|
18 |
+
|
19 |
+
):
|
20 |
+
try:
|
21 |
+
if file is None and base64_string is None:
|
22 |
+
raise HTTPException(status_code=400, detail="No input data provided")
|
23 |
+
base64_handler = handlers.Base64Handler()
|
24 |
+
image_handler = handlers.ImageFileHandler(successor=base64_handler)
|
25 |
+
input_data: Union[UploadFile, str] = file if file is not None else base64_string
|
26 |
+
print(input_data)
|
27 |
+
pil_image = await image_handler.handle(input_data)
|
28 |
+
if pil_image is None:
|
29 |
+
raise HTTPException(status_code=400, detail="Unsupported file type")
|
30 |
+
|
31 |
+
# Load the YOLO model
|
32 |
+
|
33 |
+
model = YOLO(model_path)
|
34 |
+
|
35 |
+
# Inference using the pil image
|
36 |
+
output = model(pil_image)
|
37 |
+
results = Detections.from_ultralytics(output[0])
|
38 |
+
|
39 |
+
if len(results) == 0:
|
40 |
+
raise HTTPException(status_code=404, detail="No face detected")
|
41 |
+
|
42 |
+
# Get the first bounding box
|
43 |
+
first_bbox = results[0].xyxy[0].tolist()
|
44 |
+
|
45 |
+
# Crop the image using the bounding box
|
46 |
+
x_min, y_min, x_max, y_max = map(int, first_bbox)
|
47 |
+
cropped_image = pil_image.crop((x_min, y_min, x_max, y_max))
|
48 |
+
|
49 |
+
# Convert cropped image to Base64
|
50 |
+
buffered = io.BytesIO()
|
51 |
+
cropped_image.save(buffered, format="JPEG")
|
52 |
+
cropped_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
53 |
+
|
54 |
+
return {"bounding_box": first_bbox, "cropped_image": cropped_image_base64}
|
55 |
+
|
56 |
+
except HTTPException as e:
|
57 |
+
error_traceback = traceback.format_exc()
|
58 |
+
raise e
|
59 |
+
except Exception as e:
|
60 |
+
error_traceback = traceback.format_exc()
|
61 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
utils/handlers.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import base64
|
3 |
+
import io
|
4 |
+
from typing import Optional, Union
|
5 |
+
from fastapi import UploadFile
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class FileHandler(ABC):
|
9 |
+
def __init__(self, successor: Optional['FileHandler'] = None):
|
10 |
+
self._successor = successor
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
async def handle(self, input_data: Union[UploadFile, str]) -> Optional[Image.Image]:
|
14 |
+
if self._successor:
|
15 |
+
return await self._successor.handle(input_data)
|
16 |
+
return None
|
17 |
+
|
18 |
+
class ImageFileHandler(FileHandler):
|
19 |
+
async def handle(self, input_data: Union[UploadFile, str]) -> Optional[Image.Image]:
|
20 |
+
# print(id(UploadFile))
|
21 |
+
# print(id(input_data.__class__))
|
22 |
+
# print(f"ImageFileHandler received: {type(input_data)}")
|
23 |
+
# print(f"Module of input_data: {input_data.__class__.__module__}")
|
24 |
+
# print(isinstance(input_data, UploadFile))
|
25 |
+
if hasattr(input_data, 'read') and hasattr(input_data, 'filename'):
|
26 |
+
print("Handling UploadFile")
|
27 |
+
try:
|
28 |
+
image_file = await input_data.read()
|
29 |
+
return Image.open(io.BytesIO(image_file)).convert("RGB")
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Error processing UploadFile: {e}")
|
32 |
+
return None
|
33 |
+
return await super().handle(input_data)
|
34 |
+
|
35 |
+
class Base64Handler(FileHandler):
|
36 |
+
async def handle(self, input_data: Union[UploadFile, str]) -> Optional[Image.Image]:
|
37 |
+
print(f"Base64Handler received: {type(input_data)}")
|
38 |
+
if isinstance(input_data, str):
|
39 |
+
print("Handling Base64 string")
|
40 |
+
try:
|
41 |
+
decoded_data = base64.b64decode(input_data)
|
42 |
+
# Handle Base64 decoded data (e.g., detect face in the decoded image)
|
43 |
+
return Image.open(io.BytesIO(decoded_data)).convert("RGB")
|
44 |
+
except Exception:
|
45 |
+
pass
|
46 |
+
return await super().handle(input_data)
|