Hrithik28 commited on
Commit
b8739b2
1 Parent(s): b9602d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -122
app.py CHANGED
@@ -1,37 +1,54 @@
1
  import os
2
  import sys
3
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
- import dlib
6
- import cv2
7
  import skvideo.io
 
 
 
8
  import torch
 
 
 
9
  import fairseq
10
  from fairseq import checkpoint_utils, options, tasks, utils
11
  from fairseq.dataclass.configs import GenerationConfig
12
  from huggingface_hub import hf_hub_download
13
  import gradio as gr
14
  from pytube import YouTube
15
- from base64 import b64encode
16
- from tqdm import tqdm
17
- from argparse import Namespace
18
 
19
- # Ensure necessary directories and files exist
20
- required_paths = [
21
- "/home/user/app/av_hubert/avhubert",
22
- "/home/user/app/video",
23
- "/home/user/app/mmod_human_face_detector.dat",
24
- "/home/user/app/shape_predictor_68_face_landmarks.dat",
25
- "/home/user/app/20words_mean_face.npy",
26
- "/home/user/app/roi.mp4"
27
- ]
28
- for path in required_paths:
29
- if not os.path.exists(path):
30
- raise FileNotFoundError(f"Required path {path} does not exist")
31
-
32
- # Load model and setup task
33
  user_dir = "/home/user/app/av_hubert/avhubert"
34
- sys.path.append(user_dir)
35
  utils.import_user_module(Namespace(user_dir=user_dir))
36
  data_dir = "/home/user/app/video"
37
 
@@ -43,11 +60,6 @@ mouth_roi_path = "/home/user/app/roi.mp4"
43
  modalities = ["video"]
44
  gen_subset = "test"
45
  gen_cfg = GenerationConfig(beam=20)
46
-
47
- # Check if the model file exists
48
- if not os.path.exists(ckpt_path):
49
- raise FileNotFoundError(f"Checkpoint file not found at {ckpt_path}")
50
-
51
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
52
  models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
53
  saved_cfg.task.modalities = modalities
@@ -56,7 +68,6 @@ saved_cfg.task.label_dir = data_dir
56
  task = tasks.setup_task(saved_cfg.task)
57
  generator = task.build_generator(models, gen_cfg)
58
 
59
- # Helper Functions
60
  def get_youtube(video_url):
61
  yt = YouTube(video_url)
62
  abs_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
@@ -66,53 +77,61 @@ def get_youtube(video_url):
66
 
67
  def detect_landmark(image, detector, predictor):
68
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
69
- face_locations = detector(gray, 1)
70
  coords = None
71
- for face_location in face_locations:
72
- rect = face_location.rect if torch.cuda.is_available() else face_location
 
 
 
73
  shape = predictor(gray, rect)
74
  coords = np.zeros((68, 2), dtype=np.int32)
75
- for i in range(68):
76
  coords[i] = (shape.part(i).x, shape.part(i).y)
77
  return coords
78
 
79
- def landmarks_interpolate(landmarks):
80
- landmarks = np.array(landmarks)
81
- for i in range(landmarks.shape[1]):
82
- if landmarks[:, i, :].size == 0:
83
- continue
84
- x = np.arange(len(landmarks))
85
- y = landmarks[:, i, :]
86
- valid = ~np.isnan(y)
87
- y[~valid] = np.interp(x[~valid], x[valid], y[valid])
88
- return landmarks
89
-
90
- def crop_patch(video_path, landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE, window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96):
91
- video_capture = cv2.VideoCapture(video_path)
92
- frames = []
93
- for landmark in landmarks:
94
- ret, frame = video_capture.read()
95
- if not ret:
96
- break
97
- h, w, _ = frame.shape
98
- x1, y1, x2, y2 = 100, 100, 200, 200 # Replace with actual ROI based on landmarks
99
- roi = frame[y1:y2, x1:x2]
100
- roi = cv2.resize(roi, STD_SIZE)
101
- frames.append(roi)
102
- return frames
103
-
104
- def write_video_ffmpeg(frames, output_path, ffmpeg_path):
105
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
106
- height, width = frames[0].shape[:2]
107
- out = cv2.VideoWriter(output_path, fourcc, 25, (width, height))
108
- for frame in frames:
109
- out.write(frame)
110
- out.release()
111
 
112
- def preprocess_video(input_video_path):
113
- if not input_video_path or not os.path.exists(input_video_path):
114
- raise ValueError("Invalid video path provided.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if torch.cuda.is_available():
117
  detector = dlib.cnn_face_detection_model_v1(face_detector_path)
118
  else:
@@ -122,74 +141,97 @@ def preprocess_video(input_video_path):
122
  STD_SIZE = (256, 256)
123
  mean_face_landmarks = np.load(mean_face_path)
124
  stablePntsIDs = [33, 36, 39, 42, 45]
125
-
126
- try:
127
- videogen = skvideo.io.vread(input_video_path)
128
- frames = np.array([frame for frame in videogen])
129
- except Exception as e:
130
- raise ValueError(f"Error reading video: {e}")
131
-
132
- if frames.size == 0:
133
- raise ValueError("No frames found in video")
134
-
135
  landmarks = []
136
  for frame in tqdm(frames):
137
  landmark = detect_landmark(frame, detector, predictor)
138
  landmarks.append(landmark)
139
-
140
  preprocessed_landmarks = landmarks_interpolate(landmarks)
141
  rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
142
- window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
143
  write_video_ffmpeg(rois, mouth_roi_path, "/usr/bin/ffmpeg")
144
  return mouth_roi_path
145
 
146
  def predict(process_video):
147
- if not process_video or not os.path.exists(process_video):
148
- raise ValueError("Invalid video path provided.")
149
-
150
- features = []
151
- with open(process_video, "rb") as f:
152
- data = f.read()
153
- features.append(data)
154
-
155
- sample = next(iter(features))
156
- output = task.forward(sample)
157
- return output
158
 
159
- def get_youtube(video_url):
160
- yt = YouTube(video_url)
161
- abs_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
162
- print("Success download video")
163
- print(abs_video_path)
164
- return abs_video_path
 
165
 
166
- # Gradio UI
167
- def process_ui():
168
- with gr.Blocks() as demo:
169
- youtube_url_in = gr.Textbox(label="YouTube URL")
170
- video_in = gr.Video(label="Input Video")
171
- video_out = gr.Video(label="Processed Video")
172
- text_output = gr.Textbox(label="Prediction")
173
-
174
- with gr.Row():
175
- youtube_url_in.render()
176
- download_youtube_btn = gr.Button("Download YouTube video")
177
- download_youtube_btn.click(get_youtube, [youtube_url_in], [video_in])
178
-
179
- with gr.Row():
180
- video_in.render()
181
- video_out.render()
182
-
183
- with gr.Row():
184
- detect_landmark_btn = gr.Button("Detect Landmark")
185
- detect_landmark_btn.click(preprocess_video, [video_in], [video_out])
186
- predict_btn = gr.Button("Predict")
187
- predict_btn.click(predict, [video_out], [text_output])
188
-
189
- with gr.Row():
190
- text_output.render()
191
 
192
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- if __name__ == "__main__":
195
- process_ui()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import json
4
+
5
+
6
+ os.system('git clone https://github.com/facebookresearch/av_hubert.git')
7
+ os.chdir('/home/user/app/av_hubert')
8
+ os.system('git submodule init')
9
+ os.system('git submodule update')
10
+ os.chdir('/home/user/app/av_hubert/fairseq')
11
+ os.system('pip install ./')
12
+ os.system('pip install scipy')
13
+ os.system('pip install sentencepiece')
14
+ os.system('pip install python_speech_features')
15
+ os.system('pip install scikit-video')
16
+ os.system('pip install transformers')
17
+ os.system('pip install gradio==3.12')
18
+ os.system('pip install numpy==1.23.3')
19
+
20
+
21
+ # sys.path.append('/home/user/app/av_hubert')
22
+ sys.path.append('/home/user/app/av_hubert/avhubert')
23
+
24
+ print(sys.path)
25
+ print(os.listdir())
26
+ print(sys.argv, type(sys.argv))
27
+ sys.argv.append('dummy')
28
+
29
+
30
+
31
+ import dlib, cv2, os
32
  import numpy as np
33
+ import skvideo
 
34
  import skvideo.io
35
+ from tqdm import tqdm
36
+ from preparation.align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
37
+ from base64 import b64encode
38
  import torch
39
+ import cv2
40
+ import tempfile
41
+ from argparse import Namespace
42
  import fairseq
43
  from fairseq import checkpoint_utils, options, tasks, utils
44
  from fairseq.dataclass.configs import GenerationConfig
45
  from huggingface_hub import hf_hub_download
46
  import gradio as gr
47
  from pytube import YouTube
 
 
 
48
 
49
+ # os.chdir('/home/user/app/av_hubert/avhubert')
50
+
 
 
 
 
 
 
 
 
 
 
 
 
51
  user_dir = "/home/user/app/av_hubert/avhubert"
 
52
  utils.import_user_module(Namespace(user_dir=user_dir))
53
  data_dir = "/home/user/app/video"
54
 
 
60
  modalities = ["video"]
61
  gen_subset = "test"
62
  gen_cfg = GenerationConfig(beam=20)
 
 
 
 
 
63
  models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
64
  models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
65
  saved_cfg.task.modalities = modalities
 
68
  task = tasks.setup_task(saved_cfg.task)
69
  generator = task.build_generator(models, gen_cfg)
70
 
 
71
  def get_youtube(video_url):
72
  yt = YouTube(video_url)
73
  abs_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
 
77
 
78
  def detect_landmark(image, detector, predictor):
79
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
80
+ face_locations = detector(gray, 1)
81
  coords = None
82
+ for (_, face_location) in enumerate(face_locations):
83
+ if torch.cuda.is_available():
84
+ rect = face_location.rect
85
+ else:
86
+ rect = face_location
87
  shape = predictor(gray, rect)
88
  coords = np.zeros((68, 2), dtype=np.int32)
89
+ for i in range(0, 68):
90
  coords[i] = (shape.part(i).x, shape.part(i).y)
91
  return coords
92
 
93
+ # def predict_and_save(process_video):
94
+ # num_frames = int(cv2.VideoCapture(process_video).get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # tsv_cont = ["/\n", f"test-0\t{process_video}\t{None}\t{num_frames}\t{int(16_000*num_frames/25)}\n"]
97
+ # label_cont = ["DUMMY\n"]
98
+ # with open(f"{data_dir}/test.tsv", "w") as fo:
99
+ # fo.write("".join(tsv_cont))
100
+ # with open(f"{data_dir}/test.wrd", "w") as fo:
101
+ # fo.write("".join(label_cont))
102
+ # task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
103
+
104
+ # def decode_fn(x):
105
+ # dictionary = task.target_dictionary
106
+ # symbols_ignore = generator.symbols_to_strip_from_output
107
+ # symbols_ignore.add(dictionary.pad())
108
+ # return task.datasets[gen_subset].label_processors[0].decode(x, symbols_ignore)
109
+
110
+ # itr = task.get_batch_iterator(dataset=task.dataset(gen_subset)).next_epoch_itr(shuffle=False)
111
+ # sample = next(itr)
112
+ # if torch.cuda.is_available():
113
+ # sample = utils.move_to_cuda(sample)
114
+ # hypos = task.inference_step(generator, models, sample)
115
+ # ref = decode_fn(sample['target'][0].int().cpu())
116
+ # hypo = hypos[0][0]['tokens'].int().cpu()
117
+ # hypo = decode_fn(hypo)
118
 
119
+ # # Collect timestamps and texts
120
+ # transcript = []
121
+ # for i, (start, end) in enumerate(sample['net_input']['video_lengths'], 1):
122
+ # start_time = float(start) / 16_000
123
+ # end_time = float(end) / 16_000
124
+ # text = hypo[i].strip()
125
+ # transcript.append({"timestamp": [start_time, end_time], "text": text})
126
+
127
+ # # Save transcript to a JSON file
128
+ # with open('speech_transcript.json', 'w') as outfile:
129
+ # json.dump(transcript, outfile, indent=4)
130
+
131
+ # return hypo
132
+
133
+
134
+ def preprocess_video(input_video_path):
135
  if torch.cuda.is_available():
136
  detector = dlib.cnn_face_detection_model_v1(face_detector_path)
137
  else:
 
141
  STD_SIZE = (256, 256)
142
  mean_face_landmarks = np.load(mean_face_path)
143
  stablePntsIDs = [33, 36, 39, 42, 45]
144
+ videogen = skvideo.io.vread(input_video_path)
145
+ frames = np.array([frame for frame in videogen])
 
 
 
 
 
 
 
 
146
  landmarks = []
147
  for frame in tqdm(frames):
148
  landmark = detect_landmark(frame, detector, predictor)
149
  landmarks.append(landmark)
 
150
  preprocessed_landmarks = landmarks_interpolate(landmarks)
151
  rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
152
+ window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
153
  write_video_ffmpeg(rois, mouth_roi_path, "/usr/bin/ffmpeg")
154
  return mouth_roi_path
155
 
156
  def predict(process_video):
157
+ num_frames = int(cv2.VideoCapture(process_video).get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
 
 
 
 
 
 
158
 
159
+ tsv_cont = ["/\n", f"test-0\t{process_video}\t{None}\t{num_frames}\t{int(16_000*num_frames/25)}\n"]
160
+ label_cont = ["DUMMY\n"]
161
+ with open(f"{data_dir}/test.tsv", "w") as fo:
162
+ fo.write("".join(tsv_cont))
163
+ with open(f"{data_dir}/test.wrd", "w") as fo:
164
+ fo.write("".join(label_cont))
165
+ task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
166
 
167
+ def decode_fn(x):
168
+ dictionary = task.target_dictionary
169
+ symbols_ignore = generator.symbols_to_strip_from_output
170
+ symbols_ignore.add(dictionary.pad())
171
+ return task.datasets[gen_subset].label_processors[0].decode(x, symbols_ignore)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ itr = task.get_batch_iterator(dataset=task.dataset(gen_subset)).next_epoch_itr(shuffle=False)
174
+ sample = next(itr)
175
+ if torch.cuda.is_available():
176
+ sample = utils.move_to_cuda(sample)
177
+ hypos = task.inference_step(generator, models, sample)
178
+ ref = decode_fn(sample['target'][0].int().cpu())
179
+ hypo = hypos[0][0]['tokens'].int().cpu()
180
+ hypo = decode_fn(hypo)
181
+ return hypo
182
+
183
+
184
+ # ---- Gradio Layout -----
185
+ youtube_url_in = gr.Textbox(label="Youtube url", lines=1, interactive=True)
186
+ video_in = gr.Video(label="Input Video", mirror_webcam=False, interactive=True)
187
+ video_out = gr.Video(label="Audio Visual Video", mirror_webcam=False, interactive=True)
188
+ demo = gr.Blocks()
189
+ demo.encrypt = False
190
+ text_output = gr.Textbox()
191
 
192
+ with demo:
193
+ gr.Markdown('''
194
+ <div>
195
+ <h1 style='text-align: center'>Lip Reading Using Machine learning (Audio-Visual Hidden Unit BERT Model (AV-HuBERT))</h1>
196
+ </div>
197
+ ''')
198
+ with gr.Row():
199
+ gr.Markdown('''
200
+ ### Reading Lip movement with youtube link using Avhubert
201
+ ##### Step 1a. Download video from youtube (Note: the length of video should be less than 10 seconds if not it will be cut and the face should be stable for better result)
202
+ ##### Step 1b. Drag and drop videos to upload directly
203
+ ##### Step 2. Generating landmarks surrounding mouth area
204
+ ##### Step 3. Reading lip movement.
205
+ ''')
206
+ with gr.Row():
207
+ gr.Markdown('''
208
+ ### You can test by following examples:
209
+ ''')
210
+ examples = gr.Examples(examples=
211
+ [ "https://www.youtube.com/watch?v=ZXVDnuepW2s",
212
+ "https://www.youtube.com/watch?v=X8_glJn1B8o",
213
+ "https://www.youtube.com/watch?v=80yqL2KzBVw"],
214
+ label="Examples", inputs=[youtube_url_in])
215
+ with gr.Column():
216
+ youtube_url_in.render()
217
+ download_youtube_btn = gr.Button("Download Youtube video")
218
+ download_youtube_btn.click(get_youtube, [youtube_url_in], [
219
+ video_in])
220
+ print(video_in)
221
+ with gr.Row():
222
+ video_in.render()
223
+ video_out.render()
224
+ with gr.Row():
225
+ detect_landmark_btn = gr.Button("Detect landmark")
226
+ detect_landmark_btn.click(preprocess_video, [video_in], [
227
+ video_out])
228
+ predict_btn = gr.Button("Predict")
229
+ #predict_btn.click(predict, [video_out], [text_output])
230
+ predict_btn.click(predict, [video_out], [text_output])
231
+ with gr.Row():
232
+ # video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
233
+ text_output.render()
234
+
235
+
236
+
237
+ demo.launch(debug=True)