TADBot / FER /detectfaces.py
ryefoxlime's picture
importing only certain func from main
8e0d21f
raw
history blame
4.25 kB
from models.PosterV2_7cls import pyramid_trans_expr2
import cv2
import torch
import os
import time
from PIL import Image
from main import RecorderMeter1, RecorderMeter # noqa: F401
# Define the path to the model checkpoint
model_path = os.path.abspath(r"FER\models\checkpoints\raf-db-model_best.pth")
# Determine the available device for model execution
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# Initialize the model with specified image size and number of classes
model = pyramid_trans_expr2(img_size=224, num_classes=7)
# Wrap the model with DataParallel for potential multi-GPU usage
model = torch.nn.DataParallel(model)
# Move the model to the chosen device
model = model.to(device)
# Print the current time
currtime = time.strftime("%H:%M:%S")
print(currtime)
def main():
# Load the model checkpoint if it exists
if model_path is not None:
if os.path.isfile(model_path):
print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
best_acc = checkpoint["best_acc"]
best_acc = best_acc.to()
print(f"best_acc:{best_acc}")
model.load_state_dict(checkpoint["state_dict"])
print(
"=> loaded checkpoint '{}' (epoch {})".format(
model_path, checkpoint["epoch"]
)
)
else:
print(
"[!] detectfaces.py => no checkpoint found at '{}'".format(model_path)
)
# Start webcam capture and prediction
imagecapture(model)
return
def imagecapture(model):
# Initialize webcam capture
cap = cv2.VideoCapture(0)
time.sleep(5) # Wait for 5 seconds to allow the camera to initialize
# Keep trying to open the webcam until successful
while not cap.isOpened():
time.sleep(2) # Wait for 2 seconds before retrying
# Flag to control webcam capture
capturing = True
while capturing:
# Import the predict function from the prediction module
from prediction import predict
# Read a frame from the webcam
ret, frame = cap.read()
# Handle potential error reading the frame
if not ret:
print("Error: Could not read frame.")
break
# Convert the frame to grayscale
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Detect faces using Haar Cascades
faces = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
).detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5, minSize=(30, 30))
# Display the current frame
cv2.imshow("Webcam", frame)
# If faces are detected, proceed with prediction
if len(faces) > 0:
currtimeimg = time.strftime("%H:%M:%S")
print(f"[!]Face detected at {currtimeimg}")
# Crop the face region
face_region = frame[
faces[0][1] : faces[0][1] + faces[0][3],
faces[0][0] : faces[0][0] + faces[0][2],
]
# Convert the face region to a PIL image
face_pil_image = Image.fromarray(
cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB)
)
print("[!]Start Expressions")
# Record the prediction start time
starttime = time.strftime("%H:%M:%S")
print(f"-->Prediction starting at {starttime}")
# Perform emotion prediction
predict(model, image_path=face_pil_image)
# Record the prediction end time
endtime = time.strftime("%H:%M:%S")
print(f"-->Done prediction at {endtime}")
# Stop capturing once prediction is complete
capturing = False
# Exit the loop if the 'q' key is pressed
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# Release webcam resources and close OpenCV windows
cap.release()
cv2.destroyAllWindows()
# Execute the main function if the script is run directly
if __name__ == "__main__":
main()