sindhuhegde commited on
Commit
aa5ee46
1 Parent(s): 4b6d86c

Add sync-offset-prediction app

Browse files
app.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os, subprocess
4
+ from shutil import rmtree
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import librosa
9
+ import torch
10
+
11
+ from utils.audio_utils import *
12
+ from utils.inference_utils import *
13
+ from sync_models.gestsync_models import *
14
+
15
+ from tqdm import tqdm
16
+ from scipy.io.wavfile import write
17
+ import mediapipe as mp
18
+ from protobuf_to_dict import protobuf_to_dict
19
+ mp_holistic = mp.solutions.holistic
20
+ from ultralytics import YOLO
21
+ from decord import VideoReader, cpu
22
+
23
+ import warnings
24
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
25
+ warnings.filterwarnings("ignore", category=UserWarning)
26
+
27
+ # Set the path to checkpoint file
28
+ CHECKPOINT_PATH = "checkpoints/model_rgb.pth" # Update this path
29
+
30
+ # Initialize global variables
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ use_cuda = torch.cuda.is_available()
33
+ n_negative_samples = 100
34
+
35
+ def preprocess_video(path, result_folder, padding=20):
36
+
37
+ '''
38
+ This function preprocesses the input video to extract the audio and crop the frames using YOLO model
39
+
40
+ Args:
41
+ - path (string) : Path of the input video file
42
+ - result_folder (string) : Path of the folder to save the extracted audio and cropped video
43
+ - padding (int) : Padding to add to the bounding box
44
+ Returns:
45
+ - wav_file (string) : Path of the extracted audio file
46
+ - fps (int) : FPS of the input video
47
+ - video_output (string) : Path of the cropped video file
48
+ - msg (string) : Message to be returned
49
+ '''
50
+
51
+ # Load all video frames
52
+ try:
53
+ vr = VideoReader(path, ctx=cpu(0))
54
+ fps = vr.get_avg_fps()
55
+ frame_count = len(vr)
56
+ except:
57
+ msg = "Oops! Could not load the video. Please check the input video and try again."
58
+ return None, None, None, msg
59
+
60
+ all_frames = []
61
+ for k in range(len(vr)):
62
+ all_frames.append(vr[k].asnumpy())
63
+ all_frames = np.asarray(all_frames)
64
+
65
+ # Load YOLOv5 model (pre-trained on COCO dataset)
66
+ yolo_model = YOLO("yolov9c.pt")
67
+
68
+
69
+ if frame_count < 25:
70
+ msg = "Not enough frames to process! Please give a longer video as input"
71
+ return None, None, None, msg
72
+
73
+ person_videos = {}
74
+ person_tracks = {}
75
+
76
+ for frame_idx in range(frame_count):
77
+
78
+ frame = all_frames[frame_idx]
79
+
80
+ # Perform person detection
81
+ results = yolo_model(frame, verbose=False)
82
+ detections = results[0].boxes
83
+
84
+ for i, det in enumerate(detections):
85
+ x1, y1, x2, y2 = det.xyxy[0]
86
+ cls = det.cls[0]
87
+ if int(cls) == 0: # Class 0 is 'person' in COCO dataset
88
+
89
+ x1 = max(0, int(x1) - padding)
90
+ y1 = max(0, int(y1) - padding)
91
+ x2 = min(frame.shape[1], int(x2) + padding)
92
+ y2 = min(frame.shape[0], int(y2) + padding)
93
+
94
+ if i not in person_videos:
95
+ person_videos[i] = []
96
+ person_tracks[i] = []
97
+
98
+ person_videos[i].append(frame)
99
+ person_tracks[i].append([x1,y1,x2,y2])
100
+
101
+
102
+ num_persons = 0
103
+ for i in person_videos.keys():
104
+ if len(person_videos[i]) >= frame_count//2:
105
+ num_persons+=1
106
+
107
+ if num_persons==0:
108
+ msg = "No person detected in the video! Please give a video with one person as input"
109
+ return None, None, None, msg
110
+ if num_persons>1:
111
+ msg = "More than one person detected in the video! Please give a video with only one person as input"
112
+ return None, None, None, msg
113
+
114
+ # Extract the audio from the input video file using ffmpeg
115
+ wav_file = os.path.join(result_folder, "audio.wav")
116
+
117
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
118
+ -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
119
+
120
+ if status != 0:
121
+ msg = "Oops! Could not load the audio file. Please check the input video and try again."
122
+ return None, None, None, msg
123
+
124
+ # For the person detected, crop the frame based on the bounding box
125
+ if len(person_videos[0]) > frame_count-10:
126
+ crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
127
+ fourcc = cv2.VideoWriter_fourcc(*'DIVX')
128
+
129
+ # Get bounding box coordinates based on person_tracks[i]
130
+ max_x1 = min([track[0] for track in person_tracks[0]])
131
+ max_y1 = min([track[1] for track in person_tracks[0]])
132
+ max_x2 = max([track[2] for track in person_tracks[0]])
133
+ max_y2 = max([track[3] for track in person_tracks[0]])
134
+
135
+ max_width = max_x2 - max_x1
136
+ max_height = max_y2 - max_y1
137
+
138
+ out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
139
+ for frame in person_videos[0]:
140
+ crop = frame[max_y1:max_y2, max_x1:max_x2]
141
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
142
+ out.write(crop)
143
+ out.release()
144
+
145
+ no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
146
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
147
+ if status != 0:
148
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
149
+ return None, None, None, msg
150
+
151
+ video_output = crop_filename.split('.')[0] + '.mp4'
152
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
153
+ (wav_file , no_sound_video, video_output), shell=True)
154
+ if status != 0:
155
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
156
+ return None, None, None, msg
157
+
158
+ os.remove(crop_filename)
159
+ os.remove(no_sound_video)
160
+
161
+ print("Successfully saved the pre-processed video: ", video_output)
162
+ else:
163
+ msg = "Could not track the person in the full video! Please give a single-speaker video as input"
164
+ return None, None, None, msg
165
+
166
+ return wav_file, fps, video_output, "success"
167
+
168
+ def resample_video(video_file, video_fname, result_folder):
169
+
170
+ '''
171
+ This function resamples the video to 25 fps
172
+
173
+ Args:
174
+ - video_file (string) : Path of the input video file
175
+ - video_fname (string) : Name of the input video file
176
+ - result_folder (string) : Path of the folder to save the resampled video
177
+ Returns:
178
+ - video_file_25fps (string) : Path of the resampled video file
179
+ '''
180
+ video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
181
+
182
+ # Resample the video to 25 fps
183
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}".format(video_file, video_file_25fps))
184
+ from subprocess import call
185
+ cmd = command.split(' ')
186
+ print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
187
+ call(cmd)
188
+
189
+ return video_file_25fps
190
+
191
+ def load_checkpoint(path, model):
192
+ '''
193
+ This function loads the trained model from the checkpoint
194
+
195
+ Args:
196
+ - path (string) : Path of the checkpoint file
197
+ - model (object) : Model object
198
+ Returns:
199
+ - model (object) : Model object with the weights loaded from the checkpoint
200
+ '''
201
+
202
+ # Load the checkpoint
203
+ if use_cuda:
204
+ checkpoint = torch.load(path)
205
+ else:
206
+ checkpoint = torch.load(path, map_location="cpu")
207
+
208
+ s = checkpoint["state_dict"]
209
+ new_s = {}
210
+
211
+ for k, v in s.items():
212
+ new_s[k.replace('module.', '')] = v
213
+ model.load_state_dict(new_s)
214
+ model.cuda()
215
+
216
+ print("Loaded checkpoint from: {}".format(path))
217
+
218
+ return model.eval()
219
+
220
+
221
+ def load_video_frames(video_file):
222
+ '''
223
+ This function extracts the frames from the video
224
+
225
+ Args:
226
+ - video_file (string) : Path of the video file
227
+ Returns:
228
+ - frames (list) : List of frames extracted from the video
229
+ - msg (string) : Message to be returned
230
+ '''
231
+
232
+ # Read the video
233
+ try:
234
+ vr = VideoReader(video_file, ctx=cpu(0))
235
+ except:
236
+ msg = "Oops! Could not load the input video file"
237
+ return None, msg
238
+
239
+
240
+ # Extract the frames
241
+ frames = []
242
+ for k in range(len(vr)):
243
+ frames.append(vr[k].asnumpy())
244
+
245
+ frames = np.asarray(frames)
246
+
247
+ return frames, "success"
248
+
249
+
250
+
251
+ def get_keypoints(frames):
252
+
253
+ '''
254
+ This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
255
+
256
+ Args:
257
+ - frames (list) : List of frames extracted from the video
258
+ Returns:
259
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
260
+ - msg (string) : Message to be returned
261
+ '''
262
+
263
+ try:
264
+ holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
265
+
266
+ resolution = frames[0].shape
267
+ all_frame_kps = []
268
+
269
+ for frame in frames:
270
+
271
+ results = holistic.process(frame)
272
+
273
+ pose, left_hand, right_hand, face = None, None, None, None
274
+ if results.pose_landmarks is not None:
275
+ pose = protobuf_to_dict(results.pose_landmarks)['landmark']
276
+ if results.left_hand_landmarks is not None:
277
+ left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
278
+ if results.right_hand_landmarks is not None:
279
+ right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
280
+ if results.face_landmarks is not None:
281
+ face = protobuf_to_dict(results.face_landmarks)['landmark']
282
+
283
+ frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
284
+
285
+ all_frame_kps.append(frame_dict)
286
+
287
+ kp_dict = {"kps":all_frame_kps, "resolution":resolution}
288
+ except Exception as e:
289
+ print("Error: ", e)
290
+ return None, "Error: Could not extract keypoints from the frames"
291
+
292
+ return kp_dict, "success"
293
+
294
+
295
+ def check_visible_gestures(kp_dict):
296
+
297
+ '''
298
+ This function checks if the gestures in the video are visible
299
+
300
+ Args:
301
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
302
+ Returns:
303
+ - msg (string) : Message to be returned
304
+ '''
305
+
306
+ keypoints = kp_dict['kps']
307
+ keypoints = np.array(keypoints)
308
+
309
+ if len(keypoints)<25:
310
+ msg = "Not enough keypoints to process! Please give a longer video as input"
311
+ return msg
312
+
313
+ pose_count, hand_count = 0, 0
314
+ for frame_kp_dict in keypoints:
315
+
316
+ pose = frame_kp_dict["pose"]
317
+ left_hand = frame_kp_dict["left_hand"]
318
+ right_hand = frame_kp_dict["right_hand"]
319
+
320
+ if pose is None:
321
+ pose_count += 1
322
+
323
+ if left_hand is None and right_hand is None:
324
+ hand_count += 1
325
+
326
+
327
+ if hand_count/len(keypoints) > 0.7 or pose_count/len(keypoints) > 0.7:
328
+ msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
329
+ return msg
330
+
331
+ print("Successfully verified the input video - Gestures are visible!")
332
+
333
+ return "success"
334
+
335
+ def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, width=480, height=270):
336
+
337
+ '''
338
+ This function masks the faces using the keypoints extracted from the frames
339
+
340
+ Args:
341
+ - input_frames (list) : List of frames extracted from the video
342
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
343
+ - stride (int) : Stride to extract the frames
344
+ - window_frames (int) : Number of frames in each window that is given as input to the model
345
+ - width (int) : Width of the frames
346
+ - height (int) : Height of the frames
347
+ Returns:
348
+ - input_frames (array) : Frame window to be given as input to the model
349
+ - num_frames (int) : Number of frames to extract
350
+ - orig_masked_frames (array) : Masked frames extracted from the video
351
+ - msg (string) : Message to be returned
352
+ '''
353
+
354
+ # Face indices to extract the face-coordinates needed for masking
355
+ face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
356
+ 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
357
+
358
+
359
+ input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
360
+
361
+ input_frames_masked = []
362
+ for i, frame_kp_dict in enumerate(input_keypoints):
363
+
364
+ img = input_frames[i]
365
+ face = frame_kp_dict["face"]
366
+
367
+ if face is None:
368
+ img = cv2.resize(img, (width, height))
369
+ masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
370
+ else:
371
+ face_kps = []
372
+ for idx in range(len(face)):
373
+ if idx in face_oval_idx:
374
+ x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
375
+ face_kps.append((x,y))
376
+
377
+ face_kps = np.array(face_kps)
378
+ x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
379
+ x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
380
+ masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
381
+
382
+ if masked_img.shape[0] != width or masked_img.shape[1] != height:
383
+ masked_img = cv2.resize(masked_img, (width, height))
384
+
385
+ input_frames_masked.append(masked_img)
386
+
387
+ orig_masked_frames = np.array(input_frames_masked)
388
+ input_frames = np.array(input_frames_masked) / 255.
389
+ # print("Input images full: ", input_frames.shape) # num_framesx270x480x3
390
+
391
+ input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
392
+ # print("Input images window: ", input_frames.shape) # Tx25x270x480x3
393
+
394
+ num_frames = input_frames.shape[0]
395
+
396
+ if num_frames<10:
397
+ msg = "Not enough frames to process! Please give a longer video as input."
398
+ return None, None, None, msg
399
+
400
+ return input_frames, num_frames, orig_masked_frames, "success"
401
+
402
+ def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
403
+
404
+ '''
405
+ This function extracts the spectrogram from the audio file
406
+
407
+ Args:
408
+ - wav_file (string) : Path of the extracted audio file
409
+ - num_frames (int) : Number of frames to extract
410
+ - window_frames (int) : Number of frames in each window that is given as input to the model
411
+ - stride (int) : Stride to extract the audio frames
412
+ Returns:
413
+ - spec (array) : Spectrogram array window to be used as input to the model
414
+ - orig_spec (array) : Spectrogram array extracted from the audio file
415
+ - msg (string) : Message to be returned
416
+ '''
417
+
418
+ # Extract the audio from the input video file using ffmpeg
419
+ try:
420
+ wav = librosa.load(wav_file, sr=16000)[0]
421
+ except:
422
+ msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
423
+ return None, None, msg
424
+
425
+ # Convert to tensor
426
+ wav = torch.FloatTensor(wav).unsqueeze(0)
427
+ mel, _, _, _ = wav2filterbanks(wav.to(device))
428
+ spec = mel.squeeze(0).cpu().numpy()
429
+ orig_spec = spec
430
+ spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
431
+
432
+ if len(spec) != num_frames:
433
+ spec = spec[:num_frames]
434
+ frame_diff = np.abs(len(spec) - num_frames)
435
+ if frame_diff > 60:
436
+ print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
437
+
438
+ return spec, orig_spec, "success"
439
+
440
+
441
+ def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
442
+ '''
443
+ This function calculates the audio-visual offset between the video and audio
444
+
445
+ Args:
446
+ - vid_emb (array) : Video embedding array
447
+ - aud_emb (array) : Audio embedding array
448
+ - num_avg_frames (int) : Number of frames to average the scores
449
+ - model (object) : Model object
450
+ Returns:
451
+ - offset (int) : Optimal audio-visual offset
452
+ - msg (string) : Message to be returned
453
+ '''
454
+
455
+ pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
456
+ if status != "success":
457
+ return None, status
458
+ scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
459
+ offset = scores.argmax()*stride - pos_idx
460
+
461
+ return offset.item(), "success"
462
+
463
+ def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
464
+
465
+ '''
466
+ This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
467
+
468
+ Args:
469
+ - vid_emb (array) : Video embedding array
470
+ - aud_emb (array) : Audio embedding array
471
+ - num_avg_frames (int) : Number of frames to average the scores
472
+ - stride (int) : Stride to extract the negative windows
473
+ Returns:
474
+ - vid_emb_pos (array) : Positive video embedding array
475
+ - aud_emb_posneg (array) : All possible combinations of audio embedding array
476
+ - pos_idx_frame (int) : Positive video embedding array frame
477
+ - stride (int) : Stride used to extract the negative windows
478
+ - msg (string) : Message to be returned
479
+ '''
480
+
481
+ slice_size = num_avg_frames
482
+ aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
483
+ aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
484
+ aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
485
+
486
+ pos_idx = (aud_emb_posneg.shape[1]//2)
487
+ pos_idx_frame = pos_idx*stride
488
+
489
+ min_offset_frames = -(pos_idx)*stride
490
+ max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
491
+ print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
492
+
493
+ vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
494
+ if vid_emb_pos.shape[2] != slice_size:
495
+ msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
496
+ return None, None, None, None, msg
497
+
498
+ return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
499
+
500
+ def calc_av_scores(vid_emb, aud_emb, model):
501
+
502
+ '''
503
+ This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
504
+
505
+ Args:
506
+ - vid_emb (array) : Video embedding array
507
+ - aud_emb (array) : Audio embedding array
508
+ - model (object) : Model object
509
+ Returns:
510
+ - scores (array) : Audio-visual similarity scores
511
+ - att_map (array) : Attention map
512
+ '''
513
+
514
+ scores = calc_att_map(vid_emb, aud_emb, model)
515
+ att_map = logsoftmax_2d(scores)
516
+ scores = scores.mean(-1)
517
+
518
+ return scores, att_map
519
+
520
+ def calc_att_map(vid_emb, aud_emb, model):
521
+
522
+ '''
523
+ This function calculates the similarity between the video and audio embeddings
524
+
525
+ Args:
526
+ - vid_emb (array) : Video embedding array
527
+ - aud_emb (array) : Audio embedding array
528
+ - model (object) : Model object
529
+ Returns:
530
+ - scores (array) : Audio-visual similarity scores
531
+ '''
532
+
533
+ vid_emb = vid_emb[:, :, None]
534
+ aud_emb = aud_emb.transpose(1, 2)
535
+
536
+ scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
537
+ vid_emb,
538
+ aud_emb,
539
+ part_len=10,
540
+ dim=3,
541
+ device=device)
542
+
543
+ scores = model.logits_scale(scores[..., None]).squeeze(-1)
544
+
545
+ return scores
546
+
547
+ def generate_video(frames, audio_file, video_fname):
548
+
549
+ '''
550
+ This function generates the video from the frames and audio file
551
+
552
+ Args:
553
+ - frames (array) : Frames to be used to generate the video
554
+ - audio_file (string) : Path of the audio file
555
+ - video_fname (string) : Path of the video file
556
+ Returns:
557
+ - video_output (string) : Path of the video file
558
+ '''
559
+
560
+ fname = 'inference.avi'
561
+ video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
562
+
563
+ for i in range(len(frames)):
564
+ video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
565
+ video.release()
566
+
567
+ no_sound_video = video_fname + '_nosound.mp4'
568
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
569
+ if status != 0:
570
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
571
+ return None, msg
572
+
573
+ video_output = video_fname + '.mp4'
574
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' %
575
+ (audio_file, no_sound_video, video_output), shell=True)
576
+ if status != 0:
577
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
578
+ return None, msg
579
+
580
+ os.remove(fname)
581
+ os.remove(no_sound_video)
582
+
583
+ return video_output
584
+
585
+ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
586
+
587
+ '''
588
+ This function corrects the video and audio to sync with each other
589
+
590
+ Args:
591
+ - video_path (string) : Path of the video file
592
+ - frames (array) : Frames to be used to generate the video
593
+ - wav_file (string) : Path of the audio file
594
+ - offset (int) : Predicted sync-offset to be used to correct the video
595
+ - result_folder (string) : Path of the result folder to save the output sync-corrected video
596
+ - sample_rate (int) : Sample rate of the audio
597
+ - fps (int) : Frames per second of the video
598
+ Returns:
599
+ - video_output (string) : Path of the video file
600
+ '''
601
+
602
+ if offset == 0:
603
+ print("The input audio and video are in-sync! No need to perform sync correction.")
604
+ return video_path
605
+
606
+ print("Performing Sync Correction...")
607
+ corrected_frames = np.zeros_like(frames)
608
+ if offset > 0:
609
+ audio_offset = int(offset*(sample_rate/fps))
610
+ wav = librosa.core.load(wav_file, sr=sample_rate)[0]
611
+ corrected_wav = wav[audio_offset:]
612
+ corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
613
+ write(corrected_wav_file, sample_rate, corrected_wav)
614
+ wav_file = corrected_wav_file
615
+ corrected_frames = frames
616
+ elif offset < 0:
617
+ corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
618
+ corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
619
+
620
+ corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
621
+ video_output = generate_video(corrected_frames, wav_file, corrected_video_path)
622
+
623
+ return video_output
624
+
625
+ def process_video(video_path, num_avg_frames):
626
+ try:
627
+ # Extract the video filename
628
+ video_fname = os.path.basename(video_path.split(".")[0])
629
+
630
+ # Create folders to save the inputs and results
631
+ result_folder = os.path.join("results", video_fname)
632
+ result_folder_input = os.path.join(result_folder, "input")
633
+ result_folder_output = os.path.join(result_folder, "output")
634
+
635
+ if os.path.exists(result_folder):
636
+ rmtree(result_folder)
637
+
638
+ os.makedirs(result_folder)
639
+ os.makedirs(result_folder_input)
640
+ os.makedirs(result_folder_output)
641
+
642
+
643
+ # Preprocess the video
644
+ wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input)
645
+ if status != "success":
646
+ return status, None
647
+
648
+ # Resample the video to 25 fps if it is not already 25 fps
649
+ print("FPS of video: ", fps)
650
+ if fps!=25:
651
+ vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
652
+ orig_vid_path_25fps = resample_video(video_path, "input_video_25fps", result_folder_input)
653
+ else:
654
+ vid_path = vid_path_processed
655
+ orig_vid_path_25fps = video_path
656
+
657
+ # Load the original video frames (before pre-processing) - Needed for the final sync-correction
658
+ orig_frames, status = load_video_frames(orig_vid_path_25fps)
659
+ if status != "success":
660
+ return status, None
661
+
662
+ # Load the pre-processed video frames
663
+ frames, status = load_video_frames(vid_path)
664
+ if status != "success":
665
+ return status, None
666
+
667
+
668
+ if len(frames) < num_avg_frames:
669
+ return "Error: The input video is too short. Please use a longer input video.", None
670
+
671
+ # Load keypoints and check if gestures are visible
672
+ kp_dict, status = get_keypoints(frames)
673
+ if status != "success":
674
+ return status, None
675
+
676
+ status = check_visible_gestures(kp_dict)
677
+ if status != "success":
678
+ return status, None
679
+
680
+ # Load RGB frames
681
+ rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, window_frames=25, width=480, height=270)
682
+ if status != "success":
683
+ return status, None
684
+
685
+ # Convert frames to tensor
686
+ rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
687
+ rgb_frames = torch.FloatTensor(np.array(rgb_frames)).unsqueeze(0)
688
+ B = rgb_frames.size(0)
689
+
690
+ # Load spectrograms
691
+ spec, orig_spec, status = load_spectrograms(wav_file, num_frames, window_frames=25)
692
+ if status != "success":
693
+ return status, None
694
+ spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
695
+
696
+ # Create input windows
697
+ video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
698
+ audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
699
+
700
+ # Load the trained model
701
+ model = Transformer_RGB()
702
+ model = load_checkpoint(CHECKPOINT_PATH, model)
703
+
704
+ # Process in batches
705
+ batch_size = 12
706
+ video_emb = []
707
+ audio_emb = []
708
+
709
+ for i in tqdm(range(0, len(video_sequences), batch_size)):
710
+ video_inp = video_sequences[i:i+batch_size, ]
711
+ audio_inp = audio_sequences[i:i+batch_size, ]
712
+
713
+ vid_emb = model.forward_vid(video_inp.to(device))
714
+ vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
715
+ aud_emb = model.forward_aud(audio_inp.to(device))
716
+
717
+ video_emb.append(vid_emb.detach())
718
+ audio_emb.append(aud_emb.detach())
719
+
720
+ torch.cuda.empty_cache()
721
+
722
+ audio_emb = torch.cat(audio_emb, dim=0)
723
+ video_emb = torch.cat(video_emb, dim=0)
724
+
725
+ # L2 normalize embeddings
726
+ video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
727
+ audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
728
+
729
+ audio_emb = torch.split(audio_emb, B, dim=0)
730
+ audio_emb = torch.stack(audio_emb, dim=2)
731
+ audio_emb = audio_emb.squeeze(3)
732
+ audio_emb = audio_emb[:, None]
733
+
734
+ video_emb = torch.split(video_emb, B, dim=0)
735
+ video_emb = torch.stack(video_emb, dim=2)
736
+ video_emb = video_emb.squeeze(3)
737
+
738
+ # Calculate sync offset
739
+ pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
740
+ if status != "success":
741
+ return status, None
742
+ print("Predicted offset: ", pred_offset)
743
+
744
+ # Generate sync-corrected video
745
+ video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
746
+ print("Successfully generated the video:", video_output)
747
+
748
+ return f"Predicted offset: {pred_offset}", video_output
749
+
750
+ except Exception as e:
751
+ return f"Error: {str(e)}", None
752
+
753
+
754
+
755
+ if __name__ == "__main__":
756
+
757
+ # Define the custom HTML for the header
758
+ custom_css = """
759
+ <style>
760
+ body {
761
+ background-color: #ffffff;
762
+ color: #333333; /* Default text color */
763
+ }
764
+ .container {
765
+ max-width: 100% !important;
766
+ padding-left: 0 !important;
767
+ padding-right: 0 !important;
768
+ }
769
+ .header {
770
+ background-color: #f0f0f0;
771
+ color: #333333;
772
+ padding: 30px;
773
+ margin-bottom: 30px;
774
+ text-align: center;
775
+ font-family: 'Helvetica Neue', Arial, sans-serif;
776
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
777
+ }
778
+ .header h1 {
779
+ font-size: 36px;
780
+ margin-bottom: 15px;
781
+ font-weight: bold;
782
+ color: #333333; /* Explicitly set heading color */
783
+ }
784
+ .header h2 {
785
+ font-size: 24px;
786
+ margin-bottom: 10px;
787
+ color: #333333; /* Explicitly set subheading color */
788
+ }
789
+ .header p {
790
+ font-size: 18px;
791
+ margin: 5px 0;
792
+ color: #666666;
793
+ }
794
+ .blue-text {
795
+ color: #4a90e2;
796
+ }
797
+ /* Custom styles for slider container */
798
+ .slider-container {
799
+ background-color: white !important;
800
+ padding-top: 0.9em;
801
+ padding-bottom: 0.9em;
802
+ }
803
+ /* Add gap before examples */
804
+ .examples-holder {
805
+ margin-top: 2em;
806
+ }
807
+ /* Set fixed size for example videos */
808
+ .gradio-container .gradio-examples .gr-sample {
809
+ width: 240px !important;
810
+ height: 135px !important;
811
+ object-fit: cover;
812
+ display: inline-block;
813
+ margin-right: 10px;
814
+ }
815
+
816
+ .gradio-container .gradio-examples {
817
+ display: flex;
818
+ flex-wrap: wrap;
819
+ gap: 10px;
820
+ }
821
+
822
+ /* Ensure the parent container does not stretch */
823
+ .gradio-container .gradio-examples {
824
+ max-width: 100%;
825
+ overflow: hidden;
826
+ }
827
+
828
+ /* Additional styles to ensure proper sizing in Safari */
829
+ .gradio-container .gradio-examples .gr-sample img {
830
+ width: 240px !important;
831
+ height: 135px !important;
832
+ object-fit: cover;
833
+ }
834
+ </style>
835
+ """
836
+
837
+ custom_html = custom_css + """
838
+ <div class="header">
839
+ <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
840
+ <h2>Upload any video to predict the synchronization offset and generate a sync-corrected video</h2>
841
+ <p>Sindhu Hegde and Andrew Zisserman</p>
842
+ <p>VGG, University of Oxford</p>
843
+ </div>
844
+ """
845
+
846
+ # Define paths to sample videos
847
+ sample_videos = [
848
+ "samples/sync_sample_1.mp4",
849
+ "samples/sync_sample_2.mp4",
850
+ ]
851
+
852
+ # Define Gradio interface
853
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
854
+ gr.HTML(custom_html)
855
+ with gr.Row():
856
+ with gr.Column():
857
+ with gr.Group(elem_classes="slider-container"):
858
+ num_avg_frames = gr.Slider(
859
+ minimum=50,
860
+ maximum=150,
861
+ step=5,
862
+ value=75,
863
+ label="Number of Average Frames",
864
+ )
865
+ video_input = gr.Video(label="Upload Video", height=400)
866
+
867
+ with gr.Column():
868
+ result_text = gr.Textbox(label="Result")
869
+ output_video = gr.Video(label="Sync Corrected Video", height=400)
870
+
871
+ with gr.Row():
872
+ submit_button = gr.Button("Submit", variant="primary")
873
+ clear_button = gr.Button("Clear")
874
+
875
+ submit_button.click(
876
+ fn=process_video,
877
+ inputs=[video_input, num_avg_frames],
878
+ outputs=[result_text, output_video]
879
+ )
880
+
881
+ clear_button.click(
882
+ fn=lambda: (None, 75, "", None),
883
+ inputs=[],
884
+ outputs=[video_input, num_avg_frames, result_text, output_video]
885
+ )
886
+
887
+ gr.HTML('<div class="examples-holder"></div>')
888
+
889
+ # Add examples
890
+ gr.Examples(
891
+ examples=sample_videos,
892
+ inputs=video_input,
893
+ outputs=None,
894
+ fn=None,
895
+ cache_examples=False,
896
+ )
897
+
898
+ # Launch the interface
899
+ demo.launch(allowed_paths=["."], server_name="0.0.0.0", server_port=7860, share=True)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord==0.5.2
2
+ ffmpeg==1.4
3
+ librosa==0.9.2
4
+ mediapipe==0.9.1.0
5
+ numpy==1.26.4
6
+ opencv-python==4.9.0.80
7
+ opencv-python-headless==4.10.0.84
8
+ protobuf==3.20.3
9
+ protobuf-to-dict==0.1.0
10
+ protobuf3-to-dict==0.1.5
11
+ python_speech_features==0.6
12
+ scenedetect==0.6.4
13
+ scikit-learn==1.5.1
14
+ torch==1.10.0
15
+ torchvision==0.11.1
16
+ tqdm==4.66.4
17
+ ultralytics==8.2.70
18
+ ultralytics-thop==2.0.0
19
+ urllib3==1.26.19
samples/sync_sample_1.mp4 ADDED
Binary file (401 kB). View file
 
samples/sync_sample_2.mp4 ADDED
Binary file (256 kB). View file
 
sync_models/gestsync_models.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from sync_models.modules import *
5
+
6
+
7
+
8
+ class Transformer_RGB(nn.Module):
9
+
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ self.net_vid = self.build_net_vid()
14
+ self.ff_vid = nn.Sequential(
15
+ nn.Linear(512, 512),
16
+ nn.ReLU(),
17
+ nn.Linear(512, 1024)
18
+ )
19
+
20
+ self.pos_encoder = PositionalEncoding_RGB(d_model=512)
21
+ encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
22
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
23
+
24
+ self.net_aud = self.build_net_aud()
25
+ self.lstm = nn.LSTM(512, 256, num_layers=1, bidirectional=True, batch_first=True)
26
+
27
+ self.ff_aud = NetFC_2D(input_dim=512, hidden_dim=512, embed_dim=1024)
28
+
29
+
30
+ self.logits_scale = nn.Linear(1, 1, bias=False)
31
+ torch.nn.init.ones_(self.logits_scale.weight)
32
+
33
+ self.fc = nn.Linear(1,1)
34
+
35
+ def build_net_vid(self):
36
+ layers = [
37
+ {
38
+ 'type': 'conv3d',
39
+ 'n_channels': 64,
40
+ 'kernel_size': (5, 7, 7),
41
+ 'stride': (1, 3, 3),
42
+ 'padding': (0),
43
+ 'maxpool': {
44
+ 'kernel_size': (1, 3, 3),
45
+ 'stride': (1, 2, 2)
46
+ }
47
+ },
48
+ {
49
+ 'type': 'conv3d',
50
+ 'n_channels': 128,
51
+ 'kernel_size': (1, 5, 5),
52
+ 'stride': (1, 2, 2),
53
+ 'padding': (0, 0, 0),
54
+ },
55
+ {
56
+ 'type': 'conv3d',
57
+ 'n_channels': 256,
58
+ 'kernel_size': (1, 3, 3),
59
+ 'stride': (1, 2, 2),
60
+ 'padding': (0, 1, 1),
61
+ },
62
+ {
63
+ 'type': 'conv3d',
64
+ 'n_channels': 256,
65
+ 'kernel_size': (1, 3, 3),
66
+ 'stride': (1, 1, 2),
67
+ 'padding': (0, 1, 1),
68
+ },
69
+ {
70
+ 'type': 'conv3d',
71
+ 'n_channels': 256,
72
+ 'kernel_size': (1, 3, 3),
73
+ 'stride': (1, 1, 1),
74
+ 'padding': (0, 1, 1),
75
+ 'maxpool': {
76
+ 'kernel_size': (1, 3, 3),
77
+ 'stride': (1, 2, 2)
78
+ }
79
+ },
80
+ {
81
+ 'type': 'fc3d',
82
+ 'n_channels': 512,
83
+ 'kernel_size': (1, 4, 4),
84
+ 'stride': (1, 1, 1),
85
+ 'padding': (0),
86
+ },
87
+ ]
88
+ return VGGNet(n_channels_in=3, layers=layers)
89
+
90
+ def build_net_aud(self):
91
+ layers = [
92
+ {
93
+ 'type': 'conv2d',
94
+ 'n_channels': 64,
95
+ 'kernel_size': (3, 3),
96
+ 'stride': (2, 2),
97
+ 'padding': (1, 1),
98
+ 'maxpool': {
99
+ 'kernel_size': (3, 3),
100
+ 'stride': (2, 2)
101
+ }
102
+ },
103
+ {
104
+ 'type': 'conv2d',
105
+ 'n_channels': 192,
106
+ 'kernel_size': (3, 3),
107
+ 'stride': (1, 2),
108
+ 'padding': (1, 1),
109
+ 'maxpool': {
110
+ 'kernel_size': (3, 3),
111
+ 'stride': (2, 2)
112
+ }
113
+ },
114
+ {
115
+ 'type': 'conv2d',
116
+ 'n_channels': 384,
117
+ 'kernel_size': (3, 3),
118
+ 'stride': (1, 1),
119
+ 'padding': (1, 1),
120
+ },
121
+ {
122
+ 'type': 'conv2d',
123
+ 'n_channels': 256,
124
+ 'kernel_size': (3, 3),
125
+ 'stride': (1, 1),
126
+ 'padding': (1, 1),
127
+ },
128
+ {
129
+ 'type': 'conv2d',
130
+ 'n_channels': 256,
131
+ 'kernel_size': (3, 3),
132
+ 'stride': (1, 1),
133
+ 'padding': (1, 1),
134
+ 'maxpool': {
135
+ 'kernel_size': (2, 3),
136
+ 'stride': (2, 2)
137
+ }
138
+ },
139
+ {
140
+ 'type': 'fc2d',
141
+ 'n_channels': 512,
142
+ 'kernel_size': (4, 2),
143
+ 'stride': (1, 1),
144
+ 'padding': (0, 0),
145
+ },
146
+ ]
147
+ return VGGNet(n_channels_in=1, layers=layers)
148
+
149
+ def forward_vid(self, x, return_feats=False):
150
+ out_conv = self.net_vid(x).squeeze(-1).squeeze(-1)
151
+ # print("Conv: ", out_conv.shape) # Bx1024x21x1x1
152
+
153
+ out = self.pos_encoder(out_conv.transpose(1,2))
154
+ out_trans = self.transformer_encoder(out)
155
+ # print("Transformer: ", out_trans.shape) # Bx21x1024
156
+
157
+ out = self.ff_vid(out_trans).transpose(1,2)
158
+ # print("MLP output: ", out.shape) # Bx1024
159
+
160
+ if return_feats:
161
+ return out, out_conv
162
+ else:
163
+ return out
164
+
165
+ def forward_aud(self, x):
166
+ out = self.net_aud(x)
167
+ out = self.ff_aud(out)
168
+ out = out.squeeze(-1)
169
+ return out
sync_models/modules.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import math
5
+
6
+ class PositionalEncoding_RGB(nn.Module):
7
+ "Implement the PE function."
8
+ def __init__(self, d_model, dropout=0.1, max_len=50):
9
+ super(PositionalEncoding_RGB, self).__init__()
10
+ self.dropout = nn.Dropout(p=dropout)
11
+
12
+ # Compute the positional encodings once in log space.
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ x = x + Variable(self.pe[:, :x.size(1)],
23
+ requires_grad=False)
24
+ return self.dropout(x)
25
+
26
+ def calc_receptive_field(layers, imsize, layer_names=None):
27
+ if layer_names is not None:
28
+ print("-------Net summary------")
29
+ currentLayer = [imsize, 1, 1, 0.5]
30
+
31
+ for l_id, layer in enumerate(layers):
32
+ conv = [
33
+ layer[key][-1] if type(layer[key]) in [list, tuple] else layer[key]
34
+ for key in ['kernel_size', 'stride', 'padding']
35
+ ]
36
+ currentLayer = outFromIn(conv, currentLayer)
37
+ if 'maxpool' in layer:
38
+ conv = [
39
+ (layer['maxpool'][key][-1] if type(layer['maxpool'][key])
40
+ in [list, tuple] else layer['maxpool'][key]) if
41
+ (not key == 'padding' or 'padding' in layer['maxpool']) else 0
42
+ for key in ['kernel_size', 'stride', 'padding']
43
+ ]
44
+ currentLayer = outFromIn(conv, currentLayer, ceil_mode=False)
45
+ return currentLayer
46
+
47
+ def outFromIn(conv, layerIn, ceil_mode=True):
48
+ n_in = layerIn[0]
49
+ j_in = layerIn[1]
50
+ r_in = layerIn[2]
51
+ start_in = layerIn[3]
52
+ k = conv[0]
53
+ s = conv[1]
54
+ p = conv[2]
55
+
56
+ n_out = math.floor((n_in - k + 2 * p) / s) + 1
57
+ actualP = (n_out - 1) * s - n_in + k
58
+ pR = math.ceil(actualP / 2)
59
+ pL = math.floor(actualP / 2)
60
+
61
+ j_out = j_in * s
62
+ r_out = r_in + (k - 1) * j_in
63
+ start_out = start_in + ((k - 1) / 2 - pL) * j_in
64
+ return n_out, j_out, r_out, start_out
65
+
66
+ class DebugModule(nn.Module):
67
+ """
68
+ Wrapper class for printing the activation dimensions
69
+ """
70
+
71
+ def __init__(self, name=None):
72
+ super().__init__()
73
+ self.name = name
74
+ self.debug_log = True
75
+
76
+ def debug_line(self, layer_str, output, memuse=1, final_call=False):
77
+ if self.debug_log:
78
+ namestr = '{}: '.format(self.name) if self.name is not None else ''
79
+ # print('{}{:80s}: dims {}'.format(namestr, repr(layer_str),
80
+ # output.shape))
81
+
82
+ if final_call:
83
+ self.debug_log = False
84
+ # print()
85
+
86
+ class VGGNet(DebugModule):
87
+
88
+ conv_dict = {
89
+ 'conv1d': nn.Conv1d,
90
+ 'conv2d': nn.Conv2d,
91
+ 'conv3d': nn.Conv3d,
92
+ 'fc1d': nn.Conv1d,
93
+ 'fc2d': nn.Conv2d,
94
+ 'fc3d': nn.Conv3d,
95
+ }
96
+
97
+ pool_dict = {
98
+ 'conv1d': nn.MaxPool1d,
99
+ 'conv2d': nn.MaxPool2d,
100
+ 'conv3d': nn.MaxPool3d,
101
+ }
102
+
103
+ norm_dict = {
104
+ 'conv1d': nn.BatchNorm1d,
105
+ 'conv2d': nn.BatchNorm2d,
106
+ 'conv3d': nn.BatchNorm3d,
107
+ 'fc1d': nn.BatchNorm1d,
108
+ 'fc2d': nn.BatchNorm2d,
109
+ 'fc3d': nn.BatchNorm3d,
110
+ }
111
+
112
+ def __init__(self, n_channels_in, layers):
113
+ super(VGGNet, self).__init__()
114
+
115
+ self.layers = layers
116
+
117
+ n_channels_prev = n_channels_in
118
+ for l_id, lr in enumerate(self.layers):
119
+ l_id += 1
120
+ name = 'fc' if 'fc' in lr['type'] else 'conv'
121
+ conv_type = self.conv_dict[lr['type']]
122
+ norm_type = self.norm_dict[lr['type']]
123
+ self.__setattr__(
124
+ '{:s}{:d}'.format(name, l_id),
125
+ conv_type(n_channels_prev,
126
+ lr['n_channels'],
127
+ kernel_size=lr['kernel_size'],
128
+ stride=lr['stride'],
129
+ padding=lr['padding']))
130
+ n_channels_prev = lr['n_channels']
131
+ self.__setattr__('bn{:d}'.format(l_id), norm_type(lr['n_channels']))
132
+ if 'maxpool' in lr:
133
+ pool_type = self.pool_dict[lr['type']]
134
+ padding = lr['maxpool']['padding'] if 'padding' in lr[
135
+ 'maxpool'] else 0
136
+ self.__setattr__(
137
+ 'mp{:d}'.format(l_id),
138
+ pool_type(kernel_size=lr['maxpool']['kernel_size'],
139
+ stride=lr['maxpool']['stride'],
140
+ padding=padding),
141
+ )
142
+
143
+ def forward(self, inp):
144
+ self.debug_line('Input', inp)
145
+ out = inp
146
+ for l_id, lr in enumerate(self.layers):
147
+ l_id += 1
148
+ name = 'fc' if 'fc' in lr['type'] else 'conv'
149
+ out = self.__getattr__('{:s}{:d}'.format(name, l_id))(out)
150
+ out = self.__getattr__('bn{:d}'.format(l_id))(out)
151
+ out = nn.ReLU(inplace=True)(out)
152
+ self.debug_line(self.__getattr__('{:s}{:d}'.format(name, l_id)),
153
+ out)
154
+ if 'maxpool' in lr:
155
+ out = self.__getattr__('mp{:d}'.format(l_id))(out)
156
+ self.debug_line(self.__getattr__('mp{:d}'.format(l_id)), out)
157
+
158
+ self.debug_line('Output', out, final_call=True)
159
+
160
+ return out
161
+
162
+
163
+
164
+ class NetFC(DebugModule):
165
+
166
+ def __init__(self, input_dim, hidden_dim, embed_dim):
167
+ super(NetFC, self).__init__()
168
+ self.fc7 = nn.Conv3d(input_dim, hidden_dim, kernel_size=(1, 1, 1))
169
+ self.bn7 = nn.BatchNorm3d(hidden_dim)
170
+ self.fc8 = nn.Conv3d(hidden_dim, embed_dim, kernel_size=(1, 1, 1))
171
+
172
+ def forward(self, inp):
173
+ out = self.fc7(inp)
174
+ self.debug_line(self.fc7, out)
175
+ out = self.bn7(out)
176
+ out = nn.ReLU(inplace=True)(out)
177
+ out = self.fc8(out)
178
+ self.debug_line(self.fc8, out, final_call=True)
179
+ return out
180
+
181
+ class NetFC_2D(DebugModule):
182
+
183
+ def __init__(self, input_dim, hidden_dim, embed_dim):
184
+ super(NetFC_2D, self).__init__()
185
+ self.fc7 = nn.Conv2d(input_dim, hidden_dim, kernel_size=(1, 1))
186
+ self.bn7 = nn.BatchNorm2d(hidden_dim)
187
+ self.fc8 = nn.Conv2d(hidden_dim, embed_dim, kernel_size=(1, 1))
188
+
189
+ def forward(self, inp):
190
+ out = self.fc7(inp)
191
+ self.debug_line(self.fc7, out)
192
+ out = self.bn7(out)
193
+ out = nn.ReLU(inplace=True)(out)
194
+ out = self.fc8(out)
195
+ self.debug_line(self.fc8, out, final_call=True)
196
+ return out
utils/audio_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import numpy as np
4
+ from scipy.io import wavfile
5
+
6
+ import warnings
7
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
8
+ warnings.filterwarnings("ignore", category=FutureWarning)
9
+
10
+
11
+ audio_opts = {
12
+ 'sample_rate': 16000,
13
+ 'n_fft': 512,
14
+ 'win_length': 320,
15
+ 'hop_length': 160,
16
+ 'n_mel': 80,
17
+ }
18
+
19
+
20
+ def load_wav(path, fr=0, to=10000, sample_rate=16000):
21
+ """Loads Audio wav from path at time indices given by fr, to (seconds)"""
22
+
23
+ _, wav = wavfile.read(path)
24
+ fr_aud = int(np.round(fr * sample_rate))
25
+ to_aud = int(np.round((to) * sample_rate))
26
+
27
+ wav = wav[fr_aud:to_aud]
28
+
29
+ return wav
30
+
31
+
32
+ def wav2filterbanks(wav, mel_basis=None):
33
+ """
34
+ :param wav: Tensor b x T
35
+ """
36
+
37
+ assert len(wav.shape) == 2, 'Need batch of wavs as input'
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ # device = 'cpu'
40
+ spect = torch.stft(wav,
41
+ n_fft=audio_opts['n_fft'],
42
+ hop_length=audio_opts['hop_length'],
43
+ win_length=audio_opts['win_length'],
44
+ window=torch.hann_window(audio_opts['win_length']).to(device),
45
+ center=True,
46
+ pad_mode='reflect',
47
+ normalized=False,
48
+ onesided=True) # b x F x T x 2
49
+ spect = spect[:, :, :-1, :]
50
+
51
+ # ----- Log filterbanks --------------
52
+ # mag spectrogram - # b x F x T
53
+ mag = power_spect = torch.norm(spect, dim=-1)
54
+ phase = torch.atan2(spect[..., 1], spect[..., 0])
55
+ if mel_basis is None:
56
+ # Build a Mel filter
57
+ mel_basis = torch.from_numpy(
58
+ librosa.filters.mel(audio_opts['sample_rate'],
59
+ audio_opts['n_fft'],
60
+ n_mels=audio_opts['n_mel'],
61
+ fmin=0,
62
+ fmax=int(audio_opts['sample_rate'] / 2)))
63
+ mel_basis = mel_basis.float().to(power_spect.device)
64
+ features = torch.log(torch.matmul(mel_basis, power_spect) +
65
+ 1e-20) # b x F x T
66
+ features = features.permute([0, 2, 1]).contiguous() # b x T x F
67
+ # -------------------
68
+
69
+ # norm_axis = 1 # normalize every sample over time
70
+ # mean = features.mean(dim=norm_axis, keepdim=True) # b x 1 x F
71
+ # std_dev = features.std(dim=norm_axis, keepdim=True) # b x 1 x F
72
+ # features = (features - mean) / std_dev # b x T x F
73
+
74
+ return features, mag, phase, mel_basis
75
+
76
+
77
+ def torch_mag_phase_2_np_complex(mag_spect, phase):
78
+ complex_spect_2d = torch.stack(
79
+ [mag_spect * torch.cos(phase), mag_spect * torch.sin(phase)], -1)
80
+ complex_spect_np = complex_spect_2d.cpu().detach().numpy()
81
+ complex_spect_np = complex_spect_np[..., 0] + 1j * complex_spect_np[..., 1]
82
+ return complex_spect_np
83
+
84
+
85
+ def torch_mag_phase_2_complex_as_2d(mag_spect, phase):
86
+ complex_spect_2d = torch.stack(
87
+ [mag_spect * torch.cos(phase), mag_spect * torch.sin(phase)], -1)
88
+ return complex_spect_2d
89
+
90
+
91
+ def torch_phase_from_normalized_complex(spect):
92
+ phase = torch.atan2(spect[..., 1], spect[..., 0])
93
+ return phase
94
+
95
+
96
+ def reconstruct_wav_from_mag_phase(mag, phase):
97
+ spect = torch_mag_phase_2_np_complex(mag, phase)
98
+ wav = np.stack([
99
+ librosa.core.istft(spect[ii],
100
+ hop_length=audio_opts['hop_length'],
101
+ win_length=audio_opts['win_length'],
102
+ center=True) for ii in range(spect.shape[0])
103
+ ])
104
+
105
+ return wav
utils/inference_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def run_func_in_parts(func, vid_emb, aud_emb, part_len, dim, device):
5
+ """
6
+ Run given function in parts, spliting the inputs on dimension dim
7
+ This is used to save memory when inputs too large to compute on gpu
8
+ """
9
+ dist_chunk = []
10
+ for v_spl, a_spl in list(
11
+ zip(vid_emb.split(part_len, dim=dim),
12
+ aud_emb.split(part_len, dim=dim))):
13
+ dist_chunk.append(func(v_spl.to(device), a_spl.to(device)))
14
+ dist = torch.cat(dist_chunk, dim - 1)
15
+ return dist
16
+
17
+ def logsoftmax_2d(logits):
18
+ # Log softmax on last 2 dims because torch won't allow multiple dims
19
+ orig_shape = logits.shape
20
+ logprobs = torch.nn.LogSoftmax(dim=-1)(
21
+ logits.reshape(list(logits.shape[:-2]) + [-1])).reshape(orig_shape)
22
+ return logprobs