MMESA-ZeroGPU / tabs /gaze_estimation.py
vitorcalvi's picture
pre-launch
efabbbd
raw
history blame
2.43 kB
import tempfile
import cv2
import dlib
import numpy as np
from scipy.spatial import distance as dist
from imutils import face_utils
import gradio as gr
def detect_eye_movements(video_path):
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("assets/models/shape_predictor_68_face_landmarks.dat")
cap = cv2.VideoCapture(video_path)
frame_width, frame_height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
with tempfile.NamedTemporaryFile(delete=False, suffix='.avi') as temp_file:
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (frame_width, frame_height))
gaze_points = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
for rect in detector(gray, 0):
shape = face_utils.shape_to_np(predictor(gray, rect))
for eye in [shape[36:42], shape[42:48]]:
eye_center = eye.mean(axis=0).astype("int")
gaze_points.append(eye_center)
cv2.circle(frame, tuple(eye_center), 3, (0, 255, 0), -1)
out.write(frame)
cap.release()
out.release()
fixed_threshold = 10
fixed_gaze_count = sum(dist.euclidean(gaze_points[i-1], gaze_points[i]) < fixed_threshold
for i in range(1, len(gaze_points)))
gaze_type = "Fixed Gaze" if fixed_gaze_count > len(gaze_points) // 2 else "Scattered Gaze"
return temp_file.name, gaze_type
def create_gaze_estimation_tab():
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video(label="Input Video")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Analyze", elem_classes="submit")
with gr.Column(scale=1, elem_classes="dl4"):
output_video = gr.Video(label="Processed Video", elem_classes="video2")
output_gaze_type = gr.Label(label="Gaze Type")
submit_btn.click(detect_eye_movements, inputs=input_video, outputs=[output_video, output_gaze_type], queue=True)
clear_btn.click(lambda: (None, None, None), outputs=[input_video, output_video, output_gaze_type], queue=True)
gr.Examples(["./assets/videos/fitness.mp4"], inputs=[input_video])