gestsync / app.py
sindhuhegde's picture
Update app
4b29652
raw
history blame
49 kB
import gradio as gr
import os
import torch
from shutil import rmtree
from torch import nn
from torch.nn import functional as F
import numpy as np
import subprocess
import cv2
import pickle
import librosa
from decord import VideoReader
from decord import cpu, gpu
from utils.audio_utils import *
from utils.inference_utils import *
from sync_models.gestsync_models import *
from tqdm import tqdm
from glob import glob
from scipy.io.wavfile import write
import mediapipe as mp
from protobuf_to_dict import protobuf_to_dict
import warnings
mp_holistic = mp.solutions.holistic
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Initialize global variables
CHECKPOINT_PATH = "model_rgb.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
batch_size = 12
fps = 25
n_negative_samples = 100
# Initialize the mediapipe holistic keypoint detection model
holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
def preprocess_video(path, result_folder, apply_preprocess, padding=20):
'''
This function preprocesses the input video to extract the audio and crop the frames using YOLO model
Args:
- path (string) : Path of the input video file
- result_folder (string) : Path of the folder to save the extracted audio and cropped video
- padding (int) : Padding to add to the bounding box
Returns:
- wav_file (string) : Path of the extracted audio file
- fps (int) : FPS of the input video
- video_output (string) : Path of the cropped video file
- msg (string) : Message to be returned
'''
# Load all video frames
try:
vr = VideoReader(path, ctx=cpu(0))
fps = vr.get_avg_fps()
frame_count = len(vr)
except:
msg = "Oops! Could not load the video. Please check the input video and try again."
return None, None, None, msg
if frame_count < 25:
msg = "Not enough frames to process! Please give a longer video as input"
return None, None, None, msg
# Extract the audio from the input video file using ffmpeg
wav_file = os.path.join(result_folder, "audio.wav")
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
-acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
if status != 0:
msg = "Oops! Could not load the audio file. Please check the input video and try again."
return None, None, None, msg
print("Extracted the audio from the video")
if apply_preprocess=="True":
all_frames = []
for k in range(len(vr)):
all_frames.append(vr[k].asnumpy())
all_frames = np.asarray(all_frames)
print("Extracted the frames for pre-processing")
# Load YOLOv9 model (pre-trained on COCO dataset)
yolo_model = YOLO("yolov9s.pt")
print("Loaded the YOLO model")
person_videos = {}
person_tracks = {}
print("Processing the frames...")
for frame_idx in tqdm(range(frame_count)):
frame = all_frames[frame_idx]
# Perform person detection
results = yolo_model(frame, verbose=False)
detections = results[0].boxes
for i, det in enumerate(detections):
x1, y1, x2, y2 = det.xyxy[0]
cls = det.cls[0]
if int(cls) == 0: # Class 0 is 'person' in COCO dataset
x1 = max(0, int(x1) - padding)
y1 = max(0, int(y1) - padding)
x2 = min(frame.shape[1], int(x2) + padding)
y2 = min(frame.shape[0], int(y2) + padding)
if i not in person_videos:
person_videos[i] = []
person_tracks[i] = []
person_videos[i].append(frame)
person_tracks[i].append([x1,y1,x2,y2])
num_persons = 0
for i in person_videos.keys():
if len(person_videos[i]) >= frame_count//2:
num_persons+=1
if num_persons==0:
msg = "No person detected in the video! Please give a video with one person as input"
return None, None, None, msg
if num_persons>1:
msg = "More than one person detected in the video! Please give a video with only one person as input"
return None, None, None, msg
# For the person detected, crop the frame based on the bounding box
if len(person_videos[0]) > frame_count-10:
crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
# Get bounding box coordinates based on person_tracks[i]
max_x1 = min([track[0] for track in person_tracks[0]])
max_y1 = min([track[1] for track in person_tracks[0]])
max_x2 = max([track[2] for track in person_tracks[0]])
max_y2 = max([track[3] for track in person_tracks[0]])
max_width = max_x2 - max_x1
max_height = max_y2 - max_y1
out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
for frame in person_videos[0]:
crop = frame[max_y1:max_y2, max_x1:max_x2]
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
out.write(crop)
out.release()
no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
if status != 0:
msg = "Oops! Could not preprocess the video. Please check the input video and try again."
return None, None, None, msg
video_output = crop_filename.split('.')[0] + '.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
(wav_file , no_sound_video, video_output), shell=True)
if status != 0:
msg = "Oops! Could not preprocess the video. Please check the input video and try again."
return None, None, None, msg
os.remove(crop_filename)
os.remove(no_sound_video)
print("Successfully saved the pre-processed video: ", video_output)
else:
msg = "Could not track the person in the full video! Please give a single-speaker video as input"
return None, None, None, msg
else:
video_output = path
return wav_file, fps, video_output, "success"
def resample_video(video_file, video_fname, result_folder):
'''
This function resamples the video to 25 fps
Args:
- video_file (string) : Path of the input video file
- video_fname (string) : Name of the input video file
- result_folder (string) : Path of the folder to save the resampled video
Returns:
- video_file_25fps (string) : Path of the resampled video file
'''
video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
# Resample the video to 25 fps
status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True)
if status != 0:
msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again."
return None, msg
print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
return video_file_25fps, "success"
def load_checkpoint(path, model):
'''
This function loads the trained model from the checkpoint
Args:
- path (string) : Path of the checkpoint file
- model (object) : Model object
Returns:
- model (object) : Model object with the weights loaded from the checkpoint
'''
# Load the checkpoint
if use_cuda:
checkpoint = torch.load(path)
else:
checkpoint = torch.load(path, map_location="cpu")
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
if use_cuda:
model.cuda()
print("Loaded checkpoint from: {}".format(path))
return model.eval()
def load_video_frames(video_file):
'''
This function extracts the frames from the video
Args:
- video_file (string) : Path of the video file
Returns:
- frames (list) : List of frames extracted from the video
- msg (string) : Message to be returned
'''
# Read the video
try:
vr = VideoReader(video_file, ctx=cpu(0))
except:
msg = "Oops! Could not load the input video file"
return None, msg
# Extract the frames
frames = []
for k in range(len(vr)):
frames.append(vr[k].asnumpy())
frames = np.asarray(frames)
return frames, "success"
def get_keypoints(frames):
'''
This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
Args:
- frames (list) : List of frames extracted from the video
Returns:
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
- msg (string) : Message to be returned
'''
try:
holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
resolution = frames[0].shape
all_frame_kps = []
for frame in frames:
results = holistic.process(frame)
pose, left_hand, right_hand, face = None, None, None, None
if results.pose_landmarks is not None:
pose = protobuf_to_dict(results.pose_landmarks)['landmark']
if results.left_hand_landmarks is not None:
left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
if results.right_hand_landmarks is not None:
right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
if results.face_landmarks is not None:
face = protobuf_to_dict(results.face_landmarks)['landmark']
frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
all_frame_kps.append(frame_dict)
kp_dict = {"kps":all_frame_kps, "resolution":resolution}
except Exception as e:
print("Error: ", e)
return None, "Error: Could not extract keypoints from the frames"
return kp_dict, "success"
def check_visible_gestures(kp_dict):
'''
This function checks if the gestures in the video are visible
Args:
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
Returns:
- msg (string) : Message to be returned
'''
keypoints = kp_dict['kps']
keypoints = np.array(keypoints)
if len(keypoints)<25:
msg = "Not enough keypoints to process! Please give a longer video as input"
return msg
pose_count, hand_count = 0, 0
for frame_kp_dict in keypoints:
pose = frame_kp_dict["pose"]
left_hand = frame_kp_dict["left_hand"]
right_hand = frame_kp_dict["right_hand"]
if pose is None:
pose_count += 1
if left_hand is None and right_hand is None:
hand_count += 1
if hand_count/len(keypoints) > 0.6 or pose_count/len(keypoints) > 0.6:
msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
return msg
print("Successfully verified the input video - Gestures are visible!")
return "success"
def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_frames=25, width=480, height=270):
'''
This function masks the faces using the keypoints extracted from the frames
Args:
- input_frames (list) : List of frames extracted from the video
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
- stride (int) : Stride to extract the frames
- window_frames (int) : Number of frames in each window that is given as input to the model
- width (int) : Width of the frames
- height (int) : Height of the frames
Returns:
- input_frames (array) : Frame window to be given as input to the model
- num_frames (int) : Number of frames to extract
- orig_masked_frames (array) : Masked frames extracted from the video
- msg (string) : Message to be returned
'''
print("Creating masked input frames...")
input_frames_masked = []
if kp_dict is None:
for img in tqdm(input_frames):
img = cv2.resize(img, (width, height))
masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
input_frames_masked.append(masked_img)
else:
# Face indices to extract the face-coordinates needed for masking
face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
print("Input keypoints: ", len(input_keypoints))
for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):
img = input_frames[i]
face = frame_kp_dict["face"]
if face is None:
img = cv2.resize(img, (width, height))
masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
else:
face_kps = []
for idx in range(len(face)):
if idx in face_oval_idx:
x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
face_kps.append((x,y))
face_kps = np.array(face_kps)
x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
if masked_img.shape[0] != width or masked_img.shape[1] != height:
masked_img = cv2.resize(masked_img, (width, height))
input_frames_masked.append(masked_img)
orig_masked_frames = np.array(input_frames_masked)
input_frames = np.array(input_frames_masked) / 255.
if asd:
input_frames = np.pad(input_frames, ((12, 12), (0,0), (0,0), (0,0)), 'edge')
# print("Input images full: ", input_frames.shape) # num_framesx270x480x3
input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
# print("Input images window: ", input_frames.shape) # Tx25x270x480x3
num_frames = input_frames.shape[0]
if num_frames<10:
msg = "Not enough frames to process! Please give a longer video as input."
return None, None, None, msg
return input_frames, num_frames, orig_masked_frames, "success"
def load_spectrograms(wav_file, asd=False, num_frames=None, window_frames=25, stride=4):
'''
This function extracts the spectrogram from the audio file
Args:
- wav_file (string) : Path of the extracted audio file
- num_frames (int) : Number of frames to extract
- window_frames (int) : Number of frames in each window that is given as input to the model
- stride (int) : Stride to extract the audio frames
Returns:
- spec (array) : Spectrogram array window to be used as input to the model
- orig_spec (array) : Spectrogram array extracted from the audio file
- msg (string) : Message to be returned
'''
# Extract the audio from the input video file using ffmpeg
try:
wav = librosa.load(wav_file, sr=16000)[0]
except:
msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
return None, None, msg
# Convert to tensor
wav = torch.FloatTensor(wav).unsqueeze(0)
mel, _, _, _ = wav2filterbanks(wav.to(device))
spec = mel.squeeze(0).cpu().numpy()
orig_spec = spec
spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
if num_frames is not None:
if len(spec) != num_frames:
spec = spec[:num_frames]
frame_diff = np.abs(len(spec) - num_frames)
if frame_diff > 60:
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
if asd:
pad_frames = (window_frames//2)
spec = np.pad(spec, ((pad_frames, pad_frames), (0,0), (0,0)), 'edge')
return spec, orig_spec, "success"
def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
'''
This function calculates the audio-visual offset between the video and audio
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- num_avg_frames (int) : Number of frames to average the scores
- model (object) : Model object
Returns:
- offset (int) : Optimal audio-visual offset
- msg (string) : Message to be returned
'''
pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
if status != "success":
return None, status
scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
offset = scores.argmax()*stride - pos_idx
return offset.item(), "success"
def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
'''
This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- num_avg_frames (int) : Number of frames to average the scores
- stride (int) : Stride to extract the negative windows
Returns:
- vid_emb_pos (array) : Positive video embedding array
- aud_emb_posneg (array) : All possible combinations of audio embedding array
- pos_idx_frame (int) : Positive video embedding array frame
- stride (int) : Stride used to extract the negative windows
- msg (string) : Message to be returned
'''
slice_size = num_avg_frames
aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
pos_idx = (aud_emb_posneg.shape[1]//2)
pos_idx_frame = pos_idx*stride
min_offset_frames = -(pos_idx)*stride
max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
if vid_emb_pos.shape[2] != slice_size:
msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
return None, None, None, None, msg
return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
def calc_av_scores(vid_emb, aud_emb, model):
'''
This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- model (object) : Model object
Returns:
- scores (array) : Audio-visual similarity scores
- att_map (array) : Attention map
'''
scores = calc_att_map(vid_emb, aud_emb, model)
att_map = logsoftmax_2d(scores)
scores = scores.mean(-1)
return scores, att_map
def calc_att_map(vid_emb, aud_emb, model):
'''
This function calculates the similarity between the video and audio embeddings
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- model (object) : Model object
Returns:
- scores (array) : Audio-visual similarity scores
'''
vid_emb = vid_emb[:, :, None]
aud_emb = aud_emb.transpose(1, 2)
scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
vid_emb,
aud_emb,
part_len=10,
dim=3,
device=device)
scores = model.logits_scale(scores[..., None]).squeeze(-1)
return scores
def generate_video(frames, audio_file, video_fname):
'''
This function generates the video from the frames and audio file
Args:
- frames (array) : Frames to be used to generate the video
- audio_file (string) : Path of the audio file
- video_fname (string) : Path of the video file
Returns:
- video_output (string) : Path of the video file
'''
fname = 'inference.avi'
video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
for i in range(len(frames)):
video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
video.release()
no_sound_video = video_fname + '_nosound.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
if status != 0:
msg = "Oops! Could not generate the video. Please check the input video and try again."
return None, msg
video_output = video_fname + '.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' %
(audio_file, no_sound_video, video_output), shell=True)
if status != 0:
msg = "Oops! Could not generate the video. Please check the input video and try again."
return None, msg
os.remove(fname)
os.remove(no_sound_video)
return video_output, "success"
def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
'''
This function corrects the video and audio to sync with each other
Args:
- video_path (string) : Path of the video file
- frames (array) : Frames to be used to generate the video
- wav_file (string) : Path of the audio file
- offset (int) : Predicted sync-offset to be used to correct the video
- result_folder (string) : Path of the result folder to save the output sync-corrected video
- sample_rate (int) : Sample rate of the audio
- fps (int) : Frames per second of the video
Returns:
- video_output (string) : Path of the video file
'''
if offset == 0:
print("The input audio and video are in-sync! No need to perform sync correction.")
return video_path, "success"
print("Performing Sync Correction...")
corrected_frames = np.zeros_like(frames)
if offset > 0:
audio_offset = int(offset*(sample_rate/fps))
wav = librosa.core.load(wav_file, sr=sample_rate)[0]
corrected_wav = wav[audio_offset:]
corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
write(corrected_wav_file, sample_rate, corrected_wav)
wav_file = corrected_wav_file
corrected_frames = frames
elif offset < 0:
corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path)
if status != "success":
return None, status
return video_output, "success"
def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
'''
This function loads the masked input frames from the video
Args:
- test_videos (list) : List of videos to be processed (speaker-specific tracks)
- spec (array) : Spectrogram of the audio
- wav_file (string) : Path of the audio file
- scene_num (int) : Scene number to be used to save the input masked video
- result_folder (string) : Path of the folder to save the input masked video
Returns:
- all_frames (list) : List of masked input frames window to be used as input to the model
- all_orig_frames (list) : List of original masked input frames
'''
all_frames, all_orig_frames = [], []
for video_num, video in enumerate(test_videos):
print("Processing video: ", video)
# Load the video frames
frames, status = load_video_frames(video)
if status != "success":
return None, None, status
print("Successfully loaded the video frames")
# Extract the keypoints from the frames
kp_dict, status = get_keypoints(frames)
if status != "success":
return None, None, status
print("Successfully extracted the keypoints")
# Mask the frames using the keypoints extracted from the frames and prepare the input to the model
masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=True)
if status != "success":
return None, None, status
print("Successfully loaded the masked frames")
# Check if the length of the input frames is equal to the length of the spectrogram
if spec.shape[2]!=masked_frames.shape[0]:
num_frames = spec.shape[2]
masked_frames = masked_frames[:num_frames]
orig_masked_frames = orig_masked_frames[:num_frames]
frame_diff = np.abs(spec.shape[2] - num_frames)
if frame_diff > 60:
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
# Transpose the frames to the correct format
frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
print("Successfully converted the frames to tensor")
all_frames.append(frames)
all_orig_frames.append(orig_masked_frames)
return all_frames, all_orig_frames, "success"
def extract_audio(video, result_folder):
'''
This function extracts the audio from the video file
Args:
- video (string) : Path of the video file
- result_folder (string) : Path of the folder to save the extracted audio file
Returns:
- wav_file (string) : Path of the extracted audio file
'''
wav_file = os.path.join(result_folder, "audio.wav")
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \
-acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True)
if status != 0:
msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again"
return None, msg
return wav_file, "success"
def get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True):
'''
This function extracts the video and audio embeddings from the input frames and audio sequences
Args:
- video_sequences (array) : Array of video frames to be used as input to the model
- audio_sequences (array) : Array of audio frames to be used as input to the model
- model (object) : Model object
- calc_aud_emb (bool) : Flag to calculate the audio embedding
Returns:
- video_emb (array) : Video embedding
- audio_emb (array) : Audio embedding
'''
batch_size = 12
video_emb = []
audio_emb = []
for i in range(0, len(video_sequences), batch_size):
video_inp = video_sequences[i:i+batch_size, ]
vid_emb = model.forward_vid(video_inp.to(device), return_feats=False)
vid_emb = torch.mean(vid_emb, axis=-1)
video_emb.append(vid_emb.detach())
if calc_aud_emb:
audio_inp = audio_sequences[i:i+batch_size, ]
aud_emb = model.forward_aud(audio_inp.to(device))
audio_emb.append(aud_emb.detach())
torch.cuda.empty_cache()
video_emb = torch.cat(video_emb, dim=0)
if calc_aud_emb:
audio_emb = torch.cat(audio_emb, dim=0)
return video_emb, audio_emb
return video_emb
def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model):
'''
This function predicts the active speaker in each frame
Args:
- all_video_embeddings (array) : Array of video embeddings of all speakers
- audio_embedding (array) : Audio embedding
- global_score (bool) : Flag to calculate the global score
Returns:
- pred_speaker (list) : List of active speakers in each frame
'''
cos = nn.CosineSimilarity(dim=1)
audio_embedding = audio_embedding.squeeze(2)
scores = []
for i in range(len(all_video_embeddings)):
video_embedding = all_video_embeddings[i]
# Compute the similarity of each speaker's video embeddings with the audio embedding
sim = cos(video_embedding, audio_embedding)
# Apply the logits scale to the similarity scores (scaling the scores)
output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1)
if global_score=="True":
score = output.mean(0)
else:
if output.shape[0]<num_avg_frames:
num_avg_frames = output.shape[0]
output_batch = output.unfold(0, num_avg_frames, 1)
score = torch.mean(output_batch, axis=-1)
scores.append(score.detach().cpu().numpy())
if global_score=="True":
print("Using global predictions")
pred_speaker = np.argmax(scores)
else:
print("Using per-frame predictions")
pred_speaker = []
num_negs = list(range(0, len(all_video_embeddings)))
for frame_idx in range(len(scores[0])):
score = [scores[i][frame_idx] for i in num_negs]
pred_idx = np.argmax(score)
pred_speaker.append(pred_idx)
return pred_speaker, num_avg_frames
def save_video(output_tracks, input_frames, wav_file, result_folder):
'''
This function saves the output video with the active speaker detections
Args:
- output_tracks (list) : List of active speakers in each frame
- input_frames (array) : Frames to be used to generate the video
- wav_file (string) : Path of the audio file
- result_folder (string) : Path of the result folder to save the output video
Returns:
- video_output (string) : Path of the output video
'''
try:
output_frames = []
for i in range(len(input_frames)):
# If the active speaker is found, draw a bounding box around the active speaker
if i in output_tracks:
bbox = output_tracks[i]
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
else:
out = input_frames[i]
output_frames.append(out)
# Generate the output video
output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
video_output, status = generate_video(output_frames, wav_file, output_video_fname)
if status != "success":
return None, status
except Exception as e:
return None, f"Error: {str(e)}"
return video_output, "success"
def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
try:
# Extract the video filename
video_fname = os.path.basename(video_path.split(".")[0])
# Create folders to save the inputs and results
result_folder = os.path.join("results", video_fname)
result_folder_input = os.path.join(result_folder, "input")
result_folder_output = os.path.join(result_folder, "output")
if os.path.exists(result_folder):
rmtree(result_folder)
os.makedirs(result_folder)
os.makedirs(result_folder_input)
os.makedirs(result_folder_output)
# Preprocess the video
print("Applying preprocessing: ", apply_preprocess)
wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
if status != "success":
return None, status
print("Successfully preprocessed the video")
# Resample the video to 25 fps if it is not already 25 fps
print("FPS of video: ", fps)
if fps!=25:
vid_path, status = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
if status != "success":
return None, status
orig_vid_path_25fps, status = resample_video(video_path, "input_video_25fps", result_folder_input)
if status != "success":
return None, status
else:
vid_path = vid_path_processed
orig_vid_path_25fps = video_path
# Load the original video frames (before pre-processing) - Needed for the final sync-correction
orig_frames, status = load_video_frames(orig_vid_path_25fps)
if status != "success":
return None, status
# Load the pre-processed video frames
frames, status = load_video_frames(vid_path)
if status != "success":
return None, status
print("Successfully extracted the video frames")
if len(frames) < num_avg_frames:
msg = "Error: The input video is too short. Please use a longer input video."
return None, msg
# Load keypoints and check if gestures are visible
kp_dict, status = get_keypoints(frames)
if status != "success":
return None, status
print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))
status = check_visible_gestures(kp_dict)
if status != "success":
return None, status
# Load RGB frames
rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=False, window_frames=25, width=480, height=270)
if status != "success":
return None, status
print("Successfully loaded the RGB frames")
# Convert frames to tensor
rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
B = rgb_frames.size(0)
print("Successfully converted the frames to tensor")
# Load spectrograms
spec, orig_spec, status = load_spectrograms(wav_file, asd=False, num_frames=num_frames)
if status != "success":
return None, status
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
print("Successfully loaded the spectrograms")
# Create input windows
video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
# Load the trained model
model = Transformer_RGB()
model = load_checkpoint(CHECKPOINT_PATH, model)
print("Successfully loaded the model")
# Process in batches
batch_size = 12
video_emb = []
audio_emb = []
for i in tqdm(range(0, len(video_sequences), batch_size)):
video_inp = video_sequences[i:i+batch_size, ]
audio_inp = audio_sequences[i:i+batch_size, ]
vid_emb = model.forward_vid(video_inp.to(device))
vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
aud_emb = model.forward_aud(audio_inp.to(device))
video_emb.append(vid_emb.detach())
audio_emb.append(aud_emb.detach())
torch.cuda.empty_cache()
audio_emb = torch.cat(audio_emb, dim=0)
video_emb = torch.cat(video_emb, dim=0)
# L2 normalize embeddings
video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
audio_emb = torch.split(audio_emb, B, dim=0)
audio_emb = torch.stack(audio_emb, dim=2)
audio_emb = audio_emb.squeeze(3)
audio_emb = audio_emb[:, None]
video_emb = torch.split(video_emb, B, dim=0)
video_emb = torch.stack(video_emb, dim=2)
video_emb = video_emb.squeeze(3)
print("Successfully extracted GestSync embeddings")
# Calculate sync offset
pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
if status != "success":
return None, status
print("Predicted offset: ", pred_offset)
# Generate sync-corrected video
video_output, status = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
if status != "success":
return None, status
print("Successfully generated the video:", video_output)
return video_output, f"Predicted offset: {pred_offset}"
except Exception as e:
return None, f"Error: {str(e)}"
def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
try:
# Extract the video filename
video_fname = os.path.basename(video_path.split(".")[0])
# Create folders to save the inputs and results
result_folder = os.path.join("results", video_fname)
result_folder_input = os.path.join(result_folder, "input")
result_folder_output = os.path.join(result_folder, "output")
if os.path.exists(result_folder):
rmtree(result_folder)
os.makedirs(result_folder)
os.makedirs(result_folder_input)
os.makedirs(result_folder_output)
if global_speaker=="per-frame-prediction" and num_avg_frames<25:
msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
return None, msg
# Read the video
try:
vr = VideoReader(video_path, ctx=cpu(0))
except:
msg = "Oops! Could not load the input video file"
return None, msg
# Get the FPS of the video
fps = vr.get_avg_fps()
print("FPS of video: ", fps)
# Resample the video to 25 FPS if the original video is of a different frame-rate
if fps!=25:
test_video_25fps, status = resample_video(video_path, video_fname, result_folder_input)
if status != "success":
return None, status
else:
test_video_25fps = video_path
# Load the video frames
orig_frames, status = load_video_frames(test_video_25fps)
if status != "success":
return None, status
# Extract and save the audio file
orig_wav_file, status = extract_audio(video_path, result_folder)
if status != "success":
return None, status
# Pre-process and extract per-speaker tracks in each scene
print("Pre-processing the input video...")
status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
if status != 0:
msg = "Error in pre-processing the input video, please check the input video and try again..."
return None, msg
# Load the tracks file saved during pre-processing
with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
tracks = pickle.load(file)
# Create a dictionary of all tracks found along with the bounding-boxes
track_dict = {}
for scene_num in range(len(tracks)):
track_dict[scene_num] = {}
for i in range(len(tracks[scene_num])):
track_dict[scene_num][i] = {}
for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
track_dict[scene_num][i][frame_num] = bbox
# Get the total number of scenes
test_scenes = os.listdir("{}/crops".format(result_folder_input))
print("Total scenes found in the input video = ", len(test_scenes))
# Load the trained model
model = Transformer_RGB()
model = load_checkpoint(CHECKPOINT_PATH, model)
# Compute the active speaker in each scene
output_tracks = {}
for scene_num in tqdm(range(len(test_scenes))):
test_videos = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.avi"))
test_videos.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
print("Scene {} -> Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos)))
if len(test_videos)<=1:
msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
return None, msg
# Load the audio file
audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
spec, _, status = load_spectrograms(audio_file, asd=True)
if status != "success":
return None, status
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
print("Successfully loaded the spectrograms")
# Load the masked input frames
all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
if status != "success":
return None, status
print("Successfully loaded the masked input frames")
# Prepare the audio and video sequences for the model
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
print("Obtaining audio and video embeddings...")
all_video_embs = []
for idx in tqdm(range(len(all_masked_frames))):
with torch.no_grad():
video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0)
if idx==0:
video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True)
else:
video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
all_video_embs.append(video_emb)
print("Successfully extracted GestSync embeddings")
# Predict the active speaker in each scene
if global_speaker=="per-frame-prediction":
predictions, num_avg_frames = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model)
else:
predictions, _ = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model)
# Get the frames present in the scene
frames_scene = tracks[scene_num][0]['track']['frame']
# Prepare the active speakers list to draw the bounding boxes
if global_speaker=="global-prediction":
print("Aggregating scores using global predictoins")
active_speakers = [predictions]*len(frames_scene)
start, end = 0, len(frames_scene)
else:
print("Aggregating scores using per-frame predictions")
active_speakers = [0]*len(frames_scene)
mid = num_avg_frames//2
if num_avg_frames%2==0:
frame_pred = len(frames_scene)-(mid*2)+1
start, end = mid, len(frames_scene)-mid+1
else:
frame_pred = len(frames_scene)-(mid*2)
start, end = mid, len(frames_scene)-mid
print("Frame scene: {} | Avg frames: {} | Frame predictions: {}".format(len(frames_scene), num_avg_frames, frame_pred))
if len(predictions) != frame_pred:
msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
return None, msg
active_speakers[start:end] = predictions[0:]
# Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output
initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count)
active_speakers[0:start] = [initial_preds] * start
final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count)
active_speakers[end:] = [final_preds] * (len(frames_scene) - end)
start, end = 0, len(active_speakers)
# Get the output tracks for each frame
pred_idx = 0
for frame in frames_scene[start:end]:
label = active_speakers[pred_idx]
pred_idx += 1
output_tracks[frame] = track_dict[scene_num][label][frame]
# Save the output video
video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
if status != "success":
return None, status
print("Successfully saved the output video: ", video_output)
return video_output, "success"
except Exception as e:
return None, f"Error: {str(e)}"
if __name__ == "__main__":
# Custom CSS and HTML
custom_css = """
<style>
body {
background-color: #ffffff;
color: #333333; /* Default text color */
}
.container {
max-width: 100% !important;
padding-left: 0 !important;
padding-right: 0 !important;
}
.header {
background-color: #f0f0f0;
color: #333333;
padding: 30px;
margin-bottom: 30px;
text-align: center;
font-family: 'Helvetica Neue', Arial, sans-serif;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.header h1 {
font-size: 36px;
margin-bottom: 15px;
font-weight: bold;
color: #333333; /* Explicitly set heading color */
}
.header h2 {
font-size: 24px;
margin-bottom: 10px;
color: #333333; /* Explicitly set subheading color */
}
.header p {
font-size: 18px;
margin: 5px 0;
color: #666666;
}
.blue-text {
color: #4a90e2;
}
/* Custom styles for slider container */
.slider-container {
background-color: white !important;
padding-top: 0.9em;
padding-bottom: 0.9em;
}
/* Add gap before examples */
.examples-holder {
margin-top: 2em;
}
/* Set fixed size for example videos */
.gradio-container .gradio-examples .gr-sample {
width: 240px !important;
height: 135px !important;
object-fit: cover;
display: inline-block;
margin-right: 10px;
}
.gradio-container .gradio-examples {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
/* Ensure the parent container does not stretch */
.gradio-container .gradio-examples {
max-width: 100%;
overflow: hidden;
}
/* Additional styles to ensure proper sizing in Safari */
.gradio-container .gradio-examples .gr-sample img {
width: 240px !important;
height: 135px !important;
object-fit: cover;
}
</style>
"""
custom_html = custom_css + """
<div class="header">
<h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
<h2>Synchronization and Active Speaker Detection Demo</h2>
<p><a href='https://www.robots.ox.ac.uk/~vgg/research/gestsync/'>Project Page</a> | <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> | <a href='https://arxiv.org/abs/2310.05304'>Paper</a></p>
</div>
"""
tips = """
<div>
<br><br>
Please give us a 🌟 on <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> if you like our work!
Tips to get better results:
<ul>
<li>Number of Average Frames: Higher the number, better the results.</li>
<li>Clicking on "apply pre-processing" will give better results for synchornization, but this is an expensive operation and might take a while.</li>
<li>Input videos with clearly visible gestures work better.</li>
</ul>
</div>
"""
# Define functions
def toggle_slider(global_speaker):
if global_speaker == "per-frame-prediction":
return gr.update(visible=True)
else:
return gr.update(visible=False)
def toggle_demo(demo_choice):
if demo_choice == "Synchronization-correction":
return (
gr.update(value=None, visible=True), # video_input
gr.update(value=75, visible=True), # num_avg_frames
gr.update(value=None, visible=True), # apply_preprocess
gr.update(value="global-prediction", visible=False), # global_speaker
gr.update(value=None, visible=True), # output_video
gr.update(value="", visible=True), # result_text
gr.update(visible=True), # submit_button
gr.update(visible=True), # clear_button
gr.update(visible=True), # sync_examples
gr.update(visible=False), # asd_examples
gr.update(visible=True) # tips
)
else:
return (
gr.update(value=None, visible=True), # video_input
gr.update(value=75, visible=True), # num_avg_frames
gr.update(value=None, visible=False), # apply_preprocess
gr.update(value="global-prediction", visible=True), # global_speaker
gr.update(value=None, visible=True), # output_video
gr.update(value="", visible=True), # result_text
gr.update(visible=True), # submit_button
gr.update(visible=True), # clear_button
gr.update(visible=False), # sync_examples
gr.update(visible=True), # asd_examples
gr.update(visible=True) # tips
)
def clear_inputs():
return None, None, "global-prediction", 75, None, "", None
def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess):
if demo_choice == "Synchronization-correction":
return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess)
else:
return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
# Define paths to sample videos
sync_sample_videos = [
["samples/sync_sample_1.mp4"],
["samples/sync_sample_2.mp4"]
]
asd_sample_videos = [
["samples/asd_sample_1.mp4"],
["samples/asd_sample_2.mp4"]
]
# Define Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
gr.HTML(custom_html)
demo_choice = gr.Radio(
choices=["Synchronization-correction", "Active-speaker-detection"],
label="Please select the task you want to perform"
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video", height=400, visible=False)
num_avg_frames = gr.Slider(
minimum=50,
maximum=150,
step=5,
value=75,
label="Number of Average Frames",
visible=False
)
apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False)
global_speaker = gr.Radio(
choices=["global-prediction", "per-frame-prediction"],
value="global-prediction",
label="Global Speaker Prediction",
visible=False
)
global_speaker.change(
fn=toggle_slider,
inputs=global_speaker,
outputs=num_avg_frames
)
with gr.Column():
output_video = gr.Video(label="Output Video", height=400, visible=False)
result_text = gr.Textbox(label="Result", visible=False)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary", visible=False)
clear_button = gr.Button("Clear", visible=False)
# Add a gap before examples
gr.HTML('<div class="examples-holder"></div>')
# Add examples that only populate the video input
sync_examples = gr.Dataset(
samples=sync_sample_videos,
components=[video_input],
type="values",
visible=False
)
asd_examples = gr.Dataset(
samples=asd_sample_videos,
components=[video_input],
type="values",
visible=False
)
tips = gr.Markdown(tips, visible=False)
demo_choice.change(
fn=toggle_demo,
inputs=demo_choice,
outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, output_video, result_text, submit_button, clear_button, sync_examples, asd_examples, tips]
)
sync_examples.select(
fn=lambda x: gr.update(value=x[0], visible=True),
inputs=sync_examples,
outputs=video_input
)
asd_examples.select(
fn=lambda x: gr.update(value=x[0], visible=True),
inputs=asd_examples,
outputs=video_input
)
submit_button.click(
fn=process_video,
inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess],
outputs=[output_video, result_text]
)
clear_button.click(
fn=clear_inputs,
inputs=[],
outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
)
# Launch the interface
demo.launch(allowed_paths=["."], share=True)