sindhuhegde commited on
Commit
4ad47a9
1 Parent(s): 8f3cd14

Update app

Browse files
Files changed (2) hide show
  1. app.py +105 -78
  2. preprocess/inference_preprocess.py +1 -1
app.py CHANGED
@@ -16,6 +16,7 @@ from utils.inference_utils import *
16
  from sync_models.gestsync_models import *
17
  from tqdm import tqdm
18
  from glob import glob
 
19
  import mediapipe as mp
20
  from protobuf_to_dict import protobuf_to_dict
21
  import warnings
@@ -25,7 +26,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
25
  warnings.filterwarnings("ignore", category=UserWarning)
26
 
27
  # Initialize global variables
28
- CHECKPOINT_PATH = "model_rgb.pth"
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  use_cuda = torch.cuda.is_available()
31
  batch_size = 12
@@ -195,13 +196,14 @@ def resample_video(video_file, video_fname, result_folder):
195
  video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
196
 
197
  # Resample the video to 25 fps
198
- command = ("ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}".format(video_file, video_file_25fps))
199
- from subprocess import call
200
- cmd = command.split(' ')
 
 
201
  print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
202
- call(cmd)
203
 
204
- return video_file_25fps
205
 
206
  def load_checkpoint(path, model):
207
  '''
@@ -418,7 +420,7 @@ def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, wi
418
 
419
  return input_frames, num_frames, orig_masked_frames, "success"
420
 
421
- def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
422
 
423
  '''
424
  This function extracts the spectrogram from the audio file
@@ -448,8 +450,9 @@ def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
448
  orig_spec = spec
449
  spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
450
 
451
- if len(spec) != num_frames:
452
- spec = spec[:num_frames]
 
453
  frame_diff = np.abs(len(spec) - num_frames)
454
  if frame_diff > 60:
455
  print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
@@ -590,8 +593,9 @@ def generate_video(frames, audio_file, video_fname):
590
  return None, msg
591
 
592
  video_output = video_fname + '.mp4'
593
- status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' %
594
- (audio_file, no_sound_video, video_output), shell=True)
 
595
  if status != 0:
596
  msg = "Oops! Could not generate the video. Please check the input video and try again."
597
  return None, msg
@@ -599,7 +603,7 @@ def generate_video(frames, audio_file, video_fname):
599
  os.remove(fname)
600
  os.remove(no_sound_video)
601
 
602
- return video_output
603
 
604
  def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
605
 
@@ -637,9 +641,11 @@ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, samp
637
  corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
638
 
639
  corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
640
- video_output = generate_video(corrected_frames, wav_file, corrected_video_path)
 
 
641
 
642
- return video_output
643
 
644
 
645
  def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
@@ -661,23 +667,26 @@ def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_fold
661
  all_frames, all_orig_frames = [], []
662
  for video_num, video in enumerate(test_videos):
663
 
 
 
664
  # Load the video frames
665
  frames, status = load_video_frames(video)
666
  if status != "success":
667
  return None, None, status
 
668
 
669
  # Extract the keypoints from the frames
670
  kp_dict, status = get_keypoints(frames)
671
  if status != "success":
672
  return None, None, status
 
673
 
674
  # Mask the frames using the keypoints extracted from the frames and prepare the input to the model
675
  masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict)
676
  if status != "success":
677
  return None, None, status
 
678
 
679
- input_masked_vid_path = os.path.join(result_folder, "input_masked_scene_{}_speaker_{}".format(scene_num, video_num))
680
- generate_video(orig_masked_frames, wav_file, input_masked_vid_path)
681
 
682
  # Check if the length of the input frames is equal to the length of the spectrogram
683
  if spec.shape[2]!=masked_frames.shape[0]:
@@ -691,6 +700,7 @@ def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_fold
691
  # Transpose the frames to the correct format
692
  frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
693
  frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
 
694
 
695
  all_frames.append(frames)
696
  all_orig_frames.append(orig_masked_frames)
@@ -830,24 +840,29 @@ def save_video(output_tracks, input_frames, wav_file, result_folder):
830
  - video_output (string) : Path of the output video
831
  '''
832
 
833
- output_frames = []
834
- for i in range(len(input_frames)):
835
-
836
- # If the active speaker is found, draw a bounding box around the active speaker
837
- if i in output_tracks:
838
- bbox = output_tracks[i]
839
- x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
840
- out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
841
- else:
842
- out = input_frames[i]
 
843
 
844
- output_frames.append(out)
845
 
846
- # Generate the output video
847
- output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
848
- video_output = generate_video(output_frames, wav_file, output_video_fname)
 
 
 
 
849
 
850
- return video_output
851
 
852
  def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
853
 
@@ -878,8 +893,12 @@ def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
878
  # Resample the video to 25 fps if it is not already 25 fps
879
  print("FPS of video: ", fps)
880
  if fps!=25:
881
- vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
882
- orig_vid_path_25fps = resample_video(video_path, "input_video_25fps", result_folder_input)
 
 
 
 
883
  else:
884
  vid_path = vid_path_processed
885
  orig_vid_path_25fps = video_path
@@ -978,7 +997,9 @@ def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
978
  print("Predicted offset: ", pred_offset)
979
 
980
  # Generate sync-corrected video
981
- video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
 
 
982
  print("Successfully generated the video:", video_output)
983
 
984
  return f"Predicted offset: {pred_offset}", video_output
@@ -1005,14 +1026,14 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1005
 
1006
  if global_speaker=="per-frame-prediction" and num_avg_frames<25:
1007
  msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
1008
- return None, msg
1009
 
1010
  # Read the video
1011
  try:
1012
  vr = VideoReader(video_path, ctx=cpu(0))
1013
  except:
1014
  msg = "Oops! Could not load the input video file"
1015
- return None, msg
1016
 
1017
  # Get the FPS of the video
1018
  fps = vr.get_avg_fps()
@@ -1020,29 +1041,28 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1020
 
1021
  # Resample the video to 25 FPS if the original video is of a different frame-rate
1022
  if fps!=25:
1023
- test_video_25fps = resample_video(video_path, video_fname, result_folder_input)
 
 
1024
  else:
1025
  test_video_25fps = video_path
1026
 
1027
  # Load the video frames
1028
  orig_frames, status = load_video_frames(test_video_25fps)
1029
  if status != "success":
1030
- return None, status
1031
- print("Successfully loaded the frames")
1032
 
1033
  # Extract and save the audio file
1034
  orig_wav_file, status = extract_audio(video_path, result_folder)
1035
  if status != "success":
1036
- return None, status
1037
- print("Successfully loaded the spectrograms")
1038
 
1039
  # Pre-process and extract per-speaker tracks in each scene
1040
  print("Pre-processing the input video...")
1041
  status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
1042
  if status != 0:
1043
- return None, "Error in pre-processing the input video, please check the input video and try again..."
1044
- print("Successfully preprocessed the video")
1045
-
1046
  # Load the tracks file saved during pre-processing
1047
  with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
1048
  tracks = pickle.load(file)
@@ -1056,7 +1076,6 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1056
  track_dict[scene_num][i] = {}
1057
  for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
1058
  track_dict[scene_num][i][frame_num] = bbox
1059
- print("Successfully loaded the extracted person-tracks")
1060
 
1061
  # Get the total number of scenes
1062
  test_scenes = os.listdir("{}/crops".format(result_folder_input))
@@ -1065,7 +1084,6 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1065
  # Load the trained model
1066
  model = Transformer_RGB()
1067
  model = load_checkpoint(CHECKPOINT_PATH, model)
1068
- print("Successfully loaded the model")
1069
 
1070
  # Compute the active speaker in each scene
1071
  output_tracks = {}
@@ -1076,20 +1094,21 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1076
 
1077
  if len(test_videos)<=1:
1078
  msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
1079
- return None, msg
1080
 
1081
  # Load the audio file
1082
  audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
1083
  spec, _, status = load_spectrograms(audio_file, window_frames=25)
1084
  if status != "success":
1085
- return None, status
1086
  spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
 
1087
 
1088
  # Load the masked input frames
1089
  all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
1090
  if status != "success":
1091
- return None, status
1092
-
1093
 
1094
  # Prepare the audio and video sequences for the model
1095
  audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
@@ -1105,6 +1124,7 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1105
  else:
1106
  video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
1107
  all_video_embs.append(video_emb)
 
1108
 
1109
  # Predict the active speaker in each scene
1110
  if global_speaker=="per-frame-prediction":
@@ -1134,7 +1154,7 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1134
 
1135
  if len(predictions) != frame_pred:
1136
  msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
1137
- return None, msg
1138
 
1139
  active_speakers[start:end] = predictions[0:]
1140
 
@@ -1154,21 +1174,19 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1154
  output_tracks[frame] = track_dict[scene_num][label][frame]
1155
 
1156
  # Save the output video
1157
- print("Generating active-speaker detection output video...")
1158
  video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
1159
  if status != "success":
1160
- return None, status
1161
  print("Successfully saved the output video: ", video_output)
1162
 
1163
- return video_output, "success"
1164
 
1165
  except Exception as e:
1166
- return None, f"Error: {str(e)}"
1167
-
1168
 
1169
  if __name__ == "__main__":
1170
 
1171
-
1172
  # Custom CSS and HTML
1173
  custom_css = """
1174
  <style>
@@ -1291,7 +1309,7 @@ if __name__ == "__main__":
1291
  gr.update(visible=True), # submit_button
1292
  gr.update(visible=True), # clear_button
1293
  gr.update(visible=False), # sync_examples
1294
- gr.update(visible=True) # asd_examples
1295
  )
1296
 
1297
  def clear_inputs():
@@ -1303,16 +1321,16 @@ if __name__ == "__main__":
1303
  else:
1304
  return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
1305
 
1306
-
1307
  # Define paths to sample videos
1308
  sync_sample_videos = [
1309
- "samples/sync_sample_1.mp4",
1310
- "samples/sync_sample_2.mp4",
 
1311
  ]
1312
 
1313
  asd_sample_videos = [
1314
- "samples/asd_sample_1.mp4",
1315
- "samples/asd_sample_2.mp4"
1316
  ]
1317
 
1318
  # Define Gradio interface
@@ -1356,32 +1374,40 @@ if __name__ == "__main__":
1356
  # Add a gap before examples
1357
  gr.HTML('<div class="examples-holder"></div>')
1358
 
 
1359
  # Add examples that only populate the video input
1360
- sync_examples = gr.Examples(
1361
- examples=sync_sample_videos,
1362
- inputs=video_input,
1363
- outputs=None,
1364
- fn=None,
1365
- cache_examples=False,
1366
  visible=False
1367
  )
1368
 
1369
- asd_examples = gr.Examples(
1370
- examples=asd_sample_videos,
1371
- inputs=video_input,
1372
- outputs=None,
1373
- fn=None,
1374
- cache_examples=False,
1375
  visible=False
1376
  )
1377
 
1378
-
1379
  demo_choice.change(
1380
  fn=toggle_demo,
1381
  inputs=demo_choice,
1382
- outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, result_text, output_video, submit_button, clear_button, sync_examples.dataset, asd_examples.dataset]
1383
  )
1384
 
 
 
 
 
 
 
 
 
 
 
 
 
1385
 
1386
  submit_button.click(
1387
  fn=process_video,
@@ -1394,5 +1420,6 @@ if __name__ == "__main__":
1394
  inputs=[],
1395
  outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
1396
  )
 
1397
  # Launch the interface
1398
- demo.launch(allowed_paths=["."], server_name="0.0.0.0", server_port=7860, share=True)
 
16
  from sync_models.gestsync_models import *
17
  from tqdm import tqdm
18
  from glob import glob
19
+ from scipy.io.wavfile import write
20
  import mediapipe as mp
21
  from protobuf_to_dict import protobuf_to_dict
22
  import warnings
 
26
  warnings.filterwarnings("ignore", category=UserWarning)
27
 
28
  # Initialize global variables
29
+ CHECKPOINT_PATH = "model_rgb.pth"
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  use_cuda = torch.cuda.is_available()
32
  batch_size = 12
 
196
  video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
197
 
198
  # Resample the video to 25 fps
199
+ # status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}'.format(video_file, video_file_25fps), shell=True)
200
+ status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True)
201
+ if status != 0:
202
+ msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again."
203
+ return None, msg
204
  print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
 
205
 
206
+ return video_file_25fps, "success"
207
 
208
  def load_checkpoint(path, model):
209
  '''
 
420
 
421
  return input_frames, num_frames, orig_masked_frames, "success"
422
 
423
+ def load_spectrograms(wav_file, num_frames=None, window_frames=25, stride=4):
424
 
425
  '''
426
  This function extracts the spectrogram from the audio file
 
450
  orig_spec = spec
451
  spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
452
 
453
+ if num_frames is not None:
454
+ if len(spec) != num_frames:
455
+ spec = spec[:num_frames]
456
  frame_diff = np.abs(len(spec) - num_frames)
457
  if frame_diff > 60:
458
  print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
 
593
  return None, msg
594
 
595
  video_output = video_fname + '.mp4'
596
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' %
597
+ (audio_file, no_sound_video, video_output), shell=True)
598
+
599
  if status != 0:
600
  msg = "Oops! Could not generate the video. Please check the input video and try again."
601
  return None, msg
 
603
  os.remove(fname)
604
  os.remove(no_sound_video)
605
 
606
+ return video_output, "success"
607
 
608
  def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
609
 
 
641
  corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
642
 
643
  corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
644
+ video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path)
645
+ if status != "success":
646
+ return None, status
647
 
648
+ return video_output, "success"
649
 
650
 
651
  def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
 
667
  all_frames, all_orig_frames = [], []
668
  for video_num, video in enumerate(test_videos):
669
 
670
+ print("Processing video: ", video)
671
+
672
  # Load the video frames
673
  frames, status = load_video_frames(video)
674
  if status != "success":
675
  return None, None, status
676
+ print("Successfully loaded the video frames")
677
 
678
  # Extract the keypoints from the frames
679
  kp_dict, status = get_keypoints(frames)
680
  if status != "success":
681
  return None, None, status
682
+ print("Successfully extracted the keypoints")
683
 
684
  # Mask the frames using the keypoints extracted from the frames and prepare the input to the model
685
  masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict)
686
  if status != "success":
687
  return None, None, status
688
+ print("Successfully loaded the masked frames")
689
 
 
 
690
 
691
  # Check if the length of the input frames is equal to the length of the spectrogram
692
  if spec.shape[2]!=masked_frames.shape[0]:
 
700
  # Transpose the frames to the correct format
701
  frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
702
  frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
703
+ print("Successfully converted the frames to tensor")
704
 
705
  all_frames.append(frames)
706
  all_orig_frames.append(orig_masked_frames)
 
840
  - video_output (string) : Path of the output video
841
  '''
842
 
843
+ try:
844
+ output_frames = []
845
+ for i in range(len(input_frames)):
846
+
847
+ # If the active speaker is found, draw a bounding box around the active speaker
848
+ if i in output_tracks:
849
+ bbox = output_tracks[i]
850
+ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
851
+ out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
852
+ else:
853
+ out = input_frames[i]
854
 
855
+ output_frames.append(out)
856
 
857
+ # Generate the output video
858
+ output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
859
+ video_output, status = generate_video(output_frames, wav_file, output_video_fname)
860
+ if status != "success":
861
+ return None, status
862
+ except Exception as e:
863
+ return None, f"Error: {str(e)}"
864
 
865
+ return video_output, "success"
866
 
867
  def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
868
 
 
893
  # Resample the video to 25 fps if it is not already 25 fps
894
  print("FPS of video: ", fps)
895
  if fps!=25:
896
+ vid_path, status = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
897
+ if status != "success":
898
+ return status, None
899
+ orig_vid_path_25fps, status = resample_video(video_path, "input_video_25fps", result_folder_input)
900
+ if status != "success":
901
+ return status, None
902
  else:
903
  vid_path = vid_path_processed
904
  orig_vid_path_25fps = video_path
 
997
  print("Predicted offset: ", pred_offset)
998
 
999
  # Generate sync-corrected video
1000
+ video_output, status = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
1001
+ if status != "success":
1002
+ return status, None
1003
  print("Successfully generated the video:", video_output)
1004
 
1005
  return f"Predicted offset: {pred_offset}", video_output
 
1026
 
1027
  if global_speaker=="per-frame-prediction" and num_avg_frames<25:
1028
  msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
1029
+ return msg, None
1030
 
1031
  # Read the video
1032
  try:
1033
  vr = VideoReader(video_path, ctx=cpu(0))
1034
  except:
1035
  msg = "Oops! Could not load the input video file"
1036
+ return msg, None
1037
 
1038
  # Get the FPS of the video
1039
  fps = vr.get_avg_fps()
 
1041
 
1042
  # Resample the video to 25 FPS if the original video is of a different frame-rate
1043
  if fps!=25:
1044
+ test_video_25fps, status = resample_video(video_path, video_fname, result_folder_input)
1045
+ if status != "success":
1046
+ return status, None
1047
  else:
1048
  test_video_25fps = video_path
1049
 
1050
  # Load the video frames
1051
  orig_frames, status = load_video_frames(test_video_25fps)
1052
  if status != "success":
1053
+ return status, None
 
1054
 
1055
  # Extract and save the audio file
1056
  orig_wav_file, status = extract_audio(video_path, result_folder)
1057
  if status != "success":
1058
+ return status, None
 
1059
 
1060
  # Pre-process and extract per-speaker tracks in each scene
1061
  print("Pre-processing the input video...")
1062
  status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
1063
  if status != 0:
1064
+ return "Error in pre-processing the input video, please check the input video and try again...", None
1065
+
 
1066
  # Load the tracks file saved during pre-processing
1067
  with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
1068
  tracks = pickle.load(file)
 
1076
  track_dict[scene_num][i] = {}
1077
  for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
1078
  track_dict[scene_num][i][frame_num] = bbox
 
1079
 
1080
  # Get the total number of scenes
1081
  test_scenes = os.listdir("{}/crops".format(result_folder_input))
 
1084
  # Load the trained model
1085
  model = Transformer_RGB()
1086
  model = load_checkpoint(CHECKPOINT_PATH, model)
 
1087
 
1088
  # Compute the active speaker in each scene
1089
  output_tracks = {}
 
1094
 
1095
  if len(test_videos)<=1:
1096
  msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
1097
+ return msg, None
1098
 
1099
  # Load the audio file
1100
  audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
1101
  spec, _, status = load_spectrograms(audio_file, window_frames=25)
1102
  if status != "success":
1103
+ return status, None
1104
  spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
1105
+ print("Successfully loaded the spectrograms")
1106
 
1107
  # Load the masked input frames
1108
  all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
1109
  if status != "success":
1110
+ return status, None
1111
+ print("Successfully loaded the masked input frames")
1112
 
1113
  # Prepare the audio and video sequences for the model
1114
  audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
 
1124
  else:
1125
  video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
1126
  all_video_embs.append(video_emb)
1127
+ print("Successfully extracted GestSync embeddings")
1128
 
1129
  # Predict the active speaker in each scene
1130
  if global_speaker=="per-frame-prediction":
 
1154
 
1155
  if len(predictions) != frame_pred:
1156
  msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
1157
+ return msg, None
1158
 
1159
  active_speakers[start:end] = predictions[0:]
1160
 
 
1174
  output_tracks[frame] = track_dict[scene_num][label][frame]
1175
 
1176
  # Save the output video
 
1177
  video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
1178
  if status != "success":
1179
+ return status, None
1180
  print("Successfully saved the output video: ", video_output)
1181
 
1182
+ return "success", video_output
1183
 
1184
  except Exception as e:
1185
+ return f"Error: {str(e)}", None
 
1186
 
1187
  if __name__ == "__main__":
1188
 
1189
+
1190
  # Custom CSS and HTML
1191
  custom_css = """
1192
  <style>
 
1309
  gr.update(visible=True), # submit_button
1310
  gr.update(visible=True), # clear_button
1311
  gr.update(visible=False), # sync_examples
1312
+ gr.update(visible=True) # asd_examples
1313
  )
1314
 
1315
  def clear_inputs():
 
1321
  else:
1322
  return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
1323
 
 
1324
  # Define paths to sample videos
1325
  sync_sample_videos = [
1326
+ ["samples/sync_sample_1.mp4"],
1327
+ ["samples/sync_sample_2.mp4"],
1328
+ ["samples/sync_sample_3.mp4"]
1329
  ]
1330
 
1331
  asd_sample_videos = [
1332
+ ["samples/asd_sample_1.mp4"],
1333
+ ["samples/asd_sample_2.mp4"]
1334
  ]
1335
 
1336
  # Define Gradio interface
 
1374
  # Add a gap before examples
1375
  gr.HTML('<div class="examples-holder"></div>')
1376
 
1377
+
1378
  # Add examples that only populate the video input
1379
+ sync_examples = gr.Dataset(
1380
+ samples=sync_sample_videos,
1381
+ components=[video_input],
1382
+ type="values",
 
 
1383
  visible=False
1384
  )
1385
 
1386
+ asd_examples = gr.Dataset(
1387
+ samples=asd_sample_videos,
1388
+ components=[video_input],
1389
+ type="values",
 
 
1390
  visible=False
1391
  )
1392
 
 
1393
  demo_choice.change(
1394
  fn=toggle_demo,
1395
  inputs=demo_choice,
1396
+ outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, result_text, output_video, submit_button, clear_button, sync_examples, asd_examples]
1397
  )
1398
 
1399
+ sync_examples.select(
1400
+ fn=lambda x: gr.update(value=x[0], visible=True),
1401
+ inputs=sync_examples,
1402
+ outputs=video_input
1403
+ )
1404
+
1405
+ asd_examples.select(
1406
+ fn=lambda x: gr.update(value=x[0], visible=True),
1407
+ inputs=asd_examples,
1408
+ outputs=video_input
1409
+ )
1410
+
1411
 
1412
  submit_button.click(
1413
  fn=process_video,
 
1420
  inputs=[],
1421
  outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
1422
  )
1423
+
1424
  # Launch the interface
1425
+ demo.launch(allowed_paths=["."], share=True)
preprocess/inference_preprocess.py CHANGED
@@ -165,7 +165,7 @@ def crop_video(opt, track, cropfile, tight_scale=1):
165
  def inference_video(opt, padding=0):
166
  videofile = os.path.join(opt.avi_dir, 'video.avi')
167
  vidObj = cv2.VideoCapture(videofile)
168
- yolo_model = YOLO("yolov9t.pt")
169
  global dets, fidx
170
  dets = []
171
  fidx = 0
 
165
  def inference_video(opt, padding=0):
166
  videofile = os.path.join(opt.avi_dir, 'video.avi')
167
  vidObj = cv2.VideoCapture(videofile)
168
+ yolo_model = YOLO("yolov9m.pt")
169
  global dets, fidx
170
  dets = []
171
  fidx = 0