|
from models.PosterV2_7cls import pyramid_trans_expr2 |
|
import cv2 |
|
import torch |
|
import os |
|
import time |
|
from PIL import Image |
|
from main import RecorderMeter1, RecorderMeter |
|
|
|
|
|
model_path = os.path.abspath(r"FER\models\checkpoints\raf-db-model_best.pth") |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
device = "mps" |
|
elif torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
|
|
model = pyramid_trans_expr2(img_size=224, num_classes=7) |
|
|
|
|
|
model = torch.nn.DataParallel(model) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
currtime = time.strftime("%H:%M:%S") |
|
print(currtime) |
|
|
|
|
|
def main(): |
|
|
|
if model_path is not None: |
|
if os.path.isfile(model_path): |
|
print("=> loading checkpoint '{}'".format(model_path)) |
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
|
best_acc = checkpoint["best_acc"] |
|
best_acc = best_acc.to() |
|
print(f"best_acc:{best_acc}") |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
print( |
|
"=> loaded checkpoint '{}' (epoch {})".format( |
|
model_path, checkpoint["epoch"] |
|
) |
|
) |
|
else: |
|
print( |
|
"[!] detectfaces.py => no checkpoint found at '{}'".format(model_path) |
|
) |
|
|
|
imagecapture(model) |
|
return |
|
|
|
|
|
def imagecapture(model): |
|
|
|
cap = cv2.VideoCapture(0) |
|
time.sleep(5) |
|
|
|
|
|
while not cap.isOpened(): |
|
time.sleep(2) |
|
|
|
|
|
capturing = True |
|
while capturing: |
|
|
|
from prediction import predict |
|
|
|
|
|
ret, frame = cap.read() |
|
|
|
|
|
if not ret: |
|
print("Error: Could not read frame.") |
|
break |
|
|
|
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
faces = cv2.CascadeClassifier( |
|
cv2.data.haarcascades + "haarcascade_frontalface_default.xml" |
|
).detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5, minSize=(30, 30)) |
|
|
|
|
|
cv2.imshow("Webcam", frame) |
|
|
|
|
|
if len(faces) > 0: |
|
currtimeimg = time.strftime("%H:%M:%S") |
|
print(f"[!]Face detected at {currtimeimg}") |
|
|
|
face_region = frame[ |
|
faces[0][1] : faces[0][1] + faces[0][3], |
|
faces[0][0] : faces[0][0] + faces[0][2], |
|
] |
|
|
|
face_pil_image = Image.fromarray( |
|
cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB) |
|
) |
|
print("[!]Start Expressions") |
|
|
|
starttime = time.strftime("%H:%M:%S") |
|
print(f"-->Prediction starting at {starttime}") |
|
|
|
predict(model, image_path=face_pil_image) |
|
|
|
endtime = time.strftime("%H:%M:%S") |
|
print(f"-->Done prediction at {endtime}") |
|
|
|
|
|
capturing = False |
|
|
|
|
|
if cv2.waitKey(1) & 0xFF == ord("q"): |
|
break |
|
|
|
|
|
cap.release() |
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|