sindhuhegde commited on
Commit
b63403f
1 Parent(s): 2d17a01

Update app

Browse files
Files changed (1) hide show
  1. app_v1.py +0 -954
app_v1.py DELETED
@@ -1,954 +0,0 @@
1
- import gradio as gr
2
- import argparse
3
- import os, subprocess
4
- from shutil import rmtree
5
-
6
- import numpy as np
7
- import cv2
8
- import librosa
9
- import torch
10
-
11
- from utils.audio_utils import *
12
- from utils.inference_utils import *
13
- from sync_models.gestsync_models import *
14
-
15
- import sys
16
- if sys.version_info > (3, 0): long, unicode, basestring = int, str, str
17
-
18
- from tqdm import tqdm
19
- from scipy.io.wavfile import write
20
- import mediapipe as mp
21
- from protobuf_to_dict import protobuf_to_dict
22
- mp_holistic = mp.solutions.holistic
23
- from ultralytics import YOLO
24
- from decord import VideoReader, cpu
25
-
26
- import warnings
27
- warnings.filterwarnings("ignore", category=DeprecationWarning)
28
- warnings.filterwarnings("ignore", category=UserWarning)
29
-
30
- # Set the path to checkpoint file
31
- CHECKPOINT_PATH = "model_rgb.pth"
32
-
33
- # Initialize global variables
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- use_cuda = torch.cuda.is_available()
36
- n_negative_samples = 100
37
- print("Using CUDA: ", use_cuda, device)
38
-
39
- def preprocess_video(path, result_folder, apply_preprocess, padding=20):
40
-
41
- '''
42
- This function preprocesses the input video to extract the audio and crop the frames using YOLO model
43
-
44
- Args:
45
- - path (string) : Path of the input video file
46
- - result_folder (string) : Path of the folder to save the extracted audio and cropped video
47
- - padding (int) : Padding to add to the bounding box
48
- Returns:
49
- - wav_file (string) : Path of the extracted audio file
50
- - fps (int) : FPS of the input video
51
- - video_output (string) : Path of the cropped video file
52
- - msg (string) : Message to be returned
53
- '''
54
-
55
- # Load all video frames
56
- try:
57
- vr = VideoReader(path, ctx=cpu(0))
58
- fps = vr.get_avg_fps()
59
- frame_count = len(vr)
60
- except:
61
- msg = "Oops! Could not load the video. Please check the input video and try again."
62
- return None, None, None, msg
63
-
64
- if frame_count < 25:
65
- msg = "Not enough frames to process! Please give a longer video as input"
66
- return None, None, None, msg
67
-
68
- # Extract the audio from the input video file using ffmpeg
69
- wav_file = os.path.join(result_folder, "audio.wav")
70
-
71
- status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
72
- -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
73
-
74
- if status != 0:
75
- msg = "Oops! Could not load the audio file. Please check the input video and try again."
76
- return None, None, None, msg
77
- print("Extracted the audio from the video")
78
-
79
- if apply_preprocess=="True":
80
- all_frames = []
81
- for k in range(len(vr)):
82
- all_frames.append(vr[k].asnumpy())
83
- all_frames = np.asarray(all_frames)
84
- print("Extracted the frames for pre-processing")
85
-
86
- # Load YOLOv9 model (pre-trained on COCO dataset)
87
- yolo_model = YOLO("yolov9s.pt")
88
- print("Loaded the YOLO model")
89
-
90
-
91
-
92
- person_videos = {}
93
- person_tracks = {}
94
-
95
- print("Processing the frames...")
96
- for frame_idx in tqdm(range(frame_count)):
97
-
98
- frame = all_frames[frame_idx]
99
-
100
- # Perform person detection
101
- results = yolo_model(frame, verbose=False)
102
- detections = results[0].boxes
103
-
104
- for i, det in enumerate(detections):
105
- x1, y1, x2, y2 = det.xyxy[0]
106
- cls = det.cls[0]
107
- if int(cls) == 0: # Class 0 is 'person' in COCO dataset
108
-
109
- x1 = max(0, int(x1) - padding)
110
- y1 = max(0, int(y1) - padding)
111
- x2 = min(frame.shape[1], int(x2) + padding)
112
- y2 = min(frame.shape[0], int(y2) + padding)
113
-
114
- if i not in person_videos:
115
- person_videos[i] = []
116
- person_tracks[i] = []
117
-
118
- person_videos[i].append(frame)
119
- person_tracks[i].append([x1,y1,x2,y2])
120
-
121
-
122
- num_persons = 0
123
- for i in person_videos.keys():
124
- if len(person_videos[i]) >= frame_count//2:
125
- num_persons+=1
126
-
127
- if num_persons==0:
128
- msg = "No person detected in the video! Please give a video with one person as input"
129
- return None, None, None, msg
130
- if num_persons>1:
131
- msg = "More than one person detected in the video! Please give a video with only one person as input"
132
- return None, None, None, msg
133
-
134
-
135
-
136
- # For the person detected, crop the frame based on the bounding box
137
- if len(person_videos[0]) > frame_count-10:
138
- crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
139
- fourcc = cv2.VideoWriter_fourcc(*'DIVX')
140
-
141
- # Get bounding box coordinates based on person_tracks[i]
142
- max_x1 = min([track[0] for track in person_tracks[0]])
143
- max_y1 = min([track[1] for track in person_tracks[0]])
144
- max_x2 = max([track[2] for track in person_tracks[0]])
145
- max_y2 = max([track[3] for track in person_tracks[0]])
146
-
147
- max_width = max_x2 - max_x1
148
- max_height = max_y2 - max_y1
149
-
150
- out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
151
- for frame in person_videos[0]:
152
- crop = frame[max_y1:max_y2, max_x1:max_x2]
153
- crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
154
- out.write(crop)
155
- out.release()
156
-
157
- no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
158
- status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
159
- if status != 0:
160
- msg = "Oops! Could not preprocess the video. Please check the input video and try again."
161
- return None, None, None, msg
162
-
163
- video_output = crop_filename.split('.')[0] + '.mp4'
164
- status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
165
- (wav_file , no_sound_video, video_output), shell=True)
166
- if status != 0:
167
- msg = "Oops! Could not preprocess the video. Please check the input video and try again."
168
- return None, None, None, msg
169
-
170
- os.remove(crop_filename)
171
- os.remove(no_sound_video)
172
-
173
- print("Successfully saved the pre-processed video: ", video_output)
174
- else:
175
- msg = "Could not track the person in the full video! Please give a single-speaker video as input"
176
- return None, None, None, msg
177
-
178
- else:
179
- video_output = path
180
-
181
- return wav_file, fps, video_output, "success"
182
-
183
- def resample_video(video_file, video_fname, result_folder):
184
-
185
- '''
186
- This function resamples the video to 25 fps
187
-
188
- Args:
189
- - video_file (string) : Path of the input video file
190
- - video_fname (string) : Name of the input video file
191
- - result_folder (string) : Path of the folder to save the resampled video
192
- Returns:
193
- - video_file_25fps (string) : Path of the resampled video file
194
- '''
195
- video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
196
-
197
- # Resample the video to 25 fps
198
- command = ("ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}".format(video_file, video_file_25fps))
199
- from subprocess import call
200
- cmd = command.split(' ')
201
- print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
202
- call(cmd)
203
-
204
- return video_file_25fps
205
-
206
- def load_checkpoint(path, model):
207
- '''
208
- This function loads the trained model from the checkpoint
209
-
210
- Args:
211
- - path (string) : Path of the checkpoint file
212
- - model (object) : Model object
213
- Returns:
214
- - model (object) : Model object with the weights loaded from the checkpoint
215
- '''
216
-
217
- # Load the checkpoint
218
- if use_cuda:
219
- checkpoint = torch.load(path)
220
- else:
221
- checkpoint = torch.load(path, map_location="cpu")
222
-
223
- s = checkpoint["state_dict"]
224
- new_s = {}
225
-
226
- for k, v in s.items():
227
- new_s[k.replace('module.', '')] = v
228
- model.load_state_dict(new_s)
229
-
230
- if use_cuda:
231
- model.cuda()
232
-
233
- print("Loaded checkpoint from: {}".format(path))
234
-
235
- return model.eval()
236
-
237
-
238
- def load_video_frames(video_file):
239
- '''
240
- This function extracts the frames from the video
241
-
242
- Args:
243
- - video_file (string) : Path of the video file
244
- Returns:
245
- - frames (list) : List of frames extracted from the video
246
- - msg (string) : Message to be returned
247
- '''
248
-
249
- # Read the video
250
- try:
251
- vr = VideoReader(video_file, ctx=cpu(0))
252
- except:
253
- msg = "Oops! Could not load the input video file"
254
- return None, msg
255
-
256
-
257
- # Extract the frames
258
- frames = []
259
- for k in range(len(vr)):
260
- frames.append(vr[k].asnumpy())
261
-
262
- frames = np.asarray(frames)
263
-
264
- return frames, "success"
265
-
266
-
267
-
268
- def get_keypoints(frames):
269
-
270
- '''
271
- This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
272
-
273
- Args:
274
- - frames (list) : List of frames extracted from the video
275
- Returns:
276
- - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
277
- - msg (string) : Message to be returned
278
- '''
279
-
280
- try:
281
- holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
282
-
283
- resolution = frames[0].shape
284
- all_frame_kps = []
285
-
286
- for frame in frames:
287
-
288
- results = holistic.process(frame)
289
-
290
- pose, left_hand, right_hand, face = None, None, None, None
291
- if results.pose_landmarks is not None:
292
- pose = protobuf_to_dict(results.pose_landmarks)['landmark']
293
- if results.left_hand_landmarks is not None:
294
- left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
295
- if results.right_hand_landmarks is not None:
296
- right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
297
- if results.face_landmarks is not None:
298
- face = protobuf_to_dict(results.face_landmarks)['landmark']
299
-
300
- frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
301
-
302
- all_frame_kps.append(frame_dict)
303
-
304
- kp_dict = {"kps":all_frame_kps, "resolution":resolution}
305
- except Exception as e:
306
- print("Error: ", e)
307
- return None, "Error: Could not extract keypoints from the frames"
308
-
309
- return kp_dict, "success"
310
-
311
-
312
- def check_visible_gestures(kp_dict):
313
-
314
- '''
315
- This function checks if the gestures in the video are visible
316
-
317
- Args:
318
- - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
319
- Returns:
320
- - msg (string) : Message to be returned
321
- '''
322
-
323
- keypoints = kp_dict['kps']
324
- keypoints = np.array(keypoints)
325
-
326
- if len(keypoints)<25:
327
- msg = "Not enough keypoints to process! Please give a longer video as input"
328
- return msg
329
-
330
- pose_count, hand_count = 0, 0
331
- for frame_kp_dict in keypoints:
332
-
333
- pose = frame_kp_dict["pose"]
334
- left_hand = frame_kp_dict["left_hand"]
335
- right_hand = frame_kp_dict["right_hand"]
336
-
337
- if pose is None:
338
- pose_count += 1
339
-
340
- if left_hand is None and right_hand is None:
341
- hand_count += 1
342
-
343
-
344
- if hand_count/len(keypoints) > 0.7 or pose_count/len(keypoints) > 0.7:
345
- msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
346
- return msg
347
-
348
- print("Successfully verified the input video - Gestures are visible!")
349
-
350
- return "success"
351
-
352
- def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, width=480, height=270):
353
-
354
- '''
355
- This function masks the faces using the keypoints extracted from the frames
356
-
357
- Args:
358
- - input_frames (list) : List of frames extracted from the video
359
- - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
360
- - stride (int) : Stride to extract the frames
361
- - window_frames (int) : Number of frames in each window that is given as input to the model
362
- - width (int) : Width of the frames
363
- - height (int) : Height of the frames
364
- Returns:
365
- - input_frames (array) : Frame window to be given as input to the model
366
- - num_frames (int) : Number of frames to extract
367
- - orig_masked_frames (array) : Masked frames extracted from the video
368
- - msg (string) : Message to be returned
369
- '''
370
-
371
- # Face indices to extract the face-coordinates needed for masking
372
- face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
373
- 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
374
-
375
-
376
- input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
377
- print("Input keypoints: ", len(input_keypoints))
378
-
379
- print("Creating masked input frames...")
380
- input_frames_masked = []
381
- for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):
382
-
383
- img = input_frames[i]
384
- face = frame_kp_dict["face"]
385
-
386
- if face is None:
387
- img = cv2.resize(img, (width, height))
388
- masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
389
- else:
390
- face_kps = []
391
- for idx in range(len(face)):
392
- if idx in face_oval_idx:
393
- x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
394
- face_kps.append((x,y))
395
-
396
- face_kps = np.array(face_kps)
397
- x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
398
- x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
399
- masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
400
-
401
- if masked_img.shape[0] != width or masked_img.shape[1] != height:
402
- masked_img = cv2.resize(masked_img, (width, height))
403
-
404
- input_frames_masked.append(masked_img)
405
-
406
- orig_masked_frames = np.array(input_frames_masked)
407
- input_frames = np.array(input_frames_masked) / 255.
408
- print("Input images full: ", input_frames.shape) # num_framesx270x480x3
409
-
410
- input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
411
- print("Input images window: ", input_frames.shape) # Tx25x270x480x3
412
-
413
- num_frames = input_frames.shape[0]
414
-
415
- if num_frames<10:
416
- msg = "Not enough frames to process! Please give a longer video as input."
417
- return None, None, None, msg
418
-
419
- return input_frames, num_frames, orig_masked_frames, "success"
420
-
421
- def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):
422
-
423
- '''
424
- This function extracts the spectrogram from the audio file
425
-
426
- Args:
427
- - wav_file (string) : Path of the extracted audio file
428
- - num_frames (int) : Number of frames to extract
429
- - window_frames (int) : Number of frames in each window that is given as input to the model
430
- - stride (int) : Stride to extract the audio frames
431
- Returns:
432
- - spec (array) : Spectrogram array window to be used as input to the model
433
- - orig_spec (array) : Spectrogram array extracted from the audio file
434
- - msg (string) : Message to be returned
435
- '''
436
-
437
- # Extract the audio from the input video file using ffmpeg
438
- try:
439
- wav = librosa.load(wav_file, sr=16000)[0]
440
- except:
441
- msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
442
- return None, None, msg
443
-
444
- # Convert to tensor
445
- wav = torch.FloatTensor(wav).unsqueeze(0)
446
- mel, _, _, _ = wav2filterbanks(wav.to(device))
447
- spec = mel.squeeze(0).cpu().numpy()
448
- orig_spec = spec
449
- spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
450
-
451
- if len(spec) != num_frames:
452
- spec = spec[:num_frames]
453
- frame_diff = np.abs(len(spec) - num_frames)
454
- if frame_diff > 60:
455
- print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
456
-
457
- return spec, orig_spec, "success"
458
-
459
-
460
- def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
461
- '''
462
- This function calculates the audio-visual offset between the video and audio
463
-
464
- Args:
465
- - vid_emb (array) : Video embedding array
466
- - aud_emb (array) : Audio embedding array
467
- - num_avg_frames (int) : Number of frames to average the scores
468
- - model (object) : Model object
469
- Returns:
470
- - offset (int) : Optimal audio-visual offset
471
- - msg (string) : Message to be returned
472
- '''
473
-
474
- pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
475
- if status != "success":
476
- return None, status
477
- scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
478
- offset = scores.argmax()*stride - pos_idx
479
-
480
- return offset.item(), "success"
481
-
482
- def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
483
-
484
- '''
485
- This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
486
-
487
- Args:
488
- - vid_emb (array) : Video embedding array
489
- - aud_emb (array) : Audio embedding array
490
- - num_avg_frames (int) : Number of frames to average the scores
491
- - stride (int) : Stride to extract the negative windows
492
- Returns:
493
- - vid_emb_pos (array) : Positive video embedding array
494
- - aud_emb_posneg (array) : All possible combinations of audio embedding array
495
- - pos_idx_frame (int) : Positive video embedding array frame
496
- - stride (int) : Stride used to extract the negative windows
497
- - msg (string) : Message to be returned
498
- '''
499
-
500
- slice_size = num_avg_frames
501
- aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
502
- aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
503
- aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
504
-
505
- pos_idx = (aud_emb_posneg.shape[1]//2)
506
- pos_idx_frame = pos_idx*stride
507
-
508
- min_offset_frames = -(pos_idx)*stride
509
- max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
510
- print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
511
-
512
- vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
513
- if vid_emb_pos.shape[2] != slice_size:
514
- msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
515
- return None, None, None, None, msg
516
-
517
- return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
518
-
519
- def calc_av_scores(vid_emb, aud_emb, model):
520
-
521
- '''
522
- This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
523
-
524
- Args:
525
- - vid_emb (array) : Video embedding array
526
- - aud_emb (array) : Audio embedding array
527
- - model (object) : Model object
528
- Returns:
529
- - scores (array) : Audio-visual similarity scores
530
- - att_map (array) : Attention map
531
- '''
532
-
533
- scores = calc_att_map(vid_emb, aud_emb, model)
534
- att_map = logsoftmax_2d(scores)
535
- scores = scores.mean(-1)
536
-
537
- return scores, att_map
538
-
539
- def calc_att_map(vid_emb, aud_emb, model):
540
-
541
- '''
542
- This function calculates the similarity between the video and audio embeddings
543
-
544
- Args:
545
- - vid_emb (array) : Video embedding array
546
- - aud_emb (array) : Audio embedding array
547
- - model (object) : Model object
548
- Returns:
549
- - scores (array) : Audio-visual similarity scores
550
- '''
551
-
552
- vid_emb = vid_emb[:, :, None]
553
- aud_emb = aud_emb.transpose(1, 2)
554
-
555
- scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
556
- vid_emb,
557
- aud_emb,
558
- part_len=10,
559
- dim=3,
560
- device=device)
561
-
562
- scores = model.logits_scale(scores[..., None]).squeeze(-1)
563
-
564
- return scores
565
-
566
- def generate_video(frames, audio_file, video_fname):
567
-
568
- '''
569
- This function generates the video from the frames and audio file
570
-
571
- Args:
572
- - frames (array) : Frames to be used to generate the video
573
- - audio_file (string) : Path of the audio file
574
- - video_fname (string) : Path of the video file
575
- Returns:
576
- - video_output (string) : Path of the video file
577
- '''
578
-
579
- fname = 'inference.avi'
580
- video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
581
-
582
- for i in range(len(frames)):
583
- video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
584
- video.release()
585
-
586
- no_sound_video = video_fname + '_nosound.mp4'
587
- status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
588
- if status != 0:
589
- msg = "Oops! Could not generate the video. Please check the input video and try again."
590
- return None, msg
591
-
592
- video_output = video_fname + '.mp4'
593
- status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' %
594
- (audio_file, no_sound_video, video_output), shell=True)
595
- if status != 0:
596
- msg = "Oops! Could not generate the video. Please check the input video and try again."
597
- return None, msg
598
-
599
- os.remove(fname)
600
- os.remove(no_sound_video)
601
-
602
- return video_output
603
-
604
- def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
605
-
606
- '''
607
- This function corrects the video and audio to sync with each other
608
-
609
- Args:
610
- - video_path (string) : Path of the video file
611
- - frames (array) : Frames to be used to generate the video
612
- - wav_file (string) : Path of the audio file
613
- - offset (int) : Predicted sync-offset to be used to correct the video
614
- - result_folder (string) : Path of the result folder to save the output sync-corrected video
615
- - sample_rate (int) : Sample rate of the audio
616
- - fps (int) : Frames per second of the video
617
- Returns:
618
- - video_output (string) : Path of the video file
619
- '''
620
-
621
- if offset == 0:
622
- print("The input audio and video are in-sync! No need to perform sync correction.")
623
- return video_path
624
-
625
- print("Performing Sync Correction...")
626
- corrected_frames = np.zeros_like(frames)
627
- if offset > 0:
628
- audio_offset = int(offset*(sample_rate/fps))
629
- wav = librosa.core.load(wav_file, sr=sample_rate)[0]
630
- corrected_wav = wav[audio_offset:]
631
- corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
632
- write(corrected_wav_file, sample_rate, corrected_wav)
633
- wav_file = corrected_wav_file
634
- corrected_frames = frames
635
- elif offset < 0:
636
- corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
637
- corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
638
-
639
- corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
640
- video_output = generate_video(corrected_frames, wav_file, corrected_video_path)
641
-
642
- return video_output
643
-
644
- class Logger:
645
- def __init__(self, filename):
646
- self.terminal = sys.stdout
647
- self.log = open(filename, "w")
648
-
649
- def write(self, message):
650
- self.terminal.write(message)
651
- self.log.write(message)
652
-
653
- def flush(self):
654
- self.terminal.flush()
655
- self.log.flush()
656
-
657
- def isatty(self):
658
- return False
659
-
660
-
661
- def process_video(video_path, num_avg_frames, apply_preprocess):
662
- try:
663
- # Extract the video filename
664
- video_fname = os.path.basename(video_path.split(".")[0])
665
-
666
- # Create folders to save the inputs and results
667
- result_folder = os.path.join("results", video_fname)
668
- result_folder_input = os.path.join(result_folder, "input")
669
- result_folder_output = os.path.join(result_folder, "output")
670
-
671
- if os.path.exists(result_folder):
672
- rmtree(result_folder)
673
-
674
- os.makedirs(result_folder)
675
- os.makedirs(result_folder_input)
676
- os.makedirs(result_folder_output)
677
-
678
-
679
- # Preprocess the video
680
- print("Applying preprocessing: ", apply_preprocess)
681
- wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
682
- if status != "success":
683
- return status, None
684
- print("Successfully preprocessed the video")
685
-
686
- # Resample the video to 25 fps if it is not already 25 fps
687
- print("FPS of video: ", fps)
688
- if fps!=25:
689
- vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
690
- orig_vid_path_25fps = resample_video(video_path, "input_video_25fps", result_folder_input)
691
- else:
692
- vid_path = vid_path_processed
693
- orig_vid_path_25fps = video_path
694
-
695
- # Load the original video frames (before pre-processing) - Needed for the final sync-correction
696
- orig_frames, status = load_video_frames(orig_vid_path_25fps)
697
- if status != "success":
698
- return status, None
699
-
700
- # Load the pre-processed video frames
701
- frames, status = load_video_frames(vid_path)
702
- if status != "success":
703
- return status, None
704
- print("Successfully extracted the video frames")
705
-
706
- if len(frames) < num_avg_frames:
707
- return "Error: The input video is too short. Please use a longer input video.", None
708
-
709
- # Load keypoints and check if gestures are visible
710
- kp_dict, status = get_keypoints(frames)
711
- if status != "success":
712
- return status, None
713
- print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))
714
-
715
- status = check_visible_gestures(kp_dict)
716
- if status != "success":
717
- return status, None
718
-
719
- # Load RGB frames
720
- rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, window_frames=25, width=480, height=270)
721
- if status != "success":
722
- return status, None
723
- print("Successfully loaded the RGB frames")
724
-
725
- # Convert frames to tensor
726
- rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
727
- rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
728
- B = rgb_frames.size(0)
729
- print("Successfully converted the frames to tensor")
730
-
731
- # Load spectrograms
732
- spec, orig_spec, status = load_spectrograms(wav_file, num_frames, window_frames=25)
733
- if status != "success":
734
- return status, None
735
- spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
736
- print("Successfully loaded the spectrograms")
737
-
738
- # Create input windows
739
- video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
740
- audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
741
-
742
- # Load the trained model
743
- model = Transformer_RGB()
744
- model = load_checkpoint(CHECKPOINT_PATH, model)
745
- print("Successfully loaded the model")
746
-
747
- # Process in batches
748
- batch_size = 12
749
- video_emb = []
750
- audio_emb = []
751
-
752
- for i in tqdm(range(0, len(video_sequences), batch_size)):
753
- video_inp = video_sequences[i:i+batch_size, ]
754
- audio_inp = audio_sequences[i:i+batch_size, ]
755
-
756
- vid_emb = model.forward_vid(video_inp.to(device))
757
- vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
758
- aud_emb = model.forward_aud(audio_inp.to(device))
759
-
760
- video_emb.append(vid_emb.detach())
761
- audio_emb.append(aud_emb.detach())
762
-
763
- torch.cuda.empty_cache()
764
-
765
- audio_emb = torch.cat(audio_emb, dim=0)
766
- video_emb = torch.cat(video_emb, dim=0)
767
-
768
- # L2 normalize embeddings
769
- video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
770
- audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
771
-
772
- audio_emb = torch.split(audio_emb, B, dim=0)
773
- audio_emb = torch.stack(audio_emb, dim=2)
774
- audio_emb = audio_emb.squeeze(3)
775
- audio_emb = audio_emb[:, None]
776
-
777
- video_emb = torch.split(video_emb, B, dim=0)
778
- video_emb = torch.stack(video_emb, dim=2)
779
- video_emb = video_emb.squeeze(3)
780
- print("Successfully extracted GestSync embeddings")
781
-
782
- # Calculate sync offset
783
- pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
784
- if status != "success":
785
- return status, None
786
- print("Predicted offset: ", pred_offset)
787
-
788
- # Generate sync-corrected video
789
- video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
790
- print("Successfully generated the video:", video_output)
791
-
792
- return f"Predicted offset: {pred_offset}", video_output
793
-
794
- except Exception as e:
795
- return f"Error: {str(e)}", None
796
-
797
- def read_logs():
798
- sys.stdout.flush()
799
- with open("output.log", "r") as f:
800
- return f.read()
801
-
802
-
803
- if __name__ == "__main__":
804
-
805
- sys.stdout = Logger("output.log")
806
-
807
-
808
- # Define the custom HTML for the header
809
- custom_css = """
810
- <style>
811
- body {
812
- background-color: #ffffff;
813
- color: #333333; /* Default text color */
814
- }
815
- .container {
816
- max-width: 100% !important;
817
- padding-left: 0 !important;
818
- padding-right: 0 !important;
819
- }
820
- .header {
821
- background-color: #f0f0f0;
822
- color: #333333;
823
- padding: 30px;
824
- margin-bottom: 30px;
825
- text-align: center;
826
- font-family: 'Helvetica Neue', Arial, sans-serif;
827
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
828
- }
829
- .header h1 {
830
- font-size: 36px;
831
- margin-bottom: 15px;
832
- font-weight: bold;
833
- color: #333333; /* Explicitly set heading color */
834
- }
835
- .header h2 {
836
- font-size: 24px;
837
- margin-bottom: 10px;
838
- color: #333333; /* Explicitly set subheading color */
839
- }
840
- .header p {
841
- font-size: 18px;
842
- margin: 5px 0;
843
- color: #666666;
844
- }
845
- .blue-text {
846
- color: #4a90e2;
847
- }
848
- /* Custom styles for slider container */
849
- .slider-container {
850
- background-color: white !important;
851
- padding-top: 0.9em;
852
- padding-bottom: 0.9em;
853
- }
854
- /* Add gap before examples */
855
- .examples-holder {
856
- margin-top: 2em;
857
- }
858
- /* Set fixed size for example videos */
859
- .gradio-container .gradio-examples .gr-sample {
860
- width: 240px !important;
861
- height: 135px !important;
862
- object-fit: cover;
863
- display: inline-block;
864
- margin-right: 10px;
865
- }
866
-
867
- .gradio-container .gradio-examples {
868
- display: flex;
869
- flex-wrap: wrap;
870
- gap: 10px;
871
- }
872
-
873
- /* Ensure the parent container does not stretch */
874
- .gradio-container .gradio-examples {
875
- max-width: 100%;
876
- overflow: hidden;
877
- }
878
-
879
- /* Additional styles to ensure proper sizing in Safari */
880
- .gradio-container .gradio-examples .gr-sample img {
881
- width: 240px !important;
882
- height: 135px !important;
883
- object-fit: cover;
884
- }
885
- </style>
886
- """
887
-
888
- custom_html = custom_css + """
889
- <div class="header">
890
- <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
891
- <h2>Upload any video to predict the synchronization offset and generate a sync-corrected video</h2>
892
- <p>Sindhu Hegde and Andrew Zisserman</p>
893
- <p>VGG, University of Oxford</p>
894
- </div>
895
- """
896
-
897
- # Define paths to sample videos
898
- sample_videos = [
899
- "samples/sync_sample_1.mp4",
900
- "samples/sync_sample_2.mp4",
901
- ]
902
-
903
- # Define Gradio interface
904
- with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
905
- gr.HTML(custom_html)
906
- with gr.Row():
907
- with gr.Column():
908
- with gr.Group(elem_classes="slider-container"):
909
- num_avg_frames = gr.Slider(
910
- minimum=50,
911
- maximum=150,
912
- step=5,
913
- value=75,
914
- label="Number of Average Frames",
915
- )
916
- apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False)
917
- video_input = gr.Video(label="Upload Video", height=400)
918
-
919
- with gr.Column():
920
- result_text = gr.Textbox(label="Result")
921
- output_video = gr.Video(label="Sync Corrected Video", height=400)
922
-
923
- with gr.Row():
924
- submit_button = gr.Button("Submit", variant="primary")
925
- clear_button = gr.Button("Clear")
926
-
927
- submit_button.click(
928
- fn=process_video,
929
- inputs=[video_input, num_avg_frames, apply_preprocess],
930
- outputs=[result_text, output_video]
931
- )
932
-
933
- clear_button.click(
934
- fn=lambda: (None, 75, False, "", None),
935
- inputs=[],
936
- outputs=[video_input, num_avg_frames, apply_preprocess, result_text, output_video]
937
- )
938
-
939
- gr.HTML('<div class="examples-holder"></div>')
940
-
941
- # Add examples
942
- gr.Examples(
943
- examples=sample_videos,
944
- inputs=video_input,
945
- outputs=None,
946
- fn=None,
947
- cache_examples=False,
948
- )
949
-
950
- logs = gr.Textbox(label="Logs")
951
- demo.load(read_logs, None, logs, every=1)
952
-
953
- # Launch the interface
954
- demo.queue().launch(allowed_paths=["."], show_error=True)