sindhuhegde commited on
Commit
3860ffa
1 Parent(s): 43bd4b0

Update app

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +518 -79
  3. app_v1.py +954 -0
  4. preprocess/inference_preprocess.py +326 -0
  5. yolov9c.pt +0 -3
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *__pycache__*
2
+ *yolov9c.pt*
app.py CHANGED
@@ -1,40 +1,40 @@
1
  import gradio as gr
2
- import argparse
3
- import os, subprocess
4
  from shutil import rmtree
5
-
 
6
  import numpy as np
 
7
  import cv2
 
8
  import librosa
9
- import torch
10
-
11
  from utils.audio_utils import *
12
  from utils.inference_utils import *
13
  from sync_models.gestsync_models import *
14
-
15
- import sys
16
- if sys.version_info > (3, 0): long, unicode, basestring = int, str, str
17
-
18
  from tqdm import tqdm
19
- from scipy.io.wavfile import write
20
  import mediapipe as mp
21
  from protobuf_to_dict import protobuf_to_dict
22
- mp_holistic = mp.solutions.holistic
23
- from ultralytics import YOLO
24
- from decord import VideoReader, cpu
25
-
26
  import warnings
27
- warnings.filterwarnings("ignore", category=DeprecationWarning)
28
- warnings.filterwarnings("ignore", category=UserWarning)
29
 
30
- # Set the path to checkpoint file
31
- CHECKPOINT_PATH = "model_rgb.pth"
 
32
 
33
  # Initialize global variables
 
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  use_cuda = torch.cuda.is_available()
 
 
36
  n_negative_samples = 100
37
- print("Using CUDA: ", use_cuda, device)
 
 
 
38
 
39
  def preprocess_video(path, result_folder, apply_preprocess, padding=20):
40
 
@@ -641,24 +641,216 @@ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, samp
641
 
642
  return video_output
643
 
644
- class Logger:
645
- def __init__(self, filename):
646
- self.terminal = sys.stdout
647
- self.log = open(filename, "w")
648
 
649
- def write(self, message):
650
- self.terminal.write(message)
651
- self.log.write(message)
652
-
653
- def flush(self):
654
- self.terminal.flush()
655
- self.log.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
657
- def isatty(self):
658
- return False
 
 
 
 
 
 
 
 
 
659
 
660
 
661
- def process_video(video_path, num_avg_frames, apply_preprocess):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  try:
663
  # Extract the video filename
664
  video_fname = os.path.basename(video_path.split(".")[0])
@@ -794,18 +986,184 @@ def process_video(video_path, num_avg_frames, apply_preprocess):
794
  except Exception as e:
795
  return f"Error: {str(e)}", None
796
 
797
- def read_logs():
798
- sys.stdout.flush()
799
- with open("output.log", "r") as f:
800
- return f.read()
 
 
 
 
 
801
 
 
 
802
 
803
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
- sys.stdout = Logger("output.log")
806
 
 
807
 
808
- # Define the custom HTML for the header
 
809
  custom_css = """
810
  <style>
811
  body {
@@ -855,6 +1213,7 @@ if __name__ == "__main__":
855
  .examples-holder {
856
  margin-top: 2em;
857
  }
 
858
  /* Set fixed size for example videos */
859
  .gradio-container .gradio-examples .gr-sample {
860
  width: 240px !important;
@@ -888,67 +1247,147 @@ if __name__ == "__main__":
888
  custom_html = custom_css + """
889
  <div class="header">
890
  <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
891
- <h2>Upload any video to predict the synchronization offset and generate a sync-corrected video</h2>
892
  <p>Sindhu Hegde and Andrew Zisserman</p>
893
  <p>VGG, University of Oxford</p>
894
  </div>
895
  """
896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  # Define paths to sample videos
898
- sample_videos = [
899
- "samples/sync_sample_1.mp4",
900
- "samples/sync_sample_2.mp4",
901
- ]
902
-
 
 
 
 
 
 
903
  # Define Gradio interface
904
  with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
905
  gr.HTML(custom_html)
 
 
 
 
906
  with gr.Row():
907
  with gr.Column():
908
- with gr.Group(elem_classes="slider-container"):
909
- num_avg_frames = gr.Slider(
910
- minimum=50,
911
- maximum=150,
912
- step=5,
913
- value=75,
914
- label="Number of Average Frames",
915
- )
916
- apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False)
917
- video_input = gr.Video(label="Upload Video", height=400)
918
-
 
 
 
 
 
 
 
 
 
 
919
  with gr.Column():
920
- result_text = gr.Textbox(label="Result")
921
- output_video = gr.Video(label="Sync Corrected Video", height=400)
922
 
923
  with gr.Row():
924
- submit_button = gr.Button("Submit", variant="primary")
925
- clear_button = gr.Button("Clear")
926
-
927
- submit_button.click(
928
- fn=process_video,
929
- inputs=[video_input, num_avg_frames, apply_preprocess],
930
- outputs=[result_text, output_video]
931
- )
932
-
933
- clear_button.click(
934
- fn=lambda: (None, 75, False, "", None),
935
- inputs=[],
936
- outputs=[video_input, num_avg_frames, result_text, output_video]
937
- )
938
 
 
939
  gr.HTML('<div class="examples-holder"></div>')
940
 
941
- # Add examples
942
- gr.Examples(
943
- examples=sample_videos,
 
 
 
 
 
 
 
 
 
944
  inputs=video_input,
945
  outputs=None,
946
  fn=None,
947
  cache_examples=False,
 
948
  )
949
 
950
- logs = gr.Textbox(label="Logs")
951
- demo.load(read_logs, None, logs, every=1)
952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  # Launch the interface
954
- demo.queue().launch(allowed_paths=["."], show_error=True)
 
1
  import gradio as gr
2
+ import os
3
+ import torch
4
  from shutil import rmtree
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
  import numpy as np
8
+ import subprocess
9
  import cv2
10
+ import pickle
11
  import librosa
12
+ from decord import VideoReader
13
+ from decord import cpu, gpu
14
  from utils.audio_utils import *
15
  from utils.inference_utils import *
16
  from sync_models.gestsync_models import *
 
 
 
 
17
  from tqdm import tqdm
18
+ from glob import glob
19
  import mediapipe as mp
20
  from protobuf_to_dict import protobuf_to_dict
 
 
 
 
21
  import warnings
 
 
22
 
23
+ mp_holistic = mp.solutions.holistic
24
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
25
+ warnings.filterwarnings("ignore", category=UserWarning)
26
 
27
  # Initialize global variables
28
+ CHECKPOINT_PATH = "model_rgb.pth"
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  use_cuda = torch.cuda.is_available()
31
+ batch_size = 12
32
+ fps = 25
33
  n_negative_samples = 100
34
+
35
+ # Initialize the mediapipe holistic keypoint detection model
36
+ holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
37
+
38
 
39
  def preprocess_video(path, result_folder, apply_preprocess, padding=20):
40
 
 
641
 
642
  return video_output
643
 
 
 
 
 
644
 
645
+ def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
646
+
647
+ '''
648
+ This function loads the masked input frames from the video
649
+
650
+ Args:
651
+ - test_videos (list) : List of videos to be processed (speaker-specific tracks)
652
+ - spec (array) : Spectrogram of the audio
653
+ - wav_file (string) : Path of the audio file
654
+ - scene_num (int) : Scene number to be used to save the input masked video
655
+ - result_folder (string) : Path of the folder to save the input masked video
656
+ Returns:
657
+ - all_frames (list) : List of masked input frames window to be used as input to the model
658
+ - all_orig_frames (list) : List of original masked input frames
659
+ '''
660
+
661
+ all_frames, all_orig_frames = [], []
662
+ for video_num, video in enumerate(test_videos):
663
+
664
+ # Load the video frames
665
+ frames, status = load_video_frames(video)
666
+ if status != "success":
667
+ return None, None, status
668
+
669
+ # Extract the keypoints from the frames
670
+ kp_dict, status = get_keypoints(frames)
671
+ if status != "success":
672
+ return None, None, status
673
+
674
+ # Mask the frames using the keypoints extracted from the frames and prepare the input to the model
675
+ masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict)
676
+ if status != "success":
677
+ return None, None, status
678
+
679
+ input_masked_vid_path = os.path.join(result_folder, "input_masked_scene_{}_speaker_{}".format(scene_num, video_num))
680
+ generate_video(orig_masked_frames, wav_file, input_masked_vid_path)
681
+
682
+ # Check if the length of the input frames is equal to the length of the spectrogram
683
+ if spec.shape[2]!=masked_frames.shape[0]:
684
+ num_frames = spec.shape[2]
685
+ masked_frames = masked_frames[:num_frames]
686
+ orig_masked_frames = orig_masked_frames[:num_frames]
687
+ frame_diff = np.abs(spec.shape[2] - num_frames)
688
+ if frame_diff > 60:
689
+ print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
690
+
691
+ # Transpose the frames to the correct format
692
+ frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
693
+ frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
694
+
695
+ all_frames.append(frames)
696
+ all_orig_frames.append(orig_masked_frames)
697
+
698
+
699
+ return all_frames, all_orig_frames, "success"
700
+
701
+ def extract_audio(video, result_folder):
702
+
703
+ '''
704
+ This function extracts the audio from the video file
705
+
706
+ Args:
707
+ - video (string) : Path of the video file
708
+ - result_folder (string) : Path of the folder to save the extracted audio file
709
+ Returns:
710
+ - wav_file (string) : Path of the extracted audio file
711
+ '''
712
+
713
+ wav_file = os.path.join(result_folder, "audio.wav")
714
+
715
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \
716
+ -acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True)
717
+
718
+ if status != 0:
719
+ msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again"
720
+ return None, msg
721
+
722
+ return wav_file, "success"
723
+
724
+
725
+ def get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True):
726
+
727
+ '''
728
+ This function extracts the video and audio embeddings from the input frames and audio sequences
729
+
730
+ Args:
731
+ - video_sequences (array) : Array of video frames to be used as input to the model
732
+ - audio_sequences (array) : Array of audio frames to be used as input to the model
733
+ - model (object) : Model object
734
+ - calc_aud_emb (bool) : Flag to calculate the audio embedding
735
+ Returns:
736
+ - video_emb (array) : Video embedding
737
+ - audio_emb (array) : Audio embedding
738
+ '''
739
+
740
+ batch_size = 12
741
+ video_emb = []
742
+ audio_emb = []
743
+
744
+ for i in range(0, len(video_sequences), batch_size):
745
+ video_inp = video_sequences[i:i+batch_size, ]
746
+ vid_emb = model.forward_vid(video_inp.to(device), return_feats=False)
747
+ vid_emb = torch.mean(vid_emb, axis=-1)
748
+
749
+ video_emb.append(vid_emb.detach())
750
+
751
+ if calc_aud_emb:
752
+ audio_inp = audio_sequences[i:i+batch_size, ]
753
+ aud_emb = model.forward_aud(audio_inp.to(device))
754
+ audio_emb.append(aud_emb.detach())
755
 
756
+ torch.cuda.empty_cache()
757
+
758
+ video_emb = torch.cat(video_emb, dim=0)
759
+
760
+ if calc_aud_emb:
761
+ audio_emb = torch.cat(audio_emb, dim=0)
762
+
763
+ return video_emb, audio_emb
764
+
765
+ return video_emb
766
+
767
 
768
 
769
+ def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model):
770
+
771
+ '''
772
+ This function predicts the active speaker in each frame
773
+
774
+ Args:
775
+ - all_video_embeddings (array) : Array of video embeddings of all speakers
776
+ - audio_embedding (array) : Audio embedding
777
+ - global_score (bool) : Flag to calculate the global score
778
+ Returns:
779
+ - pred_speaker (list) : List of active speakers in each frame
780
+ '''
781
+
782
+ cos = nn.CosineSimilarity(dim=1)
783
+
784
+ audio_embedding = audio_embedding.squeeze(2)
785
+
786
+ scores = []
787
+ for i in range(len(all_video_embeddings)):
788
+ video_embedding = all_video_embeddings[i]
789
+
790
+ # Compute the similarity of each speaker's video embeddings with the audio embedding
791
+ sim = cos(video_embedding, audio_embedding)
792
+
793
+ # Apply the logits scale to the similarity scores (scaling the scores)
794
+ output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1)
795
+
796
+ if global_score=="True":
797
+ score = output.mean(0)
798
+ else:
799
+ output_batch = output.unfold(0, num_avg_frames, 1)
800
+ score = torch.mean(output_batch, axis=-1)
801
+
802
+ scores.append(score.detach().cpu().numpy())
803
+
804
+ if global_score=="True":
805
+ print("Using global predictions")
806
+ pred_speaker = np.argmax(scores)
807
+ else:
808
+ print("Using per-frame predictions")
809
+ pred_speaker = []
810
+ num_negs = list(range(0, len(all_video_embeddings)))
811
+ for frame_idx in range(len(scores[0])):
812
+ score = [scores[i][frame_idx] for i in num_negs]
813
+ pred_idx = np.argmax(score)
814
+ pred_speaker.append(pred_idx)
815
+
816
+ return pred_speaker
817
+
818
+
819
+ def save_video(output_tracks, input_frames, wav_file, result_folder):
820
+
821
+ '''
822
+ This function saves the output video with the active speaker detections
823
+
824
+ Args:
825
+ - output_tracks (list) : List of active speakers in each frame
826
+ - input_frames (array) : Frames to be used to generate the video
827
+ - wav_file (string) : Path of the audio file
828
+ - result_folder (string) : Path of the result folder to save the output video
829
+ Returns:
830
+ - video_output (string) : Path of the output video
831
+ '''
832
+
833
+ output_frames = []
834
+ for i in range(len(input_frames)):
835
+
836
+ # If the active speaker is found, draw a bounding box around the active speaker
837
+ if i in output_tracks:
838
+ bbox = output_tracks[i]
839
+ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
840
+ out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
841
+ else:
842
+ out = input_frames[i]
843
+
844
+ output_frames.append(out)
845
+
846
+ # Generate the output video
847
+ output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
848
+ video_output = generate_video(output_frames, wav_file, output_video_fname)
849
+
850
+ return video_output
851
+
852
+ def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
853
+
854
  try:
855
  # Extract the video filename
856
  video_fname = os.path.basename(video_path.split(".")[0])
 
986
  except Exception as e:
987
  return f"Error: {str(e)}", None
988
 
989
+ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
990
+ try:
991
+ # Extract the video filename
992
+ video_fname = os.path.basename(video_path.split(".")[0])
993
+
994
+ # Create folders to save the inputs and results
995
+ result_folder = os.path.join("results", video_fname)
996
+ result_folder_input = os.path.join(result_folder, "input")
997
+ result_folder_output = os.path.join(result_folder, "output")
998
 
999
+ if os.path.exists(result_folder):
1000
+ rmtree(result_folder)
1001
 
1002
+ os.makedirs(result_folder)
1003
+ os.makedirs(result_folder_input)
1004
+ os.makedirs(result_folder_output)
1005
+
1006
+ if global_speaker=="per-frame-prediction" and num_avg_frames<25:
1007
+ msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
1008
+ return None, msg
1009
+
1010
+ # Read the video
1011
+ try:
1012
+ vr = VideoReader(video_path, ctx=cpu(0))
1013
+ except:
1014
+ msg = "Oops! Could not load the input video file"
1015
+ return None, msg
1016
+
1017
+ # Get the FPS of the video
1018
+ fps = vr.get_avg_fps()
1019
+ print("FPS of video: ", fps)
1020
+
1021
+ # Resample the video to 25 FPS if the original video is of a different frame-rate
1022
+ if fps!=25:
1023
+ test_video_25fps = resample_video(video_path, video_fname, result_folder_input)
1024
+ else:
1025
+ test_video_25fps = video_path
1026
+
1027
+ # Load the video frames
1028
+ orig_frames, status = load_video_frames(test_video_25fps)
1029
+ if status != "success":
1030
+ return None, status
1031
+
1032
+ # Extract and save the audio file
1033
+ orig_wav_file, status = extract_audio(video_path, result_folder)
1034
+ if status != "success":
1035
+ return None, status
1036
+
1037
+ # Pre-process and extract per-speaker tracks in each scene
1038
+ print("Pre-processing the input video...")
1039
+ status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
1040
+ if status != 0:
1041
+ return None, "Error in pre-processing the input video, please check the input video and try again..."
1042
+
1043
+ # Load the tracks file saved during pre-processing
1044
+ with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
1045
+ tracks = pickle.load(file)
1046
+
1047
+
1048
+ # Create a dictionary of all tracks found along with the bounding-boxes
1049
+ track_dict = {}
1050
+ for scene_num in range(len(tracks)):
1051
+ track_dict[scene_num] = {}
1052
+ for i in range(len(tracks[scene_num])):
1053
+ track_dict[scene_num][i] = {}
1054
+ for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
1055
+ track_dict[scene_num][i][frame_num] = bbox
1056
+
1057
+ # Get the total number of scenes
1058
+ test_scenes = os.listdir("{}/crops".format(result_folder_input))
1059
+ print("Total scenes found in the input video = ", len(test_scenes))
1060
+
1061
+ # Load the trained model
1062
+ model = Transformer_RGB()
1063
+ model = load_checkpoint(CHECKPOINT_PATH, model)
1064
+
1065
+ # Compute the active speaker in each scene
1066
+ output_tracks = {}
1067
+ for scene_num in tqdm(range(len(test_scenes))):
1068
+ test_videos = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.avi"))
1069
+ test_videos.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
1070
+ print("Scene {} -> Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos)))
1071
+
1072
+ if len(test_videos)<=1:
1073
+ msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
1074
+ return None, msg
1075
+
1076
+ # Load the audio file
1077
+ audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
1078
+ spec, _, status = load_spectrograms(audio_file, window_frames=25)
1079
+ if status != "success":
1080
+ return None, status
1081
+ spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
1082
+
1083
+ # Load the masked input frames
1084
+ all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
1085
+ if status != "success":
1086
+ return None, status
1087
+
1088
+
1089
+ # Prepare the audio and video sequences for the model
1090
+ audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
1091
+
1092
+ print("Obtaining audio and video embeddings...")
1093
+ all_video_embs = []
1094
+ for idx in tqdm(range(len(all_masked_frames))):
1095
+ with torch.no_grad():
1096
+ video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0)
1097
+
1098
+ if idx==0:
1099
+ video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True)
1100
+ else:
1101
+ video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
1102
+ all_video_embs.append(video_emb)
1103
+
1104
+ # Predict the active speaker in each scene
1105
+ if global_speaker=="per-frame-prediction":
1106
+ predictions = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model)
1107
+ else:
1108
+ predictions = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model)
1109
+
1110
+ # Get the frames present in the scene
1111
+ frames_scene = tracks[scene_num][0]['track']['frame']
1112
+
1113
+ # Prepare the active speakers list to draw the bounding boxes
1114
+ if global_speaker=="global-prediction":
1115
+ print("Aggregating scores using global predictoins")
1116
+ active_speakers = [predictions]*len(frames_scene)
1117
+ start, end = 0, len(frames_scene)
1118
+ else:
1119
+ print("Aggregating scores using per-frame predictions")
1120
+ active_speakers = [0]*len(frames_scene)
1121
+ mid = num_avg_frames//2
1122
+
1123
+ if num_avg_frames%2==0:
1124
+ frame_pred = len(frames_scene)-(mid*2)+1
1125
+ start, end = mid, len(frames_scene)-mid+1
1126
+ else:
1127
+ frame_pred = len(frames_scene)-(mid*2)
1128
+ start, end = mid, len(frames_scene)-mid
1129
+
1130
+ if len(predictions) != frame_pred:
1131
+ msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
1132
+ return None, msg
1133
+
1134
+ active_speakers[start:end] = predictions[0:]
1135
+
1136
+ # Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output
1137
+ initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count)
1138
+ active_speakers[0:start] = [initial_preds] * start
1139
+
1140
+ final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count)
1141
+ active_speakers[end:] = [final_preds] * (len(frames_scene) - end)
1142
+ start, end = 0, len(active_speakers)
1143
+
1144
+ # Get the output tracks for each frame
1145
+ pred_idx = 0
1146
+ for frame in frames_scene[start:end]:
1147
+ label = active_speakers[pred_idx]
1148
+ pred_idx += 1
1149
+ output_tracks[frame] = track_dict[scene_num][label][frame]
1150
+
1151
+ # Save the output video
1152
+ video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
1153
+ if status != "success":
1154
+ return None, status
1155
+ print("Successfully saved the output video: ", video_output)
1156
+
1157
+ return video_output, "success"
1158
+
1159
+ except Exception as e:
1160
+ return None, f"Error: {str(e)}"
1161
 
 
1162
 
1163
+ if __name__ == "__main__":
1164
 
1165
+
1166
+ # Custom CSS and HTML
1167
  custom_css = """
1168
  <style>
1169
  body {
 
1213
  .examples-holder {
1214
  margin-top: 2em;
1215
  }
1216
+
1217
  /* Set fixed size for example videos */
1218
  .gradio-container .gradio-examples .gr-sample {
1219
  width: 240px !important;
 
1247
  custom_html = custom_css + """
1248
  <div class="header">
1249
  <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
1250
+ <h2>Synchronization and Active Speaker Detection Demo</h2>
1251
  <p>Sindhu Hegde and Andrew Zisserman</p>
1252
  <p>VGG, University of Oxford</p>
1253
  </div>
1254
  """
1255
 
1256
+ # Define functions
1257
+ def toggle_slider(global_speaker):
1258
+ if global_speaker == "per-frame-prediction":
1259
+ return gr.update(visible=True)
1260
+ else:
1261
+ return gr.update(visible=False)
1262
+
1263
+ def toggle_demo(demo_choice):
1264
+ if demo_choice == "Synchronization-correction":
1265
+ return (
1266
+ gr.update(value=None, visible=True), # video_input
1267
+ gr.update(value=75, visible=True), # num_avg_frames
1268
+ gr.update(value=None, visible=True), # apply_preprocess
1269
+ gr.update(value="global-prediction", visible=False), # global_speaker
1270
+ gr.update(value="", visible=True), # result_text
1271
+ gr.update(value=None, visible=True), # output_video
1272
+ gr.update(visible=True), # submit_button
1273
+ gr.update(visible=True), # clear_button
1274
+ gr.update(visible=True), # sync_examples
1275
+ gr.update(visible=False) # asd_examples
1276
+ )
1277
+ else:
1278
+ return (
1279
+ gr.update(value=None, visible=True), # video_input
1280
+ gr.update(value=75, visible=True), # num_avg_frames
1281
+ gr.update(value=None, visible=False), # apply_preprocess
1282
+ gr.update(value="global-prediction", visible=True), # global_speaker
1283
+ gr.update(value="", visible=True), # result_text
1284
+ gr.update(value=None, visible=True), # output_video
1285
+ gr.update(visible=True), # submit_button
1286
+ gr.update(visible=True), # clear_button
1287
+ gr.update(visible=False), # sync_examples
1288
+ gr.update(visible=True) # asd_examples
1289
+ )
1290
+
1291
+ def clear_inputs():
1292
+ return None, None, "global-prediction", 75, None, "", None
1293
+
1294
+ def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess):
1295
+ if demo_choice == "Synchronization-correction":
1296
+ return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess)
1297
+ else:
1298
+ return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
1299
+
1300
+
1301
  # Define paths to sample videos
1302
+ sync_sample_videos = [
1303
+ "samples/sync_sample_1.mp4",
1304
+ "samples/sync_sample_2.mp4",
1305
+ "samples/sync_sample_3.mp4"
1306
+ ]
1307
+
1308
+ asd_sample_videos = [
1309
+ "samples/asd_sample_1.mp4",
1310
+ "samples/asd_sample_2.mp4"
1311
+ ]
1312
+
1313
  # Define Gradio interface
1314
  with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
1315
  gr.HTML(custom_html)
1316
+ demo_choice = gr.Radio(
1317
+ choices=["Synchronization-correction", "Active-speaker-detection"],
1318
+ label="Please select the task you want to perform"
1319
+ )
1320
  with gr.Row():
1321
  with gr.Column():
1322
+ video_input = gr.Video(label="Upload Video", height=400, visible=False)
1323
+ num_avg_frames = gr.Slider(
1324
+ minimum=50,
1325
+ maximum=150,
1326
+ step=5,
1327
+ value=75,
1328
+ label="Number of Average Frames",
1329
+ visible=False
1330
+ )
1331
+ apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False)
1332
+ global_speaker = gr.Radio(
1333
+ choices=["global-prediction", "per-frame-prediction"],
1334
+ value="global-prediction",
1335
+ label="Global Speaker Prediction",
1336
+ visible=False
1337
+ )
1338
+ global_speaker.change(
1339
+ fn=toggle_slider,
1340
+ inputs=global_speaker,
1341
+ outputs=num_avg_frames
1342
+ )
1343
  with gr.Column():
1344
+ result_text = gr.Textbox(label="Result", visible=False)
1345
+ output_video = gr.Video(label="Output Video", height=400, visible=False)
1346
 
1347
  with gr.Row():
1348
+ submit_button = gr.Button("Submit", variant="primary", visible=False)
1349
+ clear_button = gr.Button("Clear", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
1350
 
1351
+ # Add a gap before examples
1352
  gr.HTML('<div class="examples-holder"></div>')
1353
 
1354
+ # Add examples that only populate the video input
1355
+ sync_examples = gr.Examples(
1356
+ examples=sync_sample_videos,
1357
+ inputs=video_input,
1358
+ outputs=None,
1359
+ fn=None,
1360
+ cache_examples=False,
1361
+ visible=False
1362
+ )
1363
+
1364
+ asd_examples = gr.Examples(
1365
+ examples=asd_sample_videos,
1366
  inputs=video_input,
1367
  outputs=None,
1368
  fn=None,
1369
  cache_examples=False,
1370
+ visible=False
1371
  )
1372
 
 
 
1373
 
1374
+ demo_choice.change(
1375
+ fn=toggle_demo,
1376
+ inputs=demo_choice,
1377
+ outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, result_text, output_video, submit_button, clear_button, sync_examples.dataset, asd_examples.dataset]
1378
+ )
1379
+
1380
+
1381
+ submit_button.click(
1382
+ fn=process_video,
1383
+ inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess],
1384
+ outputs=[result_text, output_video]
1385
+ )
1386
+
1387
+ clear_button.click(
1388
+ fn=clear_inputs,
1389
+ inputs=[],
1390
+ outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
1391
+ )
1392
  # Launch the interface
1393
+ demo.launch(allowed_paths=["."], server_name="0.0.0.0", server_port=7860, share=True)
app_v1.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os, subprocess
4
+ from shutil import rmtree
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import librosa
9
+ import torch
10
+
11
+ from utils.audio_utils import *
12
+ from utils.inference_utils import *
13
+ from sync_models.gestsync_models import *
14
+
15
+ import sys
16
+ if sys.version_info > (3, 0): long, unicode, basestring = int, str, str
17
+
18
+ from tqdm import tqdm
19
+ from scipy.io.wavfile import write
20
+ import mediapipe as mp
21
+ from protobuf_to_dict import protobuf_to_dict
22
+ mp_holistic = mp.solutions.holistic
23
+ from ultralytics import YOLO
24
+ from decord import VideoReader, cpu
25
+
26
+ import warnings
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
28
+ warnings.filterwarnings("ignore", category=UserWarning)
29
+
30
+ # Set the path to checkpoint file
31
+ CHECKPOINT_PATH = "model_rgb.pth"
32
+
33
+ # Initialize global variables
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ use_cuda = torch.cuda.is_available()
36
+ n_negative_samples = 100
37
+ print("Using CUDA: ", use_cuda, device)
38
+
39
+ def preprocess_video(path, result_folder, apply_preprocess, padding=20):
40
+
41
+ '''
42
+ This function preprocesses the input video to extract the audio and crop the frames using YOLO model
43
+
44
+ Args:
45
+ - path (string) : Path of the input video file
46
+ - result_folder (string) : Path of the folder to save the extracted audio and cropped video
47
+ - padding (int) : Padding to add to the bounding box
48
+ Returns:
49
+ - wav_file (string) : Path of the extracted audio file
50
+ - fps (int) : FPS of the input video
51
+ - video_output (string) : Path of the cropped video file
52
+ - msg (string) : Message to be returned
53
+ '''
54
+
55
+ # Load all video frames
56
+ try:
57
+ vr = VideoReader(path, ctx=cpu(0))
58
+ fps = vr.get_avg_fps()
59
+ frame_count = len(vr)
60
+ except:
61
+ msg = "Oops! Could not load the video. Please check the input video and try again."
62
+ return None, None, None, msg
63
+
64
+ if frame_count < 25:
65
+ msg = "Not enough frames to process! Please give a longer video as input"
66
+ return None, None, None, msg
67
+
68
+ # Extract the audio from the input video file using ffmpeg
69
+ wav_file = os.path.join(result_folder, "audio.wav")
70
+
71
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
72
+ -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
73
+
74
+ if status != 0:
75
+ msg = "Oops! Could not load the audio file. Please check the input video and try again."
76
+ return None, None, None, msg
77
+ print("Extracted the audio from the video")
78
+
79
+ if apply_preprocess=="True":
80
+ all_frames = []
81
+ for k in range(len(vr)):
82
+ all_frames.append(vr[k].asnumpy())
83
+ all_frames = np.asarray(all_frames)
84
+ print("Extracted the frames for pre-processing")
85
+
86
+ # Load YOLOv9 model (pre-trained on COCO dataset)
87
+ yolo_model = YOLO("yolov9s.pt")
88
+ print("Loaded the YOLO model")
89
+
90
+
91
+
92
+ person_videos = {}
93
+ person_tracks = {}
94
+
95
+ print("Processing the frames...")
96
+ for frame_idx in tqdm(range(frame_count)):
97
+
98
+ frame = all_frames[frame_idx]
99
+
100
+ # Perform person detection
101
+ results = yolo_model(frame, verbose=False)
102
+ detections = results[0].boxes
103
+
104
+ for i, det in enumerate(detections):
105
+ x1, y1, x2, y2 = det.xyxy[0]
106
+ cls = det.cls[0]
107
+ if int(cls) == 0: # Class 0 is 'person' in COCO dataset
108
+
109
+ x1 = max(0, int(x1) - padding)
110
+ y1 = max(0, int(y1) - padding)
111
+ x2 = min(frame.shape[1], int(x2) + padding)
112
+ y2 = min(frame.shape[0], int(y2) + padding)
113
+
114
+ if i not in person_videos:
115
+ person_videos[i] = []
116
+ person_tracks[i] = []
117
+
118
+ person_videos[i].append(frame)
119
+ person_tracks[i].append([x1,y1,x2,y2])
120
+
121
+
122
+ num_persons = 0
123
+ for i in person_videos.keys():
124
+ if len(person_videos[i]) >= frame_count//2:
125
+ num_persons+=1
126
+
127
+ if num_persons==0:
128
+ msg = "No person detected in the video! Please give a video with one person as input"
129
+ return None, None, None, msg
130
+ if num_persons>1:
131
+ msg = "More than one person detected in the video! Please give a video with only one person as input"
132
+ return None, None, None, msg
133
+
134
+
135
+
136
+ # For the person detected, crop the frame based on the bounding box
137
+ if len(person_videos[0]) > frame_count-10:
138
+ crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
139
+ fourcc = cv2.VideoWriter_fourcc(*'DIVX')
140
+
141
+ # Get bounding box coordinates based on person_tracks[i]
142
+ max_x1 = min([track[0] for track in person_tracks[0]])
143
+ max_y1 = min([track[1] for track in person_tracks[0]])
144
+ max_x2 = max([track[2] for track in person_tracks[0]])
145
+ max_y2 = max([track[3] for track in person_tracks[0]])
146
+
147
+ max_width = max_x2 - max_x1
148
+ max_height = max_y2 - max_y1
149
+
150
+ out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
151
+ for frame in person_videos[0]:
152
+ crop = frame[max_y1:max_y2, max_x1:max_x2]
153
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
154
+ out.write(crop)
155
+ out.release()
156
+
157
+ no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
158
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
159
+ if status != 0:
160
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
161
+ return None, None, None, msg
162
+
163
+ video_output = crop_filename.split('.')[0] + '.mp4'
164
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
165
+ (wav_file , no_sound_video, video_output), shell=True)
166
+ if status != 0:
167
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
168
+ return None, None, None, msg
169
+
170
+ os.remove(crop_filename)
171
+ os.remove(no_sound_video)
172
+
173
+ print("Successfully saved the pre-processed video: ", video_output)
174
+ else:
175
+ msg = "Could not track the person in the full video! Please give a single-speaker video as input"
176
+ return None, None, None, msg
177
+
178
+ else:
179
+ video_output = path
180
+
181
+ return wav_file, fps, video_output, "success"
182
+
183
+ def resample_video(video_file, video_fname, result_folder):
184
+
185
+ '''
186
+ This function resamples the video to 25 fps
187
+
188
+ Args:
189
+ - video_file (string) : Path of the input video file
190
+ - video_fname (string) : Name of the input video file
191
+ - result_folder (string) : Path of the folder to save the resampled video
192
+ Returns:
193
+ - video_file_25fps (string) : Path of the resampled video file
194
+ '''
195
+ video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
196
+
197
+ # Resample the video to 25 fps
198
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}".format(video_file, video_file_25fps))
199
+ from subprocess import call
200
+ cmd = command.split(' ')
201
+ print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
202
+ call(cmd)
203
+
204
+ return video_file_25fps
205
+
206
+ def load_checkpoint(path, model):
207
+ '''
208
+ This function loads the trained model from the checkpoint
209
+
210
+ Args:
211
+ - path (string) : Path of the checkpoint file
212
+ - model (object) : Model object
213
+ Returns:
214
+ - model (object) : Model object with the weights loaded from the checkpoint
215
+ '''
216
+
217
+ # Load the checkpoint
218
+ if use_cuda:
219
+ checkpoint = torch.load(path)
220
+ else:
221
+ checkpoint = torch.load(path, map_location="cpu")
222
+
223
+ s = checkpoint["state_dict"]
224
+ new_s = {}
225
+
226
+ for k, v in s.items():
227
+ new_s[k.replace('module.', '')] = v
228
+ model.load_state_dict(new_s)
229
+
230
+ if use_cuda:
231
+ model.cuda()
232
+
233
+ print("Loaded checkpoint from: {}".format(path))
234
+
235
+ return model.eval()
236
+
237
+
238
+ def load_video_frames(video_file):
239
+ '''
240
+ This function extracts the frames from the video
241
+
242
+ Args:
243
+ - video_file (string) : Path of the video file
244
+ Returns:
245
+ - frames (list) : List of frames extracted from the video
246
+ - msg (string) : Message to be returned
247
+ '''
248
+
249
+ # Read the video
250
+ try:
251
+ vr = VideoReader(video_file, ctx=cpu(0))
252
+ except:
253
+ msg = "Oops! Could not load the input video file"
254
+ return None, msg
255
+
256
+
257
+ # Extract the frames
258
+ frames = []
259
+ for k in range(len(vr)):
260
+ frames.append(vr[k].asnumpy())
261
+
262
+ frames = np.asarray(frames)
263
+
264
+ return frames, "success"
265
+
266
+
267
+
268
+ def get_keypoints(frames):
269
+
270
+ '''
271
+ This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
272
+
273
+ Args:
274
+ - frames (list) : List of frames extracted from the video
275
+ Returns:
276
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
277
+ - msg (string) : Message to be returned
278
+ '''
279
+
280
+ try:
281
+ holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
282
+
283
+ resolution = frames[0].shape
284
+ all_frame_kps = []
285
+
286
+ for frame in frames:
287
+
288
+ results = holistic.process(frame)
289
+
290
+ pose, left_hand, right_hand, face = None, None, None, None
291
+ if results.pose_landmarks is not None:
292
+ pose = protobuf_to_dict(results.pose_landmarks)['landmark']
293
+ if results.left_hand_landmarks is not None:
294
+ left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
295
+ if results.right_hand_landmarks is not None:
296
+ right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
297
+ if results.face_landmarks is not None:
298
+ face = protobuf_to_dict(results.face_landmarks)['landmark']
299
+
300
+ frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
301
+
302
+ all_frame_kps.append(frame_dict)
303
+
304
+ kp_dict = {"kps":all_frame_kps, "resolution":resolution}
305
+ except Exception as e:
306
+ print("Error: ", e)
307
+ return None, "Error: Could not extract keypoints from the frames"
308
+
309
+ return kp_dict, "success"
310
+
311
+
312
+ def check_visible_gestures(kp_dict):
313
+
314
+ '''
315
+ This function checks if the gestures in the video are visible
316
+
317
+ Args:
318
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
319
+ Returns:
320
+ - msg (string) : Message to be returned
321
+ '''
322
+
323
+ keypoints = kp_dict['kps']
324
+ keypoints = np.array(keypoints)
325
+
326
+ if len(keypoints)<25:
327
+ msg = "Not enough keypoints to process! Please give a longer video as input"
328
+ return msg
329
+
330
+ pose_count, hand_count = 0, 0
331
+ for frame_kp_dict in keypoints:
332
+
333
+ pose = frame_kp_dict["pose"]
334
+ left_hand = frame_kp_dict["left_hand"]
335
+ right_hand = frame_kp_dict["right_hand"]
336
+
337
+ if pose is None:
338
+ pose_count += 1
339
+
340
+ if left_hand is None and right_hand is None:
341
+ hand_count += 1
342
+
343
+
344
+ if hand_count/len(keypoints) > 0.7 or pose_count/len(keypoints) > 0.7:
345
+ msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
346
+ return msg
347
+
348
+ print("Successfully verified the input video - Gestures are visible!")
349
+
350
+ return "success"
351
+
352
+ def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, width=480, height=270):
353
+
354
+ '''
355
+ This function masks the faces using the keypoints extracted from the frames
356
+
357
+ Args:
358
+ - input_frames (list) : List of frames extracted from the video
359
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
360
+ - stride (int) : Stride to extract the frames
361
+ - window_frames (int) : Number of frames in each window that is given as input to the model
362
+ - width (int) : Width of the frames
363
+ - height (int) : Height of the frames
364
+ Returns:
365
+ - input_frames (array) : Frame window to be given as input to the model
366
+ - num_frames (int) : Number of frames to extract
367
+ - orig_masked_frames (array) : Masked frames extracted from the video
368
+ - msg (string) : Message to be returned
369
+ '''
370
+
371
+ # Face indices to extract the face-coordinates needed for masking
372
+ face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
373
+ 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
374
+
375
+
376
+ input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
377
+ print("Input keypoints: ", len(input_keypoints))
378
+
379
+ print("Creating masked input frames...")
380
+ input_frames_masked = []
381
+ for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):
382
+
383
+ img = input_frames[i]
384
+ face = frame_kp_dict["face"]
385
+
386
+ if face is None:
387
+ img = cv2.resize(img, (width, height))
388
+ masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
389
+ else:
390
+ face_kps = []
391
+ for idx in range(len(face)):
392
+ if idx in face_oval_idx:
393
+ x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
394
+ face_kps.append((x,y))
395
+
396
+ face_kps = np.array(face_kps)
397
+ x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
398
+ x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
399
+ masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
400
+
401
+ if masked_img.shape[0] != width or masked_img.shape[1] != height:
402
+ masked_img = cv2.resize(masked_img, (width, height))
403
+
404
+ input_frames_masked.append(masked_img)
405
+
406
+ orig_masked_frames = np.array(input_frames_masked)
407
+ input_frames = np.array(input_frames_masked) / 255.
408
+ print("Input images full: ", input_frames.shape) # num_framesx270x480x3
409
+
410
+ input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
411
+ print("Input images window: ", input_frames.shape) # Tx25x270x480x3
412
+
413
+ num_frames = input_frames.shape[0]
414
+
415
+ if num_frames<10:
416
+ msg = "Not enough frames to process! Please give a longer video as input."
417
+ return None, None, None, msg
418
+
419
+ return input_frames, num_frames, orig_masked_frames, "success"
420
+
421
+ def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
422
+
423
+ '''
424
+ This function extracts the spectrogram from the audio file
425
+
426
+ Args:
427
+ - wav_file (string) : Path of the extracted audio file
428
+ - num_frames (int) : Number of frames to extract
429
+ - window_frames (int) : Number of frames in each window that is given as input to the model
430
+ - stride (int) : Stride to extract the audio frames
431
+ Returns:
432
+ - spec (array) : Spectrogram array window to be used as input to the model
433
+ - orig_spec (array) : Spectrogram array extracted from the audio file
434
+ - msg (string) : Message to be returned
435
+ '''
436
+
437
+ # Extract the audio from the input video file using ffmpeg
438
+ try:
439
+ wav = librosa.load(wav_file, sr=16000)[0]
440
+ except:
441
+ msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
442
+ return None, None, msg
443
+
444
+ # Convert to tensor
445
+ wav = torch.FloatTensor(wav).unsqueeze(0)
446
+ mel, _, _, _ = wav2filterbanks(wav.to(device))
447
+ spec = mel.squeeze(0).cpu().numpy()
448
+ orig_spec = spec
449
+ spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
450
+
451
+ if len(spec) != num_frames:
452
+ spec = spec[:num_frames]
453
+ frame_diff = np.abs(len(spec) - num_frames)
454
+ if frame_diff > 60:
455
+ print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
456
+
457
+ return spec, orig_spec, "success"
458
+
459
+
460
+ def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
461
+ '''
462
+ This function calculates the audio-visual offset between the video and audio
463
+
464
+ Args:
465
+ - vid_emb (array) : Video embedding array
466
+ - aud_emb (array) : Audio embedding array
467
+ - num_avg_frames (int) : Number of frames to average the scores
468
+ - model (object) : Model object
469
+ Returns:
470
+ - offset (int) : Optimal audio-visual offset
471
+ - msg (string) : Message to be returned
472
+ '''
473
+
474
+ pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
475
+ if status != "success":
476
+ return None, status
477
+ scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
478
+ offset = scores.argmax()*stride - pos_idx
479
+
480
+ return offset.item(), "success"
481
+
482
+ def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
483
+
484
+ '''
485
+ This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
486
+
487
+ Args:
488
+ - vid_emb (array) : Video embedding array
489
+ - aud_emb (array) : Audio embedding array
490
+ - num_avg_frames (int) : Number of frames to average the scores
491
+ - stride (int) : Stride to extract the negative windows
492
+ Returns:
493
+ - vid_emb_pos (array) : Positive video embedding array
494
+ - aud_emb_posneg (array) : All possible combinations of audio embedding array
495
+ - pos_idx_frame (int) : Positive video embedding array frame
496
+ - stride (int) : Stride used to extract the negative windows
497
+ - msg (string) : Message to be returned
498
+ '''
499
+
500
+ slice_size = num_avg_frames
501
+ aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
502
+ aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
503
+ aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
504
+
505
+ pos_idx = (aud_emb_posneg.shape[1]//2)
506
+ pos_idx_frame = pos_idx*stride
507
+
508
+ min_offset_frames = -(pos_idx)*stride
509
+ max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
510
+ print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
511
+
512
+ vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
513
+ if vid_emb_pos.shape[2] != slice_size:
514
+ msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
515
+ return None, None, None, None, msg
516
+
517
+ return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
518
+
519
+ def calc_av_scores(vid_emb, aud_emb, model):
520
+
521
+ '''
522
+ This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
523
+
524
+ Args:
525
+ - vid_emb (array) : Video embedding array
526
+ - aud_emb (array) : Audio embedding array
527
+ - model (object) : Model object
528
+ Returns:
529
+ - scores (array) : Audio-visual similarity scores
530
+ - att_map (array) : Attention map
531
+ '''
532
+
533
+ scores = calc_att_map(vid_emb, aud_emb, model)
534
+ att_map = logsoftmax_2d(scores)
535
+ scores = scores.mean(-1)
536
+
537
+ return scores, att_map
538
+
539
+ def calc_att_map(vid_emb, aud_emb, model):
540
+
541
+ '''
542
+ This function calculates the similarity between the video and audio embeddings
543
+
544
+ Args:
545
+ - vid_emb (array) : Video embedding array
546
+ - aud_emb (array) : Audio embedding array
547
+ - model (object) : Model object
548
+ Returns:
549
+ - scores (array) : Audio-visual similarity scores
550
+ '''
551
+
552
+ vid_emb = vid_emb[:, :, None]
553
+ aud_emb = aud_emb.transpose(1, 2)
554
+
555
+ scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
556
+ vid_emb,
557
+ aud_emb,
558
+ part_len=10,
559
+ dim=3,
560
+ device=device)
561
+
562
+ scores = model.logits_scale(scores[..., None]).squeeze(-1)
563
+
564
+ return scores
565
+
566
+ def generate_video(frames, audio_file, video_fname):
567
+
568
+ '''
569
+ This function generates the video from the frames and audio file
570
+
571
+ Args:
572
+ - frames (array) : Frames to be used to generate the video
573
+ - audio_file (string) : Path of the audio file
574
+ - video_fname (string) : Path of the video file
575
+ Returns:
576
+ - video_output (string) : Path of the video file
577
+ '''
578
+
579
+ fname = 'inference.avi'
580
+ video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
581
+
582
+ for i in range(len(frames)):
583
+ video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
584
+ video.release()
585
+
586
+ no_sound_video = video_fname + '_nosound.mp4'
587
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
588
+ if status != 0:
589
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
590
+ return None, msg
591
+
592
+ video_output = video_fname + '.mp4'
593
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' %
594
+ (audio_file, no_sound_video, video_output), shell=True)
595
+ if status != 0:
596
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
597
+ return None, msg
598
+
599
+ os.remove(fname)
600
+ os.remove(no_sound_video)
601
+
602
+ return video_output
603
+
604
+ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
605
+
606
+ '''
607
+ This function corrects the video and audio to sync with each other
608
+
609
+ Args:
610
+ - video_path (string) : Path of the video file
611
+ - frames (array) : Frames to be used to generate the video
612
+ - wav_file (string) : Path of the audio file
613
+ - offset (int) : Predicted sync-offset to be used to correct the video
614
+ - result_folder (string) : Path of the result folder to save the output sync-corrected video
615
+ - sample_rate (int) : Sample rate of the audio
616
+ - fps (int) : Frames per second of the video
617
+ Returns:
618
+ - video_output (string) : Path of the video file
619
+ '''
620
+
621
+ if offset == 0:
622
+ print("The input audio and video are in-sync! No need to perform sync correction.")
623
+ return video_path
624
+
625
+ print("Performing Sync Correction...")
626
+ corrected_frames = np.zeros_like(frames)
627
+ if offset > 0:
628
+ audio_offset = int(offset*(sample_rate/fps))
629
+ wav = librosa.core.load(wav_file, sr=sample_rate)[0]
630
+ corrected_wav = wav[audio_offset:]
631
+ corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
632
+ write(corrected_wav_file, sample_rate, corrected_wav)
633
+ wav_file = corrected_wav_file
634
+ corrected_frames = frames
635
+ elif offset < 0:
636
+ corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
637
+ corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
638
+
639
+ corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
640
+ video_output = generate_video(corrected_frames, wav_file, corrected_video_path)
641
+
642
+ return video_output
643
+
644
+ class Logger:
645
+ def __init__(self, filename):
646
+ self.terminal = sys.stdout
647
+ self.log = open(filename, "w")
648
+
649
+ def write(self, message):
650
+ self.terminal.write(message)
651
+ self.log.write(message)
652
+
653
+ def flush(self):
654
+ self.terminal.flush()
655
+ self.log.flush()
656
+
657
+ def isatty(self):
658
+ return False
659
+
660
+
661
+ def process_video(video_path, num_avg_frames, apply_preprocess):
662
+ try:
663
+ # Extract the video filename
664
+ video_fname = os.path.basename(video_path.split(".")[0])
665
+
666
+ # Create folders to save the inputs and results
667
+ result_folder = os.path.join("results", video_fname)
668
+ result_folder_input = os.path.join(result_folder, "input")
669
+ result_folder_output = os.path.join(result_folder, "output")
670
+
671
+ if os.path.exists(result_folder):
672
+ rmtree(result_folder)
673
+
674
+ os.makedirs(result_folder)
675
+ os.makedirs(result_folder_input)
676
+ os.makedirs(result_folder_output)
677
+
678
+
679
+ # Preprocess the video
680
+ print("Applying preprocessing: ", apply_preprocess)
681
+ wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
682
+ if status != "success":
683
+ return status, None
684
+ print("Successfully preprocessed the video")
685
+
686
+ # Resample the video to 25 fps if it is not already 25 fps
687
+ print("FPS of video: ", fps)
688
+ if fps!=25:
689
+ vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
690
+ orig_vid_path_25fps = resample_video(video_path, "input_video_25fps", result_folder_input)
691
+ else:
692
+ vid_path = vid_path_processed
693
+ orig_vid_path_25fps = video_path
694
+
695
+ # Load the original video frames (before pre-processing) - Needed for the final sync-correction
696
+ orig_frames, status = load_video_frames(orig_vid_path_25fps)
697
+ if status != "success":
698
+ return status, None
699
+
700
+ # Load the pre-processed video frames
701
+ frames, status = load_video_frames(vid_path)
702
+ if status != "success":
703
+ return status, None
704
+ print("Successfully extracted the video frames")
705
+
706
+ if len(frames) < num_avg_frames:
707
+ return "Error: The input video is too short. Please use a longer input video.", None
708
+
709
+ # Load keypoints and check if gestures are visible
710
+ kp_dict, status = get_keypoints(frames)
711
+ if status != "success":
712
+ return status, None
713
+ print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))
714
+
715
+ status = check_visible_gestures(kp_dict)
716
+ if status != "success":
717
+ return status, None
718
+
719
+ # Load RGB frames
720
+ rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, window_frames=25, width=480, height=270)
721
+ if status != "success":
722
+ return status, None
723
+ print("Successfully loaded the RGB frames")
724
+
725
+ # Convert frames to tensor
726
+ rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
727
+ rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
728
+ B = rgb_frames.size(0)
729
+ print("Successfully converted the frames to tensor")
730
+
731
+ # Load spectrograms
732
+ spec, orig_spec, status = load_spectrograms(wav_file, num_frames, window_frames=25)
733
+ if status != "success":
734
+ return status, None
735
+ spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
736
+ print("Successfully loaded the spectrograms")
737
+
738
+ # Create input windows
739
+ video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
740
+ audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
741
+
742
+ # Load the trained model
743
+ model = Transformer_RGB()
744
+ model = load_checkpoint(CHECKPOINT_PATH, model)
745
+ print("Successfully loaded the model")
746
+
747
+ # Process in batches
748
+ batch_size = 12
749
+ video_emb = []
750
+ audio_emb = []
751
+
752
+ for i in tqdm(range(0, len(video_sequences), batch_size)):
753
+ video_inp = video_sequences[i:i+batch_size, ]
754
+ audio_inp = audio_sequences[i:i+batch_size, ]
755
+
756
+ vid_emb = model.forward_vid(video_inp.to(device))
757
+ vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
758
+ aud_emb = model.forward_aud(audio_inp.to(device))
759
+
760
+ video_emb.append(vid_emb.detach())
761
+ audio_emb.append(aud_emb.detach())
762
+
763
+ torch.cuda.empty_cache()
764
+
765
+ audio_emb = torch.cat(audio_emb, dim=0)
766
+ video_emb = torch.cat(video_emb, dim=0)
767
+
768
+ # L2 normalize embeddings
769
+ video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
770
+ audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
771
+
772
+ audio_emb = torch.split(audio_emb, B, dim=0)
773
+ audio_emb = torch.stack(audio_emb, dim=2)
774
+ audio_emb = audio_emb.squeeze(3)
775
+ audio_emb = audio_emb[:, None]
776
+
777
+ video_emb = torch.split(video_emb, B, dim=0)
778
+ video_emb = torch.stack(video_emb, dim=2)
779
+ video_emb = video_emb.squeeze(3)
780
+ print("Successfully extracted GestSync embeddings")
781
+
782
+ # Calculate sync offset
783
+ pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
784
+ if status != "success":
785
+ return status, None
786
+ print("Predicted offset: ", pred_offset)
787
+
788
+ # Generate sync-corrected video
789
+ video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
790
+ print("Successfully generated the video:", video_output)
791
+
792
+ return f"Predicted offset: {pred_offset}", video_output
793
+
794
+ except Exception as e:
795
+ return f"Error: {str(e)}", None
796
+
797
+ def read_logs():
798
+ sys.stdout.flush()
799
+ with open("output.log", "r") as f:
800
+ return f.read()
801
+
802
+
803
+ if __name__ == "__main__":
804
+
805
+ sys.stdout = Logger("output.log")
806
+
807
+
808
+ # Define the custom HTML for the header
809
+ custom_css = """
810
+ <style>
811
+ body {
812
+ background-color: #ffffff;
813
+ color: #333333; /* Default text color */
814
+ }
815
+ .container {
816
+ max-width: 100% !important;
817
+ padding-left: 0 !important;
818
+ padding-right: 0 !important;
819
+ }
820
+ .header {
821
+ background-color: #f0f0f0;
822
+ color: #333333;
823
+ padding: 30px;
824
+ margin-bottom: 30px;
825
+ text-align: center;
826
+ font-family: 'Helvetica Neue', Arial, sans-serif;
827
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
828
+ }
829
+ .header h1 {
830
+ font-size: 36px;
831
+ margin-bottom: 15px;
832
+ font-weight: bold;
833
+ color: #333333; /* Explicitly set heading color */
834
+ }
835
+ .header h2 {
836
+ font-size: 24px;
837
+ margin-bottom: 10px;
838
+ color: #333333; /* Explicitly set subheading color */
839
+ }
840
+ .header p {
841
+ font-size: 18px;
842
+ margin: 5px 0;
843
+ color: #666666;
844
+ }
845
+ .blue-text {
846
+ color: #4a90e2;
847
+ }
848
+ /* Custom styles for slider container */
849
+ .slider-container {
850
+ background-color: white !important;
851
+ padding-top: 0.9em;
852
+ padding-bottom: 0.9em;
853
+ }
854
+ /* Add gap before examples */
855
+ .examples-holder {
856
+ margin-top: 2em;
857
+ }
858
+ /* Set fixed size for example videos */
859
+ .gradio-container .gradio-examples .gr-sample {
860
+ width: 240px !important;
861
+ height: 135px !important;
862
+ object-fit: cover;
863
+ display: inline-block;
864
+ margin-right: 10px;
865
+ }
866
+
867
+ .gradio-container .gradio-examples {
868
+ display: flex;
869
+ flex-wrap: wrap;
870
+ gap: 10px;
871
+ }
872
+
873
+ /* Ensure the parent container does not stretch */
874
+ .gradio-container .gradio-examples {
875
+ max-width: 100%;
876
+ overflow: hidden;
877
+ }
878
+
879
+ /* Additional styles to ensure proper sizing in Safari */
880
+ .gradio-container .gradio-examples .gr-sample img {
881
+ width: 240px !important;
882
+ height: 135px !important;
883
+ object-fit: cover;
884
+ }
885
+ </style>
886
+ """
887
+
888
+ custom_html = custom_css + """
889
+ <div class="header">
890
+ <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
891
+ <h2>Upload any video to predict the synchronization offset and generate a sync-corrected video</h2>
892
+ <p>Sindhu Hegde and Andrew Zisserman</p>
893
+ <p>VGG, University of Oxford</p>
894
+ </div>
895
+ """
896
+
897
+ # Define paths to sample videos
898
+ sample_videos = [
899
+ "samples/sync_sample_1.mp4",
900
+ "samples/sync_sample_2.mp4",
901
+ ]
902
+
903
+ # Define Gradio interface
904
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
905
+ gr.HTML(custom_html)
906
+ with gr.Row():
907
+ with gr.Column():
908
+ with gr.Group(elem_classes="slider-container"):
909
+ num_avg_frames = gr.Slider(
910
+ minimum=50,
911
+ maximum=150,
912
+ step=5,
913
+ value=75,
914
+ label="Number of Average Frames",
915
+ )
916
+ apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False)
917
+ video_input = gr.Video(label="Upload Video", height=400)
918
+
919
+ with gr.Column():
920
+ result_text = gr.Textbox(label="Result")
921
+ output_video = gr.Video(label="Sync Corrected Video", height=400)
922
+
923
+ with gr.Row():
924
+ submit_button = gr.Button("Submit", variant="primary")
925
+ clear_button = gr.Button("Clear")
926
+
927
+ submit_button.click(
928
+ fn=process_video,
929
+ inputs=[video_input, num_avg_frames, apply_preprocess],
930
+ outputs=[result_text, output_video]
931
+ )
932
+
933
+ clear_button.click(
934
+ fn=lambda: (None, 75, False, "", None),
935
+ inputs=[],
936
+ outputs=[video_input, num_avg_frames, apply_preprocess, result_text, output_video]
937
+ )
938
+
939
+ gr.HTML('<div class="examples-holder"></div>')
940
+
941
+ # Add examples
942
+ gr.Examples(
943
+ examples=sample_videos,
944
+ inputs=video_input,
945
+ outputs=None,
946
+ fn=None,
947
+ cache_examples=False,
948
+ )
949
+
950
+ logs = gr.Textbox(label="Logs")
951
+ demo.load(read_logs, None, logs, every=1)
952
+
953
+ # Launch the interface
954
+ demo.queue().launch(allowed_paths=["."], show_error=True)
preprocess/inference_preprocess.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+
3
+ import sys, os, argparse, pickle, subprocess, cv2, math
4
+ import numpy as np
5
+ from shutil import rmtree, copy, copytree
6
+ from tqdm import tqdm
7
+
8
+ import scenedetect
9
+ from scenedetect.video_manager import VideoManager
10
+ from scenedetect.scene_manager import SceneManager
11
+ from scenedetect.stats_manager import StatsManager
12
+ from scenedetect.detectors import ContentDetector
13
+
14
+ from scipy.interpolate import interp1d
15
+ from scipy import signal
16
+
17
+ from ultralytics import YOLO
18
+
19
+ from decord import VideoReader
20
+
21
+ parser = argparse.ArgumentParser(description="FaceTracker")
22
+ parser.add_argument('--data_dir', type=str, help='directory to save intermediate temp results')
23
+ parser.add_argument('--facedet_scale', type=float, default=0.25, help='Scale factor for face detection')
24
+ parser.add_argument('--crop_scale', type=float, default=0, help='Scale bounding box')
25
+ parser.add_argument('--min_track', type=int, default=50, help='Minimum facetrack duration')
26
+ parser.add_argument('--frame_rate', type=int, default=25, help='Frame rate')
27
+ parser.add_argument('--num_failed_det', type=int, default=25, help='Number of missed detections allowed before tracking is stopped')
28
+ parser.add_argument('--min_frame_size', type=int, default=64, help='Minimum frame size in pixels')
29
+ parser.add_argument('--sd_root', type=str, required=True, help='Path to save crops')
30
+ parser.add_argument('--work_root', type=str, required=True, help='Path to save metadata files')
31
+ parser.add_argument('--data_root', type=str, required=True, help='Directory containing ONLY full uncropped videos')
32
+ opt = parser.parse_args()
33
+
34
+
35
+ def bb_intersection_over_union(boxA, boxB):
36
+ xA = max(boxA[0], boxB[0])
37
+ yA = max(boxA[1], boxB[1])
38
+ xB = min(boxA[2], boxB[2])
39
+ yB = min(boxB[3], boxB[3])
40
+
41
+ interArea = max(0, xB - xA) * max(0, yB - yA)
42
+
43
+ boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
44
+ boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
45
+
46
+ iou = interArea / float(boxAArea + boxBArea - interArea)
47
+
48
+ return iou
49
+
50
+ def track_shot(opt, scenefaces):
51
+ print("Tracking video...")
52
+ iouThres = 0.5 # Minimum IOU between consecutive face detections
53
+ tracks = []
54
+
55
+ while True:
56
+ track = []
57
+ for framefaces in scenefaces:
58
+ for face in framefaces:
59
+ if track == []:
60
+ track.append(face)
61
+ framefaces.remove(face)
62
+ elif face['frame'] - track[-1]['frame'] <= opt.num_failed_det:
63
+ iou = bb_intersection_over_union(face['bbox'], track[-1]['bbox'])
64
+ if iou > iouThres:
65
+ track.append(face)
66
+ framefaces.remove(face)
67
+ continue
68
+ else:
69
+ break
70
+
71
+ if track == []:
72
+ break
73
+ elif len(track) > opt.min_track:
74
+ framenum = np.array([f['frame'] for f in track])
75
+ bboxes = np.array([np.array(f['bbox']) for f in track])
76
+
77
+ frame_i = np.arange(framenum[0], framenum[-1] + 1)
78
+
79
+ bboxes_i = []
80
+ for ij in range(0, 4):
81
+ interpfn = interp1d(framenum, bboxes[:, ij])
82
+ bboxes_i.append(interpfn(frame_i))
83
+ bboxes_i = np.stack(bboxes_i, axis=1)
84
+
85
+ if max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1])) > opt.min_frame_size:
86
+ tracks.append({'frame': frame_i, 'bbox': bboxes_i})
87
+
88
+ return tracks
89
+
90
+ def check_folder(folder):
91
+ if os.path.exists(folder):
92
+ return True
93
+ return False
94
+
95
+ def del_folder(folder):
96
+ if os.path.exists(folder):
97
+ rmtree(folder)
98
+
99
+ def read_video(o, start_idx):
100
+ with open(o, 'rb') as o:
101
+ video_stream = VideoReader(o)
102
+ if start_idx > 0:
103
+ video_stream.skip_frames(start_idx)
104
+ return video_stream
105
+
106
+ def crop_video(opt, track, cropfile, tight_scale=1):
107
+ print("Cropping video...")
108
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
109
+ vOut = cv2.VideoWriter(cropfile + '.avi', fourcc, opt.frame_rate, (480, 270))
110
+
111
+ dets = {'x': [], 'y': [], 's': [], 'bbox': track['bbox'], 'frame': track['frame']}
112
+
113
+ for det in track['bbox']:
114
+ # Reduce the size of the bounding box by a small factor if tighter crops are needed (default -> no reduction in size)
115
+ width = (det[2] - det[0]) * tight_scale
116
+ height = (det[3] - det[1]) * tight_scale
117
+ center_x = (det[0] + det[2]) / 2
118
+ center_y = (det[1] + det[3]) / 2
119
+
120
+ dets['s'].append(max(height, width) / 2)
121
+ dets['y'].append(center_y) # crop center y
122
+ dets['x'].append(center_x) # crop center x
123
+
124
+ # Smooth detections
125
+ dets['s'] = signal.medfilt(dets['s'], kernel_size=13)
126
+ dets['x'] = signal.medfilt(dets['x'], kernel_size=13)
127
+ dets['y'] = signal.medfilt(dets['y'], kernel_size=13)
128
+
129
+ videofile = os.path.join(opt.avi_dir, 'video.avi')
130
+ frame_no_to_start = track['frame'][0]
131
+ video_stream = cv2.VideoCapture(videofile)
132
+ video_stream.set(cv2.CAP_PROP_POS_FRAMES, frame_no_to_start)
133
+ for fidx, frame in enumerate(track['frame']):
134
+ cs = opt.crop_scale
135
+ bs = dets['s'][fidx] # Detection box size
136
+ bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount
137
+
138
+ image = video_stream.read()[1]
139
+ frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), 'constant', constant_values=(110, 110))
140
+
141
+ my = dets['y'][fidx] + bsi # BBox center Y
142
+ mx = dets['x'][fidx] + bsi # BBox center X
143
+
144
+ face = frame[int(my - bs):int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)):int(mx + bs * (1 + cs))]
145
+ vOut.write(cv2.resize(face, (480, 270)))
146
+ video_stream.release()
147
+ audiotmp = os.path.join(opt.tmp_dir, 'audio.wav')
148
+ audiostart = (track['frame'][0]) / opt.frame_rate
149
+ audioend = (track['frame'][-1] + 1) / opt.frame_rate
150
+
151
+ vOut.release()
152
+
153
+ # ========== CROP AUDIO FILE ==========
154
+
155
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ss %.3f -to %.3f %s" % (os.path.join(opt.avi_dir, 'audio.wav'), audiostart, audioend, audiotmp))
156
+ output = subprocess.call(command, shell=True, stdout=None)
157
+
158
+ copy(audiotmp, cropfile + '.wav')
159
+
160
+ # print('Written %s' % cropfile)
161
+ # print('Mean pos: x %.2f y %.2f s %.2f' % (np.mean(dets['x']), np.mean(dets['y']), np.mean(dets['s'])))
162
+
163
+ return {'track': track, 'proc_track': dets}
164
+
165
+ def inference_video(opt, padding=0):
166
+ videofile = os.path.join(opt.avi_dir, 'video.avi')
167
+ vidObj = cv2.VideoCapture(videofile)
168
+ yolo_model = YOLO("yolov9s.pt")
169
+
170
+ dets = []
171
+ fidx = 0
172
+ print("Detecting people in the video using YOLO...")
173
+ while True:
174
+ success, image = vidObj.read()
175
+ if not success:
176
+ break
177
+
178
+ image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
179
+
180
+ # Perform person detection
181
+ results = yolo_model(image_np, verbose=False)
182
+ detections = results[0].boxes
183
+
184
+ dets.append([])
185
+ for i, det in enumerate(detections):
186
+ x1, y1, x2, y2 = det.xyxy[0].detach().cpu().numpy()
187
+ cls = det.cls[0].detach().cpu().numpy()
188
+ conf = det.conf[0].detach().cpu().numpy()
189
+ if int(cls) == 0 and conf>0.7: # Class 0 is 'person' in COCO dataset
190
+ x1 = max(0, int(x1) - padding)
191
+ y1 = max(0, int(y1) - padding)
192
+ x2 = min(image_np.shape[1], int(x2) + padding)
193
+ y2 = min(image_np.shape[0], int(y2) + padding)
194
+ dets[-1].append({'frame': fidx, 'bbox': [x1, y1, x2, y2], 'conf': conf})
195
+
196
+ fidx += 1
197
+
198
+ savepath = os.path.join(opt.work_dir, 'faces.pckl')
199
+
200
+ with open(savepath, 'wb') as fil:
201
+ pickle.dump(dets, fil)
202
+
203
+ return dets
204
+
205
+ def scene_detect(opt):
206
+ print("Detecting scenes in the video...")
207
+ video_manager = VideoManager([os.path.join(opt.avi_dir, 'video.avi')])
208
+ stats_manager = StatsManager()
209
+ scene_manager = SceneManager(stats_manager)
210
+ scene_manager.add_detector(ContentDetector())
211
+ base_timecode = video_manager.get_base_timecode()
212
+
213
+ video_manager.set_downscale_factor()
214
+ video_manager.start()
215
+ scene_manager.detect_scenes(frame_source=video_manager)
216
+ scene_list = scene_manager.get_scene_list(base_timecode)
217
+
218
+ savepath = os.path.join(opt.work_dir, 'scene.pckl')
219
+
220
+ if scene_list == []:
221
+ scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())]
222
+
223
+ with open(savepath, 'wb') as fil:
224
+ pickle.dump(scene_list, fil)
225
+
226
+ # print('%s - scenes detected %d' % (os.path.join(opt.avi_dir, 'video.avi'), len(scene_list)))
227
+
228
+ return scene_list
229
+
230
+ def process_video(file):
231
+
232
+ video_file_name = os.path.basename(file.strip())
233
+ sd_dest_folder = opt.sd_root
234
+ work_dest_folder = opt.work_root
235
+
236
+
237
+ del_folder(sd_dest_folder)
238
+ del_folder(work_dest_folder)
239
+
240
+ setattr(opt, 'videofile', file)
241
+
242
+ if os.path.exists(opt.work_dir):
243
+ rmtree(opt.work_dir)
244
+
245
+ if os.path.exists(opt.crop_dir):
246
+ rmtree(opt.crop_dir)
247
+
248
+ if os.path.exists(opt.avi_dir):
249
+ rmtree(opt.avi_dir)
250
+
251
+ if os.path.exists(opt.frames_dir):
252
+ rmtree(opt.frames_dir)
253
+
254
+ if os.path.exists(opt.tmp_dir):
255
+ rmtree(opt.tmp_dir)
256
+
257
+ os.makedirs(opt.work_dir)
258
+ os.makedirs(opt.crop_dir)
259
+ os.makedirs(opt.avi_dir)
260
+ os.makedirs(opt.frames_dir)
261
+ os.makedirs(opt.tmp_dir)
262
+
263
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -qscale:v 2 -async 1 -r 25 %s" % (opt.videofile,
264
+ os.path.join(opt.avi_dir,
265
+ 'video.avi')))
266
+ output = subprocess.call(command, shell=True, stdout=None)
267
+ if output != 0:
268
+ return
269
+
270
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (os.path.join(opt.avi_dir,
271
+ 'video.avi'),
272
+ os.path.join(opt.avi_dir,
273
+ 'audio.wav')))
274
+ output = subprocess.call(command, shell=True, stdout=None)
275
+ if output != 0:
276
+ return
277
+
278
+ faces = inference_video(opt)
279
+
280
+ try:
281
+ scene = scene_detect(opt)
282
+ except scenedetect.video_stream.VideoOpenFailure:
283
+ return
284
+
285
+
286
+ allscenes = []
287
+ for shot in scene:
288
+ if shot[1].frame_num - shot[0].frame_num >= opt.min_track:
289
+ allscenes.append(track_shot(opt, faces[shot[0].frame_num:shot[1].frame_num]))
290
+
291
+ alltracks = []
292
+ for sc_num in range(len(allscenes)):
293
+ vidtracks = []
294
+ for ii, track in enumerate(allscenes[sc_num]):
295
+ os.makedirs(os.path.join(opt.crop_dir, 'scene_'+str(sc_num)), exist_ok=True)
296
+ vidtracks.append(crop_video(opt, track, os.path.join(opt.crop_dir, 'scene_'+str(sc_num), '%05d' % ii)))
297
+ alltracks.append(vidtracks)
298
+
299
+ savepath = os.path.join(opt.work_dir, 'tracks.pckl')
300
+
301
+ with open(savepath, 'wb') as fil:
302
+ pickle.dump(alltracks, fil)
303
+
304
+ rmtree(opt.tmp_dir)
305
+ rmtree(opt.avi_dir)
306
+ rmtree(opt.frames_dir)
307
+ copytree(opt.crop_dir, sd_dest_folder)
308
+ copytree(opt.work_dir, work_dest_folder)
309
+
310
+
311
+ if __name__ == "__main__":
312
+
313
+ file = opt.data_root
314
+
315
+ os.makedirs(opt.sd_root, exist_ok=True)
316
+ os.makedirs(opt.work_root, exist_ok=True)
317
+
318
+
319
+ setattr(opt, 'avi_dir', os.path.join(opt.data_dir, 'pyavi'))
320
+ setattr(opt, 'tmp_dir', os.path.join(opt.data_dir, 'pytmp'))
321
+ setattr(opt, 'work_dir', os.path.join(opt.data_dir, 'pywork'))
322
+ setattr(opt, 'crop_dir', os.path.join(opt.data_dir, 'pycrop'))
323
+ setattr(opt, 'frames_dir', os.path.join(opt.data_dir, 'pyframes'))
324
+
325
+ process_video(file)
326
+
yolov9c.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:876eb84f515d40c34a3b111f8fc1077d3aee59d3a243afd1cc5b77d520f237c7
3
- size 51794840