import gradio as gr import os import torch from shutil import rmtree from torch import nn from torch.nn import functional as F import numpy as np import subprocess import cv2 import pickle import librosa from decord import VideoReader from decord import cpu, gpu from utils.audio_utils import * from utils.inference_utils import * from sync_models.gestsync_models import * from tqdm import tqdm from glob import glob from scipy.io.wavfile import write import mediapipe as mp from protobuf_to_dict import protobuf_to_dict import warnings import spaces mp_holistic = mp.solutions.holistic warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) # Initialize global variables CHECKPOINT_PATH = "model_rgb.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_cuda = torch.cuda.is_available() batch_size = 12 fps = 25 n_negative_samples = 100 # Initialize the mediapipe holistic keypoint detection model holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) @spaces.GPU(duration=300) def preprocess_video(path, result_folder, apply_preprocess, padding=20): ''' This function preprocesses the input video to extract the audio and crop the frames using YOLO model Args: - path (string) : Path of the input video file - result_folder (string) : Path of the folder to save the extracted audio and cropped video - padding (int) : Padding to add to the bounding box Returns: - wav_file (string) : Path of the extracted audio file - fps (int) : FPS of the input video - video_output (string) : Path of the cropped video file - msg (string) : Message to be returned ''' # Load all video frames try: vr = VideoReader(path, ctx=cpu(0)) fps = vr.get_avg_fps() frame_count = len(vr) except: msg = "Oops! Could not load the video. Please check the input video and try again." return None, None, None, msg if frame_count < 25: msg = "Not enough frames to process! Please give a longer video as input" return None, None, None, msg # Extract the audio from the input video file using ffmpeg wav_file = os.path.join(result_folder, "audio.wav") status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \ -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True) if status != 0: msg = "Oops! Could not load the audio file. Please check the input video and try again." return None, None, None, msg print("Extracted the audio from the video") if apply_preprocess=="True": all_frames = [] for k in range(len(vr)): all_frames.append(vr[k].asnumpy()) all_frames = np.asarray(all_frames) print("Extracted the frames for pre-processing") # Load YOLOv9 model (pre-trained on COCO dataset) yolo_model = YOLO("yolov9s.pt") print("Loaded the YOLO model") person_videos = {} person_tracks = {} print("Processing the frames...") for frame_idx in tqdm(range(frame_count)): frame = all_frames[frame_idx] # Perform person detection results = yolo_model(frame, verbose=False) detections = results[0].boxes for i, det in enumerate(detections): x1, y1, x2, y2 = det.xyxy[0] cls = det.cls[0] if int(cls) == 0: # Class 0 is 'person' in COCO dataset x1 = max(0, int(x1) - padding) y1 = max(0, int(y1) - padding) x2 = min(frame.shape[1], int(x2) + padding) y2 = min(frame.shape[0], int(y2) + padding) if i not in person_videos: person_videos[i] = [] person_tracks[i] = [] person_videos[i].append(frame) person_tracks[i].append([x1,y1,x2,y2]) num_persons = 0 for i in person_videos.keys(): if len(person_videos[i]) >= frame_count//2: num_persons+=1 if num_persons==0: msg = "No person detected in the video! Please give a video with one person as input" return None, None, None, msg if num_persons>1: msg = "More than one person detected in the video! Please give a video with only one person as input" return None, None, None, msg # For the person detected, crop the frame based on the bounding box if len(person_videos[0]) > frame_count-10: crop_filename = os.path.join(result_folder, "preprocessed_video.avi") fourcc = cv2.VideoWriter_fourcc(*'DIVX') # Get bounding box coordinates based on person_tracks[i] max_x1 = min([track[0] for track in person_tracks[0]]) max_y1 = min([track[1] for track in person_tracks[0]]) max_x2 = max([track[2] for track in person_tracks[0]]) max_y2 = max([track[3] for track in person_tracks[0]]) max_width = max_x2 - max_x1 max_height = max_y2 - max_y1 out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height)) for frame in person_videos[0]: crop = frame[max_y1:max_y2, max_x1:max_x2] crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) out.write(crop) out.release() no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True) if status != 0: msg = "Oops! Could not preprocess the video. Please check the input video and try again." return None, None, None, msg video_output = crop_filename.split('.')[0] + '.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' % (wav_file , no_sound_video, video_output), shell=True) if status != 0: msg = "Oops! Could not preprocess the video. Please check the input video and try again." return None, None, None, msg os.remove(crop_filename) os.remove(no_sound_video) print("Successfully saved the pre-processed video: ", video_output) else: msg = "Could not track the person in the full video! Please give a single-speaker video as input" return None, None, None, msg else: video_output = path return wav_file, fps, video_output, "success" def resample_video(video_file, video_fname, result_folder): ''' This function resamples the video to 25 fps Args: - video_file (string) : Path of the input video file - video_fname (string) : Name of the input video file - result_folder (string) : Path of the folder to save the resampled video Returns: - video_file_25fps (string) : Path of the resampled video file ''' video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname)) # Resample the video to 25 fps status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True) if status != 0: msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again." return None, msg print('Resampled the video to 25 fps: {}'.format(video_file_25fps)) return video_file_25fps, "success" def load_checkpoint(path, model): ''' This function loads the trained model from the checkpoint Args: - path (string) : Path of the checkpoint file - model (object) : Model object Returns: - model (object) : Model object with the weights loaded from the checkpoint ''' # Load the checkpoint if use_cuda: checkpoint = torch.load(path) else: checkpoint = torch.load(path, map_location="cpu") s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) if use_cuda: model.cuda() print("Loaded checkpoint from: {}".format(path)) return model.eval() def load_video_frames(video_file): ''' This function extracts the frames from the video Args: - video_file (string) : Path of the video file Returns: - frames (list) : List of frames extracted from the video - msg (string) : Message to be returned ''' # Read the video try: vr = VideoReader(video_file, ctx=cpu(0)) except: msg = "Oops! Could not load the input video file" return None, msg # Extract the frames frames = [] for k in range(len(vr)): frames.append(vr[k].asnumpy()) frames = np.asarray(frames) return frames, "success" def get_keypoints(frames): ''' This function extracts the keypoints from the frames using MediaPipe Holistic pipeline Args: - frames (list) : List of frames extracted from the video Returns: - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames - msg (string) : Message to be returned ''' try: holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) resolution = frames[0].shape all_frame_kps = [] for frame in frames: results = holistic.process(frame) pose, left_hand, right_hand, face = None, None, None, None if results.pose_landmarks is not None: pose = protobuf_to_dict(results.pose_landmarks)['landmark'] if results.left_hand_landmarks is not None: left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark'] if results.right_hand_landmarks is not None: right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark'] if results.face_landmarks is not None: face = protobuf_to_dict(results.face_landmarks)['landmark'] frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face} all_frame_kps.append(frame_dict) kp_dict = {"kps":all_frame_kps, "resolution":resolution} except Exception as e: print("Error: ", e) return None, "Error: Could not extract keypoints from the frames" return kp_dict, "success" def check_visible_gestures(kp_dict): ''' This function checks if the gestures in the video are visible Args: - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames Returns: - msg (string) : Message to be returned ''' keypoints = kp_dict['kps'] keypoints = np.array(keypoints) if len(keypoints)<25: msg = "Not enough keypoints to process! Please give a longer video as input" return msg pose_count, hand_count = 0, 0 for frame_kp_dict in keypoints: pose = frame_kp_dict["pose"] left_hand = frame_kp_dict["left_hand"] right_hand = frame_kp_dict["right_hand"] if pose is None: pose_count += 1 if left_hand is None and right_hand is None: hand_count += 1 if hand_count/len(keypoints) > 0.6 or pose_count/len(keypoints) > 0.6: msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input." return msg print("Successfully verified the input video - Gestures are visible!") return "success" def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_frames=25, width=480, height=270): ''' This function masks the faces using the keypoints extracted from the frames Args: - input_frames (list) : List of frames extracted from the video - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames - stride (int) : Stride to extract the frames - window_frames (int) : Number of frames in each window that is given as input to the model - width (int) : Width of the frames - height (int) : Height of the frames Returns: - input_frames (array) : Frame window to be given as input to the model - num_frames (int) : Number of frames to extract - orig_masked_frames (array) : Masked frames extracted from the video - msg (string) : Message to be returned ''' print("Creating masked input frames...") input_frames_masked = [] if kp_dict is None: for img in tqdm(input_frames): img = cv2.resize(img, (width, height)) masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1) input_frames_masked.append(masked_img) else: # Face indices to extract the face-coordinates needed for masking face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172, 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454] input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution'] print("Input keypoints: ", len(input_keypoints)) for i, frame_kp_dict in tqdm(enumerate(input_keypoints)): img = input_frames[i] face = frame_kp_dict["face"] if face is None: img = cv2.resize(img, (width, height)) masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1) else: face_kps = [] for idx in range(len(face)): if idx in face_oval_idx: x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0]) face_kps.append((x,y)) face_kps = np.array(face_kps) x1, y1 = min(face_kps[:,0]), min(face_kps[:,1]) x2, y2 = max(face_kps[:,0]), max(face_kps[:,1]) masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1) if masked_img.shape[0] != width or masked_img.shape[1] != height: masked_img = cv2.resize(masked_img, (width, height)) input_frames_masked.append(masked_img) orig_masked_frames = np.array(input_frames_masked) input_frames = np.array(input_frames_masked) / 255. if asd: input_frames = np.pad(input_frames, ((12, 12), (0,0), (0,0), (0,0)), 'edge') # print("Input images full: ", input_frames.shape) # num_framesx270x480x3 input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])]) # print("Input images window: ", input_frames.shape) # Tx25x270x480x3 num_frames = input_frames.shape[0] if num_frames<10: msg = "Not enough frames to process! Please give a longer video as input." return None, None, None, msg return input_frames, num_frames, orig_masked_frames, "success" def load_spectrograms(wav_file, asd=False, num_frames=None, window_frames=25, stride=4): ''' This function extracts the spectrogram from the audio file Args: - wav_file (string) : Path of the extracted audio file - num_frames (int) : Number of frames to extract - window_frames (int) : Number of frames in each window that is given as input to the model - stride (int) : Stride to extract the audio frames Returns: - spec (array) : Spectrogram array window to be used as input to the model - orig_spec (array) : Spectrogram array extracted from the audio file - msg (string) : Message to be returned ''' # Extract the audio from the input video file using ffmpeg try: wav = librosa.load(wav_file, sr=16000)[0] except: msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again." return None, None, msg # Convert to tensor wav = torch.FloatTensor(wav).unsqueeze(0) mel, _, _, _ = wav2filterbanks(wav.to(device)) spec = mel.squeeze(0).cpu().numpy() orig_spec = spec spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])]) if num_frames is not None: if len(spec) != num_frames: spec = spec[:num_frames] frame_diff = np.abs(len(spec) - num_frames) if frame_diff > 60: print("The input video and audio length do not match - The results can be unreliable! Please check the input video.") if asd: pad_frames = (window_frames//2) spec = np.pad(spec, ((pad_frames, pad_frames), (0,0), (0,0)), 'edge') return spec, orig_spec, "success" def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model): ''' This function calculates the audio-visual offset between the video and audio Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - num_avg_frames (int) : Number of frames to average the scores - model (object) : Model object Returns: - offset (int) : Optimal audio-visual offset - msg (string) : Message to be returned ''' pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames) if status != "success": return None, status scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model) offset = scores.argmax()*stride - pos_idx return offset.item(), "success" def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5): ''' This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - num_avg_frames (int) : Number of frames to average the scores - stride (int) : Stride to extract the negative windows Returns: - vid_emb_pos (array) : Positive video embedding array - aud_emb_posneg (array) : All possible combinations of audio embedding array - pos_idx_frame (int) : Positive video embedding array frame - stride (int) : Stride used to extract the negative windows - msg (string) : Message to be returned ''' slice_size = num_avg_frames aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride) aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3]) aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1] pos_idx = (aud_emb_posneg.shape[1]//2) pos_idx_frame = pos_idx*stride min_offset_frames = -(pos_idx)*stride max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames)) vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size] if vid_emb_pos.shape[2] != slice_size: msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size) return None, None, None, None, msg return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success" def calc_av_scores(vid_emb, aud_emb, model): ''' This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - model (object) : Model object Returns: - scores (array) : Audio-visual similarity scores - att_map (array) : Attention map ''' scores = calc_att_map(vid_emb, aud_emb, model) att_map = logsoftmax_2d(scores) scores = scores.mean(-1) return scores, att_map def calc_att_map(vid_emb, aud_emb, model): ''' This function calculates the similarity between the video and audio embeddings Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - model (object) : Model object Returns: - scores (array) : Audio-visual similarity scores ''' vid_emb = vid_emb[:, :, None] aud_emb = aud_emb.transpose(1, 2) scores = run_func_in_parts(lambda x, y: (x * y).sum(1), vid_emb, aud_emb, part_len=10, dim=3, device=device) scores = model.logits_scale(scores[..., None]).squeeze(-1) return scores def generate_video(frames, audio_file, video_fname): ''' This function generates the video from the frames and audio file Args: - frames (array) : Frames to be used to generate the video - audio_file (string) : Path of the audio file - video_fname (string) : Path of the video file Returns: - video_output (string) : Path of the video file ''' fname = 'inference.avi' video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0])) for i in range(len(frames)): video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)) video.release() no_sound_video = video_fname + '_nosound.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True) if status != 0: msg = "Oops! Could not generate the video. Please check the input video and try again." return None, msg video_output = video_fname + '.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' % (audio_file, no_sound_video, video_output), shell=True) if status != 0: msg = "Oops! Could not generate the video. Please check the input video and try again." return None, msg os.remove(fname) os.remove(no_sound_video) return video_output, "success" def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25): ''' This function corrects the video and audio to sync with each other Args: - video_path (string) : Path of the video file - frames (array) : Frames to be used to generate the video - wav_file (string) : Path of the audio file - offset (int) : Predicted sync-offset to be used to correct the video - result_folder (string) : Path of the result folder to save the output sync-corrected video - sample_rate (int) : Sample rate of the audio - fps (int) : Frames per second of the video Returns: - video_output (string) : Path of the video file ''' if offset == 0: print("The input audio and video are in-sync! No need to perform sync correction.") return video_path, "success" print("Performing Sync Correction...") corrected_frames = np.zeros_like(frames) if offset > 0: audio_offset = int(offset*(sample_rate/fps)) wav = librosa.core.load(wav_file, sr=sample_rate)[0] corrected_wav = wav[audio_offset:] corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav") write(corrected_wav_file, sample_rate, corrected_wav) wav_file = corrected_wav_file corrected_frames = frames elif offset < 0: corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):] corrected_frames = corrected_frames[:len(frames)-np.abs(offset)] corrected_video_path = os.path.join(result_folder, "result_sync_corrected") video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path) if status != "success": return None, status return video_output, "success" def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder): ''' This function loads the masked input frames from the video Args: - test_videos (list) : List of videos to be processed (speaker-specific tracks) - spec (array) : Spectrogram of the audio - wav_file (string) : Path of the audio file - scene_num (int) : Scene number to be used to save the input masked video - result_folder (string) : Path of the folder to save the input masked video Returns: - all_frames (list) : List of masked input frames window to be used as input to the model - all_orig_frames (list) : List of original masked input frames ''' all_frames, all_orig_frames = [], [] for video_num, video in enumerate(test_videos): print("Processing video: ", video) # Load the video frames frames, status = load_video_frames(video) if status != "success": return None, None, status print("Successfully loaded the video frames") # Extract the keypoints from the frames kp_dict, status = get_keypoints(frames) if status != "success": return None, None, status print("Successfully extracted the keypoints") # Mask the frames using the keypoints extracted from the frames and prepare the input to the model masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=True) if status != "success": return None, None, status print("Successfully loaded the masked frames") # Check if the length of the input frames is equal to the length of the spectrogram if spec.shape[2]!=masked_frames.shape[0]: num_frames = spec.shape[2] masked_frames = masked_frames[:num_frames] orig_masked_frames = orig_masked_frames[:num_frames] frame_diff = np.abs(spec.shape[2] - num_frames) if frame_diff > 60: print("The input video and audio length do not match - The results can be unreliable! Please check the input video.") # Transpose the frames to the correct format frames = np.transpose(masked_frames, (4, 0, 1, 2, 3)) frames = torch.FloatTensor(np.array(frames)).unsqueeze(0) print("Successfully converted the frames to tensor") all_frames.append(frames) all_orig_frames.append(orig_masked_frames) return all_frames, all_orig_frames, "success" def extract_audio(video, result_folder): ''' This function extracts the audio from the video file Args: - video (string) : Path of the video file - result_folder (string) : Path of the folder to save the extracted audio file Returns: - wav_file (string) : Path of the extracted audio file ''' wav_file = os.path.join(result_folder, "audio.wav") status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \ -acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True) if status != 0: msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again" return None, msg return wav_file, "success" @spaces.GPU(duration=200) def get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True): ''' This function extracts the video and audio embeddings from the input frames and audio sequences Args: - video_sequences (array) : Array of video frames to be used as input to the model - audio_sequences (array) : Array of audio frames to be used as input to the model - model (object) : Model object - calc_aud_emb (bool) : Flag to calculate the audio embedding Returns: - video_emb (array) : Video embedding - audio_emb (array) : Audio embedding ''' batch_size = 12 video_emb = [] audio_emb = [] for i in range(0, len(video_sequences), batch_size): video_inp = video_sequences[i:i+batch_size, ] vid_emb = model.forward_vid(video_inp.to(device), return_feats=False) vid_emb = torch.mean(vid_emb, axis=-1) video_emb.append(vid_emb.detach()) if calc_aud_emb: audio_inp = audio_sequences[i:i+batch_size, ] aud_emb = model.forward_aud(audio_inp.to(device)) audio_emb.append(aud_emb.detach()) torch.cuda.empty_cache() video_emb = torch.cat(video_emb, dim=0) if calc_aud_emb: audio_emb = torch.cat(audio_emb, dim=0) return video_emb, audio_emb return video_emb def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model): ''' This function predicts the active speaker in each frame Args: - all_video_embeddings (array) : Array of video embeddings of all speakers - audio_embedding (array) : Audio embedding - global_score (bool) : Flag to calculate the global score Returns: - pred_speaker (list) : List of active speakers in each frame ''' cos = nn.CosineSimilarity(dim=1) audio_embedding = audio_embedding.squeeze(2) scores = [] for i in range(len(all_video_embeddings)): video_embedding = all_video_embeddings[i] # Compute the similarity of each speaker's video embeddings with the audio embedding sim = cos(video_embedding, audio_embedding) # Apply the logits scale to the similarity scores (scaling the scores) output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1) if global_score=="True": score = output.mean(0) else: if output.shape[0] Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos))) if len(test_videos)<=1: msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..." return None, msg # Load the audio file audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0] spec, _, status = load_spectrograms(audio_file, asd=True) if status != "success": return None, status spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3) print("Successfully loaded the spectrograms") # Load the masked input frames all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input) if status != "success": return None, status print("Successfully loaded the masked input frames") # Prepare the audio and video sequences for the model audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0) print("Obtaining audio and video embeddings...") all_video_embs = [] for idx in tqdm(range(len(all_masked_frames))): with torch.no_grad(): video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0) if idx==0: video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True) else: video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False) all_video_embs.append(video_emb) print("Successfully extracted GestSync embeddings") # Predict the active speaker in each scene if global_speaker=="per-frame-prediction": predictions, num_avg_frames = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model) else: predictions, _ = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model) # Get the frames present in the scene frames_scene = tracks[scene_num][0]['track']['frame'] # Prepare the active speakers list to draw the bounding boxes if global_speaker=="global-prediction": print("Aggregating scores using global predictoins") active_speakers = [predictions]*len(frames_scene) start, end = 0, len(frames_scene) else: print("Aggregating scores using per-frame predictions") active_speakers = [0]*len(frames_scene) mid = num_avg_frames//2 if num_avg_frames%2==0: frame_pred = len(frames_scene)-(mid*2)+1 start, end = mid, len(frames_scene)-mid+1 else: frame_pred = len(frames_scene)-(mid*2) start, end = mid, len(frames_scene)-mid print("Frame scene: {} | Avg frames: {} | Frame predictions: {}".format(len(frames_scene), num_avg_frames, frame_pred)) if len(predictions) != frame_pred: msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred) return None, msg active_speakers[start:end] = predictions[0:] # Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count) active_speakers[0:start] = [initial_preds] * start final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count) active_speakers[end:] = [final_preds] * (len(frames_scene) - end) start, end = 0, len(active_speakers) # Get the output tracks for each frame pred_idx = 0 for frame in frames_scene[start:end]: label = active_speakers[pred_idx] pred_idx += 1 output_tracks[frame] = track_dict[scene_num][label][frame] # Save the output video video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output) if status != "success": return None, status print("Successfully saved the output video: ", video_output) return video_output, "success" except Exception as e: return None, f"Error: {str(e)}" if __name__ == "__main__": # Custom CSS and HTML custom_css = """ """ custom_html = custom_css + """

GestSync: Determining who is speaking without a talking head

Synchronization and Active Speaker Detection Demo

Project Page | Github | Paper

""" tips = """


Please give us a 🌟 on Github if you like our work! Tips to get better results:
""" # Define functions def toggle_slider(global_speaker): if global_speaker == "per-frame-prediction": return gr.update(visible=True) else: return gr.update(visible=False) def toggle_demo(demo_choice): if demo_choice == "Synchronization-correction": return ( gr.update(value=None, visible=True), # video_input gr.update(value=75, visible=True), # num_avg_frames gr.update(value=None, visible=True), # apply_preprocess gr.update(value="global-prediction", visible=False), # global_speaker gr.update(value=None, visible=True), # output_video gr.update(value="", visible=True), # result_text gr.update(visible=True), # submit_button gr.update(visible=True), # clear_button gr.update(visible=True), # sync_examples gr.update(visible=False), # asd_examples gr.update(visible=True) # tips ) else: return ( gr.update(value=None, visible=True), # video_input gr.update(value=75, visible=True), # num_avg_frames gr.update(value=None, visible=False), # apply_preprocess gr.update(value="global-prediction", visible=True), # global_speaker gr.update(value=None, visible=True), # output_video gr.update(value="", visible=True), # result_text gr.update(visible=True), # submit_button gr.update(visible=True), # clear_button gr.update(visible=False), # sync_examples gr.update(visible=True), # asd_examples gr.update(visible=True) # tips ) def clear_inputs(): return None, None, "global-prediction", 75, None, "", None def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess): if demo_choice == "Synchronization-correction": return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess) else: return process_video_activespeaker(video_input, global_speaker, num_avg_frames) # Define paths to sample videos sync_sample_videos = [ ["samples/sync_sample_1.mp4"], ["samples/sync_sample_2.mp4"] ] asd_sample_videos = [ ["samples/asd_sample_1.mp4"], ["samples/asd_sample_2.mp4"] ] # Define Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo: gr.HTML(custom_html) demo_choice = gr.Radio( choices=["Synchronization-correction", "Active-speaker-detection"], label="Please select the task you want to perform" ) with gr.Row(): with gr.Column(): video_input = gr.Video(label="Upload Video", height=400, visible=False) num_avg_frames = gr.Slider( minimum=50, maximum=150, step=5, value=75, label="Number of Average Frames", visible=False ) apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False) global_speaker = gr.Radio( choices=["global-prediction", "per-frame-prediction"], value="global-prediction", label="Global Speaker Prediction", visible=False ) global_speaker.change( fn=toggle_slider, inputs=global_speaker, outputs=num_avg_frames ) with gr.Column(): output_video = gr.Video(label="Output Video", height=400, visible=False) result_text = gr.Textbox(label="Result", visible=False) with gr.Row(): submit_button = gr.Button("Submit", variant="primary", visible=False) clear_button = gr.Button("Clear", visible=False) # Add a gap before examples gr.HTML('
') # Add examples that only populate the video input sync_examples = gr.Dataset( samples=sync_sample_videos, components=[video_input], type="values", visible=False ) asd_examples = gr.Dataset( samples=asd_sample_videos, components=[video_input], type="values", visible=False ) tips = gr.Markdown(tips, visible=False) demo_choice.change( fn=toggle_demo, inputs=demo_choice, outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, output_video, result_text, submit_button, clear_button, sync_examples, asd_examples, tips] ) sync_examples.select( fn=lambda x: gr.update(value=x[0], visible=True), inputs=sync_examples, outputs=video_input ) asd_examples.select( fn=lambda x: gr.update(value=x[0], visible=True), inputs=asd_examples, outputs=video_input ) submit_button.click( fn=process_video, inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess], outputs=[output_video, result_text] ) clear_button.click( fn=clear_inputs, inputs=[], outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video] ) # Launch the interface demo.launch(allowed_paths=["."], share=True)