sindhuhegde commited on
Commit
2d17a01
1 Parent(s): a448fa6

Update app

Browse files
Files changed (1) hide show
  1. app_v1.py +954 -0
app_v1.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os, subprocess
4
+ from shutil import rmtree
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import librosa
9
+ import torch
10
+
11
+ from utils.audio_utils import *
12
+ from utils.inference_utils import *
13
+ from sync_models.gestsync_models import *
14
+
15
+ import sys
16
+ if sys.version_info > (3, 0): long, unicode, basestring = int, str, str
17
+
18
+ from tqdm import tqdm
19
+ from scipy.io.wavfile import write
20
+ import mediapipe as mp
21
+ from protobuf_to_dict import protobuf_to_dict
22
+ mp_holistic = mp.solutions.holistic
23
+ from ultralytics import YOLO
24
+ from decord import VideoReader, cpu
25
+
26
+ import warnings
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
28
+ warnings.filterwarnings("ignore", category=UserWarning)
29
+
30
+ # Set the path to checkpoint file
31
+ CHECKPOINT_PATH = "model_rgb.pth"
32
+
33
+ # Initialize global variables
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ use_cuda = torch.cuda.is_available()
36
+ n_negative_samples = 100
37
+ print("Using CUDA: ", use_cuda, device)
38
+
39
+ def preprocess_video(path, result_folder, apply_preprocess, padding=20):
40
+
41
+ '''
42
+ This function preprocesses the input video to extract the audio and crop the frames using YOLO model
43
+
44
+ Args:
45
+ - path (string) : Path of the input video file
46
+ - result_folder (string) : Path of the folder to save the extracted audio and cropped video
47
+ - padding (int) : Padding to add to the bounding box
48
+ Returns:
49
+ - wav_file (string) : Path of the extracted audio file
50
+ - fps (int) : FPS of the input video
51
+ - video_output (string) : Path of the cropped video file
52
+ - msg (string) : Message to be returned
53
+ '''
54
+
55
+ # Load all video frames
56
+ try:
57
+ vr = VideoReader(path, ctx=cpu(0))
58
+ fps = vr.get_avg_fps()
59
+ frame_count = len(vr)
60
+ except:
61
+ msg = "Oops! Could not load the video. Please check the input video and try again."
62
+ return None, None, None, msg
63
+
64
+ if frame_count < 25:
65
+ msg = "Not enough frames to process! Please give a longer video as input"
66
+ return None, None, None, msg
67
+
68
+ # Extract the audio from the input video file using ffmpeg
69
+ wav_file = os.path.join(result_folder, "audio.wav")
70
+
71
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
72
+ -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
73
+
74
+ if status != 0:
75
+ msg = "Oops! Could not load the audio file. Please check the input video and try again."
76
+ return None, None, None, msg
77
+ print("Extracted the audio from the video")
78
+
79
+ if apply_preprocess=="True":
80
+ all_frames = []
81
+ for k in range(len(vr)):
82
+ all_frames.append(vr[k].asnumpy())
83
+ all_frames = np.asarray(all_frames)
84
+ print("Extracted the frames for pre-processing")
85
+
86
+ # Load YOLOv9 model (pre-trained on COCO dataset)
87
+ yolo_model = YOLO("yolov9s.pt")
88
+ print("Loaded the YOLO model")
89
+
90
+
91
+
92
+ person_videos = {}
93
+ person_tracks = {}
94
+
95
+ print("Processing the frames...")
96
+ for frame_idx in tqdm(range(frame_count)):
97
+
98
+ frame = all_frames[frame_idx]
99
+
100
+ # Perform person detection
101
+ results = yolo_model(frame, verbose=False)
102
+ detections = results[0].boxes
103
+
104
+ for i, det in enumerate(detections):
105
+ x1, y1, x2, y2 = det.xyxy[0]
106
+ cls = det.cls[0]
107
+ if int(cls) == 0: # Class 0 is 'person' in COCO dataset
108
+
109
+ x1 = max(0, int(x1) - padding)
110
+ y1 = max(0, int(y1) - padding)
111
+ x2 = min(frame.shape[1], int(x2) + padding)
112
+ y2 = min(frame.shape[0], int(y2) + padding)
113
+
114
+ if i not in person_videos:
115
+ person_videos[i] = []
116
+ person_tracks[i] = []
117
+
118
+ person_videos[i].append(frame)
119
+ person_tracks[i].append([x1,y1,x2,y2])
120
+
121
+
122
+ num_persons = 0
123
+ for i in person_videos.keys():
124
+ if len(person_videos[i]) >= frame_count//2:
125
+ num_persons+=1
126
+
127
+ if num_persons==0:
128
+ msg = "No person detected in the video! Please give a video with one person as input"
129
+ return None, None, None, msg
130
+ if num_persons>1:
131
+ msg = "More than one person detected in the video! Please give a video with only one person as input"
132
+ return None, None, None, msg
133
+
134
+
135
+
136
+ # For the person detected, crop the frame based on the bounding box
137
+ if len(person_videos[0]) > frame_count-10:
138
+ crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
139
+ fourcc = cv2.VideoWriter_fourcc(*'DIVX')
140
+
141
+ # Get bounding box coordinates based on person_tracks[i]
142
+ max_x1 = min([track[0] for track in person_tracks[0]])
143
+ max_y1 = min([track[1] for track in person_tracks[0]])
144
+ max_x2 = max([track[2] for track in person_tracks[0]])
145
+ max_y2 = max([track[3] for track in person_tracks[0]])
146
+
147
+ max_width = max_x2 - max_x1
148
+ max_height = max_y2 - max_y1
149
+
150
+ out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
151
+ for frame in person_videos[0]:
152
+ crop = frame[max_y1:max_y2, max_x1:max_x2]
153
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
154
+ out.write(crop)
155
+ out.release()
156
+
157
+ no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
158
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
159
+ if status != 0:
160
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
161
+ return None, None, None, msg
162
+
163
+ video_output = crop_filename.split('.')[0] + '.mp4'
164
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
165
+ (wav_file , no_sound_video, video_output), shell=True)
166
+ if status != 0:
167
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
168
+ return None, None, None, msg
169
+
170
+ os.remove(crop_filename)
171
+ os.remove(no_sound_video)
172
+
173
+ print("Successfully saved the pre-processed video: ", video_output)
174
+ else:
175
+ msg = "Could not track the person in the full video! Please give a single-speaker video as input"
176
+ return None, None, None, msg
177
+
178
+ else:
179
+ video_output = path
180
+
181
+ return wav_file, fps, video_output, "success"
182
+
183
+ def resample_video(video_file, video_fname, result_folder):
184
+
185
+ '''
186
+ This function resamples the video to 25 fps
187
+
188
+ Args:
189
+ - video_file (string) : Path of the input video file
190
+ - video_fname (string) : Name of the input video file
191
+ - result_folder (string) : Path of the folder to save the resampled video
192
+ Returns:
193
+ - video_file_25fps (string) : Path of the resampled video file
194
+ '''
195
+ video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
196
+
197
+ # Resample the video to 25 fps
198
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}".format(video_file, video_file_25fps))
199
+ from subprocess import call
200
+ cmd = command.split(' ')
201
+ print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
202
+ call(cmd)
203
+
204
+ return video_file_25fps
205
+
206
+ def load_checkpoint(path, model):
207
+ '''
208
+ This function loads the trained model from the checkpoint
209
+
210
+ Args:
211
+ - path (string) : Path of the checkpoint file
212
+ - model (object) : Model object
213
+ Returns:
214
+ - model (object) : Model object with the weights loaded from the checkpoint
215
+ '''
216
+
217
+ # Load the checkpoint
218
+ if use_cuda:
219
+ checkpoint = torch.load(path)
220
+ else:
221
+ checkpoint = torch.load(path, map_location="cpu")
222
+
223
+ s = checkpoint["state_dict"]
224
+ new_s = {}
225
+
226
+ for k, v in s.items():
227
+ new_s[k.replace('module.', '')] = v
228
+ model.load_state_dict(new_s)
229
+
230
+ if use_cuda:
231
+ model.cuda()
232
+
233
+ print("Loaded checkpoint from: {}".format(path))
234
+
235
+ return model.eval()
236
+
237
+
238
+ def load_video_frames(video_file):
239
+ '''
240
+ This function extracts the frames from the video
241
+
242
+ Args:
243
+ - video_file (string) : Path of the video file
244
+ Returns:
245
+ - frames (list) : List of frames extracted from the video
246
+ - msg (string) : Message to be returned
247
+ '''
248
+
249
+ # Read the video
250
+ try:
251
+ vr = VideoReader(video_file, ctx=cpu(0))
252
+ except:
253
+ msg = "Oops! Could not load the input video file"
254
+ return None, msg
255
+
256
+
257
+ # Extract the frames
258
+ frames = []
259
+ for k in range(len(vr)):
260
+ frames.append(vr[k].asnumpy())
261
+
262
+ frames = np.asarray(frames)
263
+
264
+ return frames, "success"
265
+
266
+
267
+
268
+ def get_keypoints(frames):
269
+
270
+ '''
271
+ This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
272
+
273
+ Args:
274
+ - frames (list) : List of frames extracted from the video
275
+ Returns:
276
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
277
+ - msg (string) : Message to be returned
278
+ '''
279
+
280
+ try:
281
+ holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
282
+
283
+ resolution = frames[0].shape
284
+ all_frame_kps = []
285
+
286
+ for frame in frames:
287
+
288
+ results = holistic.process(frame)
289
+
290
+ pose, left_hand, right_hand, face = None, None, None, None
291
+ if results.pose_landmarks is not None:
292
+ pose = protobuf_to_dict(results.pose_landmarks)['landmark']
293
+ if results.left_hand_landmarks is not None:
294
+ left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
295
+ if results.right_hand_landmarks is not None:
296
+ right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
297
+ if results.face_landmarks is not None:
298
+ face = protobuf_to_dict(results.face_landmarks)['landmark']
299
+
300
+ frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
301
+
302
+ all_frame_kps.append(frame_dict)
303
+
304
+ kp_dict = {"kps":all_frame_kps, "resolution":resolution}
305
+ except Exception as e:
306
+ print("Error: ", e)
307
+ return None, "Error: Could not extract keypoints from the frames"
308
+
309
+ return kp_dict, "success"
310
+
311
+
312
+ def check_visible_gestures(kp_dict):
313
+
314
+ '''
315
+ This function checks if the gestures in the video are visible
316
+
317
+ Args:
318
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
319
+ Returns:
320
+ - msg (string) : Message to be returned
321
+ '''
322
+
323
+ keypoints = kp_dict['kps']
324
+ keypoints = np.array(keypoints)
325
+
326
+ if len(keypoints)<25:
327
+ msg = "Not enough keypoints to process! Please give a longer video as input"
328
+ return msg
329
+
330
+ pose_count, hand_count = 0, 0
331
+ for frame_kp_dict in keypoints:
332
+
333
+ pose = frame_kp_dict["pose"]
334
+ left_hand = frame_kp_dict["left_hand"]
335
+ right_hand = frame_kp_dict["right_hand"]
336
+
337
+ if pose is None:
338
+ pose_count += 1
339
+
340
+ if left_hand is None and right_hand is None:
341
+ hand_count += 1
342
+
343
+
344
+ if hand_count/len(keypoints) > 0.7 or pose_count/len(keypoints) > 0.7:
345
+ msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
346
+ return msg
347
+
348
+ print("Successfully verified the input video - Gestures are visible!")
349
+
350
+ return "success"
351
+
352
+ def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, width=480, height=270):
353
+
354
+ '''
355
+ This function masks the faces using the keypoints extracted from the frames
356
+
357
+ Args:
358
+ - input_frames (list) : List of frames extracted from the video
359
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
360
+ - stride (int) : Stride to extract the frames
361
+ - window_frames (int) : Number of frames in each window that is given as input to the model
362
+ - width (int) : Width of the frames
363
+ - height (int) : Height of the frames
364
+ Returns:
365
+ - input_frames (array) : Frame window to be given as input to the model
366
+ - num_frames (int) : Number of frames to extract
367
+ - orig_masked_frames (array) : Masked frames extracted from the video
368
+ - msg (string) : Message to be returned
369
+ '''
370
+
371
+ # Face indices to extract the face-coordinates needed for masking
372
+ face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
373
+ 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
374
+
375
+
376
+ input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
377
+ print("Input keypoints: ", len(input_keypoints))
378
+
379
+ print("Creating masked input frames...")
380
+ input_frames_masked = []
381
+ for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):
382
+
383
+ img = input_frames[i]
384
+ face = frame_kp_dict["face"]
385
+
386
+ if face is None:
387
+ img = cv2.resize(img, (width, height))
388
+ masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
389
+ else:
390
+ face_kps = []
391
+ for idx in range(len(face)):
392
+ if idx in face_oval_idx:
393
+ x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
394
+ face_kps.append((x,y))
395
+
396
+ face_kps = np.array(face_kps)
397
+ x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
398
+ x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
399
+ masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
400
+
401
+ if masked_img.shape[0] != width or masked_img.shape[1] != height:
402
+ masked_img = cv2.resize(masked_img, (width, height))
403
+
404
+ input_frames_masked.append(masked_img)
405
+
406
+ orig_masked_frames = np.array(input_frames_masked)
407
+ input_frames = np.array(input_frames_masked) / 255.
408
+ print("Input images full: ", input_frames.shape) # num_framesx270x480x3
409
+
410
+ input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
411
+ print("Input images window: ", input_frames.shape) # Tx25x270x480x3
412
+
413
+ num_frames = input_frames.shape[0]
414
+
415
+ if num_frames<10:
416
+ msg = "Not enough frames to process! Please give a longer video as input."
417
+ return None, None, None, msg
418
+
419
+ return input_frames, num_frames, orig_masked_frames, "success"
420
+
421
+ def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
422
+
423
+ '''
424
+ This function extracts the spectrogram from the audio file
425
+
426
+ Args:
427
+ - wav_file (string) : Path of the extracted audio file
428
+ - num_frames (int) : Number of frames to extract
429
+ - window_frames (int) : Number of frames in each window that is given as input to the model
430
+ - stride (int) : Stride to extract the audio frames
431
+ Returns:
432
+ - spec (array) : Spectrogram array window to be used as input to the model
433
+ - orig_spec (array) : Spectrogram array extracted from the audio file
434
+ - msg (string) : Message to be returned
435
+ '''
436
+
437
+ # Extract the audio from the input video file using ffmpeg
438
+ try:
439
+ wav = librosa.load(wav_file, sr=16000)[0]
440
+ except:
441
+ msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
442
+ return None, None, msg
443
+
444
+ # Convert to tensor
445
+ wav = torch.FloatTensor(wav).unsqueeze(0)
446
+ mel, _, _, _ = wav2filterbanks(wav.to(device))
447
+ spec = mel.squeeze(0).cpu().numpy()
448
+ orig_spec = spec
449
+ spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
450
+
451
+ if len(spec) != num_frames:
452
+ spec = spec[:num_frames]
453
+ frame_diff = np.abs(len(spec) - num_frames)
454
+ if frame_diff > 60:
455
+ print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
456
+
457
+ return spec, orig_spec, "success"
458
+
459
+
460
+ def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
461
+ '''
462
+ This function calculates the audio-visual offset between the video and audio
463
+
464
+ Args:
465
+ - vid_emb (array) : Video embedding array
466
+ - aud_emb (array) : Audio embedding array
467
+ - num_avg_frames (int) : Number of frames to average the scores
468
+ - model (object) : Model object
469
+ Returns:
470
+ - offset (int) : Optimal audio-visual offset
471
+ - msg (string) : Message to be returned
472
+ '''
473
+
474
+ pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
475
+ if status != "success":
476
+ return None, status
477
+ scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
478
+ offset = scores.argmax()*stride - pos_idx
479
+
480
+ return offset.item(), "success"
481
+
482
+ def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
483
+
484
+ '''
485
+ This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
486
+
487
+ Args:
488
+ - vid_emb (array) : Video embedding array
489
+ - aud_emb (array) : Audio embedding array
490
+ - num_avg_frames (int) : Number of frames to average the scores
491
+ - stride (int) : Stride to extract the negative windows
492
+ Returns:
493
+ - vid_emb_pos (array) : Positive video embedding array
494
+ - aud_emb_posneg (array) : All possible combinations of audio embedding array
495
+ - pos_idx_frame (int) : Positive video embedding array frame
496
+ - stride (int) : Stride used to extract the negative windows
497
+ - msg (string) : Message to be returned
498
+ '''
499
+
500
+ slice_size = num_avg_frames
501
+ aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
502
+ aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
503
+ aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
504
+
505
+ pos_idx = (aud_emb_posneg.shape[1]//2)
506
+ pos_idx_frame = pos_idx*stride
507
+
508
+ min_offset_frames = -(pos_idx)*stride
509
+ max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
510
+ print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
511
+
512
+ vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
513
+ if vid_emb_pos.shape[2] != slice_size:
514
+ msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
515
+ return None, None, None, None, msg
516
+
517
+ return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
518
+
519
+ def calc_av_scores(vid_emb, aud_emb, model):
520
+
521
+ '''
522
+ This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
523
+
524
+ Args:
525
+ - vid_emb (array) : Video embedding array
526
+ - aud_emb (array) : Audio embedding array
527
+ - model (object) : Model object
528
+ Returns:
529
+ - scores (array) : Audio-visual similarity scores
530
+ - att_map (array) : Attention map
531
+ '''
532
+
533
+ scores = calc_att_map(vid_emb, aud_emb, model)
534
+ att_map = logsoftmax_2d(scores)
535
+ scores = scores.mean(-1)
536
+
537
+ return scores, att_map
538
+
539
+ def calc_att_map(vid_emb, aud_emb, model):
540
+
541
+ '''
542
+ This function calculates the similarity between the video and audio embeddings
543
+
544
+ Args:
545
+ - vid_emb (array) : Video embedding array
546
+ - aud_emb (array) : Audio embedding array
547
+ - model (object) : Model object
548
+ Returns:
549
+ - scores (array) : Audio-visual similarity scores
550
+ '''
551
+
552
+ vid_emb = vid_emb[:, :, None]
553
+ aud_emb = aud_emb.transpose(1, 2)
554
+
555
+ scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
556
+ vid_emb,
557
+ aud_emb,
558
+ part_len=10,
559
+ dim=3,
560
+ device=device)
561
+
562
+ scores = model.logits_scale(scores[..., None]).squeeze(-1)
563
+
564
+ return scores
565
+
566
+ def generate_video(frames, audio_file, video_fname):
567
+
568
+ '''
569
+ This function generates the video from the frames and audio file
570
+
571
+ Args:
572
+ - frames (array) : Frames to be used to generate the video
573
+ - audio_file (string) : Path of the audio file
574
+ - video_fname (string) : Path of the video file
575
+ Returns:
576
+ - video_output (string) : Path of the video file
577
+ '''
578
+
579
+ fname = 'inference.avi'
580
+ video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
581
+
582
+ for i in range(len(frames)):
583
+ video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
584
+ video.release()
585
+
586
+ no_sound_video = video_fname + '_nosound.mp4'
587
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
588
+ if status != 0:
589
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
590
+ return None, msg
591
+
592
+ video_output = video_fname + '.mp4'
593
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' %
594
+ (audio_file, no_sound_video, video_output), shell=True)
595
+ if status != 0:
596
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
597
+ return None, msg
598
+
599
+ os.remove(fname)
600
+ os.remove(no_sound_video)
601
+
602
+ return video_output
603
+
604
+ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
605
+
606
+ '''
607
+ This function corrects the video and audio to sync with each other
608
+
609
+ Args:
610
+ - video_path (string) : Path of the video file
611
+ - frames (array) : Frames to be used to generate the video
612
+ - wav_file (string) : Path of the audio file
613
+ - offset (int) : Predicted sync-offset to be used to correct the video
614
+ - result_folder (string) : Path of the result folder to save the output sync-corrected video
615
+ - sample_rate (int) : Sample rate of the audio
616
+ - fps (int) : Frames per second of the video
617
+ Returns:
618
+ - video_output (string) : Path of the video file
619
+ '''
620
+
621
+ if offset == 0:
622
+ print("The input audio and video are in-sync! No need to perform sync correction.")
623
+ return video_path
624
+
625
+ print("Performing Sync Correction...")
626
+ corrected_frames = np.zeros_like(frames)
627
+ if offset > 0:
628
+ audio_offset = int(offset*(sample_rate/fps))
629
+ wav = librosa.core.load(wav_file, sr=sample_rate)[0]
630
+ corrected_wav = wav[audio_offset:]
631
+ corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
632
+ write(corrected_wav_file, sample_rate, corrected_wav)
633
+ wav_file = corrected_wav_file
634
+ corrected_frames = frames
635
+ elif offset < 0:
636
+ corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
637
+ corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
638
+
639
+ corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
640
+ video_output = generate_video(corrected_frames, wav_file, corrected_video_path)
641
+
642
+ return video_output
643
+
644
+ class Logger:
645
+ def __init__(self, filename):
646
+ self.terminal = sys.stdout
647
+ self.log = open(filename, "w")
648
+
649
+ def write(self, message):
650
+ self.terminal.write(message)
651
+ self.log.write(message)
652
+
653
+ def flush(self):
654
+ self.terminal.flush()
655
+ self.log.flush()
656
+
657
+ def isatty(self):
658
+ return False
659
+
660
+
661
+ def process_video(video_path, num_avg_frames, apply_preprocess):
662
+ try:
663
+ # Extract the video filename
664
+ video_fname = os.path.basename(video_path.split(".")[0])
665
+
666
+ # Create folders to save the inputs and results
667
+ result_folder = os.path.join("results", video_fname)
668
+ result_folder_input = os.path.join(result_folder, "input")
669
+ result_folder_output = os.path.join(result_folder, "output")
670
+
671
+ if os.path.exists(result_folder):
672
+ rmtree(result_folder)
673
+
674
+ os.makedirs(result_folder)
675
+ os.makedirs(result_folder_input)
676
+ os.makedirs(result_folder_output)
677
+
678
+
679
+ # Preprocess the video
680
+ print("Applying preprocessing: ", apply_preprocess)
681
+ wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
682
+ if status != "success":
683
+ return status, None
684
+ print("Successfully preprocessed the video")
685
+
686
+ # Resample the video to 25 fps if it is not already 25 fps
687
+ print("FPS of video: ", fps)
688
+ if fps!=25:
689
+ vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
690
+ orig_vid_path_25fps = resample_video(video_path, "input_video_25fps", result_folder_input)
691
+ else:
692
+ vid_path = vid_path_processed
693
+ orig_vid_path_25fps = video_path
694
+
695
+ # Load the original video frames (before pre-processing) - Needed for the final sync-correction
696
+ orig_frames, status = load_video_frames(orig_vid_path_25fps)
697
+ if status != "success":
698
+ return status, None
699
+
700
+ # Load the pre-processed video frames
701
+ frames, status = load_video_frames(vid_path)
702
+ if status != "success":
703
+ return status, None
704
+ print("Successfully extracted the video frames")
705
+
706
+ if len(frames) < num_avg_frames:
707
+ return "Error: The input video is too short. Please use a longer input video.", None
708
+
709
+ # Load keypoints and check if gestures are visible
710
+ kp_dict, status = get_keypoints(frames)
711
+ if status != "success":
712
+ return status, None
713
+ print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))
714
+
715
+ status = check_visible_gestures(kp_dict)
716
+ if status != "success":
717
+ return status, None
718
+
719
+ # Load RGB frames
720
+ rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, window_frames=25, width=480, height=270)
721
+ if status != "success":
722
+ return status, None
723
+ print("Successfully loaded the RGB frames")
724
+
725
+ # Convert frames to tensor
726
+ rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
727
+ rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
728
+ B = rgb_frames.size(0)
729
+ print("Successfully converted the frames to tensor")
730
+
731
+ # Load spectrograms
732
+ spec, orig_spec, status = load_spectrograms(wav_file, num_frames, window_frames=25)
733
+ if status != "success":
734
+ return status, None
735
+ spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
736
+ print("Successfully loaded the spectrograms")
737
+
738
+ # Create input windows
739
+ video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
740
+ audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
741
+
742
+ # Load the trained model
743
+ model = Transformer_RGB()
744
+ model = load_checkpoint(CHECKPOINT_PATH, model)
745
+ print("Successfully loaded the model")
746
+
747
+ # Process in batches
748
+ batch_size = 12
749
+ video_emb = []
750
+ audio_emb = []
751
+
752
+ for i in tqdm(range(0, len(video_sequences), batch_size)):
753
+ video_inp = video_sequences[i:i+batch_size, ]
754
+ audio_inp = audio_sequences[i:i+batch_size, ]
755
+
756
+ vid_emb = model.forward_vid(video_inp.to(device))
757
+ vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
758
+ aud_emb = model.forward_aud(audio_inp.to(device))
759
+
760
+ video_emb.append(vid_emb.detach())
761
+ audio_emb.append(aud_emb.detach())
762
+
763
+ torch.cuda.empty_cache()
764
+
765
+ audio_emb = torch.cat(audio_emb, dim=0)
766
+ video_emb = torch.cat(video_emb, dim=0)
767
+
768
+ # L2 normalize embeddings
769
+ video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
770
+ audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
771
+
772
+ audio_emb = torch.split(audio_emb, B, dim=0)
773
+ audio_emb = torch.stack(audio_emb, dim=2)
774
+ audio_emb = audio_emb.squeeze(3)
775
+ audio_emb = audio_emb[:, None]
776
+
777
+ video_emb = torch.split(video_emb, B, dim=0)
778
+ video_emb = torch.stack(video_emb, dim=2)
779
+ video_emb = video_emb.squeeze(3)
780
+ print("Successfully extracted GestSync embeddings")
781
+
782
+ # Calculate sync offset
783
+ pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
784
+ if status != "success":
785
+ return status, None
786
+ print("Predicted offset: ", pred_offset)
787
+
788
+ # Generate sync-corrected video
789
+ video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
790
+ print("Successfully generated the video:", video_output)
791
+
792
+ return f"Predicted offset: {pred_offset}", video_output
793
+
794
+ except Exception as e:
795
+ return f"Error: {str(e)}", None
796
+
797
+ def read_logs():
798
+ sys.stdout.flush()
799
+ with open("output.log", "r") as f:
800
+ return f.read()
801
+
802
+
803
+ if __name__ == "__main__":
804
+
805
+ sys.stdout = Logger("output.log")
806
+
807
+
808
+ # Define the custom HTML for the header
809
+ custom_css = """
810
+ <style>
811
+ body {
812
+ background-color: #ffffff;
813
+ color: #333333; /* Default text color */
814
+ }
815
+ .container {
816
+ max-width: 100% !important;
817
+ padding-left: 0 !important;
818
+ padding-right: 0 !important;
819
+ }
820
+ .header {
821
+ background-color: #f0f0f0;
822
+ color: #333333;
823
+ padding: 30px;
824
+ margin-bottom: 30px;
825
+ text-align: center;
826
+ font-family: 'Helvetica Neue', Arial, sans-serif;
827
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
828
+ }
829
+ .header h1 {
830
+ font-size: 36px;
831
+ margin-bottom: 15px;
832
+ font-weight: bold;
833
+ color: #333333; /* Explicitly set heading color */
834
+ }
835
+ .header h2 {
836
+ font-size: 24px;
837
+ margin-bottom: 10px;
838
+ color: #333333; /* Explicitly set subheading color */
839
+ }
840
+ .header p {
841
+ font-size: 18px;
842
+ margin: 5px 0;
843
+ color: #666666;
844
+ }
845
+ .blue-text {
846
+ color: #4a90e2;
847
+ }
848
+ /* Custom styles for slider container */
849
+ .slider-container {
850
+ background-color: white !important;
851
+ padding-top: 0.9em;
852
+ padding-bottom: 0.9em;
853
+ }
854
+ /* Add gap before examples */
855
+ .examples-holder {
856
+ margin-top: 2em;
857
+ }
858
+ /* Set fixed size for example videos */
859
+ .gradio-container .gradio-examples .gr-sample {
860
+ width: 240px !important;
861
+ height: 135px !important;
862
+ object-fit: cover;
863
+ display: inline-block;
864
+ margin-right: 10px;
865
+ }
866
+
867
+ .gradio-container .gradio-examples {
868
+ display: flex;
869
+ flex-wrap: wrap;
870
+ gap: 10px;
871
+ }
872
+
873
+ /* Ensure the parent container does not stretch */
874
+ .gradio-container .gradio-examples {
875
+ max-width: 100%;
876
+ overflow: hidden;
877
+ }
878
+
879
+ /* Additional styles to ensure proper sizing in Safari */
880
+ .gradio-container .gradio-examples .gr-sample img {
881
+ width: 240px !important;
882
+ height: 135px !important;
883
+ object-fit: cover;
884
+ }
885
+ </style>
886
+ """
887
+
888
+ custom_html = custom_css + """
889
+ <div class="header">
890
+ <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
891
+ <h2>Upload any video to predict the synchronization offset and generate a sync-corrected video</h2>
892
+ <p>Sindhu Hegde and Andrew Zisserman</p>
893
+ <p>VGG, University of Oxford</p>
894
+ </div>
895
+ """
896
+
897
+ # Define paths to sample videos
898
+ sample_videos = [
899
+ "samples/sync_sample_1.mp4",
900
+ "samples/sync_sample_2.mp4",
901
+ ]
902
+
903
+ # Define Gradio interface
904
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
905
+ gr.HTML(custom_html)
906
+ with gr.Row():
907
+ with gr.Column():
908
+ with gr.Group(elem_classes="slider-container"):
909
+ num_avg_frames = gr.Slider(
910
+ minimum=50,
911
+ maximum=150,
912
+ step=5,
913
+ value=75,
914
+ label="Number of Average Frames",
915
+ )
916
+ apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False)
917
+ video_input = gr.Video(label="Upload Video", height=400)
918
+
919
+ with gr.Column():
920
+ result_text = gr.Textbox(label="Result")
921
+ output_video = gr.Video(label="Sync Corrected Video", height=400)
922
+
923
+ with gr.Row():
924
+ submit_button = gr.Button("Submit", variant="primary")
925
+ clear_button = gr.Button("Clear")
926
+
927
+ submit_button.click(
928
+ fn=process_video,
929
+ inputs=[video_input, num_avg_frames, apply_preprocess],
930
+ outputs=[result_text, output_video]
931
+ )
932
+
933
+ clear_button.click(
934
+ fn=lambda: (None, 75, False, "", None),
935
+ inputs=[],
936
+ outputs=[video_input, num_avg_frames, apply_preprocess, result_text, output_video]
937
+ )
938
+
939
+ gr.HTML('<div class="examples-holder"></div>')
940
+
941
+ # Add examples
942
+ gr.Examples(
943
+ examples=sample_videos,
944
+ inputs=video_input,
945
+ outputs=None,
946
+ fn=None,
947
+ cache_examples=False,
948
+ )
949
+
950
+ logs = gr.Textbox(label="Logs")
951
+ demo.load(read_logs, None, logs, every=1)
952
+
953
+ # Launch the interface
954
+ demo.queue().launch(allowed_paths=["."], show_error=True)