GazeCorrection / app.py
mbesinci's picture
Update app.py
172363f verified
import gradio as gr
import cv2
import torch
from transformers import VideoMAEForVideoClassification, AutoFeatureExtractor
# Göz teması algılama modelini ve özellik çıkarıcıyı yükleyin
model_name = "kanlo/videomae-base-ASD_Eye_Contact_v2"
model = VideoMAEForVideoClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
def preprocess_frames(frames):
# Her kareyi modele uygun şekilde işleyin
inputs = feature_extractor(frames, return_tensors="pt")
return inputs['pixel_values']
def detect_eye_contact(video_path):
# Video dosyasını aç
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Çıkış videosu için ayarlar
output_path = "eye_contact_output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
# Modeli belirli bir aralıkta çalıştırarak göz temasını algılayın
if len(frames) >= 16: # 16 karede bir işlem yapıyoruz
inputs = preprocess_frames(frames)
with torch.no_grad():
outputs = model(pixel_values=inputs)
prediction = outputs.logits.argmax(-1).item()
# Göz teması varsa çerçeveye ek açıklama ekleyin
for frame in frames:
if prediction == 1: # 1 göz teması var anlamına gelir (modelde böyle olduğunu varsayıyoruz)
cv2.putText(frame, "Eye Contact", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
else:
cv2.putText(frame, "No Eye Contact", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
out.write(frame)
frames = [] # 16 karelik grubu işledikten sonra sıfırlayın
# Kaynakları serbest bırak
cap.release()
out.release()
return output_path
# Gradio arayüzü
iface = gr.Interface(
fn=detect_eye_contact,
inputs="file",
outputs="file",
title="Eye Contact Detection in Video",
description="Upload a video to detect eye contact in each frame."
)
iface.launch()