Spaces:
Running
on
Zero
Running
on
Zero
sindhuhegde
commited on
Commit
•
4ad47a9
1
Parent(s):
8f3cd14
Update app
Browse files- app.py +105 -78
- preprocess/inference_preprocess.py +1 -1
app.py
CHANGED
@@ -16,6 +16,7 @@ from utils.inference_utils import *
|
|
16 |
from sync_models.gestsync_models import *
|
17 |
from tqdm import tqdm
|
18 |
from glob import glob
|
|
|
19 |
import mediapipe as mp
|
20 |
from protobuf_to_dict import protobuf_to_dict
|
21 |
import warnings
|
@@ -25,7 +26,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
25 |
warnings.filterwarnings("ignore", category=UserWarning)
|
26 |
|
27 |
# Initialize global variables
|
28 |
-
CHECKPOINT_PATH = "model_rgb.pth"
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
use_cuda = torch.cuda.is_available()
|
31 |
batch_size = 12
|
@@ -195,13 +196,14 @@ def resample_video(video_file, video_fname, result_folder):
|
|
195 |
video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
|
196 |
|
197 |
# Resample the video to 25 fps
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
201 |
print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
|
202 |
-
call(cmd)
|
203 |
|
204 |
-
return video_file_25fps
|
205 |
|
206 |
def load_checkpoint(path, model):
|
207 |
'''
|
@@ -418,7 +420,7 @@ def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, wi
|
|
418 |
|
419 |
return input_frames, num_frames, orig_masked_frames, "success"
|
420 |
|
421 |
-
def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
|
422 |
|
423 |
'''
|
424 |
This function extracts the spectrogram from the audio file
|
@@ -448,8 +450,9 @@ def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
|
|
448 |
orig_spec = spec
|
449 |
spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
|
450 |
|
451 |
-
if
|
452 |
-
spec
|
|
|
453 |
frame_diff = np.abs(len(spec) - num_frames)
|
454 |
if frame_diff > 60:
|
455 |
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
|
@@ -590,8 +593,9 @@ def generate_video(frames, audio_file, video_fname):
|
|
590 |
return None, msg
|
591 |
|
592 |
video_output = video_fname + '.mp4'
|
593 |
-
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' %
|
594 |
-
|
|
|
595 |
if status != 0:
|
596 |
msg = "Oops! Could not generate the video. Please check the input video and try again."
|
597 |
return None, msg
|
@@ -599,7 +603,7 @@ def generate_video(frames, audio_file, video_fname):
|
|
599 |
os.remove(fname)
|
600 |
os.remove(no_sound_video)
|
601 |
|
602 |
-
return video_output
|
603 |
|
604 |
def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
|
605 |
|
@@ -637,9 +641,11 @@ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, samp
|
|
637 |
corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
|
638 |
|
639 |
corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
|
640 |
-
video_output = generate_video(corrected_frames, wav_file, corrected_video_path)
|
|
|
|
|
641 |
|
642 |
-
return video_output
|
643 |
|
644 |
|
645 |
def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
|
@@ -661,23 +667,26 @@ def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_fold
|
|
661 |
all_frames, all_orig_frames = [], []
|
662 |
for video_num, video in enumerate(test_videos):
|
663 |
|
|
|
|
|
664 |
# Load the video frames
|
665 |
frames, status = load_video_frames(video)
|
666 |
if status != "success":
|
667 |
return None, None, status
|
|
|
668 |
|
669 |
# Extract the keypoints from the frames
|
670 |
kp_dict, status = get_keypoints(frames)
|
671 |
if status != "success":
|
672 |
return None, None, status
|
|
|
673 |
|
674 |
# Mask the frames using the keypoints extracted from the frames and prepare the input to the model
|
675 |
masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict)
|
676 |
if status != "success":
|
677 |
return None, None, status
|
|
|
678 |
|
679 |
-
input_masked_vid_path = os.path.join(result_folder, "input_masked_scene_{}_speaker_{}".format(scene_num, video_num))
|
680 |
-
generate_video(orig_masked_frames, wav_file, input_masked_vid_path)
|
681 |
|
682 |
# Check if the length of the input frames is equal to the length of the spectrogram
|
683 |
if spec.shape[2]!=masked_frames.shape[0]:
|
@@ -691,6 +700,7 @@ def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_fold
|
|
691 |
# Transpose the frames to the correct format
|
692 |
frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
|
693 |
frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
|
|
|
694 |
|
695 |
all_frames.append(frames)
|
696 |
all_orig_frames.append(orig_masked_frames)
|
@@ -830,24 +840,29 @@ def save_video(output_tracks, input_frames, wav_file, result_folder):
|
|
830 |
- video_output (string) : Path of the output video
|
831 |
'''
|
832 |
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
|
|
843 |
|
844 |
-
|
845 |
|
846 |
-
|
847 |
-
|
848 |
-
|
|
|
|
|
|
|
|
|
849 |
|
850 |
-
return video_output
|
851 |
|
852 |
def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
|
853 |
|
@@ -878,8 +893,12 @@ def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
|
|
878 |
# Resample the video to 25 fps if it is not already 25 fps
|
879 |
print("FPS of video: ", fps)
|
880 |
if fps!=25:
|
881 |
-
vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
|
882 |
-
|
|
|
|
|
|
|
|
|
883 |
else:
|
884 |
vid_path = vid_path_processed
|
885 |
orig_vid_path_25fps = video_path
|
@@ -978,7 +997,9 @@ def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
|
|
978 |
print("Predicted offset: ", pred_offset)
|
979 |
|
980 |
# Generate sync-corrected video
|
981 |
-
video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
|
|
|
|
|
982 |
print("Successfully generated the video:", video_output)
|
983 |
|
984 |
return f"Predicted offset: {pred_offset}", video_output
|
@@ -1005,14 +1026,14 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1005 |
|
1006 |
if global_speaker=="per-frame-prediction" and num_avg_frames<25:
|
1007 |
msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
|
1008 |
-
return
|
1009 |
|
1010 |
# Read the video
|
1011 |
try:
|
1012 |
vr = VideoReader(video_path, ctx=cpu(0))
|
1013 |
except:
|
1014 |
msg = "Oops! Could not load the input video file"
|
1015 |
-
return
|
1016 |
|
1017 |
# Get the FPS of the video
|
1018 |
fps = vr.get_avg_fps()
|
@@ -1020,29 +1041,28 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1020 |
|
1021 |
# Resample the video to 25 FPS if the original video is of a different frame-rate
|
1022 |
if fps!=25:
|
1023 |
-
test_video_25fps = resample_video(video_path, video_fname, result_folder_input)
|
|
|
|
|
1024 |
else:
|
1025 |
test_video_25fps = video_path
|
1026 |
|
1027 |
# Load the video frames
|
1028 |
orig_frames, status = load_video_frames(test_video_25fps)
|
1029 |
if status != "success":
|
1030 |
-
return
|
1031 |
-
print("Successfully loaded the frames")
|
1032 |
|
1033 |
# Extract and save the audio file
|
1034 |
orig_wav_file, status = extract_audio(video_path, result_folder)
|
1035 |
if status != "success":
|
1036 |
-
return
|
1037 |
-
print("Successfully loaded the spectrograms")
|
1038 |
|
1039 |
# Pre-process and extract per-speaker tracks in each scene
|
1040 |
print("Pre-processing the input video...")
|
1041 |
status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
|
1042 |
if status != 0:
|
1043 |
-
return
|
1044 |
-
|
1045 |
-
|
1046 |
# Load the tracks file saved during pre-processing
|
1047 |
with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
|
1048 |
tracks = pickle.load(file)
|
@@ -1056,7 +1076,6 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1056 |
track_dict[scene_num][i] = {}
|
1057 |
for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
|
1058 |
track_dict[scene_num][i][frame_num] = bbox
|
1059 |
-
print("Successfully loaded the extracted person-tracks")
|
1060 |
|
1061 |
# Get the total number of scenes
|
1062 |
test_scenes = os.listdir("{}/crops".format(result_folder_input))
|
@@ -1065,7 +1084,6 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1065 |
# Load the trained model
|
1066 |
model = Transformer_RGB()
|
1067 |
model = load_checkpoint(CHECKPOINT_PATH, model)
|
1068 |
-
print("Successfully loaded the model")
|
1069 |
|
1070 |
# Compute the active speaker in each scene
|
1071 |
output_tracks = {}
|
@@ -1076,20 +1094,21 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1076 |
|
1077 |
if len(test_videos)<=1:
|
1078 |
msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
|
1079 |
-
return
|
1080 |
|
1081 |
# Load the audio file
|
1082 |
audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
|
1083 |
spec, _, status = load_spectrograms(audio_file, window_frames=25)
|
1084 |
if status != "success":
|
1085 |
-
return
|
1086 |
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
|
|
|
1087 |
|
1088 |
# Load the masked input frames
|
1089 |
all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
|
1090 |
if status != "success":
|
1091 |
-
return
|
1092 |
-
|
1093 |
|
1094 |
# Prepare the audio and video sequences for the model
|
1095 |
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
|
@@ -1105,6 +1124,7 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1105 |
else:
|
1106 |
video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
|
1107 |
all_video_embs.append(video_emb)
|
|
|
1108 |
|
1109 |
# Predict the active speaker in each scene
|
1110 |
if global_speaker=="per-frame-prediction":
|
@@ -1134,7 +1154,7 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1134 |
|
1135 |
if len(predictions) != frame_pred:
|
1136 |
msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
|
1137 |
-
return
|
1138 |
|
1139 |
active_speakers[start:end] = predictions[0:]
|
1140 |
|
@@ -1154,21 +1174,19 @@ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
|
|
1154 |
output_tracks[frame] = track_dict[scene_num][label][frame]
|
1155 |
|
1156 |
# Save the output video
|
1157 |
-
print("Generating active-speaker detection output video...")
|
1158 |
video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
|
1159 |
if status != "success":
|
1160 |
-
return
|
1161 |
print("Successfully saved the output video: ", video_output)
|
1162 |
|
1163 |
-
return
|
1164 |
|
1165 |
except Exception as e:
|
1166 |
-
return
|
1167 |
-
|
1168 |
|
1169 |
if __name__ == "__main__":
|
1170 |
|
1171 |
-
|
1172 |
# Custom CSS and HTML
|
1173 |
custom_css = """
|
1174 |
<style>
|
@@ -1291,7 +1309,7 @@ if __name__ == "__main__":
|
|
1291 |
gr.update(visible=True), # submit_button
|
1292 |
gr.update(visible=True), # clear_button
|
1293 |
gr.update(visible=False), # sync_examples
|
1294 |
-
gr.update(visible=True)
|
1295 |
)
|
1296 |
|
1297 |
def clear_inputs():
|
@@ -1303,16 +1321,16 @@ if __name__ == "__main__":
|
|
1303 |
else:
|
1304 |
return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
|
1305 |
|
1306 |
-
|
1307 |
# Define paths to sample videos
|
1308 |
sync_sample_videos = [
|
1309 |
-
"samples/sync_sample_1.mp4",
|
1310 |
-
"samples/sync_sample_2.mp4",
|
|
|
1311 |
]
|
1312 |
|
1313 |
asd_sample_videos = [
|
1314 |
-
"samples/asd_sample_1.mp4",
|
1315 |
-
"samples/asd_sample_2.mp4"
|
1316 |
]
|
1317 |
|
1318 |
# Define Gradio interface
|
@@ -1356,32 +1374,40 @@ if __name__ == "__main__":
|
|
1356 |
# Add a gap before examples
|
1357 |
gr.HTML('<div class="examples-holder"></div>')
|
1358 |
|
|
|
1359 |
# Add examples that only populate the video input
|
1360 |
-
sync_examples = gr.
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
fn=None,
|
1365 |
-
cache_examples=False,
|
1366 |
visible=False
|
1367 |
)
|
1368 |
|
1369 |
-
asd_examples = gr.
|
1370 |
-
|
1371 |
-
|
1372 |
-
|
1373 |
-
fn=None,
|
1374 |
-
cache_examples=False,
|
1375 |
visible=False
|
1376 |
)
|
1377 |
|
1378 |
-
|
1379 |
demo_choice.change(
|
1380 |
fn=toggle_demo,
|
1381 |
inputs=demo_choice,
|
1382 |
-
outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, result_text, output_video, submit_button, clear_button, sync_examples
|
1383 |
)
|
1384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1385 |
|
1386 |
submit_button.click(
|
1387 |
fn=process_video,
|
@@ -1394,5 +1420,6 @@ if __name__ == "__main__":
|
|
1394 |
inputs=[],
|
1395 |
outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
|
1396 |
)
|
|
|
1397 |
# Launch the interface
|
1398 |
-
demo.launch(allowed_paths=["."],
|
|
|
16 |
from sync_models.gestsync_models import *
|
17 |
from tqdm import tqdm
|
18 |
from glob import glob
|
19 |
+
from scipy.io.wavfile import write
|
20 |
import mediapipe as mp
|
21 |
from protobuf_to_dict import protobuf_to_dict
|
22 |
import warnings
|
|
|
26 |
warnings.filterwarnings("ignore", category=UserWarning)
|
27 |
|
28 |
# Initialize global variables
|
29 |
+
CHECKPOINT_PATH = "model_rgb.pth"
|
30 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
use_cuda = torch.cuda.is_available()
|
32 |
batch_size = 12
|
|
|
196 |
video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
|
197 |
|
198 |
# Resample the video to 25 fps
|
199 |
+
# status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}'.format(video_file, video_file_25fps), shell=True)
|
200 |
+
status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True)
|
201 |
+
if status != 0:
|
202 |
+
msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again."
|
203 |
+
return None, msg
|
204 |
print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
|
|
|
205 |
|
206 |
+
return video_file_25fps, "success"
|
207 |
|
208 |
def load_checkpoint(path, model):
|
209 |
'''
|
|
|
420 |
|
421 |
return input_frames, num_frames, orig_masked_frames, "success"
|
422 |
|
423 |
+
def load_spectrograms(wav_file, num_frames=None, window_frames=25, stride=4):
|
424 |
|
425 |
'''
|
426 |
This function extracts the spectrogram from the audio file
|
|
|
450 |
orig_spec = spec
|
451 |
spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
|
452 |
|
453 |
+
if num_frames is not None:
|
454 |
+
if len(spec) != num_frames:
|
455 |
+
spec = spec[:num_frames]
|
456 |
frame_diff = np.abs(len(spec) - num_frames)
|
457 |
if frame_diff > 60:
|
458 |
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
|
|
|
593 |
return None, msg
|
594 |
|
595 |
video_output = video_fname + '.mp4'
|
596 |
+
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' %
|
597 |
+
(audio_file, no_sound_video, video_output), shell=True)
|
598 |
+
|
599 |
if status != 0:
|
600 |
msg = "Oops! Could not generate the video. Please check the input video and try again."
|
601 |
return None, msg
|
|
|
603 |
os.remove(fname)
|
604 |
os.remove(no_sound_video)
|
605 |
|
606 |
+
return video_output, "success"
|
607 |
|
608 |
def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
|
609 |
|
|
|
641 |
corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
|
642 |
|
643 |
corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
|
644 |
+
video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path)
|
645 |
+
if status != "success":
|
646 |
+
return None, status
|
647 |
|
648 |
+
return video_output, "success"
|
649 |
|
650 |
|
651 |
def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
|
|
|
667 |
all_frames, all_orig_frames = [], []
|
668 |
for video_num, video in enumerate(test_videos):
|
669 |
|
670 |
+
print("Processing video: ", video)
|
671 |
+
|
672 |
# Load the video frames
|
673 |
frames, status = load_video_frames(video)
|
674 |
if status != "success":
|
675 |
return None, None, status
|
676 |
+
print("Successfully loaded the video frames")
|
677 |
|
678 |
# Extract the keypoints from the frames
|
679 |
kp_dict, status = get_keypoints(frames)
|
680 |
if status != "success":
|
681 |
return None, None, status
|
682 |
+
print("Successfully extracted the keypoints")
|
683 |
|
684 |
# Mask the frames using the keypoints extracted from the frames and prepare the input to the model
|
685 |
masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict)
|
686 |
if status != "success":
|
687 |
return None, None, status
|
688 |
+
print("Successfully loaded the masked frames")
|
689 |
|
|
|
|
|
690 |
|
691 |
# Check if the length of the input frames is equal to the length of the spectrogram
|
692 |
if spec.shape[2]!=masked_frames.shape[0]:
|
|
|
700 |
# Transpose the frames to the correct format
|
701 |
frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
|
702 |
frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
|
703 |
+
print("Successfully converted the frames to tensor")
|
704 |
|
705 |
all_frames.append(frames)
|
706 |
all_orig_frames.append(orig_masked_frames)
|
|
|
840 |
- video_output (string) : Path of the output video
|
841 |
'''
|
842 |
|
843 |
+
try:
|
844 |
+
output_frames = []
|
845 |
+
for i in range(len(input_frames)):
|
846 |
+
|
847 |
+
# If the active speaker is found, draw a bounding box around the active speaker
|
848 |
+
if i in output_tracks:
|
849 |
+
bbox = output_tracks[i]
|
850 |
+
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
851 |
+
out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
|
852 |
+
else:
|
853 |
+
out = input_frames[i]
|
854 |
|
855 |
+
output_frames.append(out)
|
856 |
|
857 |
+
# Generate the output video
|
858 |
+
output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
|
859 |
+
video_output, status = generate_video(output_frames, wav_file, output_video_fname)
|
860 |
+
if status != "success":
|
861 |
+
return None, status
|
862 |
+
except Exception as e:
|
863 |
+
return None, f"Error: {str(e)}"
|
864 |
|
865 |
+
return video_output, "success"
|
866 |
|
867 |
def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
|
868 |
|
|
|
893 |
# Resample the video to 25 fps if it is not already 25 fps
|
894 |
print("FPS of video: ", fps)
|
895 |
if fps!=25:
|
896 |
+
vid_path, status = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
|
897 |
+
if status != "success":
|
898 |
+
return status, None
|
899 |
+
orig_vid_path_25fps, status = resample_video(video_path, "input_video_25fps", result_folder_input)
|
900 |
+
if status != "success":
|
901 |
+
return status, None
|
902 |
else:
|
903 |
vid_path = vid_path_processed
|
904 |
orig_vid_path_25fps = video_path
|
|
|
997 |
print("Predicted offset: ", pred_offset)
|
998 |
|
999 |
# Generate sync-corrected video
|
1000 |
+
video_output, status = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
|
1001 |
+
if status != "success":
|
1002 |
+
return status, None
|
1003 |
print("Successfully generated the video:", video_output)
|
1004 |
|
1005 |
return f"Predicted offset: {pred_offset}", video_output
|
|
|
1026 |
|
1027 |
if global_speaker=="per-frame-prediction" and num_avg_frames<25:
|
1028 |
msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
|
1029 |
+
return msg, None
|
1030 |
|
1031 |
# Read the video
|
1032 |
try:
|
1033 |
vr = VideoReader(video_path, ctx=cpu(0))
|
1034 |
except:
|
1035 |
msg = "Oops! Could not load the input video file"
|
1036 |
+
return msg, None
|
1037 |
|
1038 |
# Get the FPS of the video
|
1039 |
fps = vr.get_avg_fps()
|
|
|
1041 |
|
1042 |
# Resample the video to 25 FPS if the original video is of a different frame-rate
|
1043 |
if fps!=25:
|
1044 |
+
test_video_25fps, status = resample_video(video_path, video_fname, result_folder_input)
|
1045 |
+
if status != "success":
|
1046 |
+
return status, None
|
1047 |
else:
|
1048 |
test_video_25fps = video_path
|
1049 |
|
1050 |
# Load the video frames
|
1051 |
orig_frames, status = load_video_frames(test_video_25fps)
|
1052 |
if status != "success":
|
1053 |
+
return status, None
|
|
|
1054 |
|
1055 |
# Extract and save the audio file
|
1056 |
orig_wav_file, status = extract_audio(video_path, result_folder)
|
1057 |
if status != "success":
|
1058 |
+
return status, None
|
|
|
1059 |
|
1060 |
# Pre-process and extract per-speaker tracks in each scene
|
1061 |
print("Pre-processing the input video...")
|
1062 |
status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
|
1063 |
if status != 0:
|
1064 |
+
return "Error in pre-processing the input video, please check the input video and try again...", None
|
1065 |
+
|
|
|
1066 |
# Load the tracks file saved during pre-processing
|
1067 |
with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
|
1068 |
tracks = pickle.load(file)
|
|
|
1076 |
track_dict[scene_num][i] = {}
|
1077 |
for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
|
1078 |
track_dict[scene_num][i][frame_num] = bbox
|
|
|
1079 |
|
1080 |
# Get the total number of scenes
|
1081 |
test_scenes = os.listdir("{}/crops".format(result_folder_input))
|
|
|
1084 |
# Load the trained model
|
1085 |
model = Transformer_RGB()
|
1086 |
model = load_checkpoint(CHECKPOINT_PATH, model)
|
|
|
1087 |
|
1088 |
# Compute the active speaker in each scene
|
1089 |
output_tracks = {}
|
|
|
1094 |
|
1095 |
if len(test_videos)<=1:
|
1096 |
msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
|
1097 |
+
return msg, None
|
1098 |
|
1099 |
# Load the audio file
|
1100 |
audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
|
1101 |
spec, _, status = load_spectrograms(audio_file, window_frames=25)
|
1102 |
if status != "success":
|
1103 |
+
return status, None
|
1104 |
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
|
1105 |
+
print("Successfully loaded the spectrograms")
|
1106 |
|
1107 |
# Load the masked input frames
|
1108 |
all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
|
1109 |
if status != "success":
|
1110 |
+
return status, None
|
1111 |
+
print("Successfully loaded the masked input frames")
|
1112 |
|
1113 |
# Prepare the audio and video sequences for the model
|
1114 |
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
|
|
|
1124 |
else:
|
1125 |
video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
|
1126 |
all_video_embs.append(video_emb)
|
1127 |
+
print("Successfully extracted GestSync embeddings")
|
1128 |
|
1129 |
# Predict the active speaker in each scene
|
1130 |
if global_speaker=="per-frame-prediction":
|
|
|
1154 |
|
1155 |
if len(predictions) != frame_pred:
|
1156 |
msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
|
1157 |
+
return msg, None
|
1158 |
|
1159 |
active_speakers[start:end] = predictions[0:]
|
1160 |
|
|
|
1174 |
output_tracks[frame] = track_dict[scene_num][label][frame]
|
1175 |
|
1176 |
# Save the output video
|
|
|
1177 |
video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
|
1178 |
if status != "success":
|
1179 |
+
return status, None
|
1180 |
print("Successfully saved the output video: ", video_output)
|
1181 |
|
1182 |
+
return "success", video_output
|
1183 |
|
1184 |
except Exception as e:
|
1185 |
+
return f"Error: {str(e)}", None
|
|
|
1186 |
|
1187 |
if __name__ == "__main__":
|
1188 |
|
1189 |
+
|
1190 |
# Custom CSS and HTML
|
1191 |
custom_css = """
|
1192 |
<style>
|
|
|
1309 |
gr.update(visible=True), # submit_button
|
1310 |
gr.update(visible=True), # clear_button
|
1311 |
gr.update(visible=False), # sync_examples
|
1312 |
+
gr.update(visible=True) # asd_examples
|
1313 |
)
|
1314 |
|
1315 |
def clear_inputs():
|
|
|
1321 |
else:
|
1322 |
return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
|
1323 |
|
|
|
1324 |
# Define paths to sample videos
|
1325 |
sync_sample_videos = [
|
1326 |
+
["samples/sync_sample_1.mp4"],
|
1327 |
+
["samples/sync_sample_2.mp4"],
|
1328 |
+
["samples/sync_sample_3.mp4"]
|
1329 |
]
|
1330 |
|
1331 |
asd_sample_videos = [
|
1332 |
+
["samples/asd_sample_1.mp4"],
|
1333 |
+
["samples/asd_sample_2.mp4"]
|
1334 |
]
|
1335 |
|
1336 |
# Define Gradio interface
|
|
|
1374 |
# Add a gap before examples
|
1375 |
gr.HTML('<div class="examples-holder"></div>')
|
1376 |
|
1377 |
+
|
1378 |
# Add examples that only populate the video input
|
1379 |
+
sync_examples = gr.Dataset(
|
1380 |
+
samples=sync_sample_videos,
|
1381 |
+
components=[video_input],
|
1382 |
+
type="values",
|
|
|
|
|
1383 |
visible=False
|
1384 |
)
|
1385 |
|
1386 |
+
asd_examples = gr.Dataset(
|
1387 |
+
samples=asd_sample_videos,
|
1388 |
+
components=[video_input],
|
1389 |
+
type="values",
|
|
|
|
|
1390 |
visible=False
|
1391 |
)
|
1392 |
|
|
|
1393 |
demo_choice.change(
|
1394 |
fn=toggle_demo,
|
1395 |
inputs=demo_choice,
|
1396 |
+
outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, result_text, output_video, submit_button, clear_button, sync_examples, asd_examples]
|
1397 |
)
|
1398 |
|
1399 |
+
sync_examples.select(
|
1400 |
+
fn=lambda x: gr.update(value=x[0], visible=True),
|
1401 |
+
inputs=sync_examples,
|
1402 |
+
outputs=video_input
|
1403 |
+
)
|
1404 |
+
|
1405 |
+
asd_examples.select(
|
1406 |
+
fn=lambda x: gr.update(value=x[0], visible=True),
|
1407 |
+
inputs=asd_examples,
|
1408 |
+
outputs=video_input
|
1409 |
+
)
|
1410 |
+
|
1411 |
|
1412 |
submit_button.click(
|
1413 |
fn=process_video,
|
|
|
1420 |
inputs=[],
|
1421 |
outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
|
1422 |
)
|
1423 |
+
|
1424 |
# Launch the interface
|
1425 |
+
demo.launch(allowed_paths=["."], share=True)
|
preprocess/inference_preprocess.py
CHANGED
@@ -165,7 +165,7 @@ def crop_video(opt, track, cropfile, tight_scale=1):
|
|
165 |
def inference_video(opt, padding=0):
|
166 |
videofile = os.path.join(opt.avi_dir, 'video.avi')
|
167 |
vidObj = cv2.VideoCapture(videofile)
|
168 |
-
yolo_model = YOLO("
|
169 |
global dets, fidx
|
170 |
dets = []
|
171 |
fidx = 0
|
|
|
165 |
def inference_video(opt, padding=0):
|
166 |
videofile = os.path.join(opt.avi_dir, 'video.avi')
|
167 |
vidObj = cv2.VideoCapture(videofile)
|
168 |
+
yolo_model = YOLO("yolov9m.pt")
|
169 |
global dets, fidx
|
170 |
dets = []
|
171 |
fidx = 0
|