ssaad5678 commited on
Commit
539512b
1 Parent(s): 448c116

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -63
app.py CHANGED
@@ -19,7 +19,7 @@ mtcnn = MTCNN(
19
  select_largest=False,
20
  post_process=False,
21
  device=DEVICE
22
- ).eval()
23
  model = InceptionResnetV1(
24
  pretrained="vggface2",
25
  classify=True,
@@ -38,88 +38,81 @@ def predict_frame(frame):
38
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
  frame_pil = Image.fromarray(frame)
40
 
41
- # Detect faces
42
- boxes, probs = mtcnn.detect(frame_pil)
43
-
44
- faces = []
45
- confidences = []
46
-
47
- if boxes is not None:
48
- for box, prob in zip(boxes, probs):
49
- # Extract face
50
- x1, y1, x2, y2 = box.astype(int)
51
- face = frame[y1:y2, x1:x2]
52
-
53
- # Preprocess the face
54
- face = cv2.resize(face, (160, 160)) # Resize to match InceptionResnetV1 input size
55
- face = torch.tensor(face).permute(2, 0, 1).unsqueeze(0).float().to(DEVICE) / 255.0
56
-
57
- # Predict
58
- with torch.no_grad():
59
- output = torch.sigmoid(model(face).squeeze())
60
- prediction = "real" if output.item() < 0.5 else "fake"
61
-
62
- # Confidence scores
63
- real_prediction = 1 - output.item()
64
- fake_prediction = output.item()
65
-
66
- confidences.append({
67
- 'prediction': prediction,
68
- 'confidence': fake_prediction if prediction == 'fake' else real_prediction
69
- })
70
-
71
- # Visualize
72
- target_layers = [model.block8.branch1[-1]]
73
- use_cuda = True if torch.cuda.is_available() else False
74
- cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
75
- targets = [ClassifierOutputTarget(0)]
76
- grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
77
- grayscale_cam = grayscale_cam[0, :]
78
- visualization = show_cam_on_image(face.squeeze().permute(1, 2, 0).cpu().numpy(), grayscale_cam, use_rgb=True)
79
- face_with_mask = cv2.addWeighted((face.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0)
80
- faces.append(face_with_mask)
81
-
82
- return faces, confidences
83
 
84
  def predict_video(input_video):
85
  cap = cv2.VideoCapture(input_video)
86
 
87
  frames = []
88
- all_confidences = []
 
 
89
 
90
  while True:
91
  ret, frame = cap.read()
92
  if not ret:
93
  break
 
 
 
94
 
95
- faces, confidences = predict_frame(frame)
96
 
97
- if faces:
98
- frames.extend(faces)
99
- all_confidences.extend(confidences)
100
 
101
  cap.release()
102
 
103
  # Determine the final prediction based on the maximum occurrence of predictions
104
- final_prediction = 'fake' if sum(1 for conf in all_confidences if conf['prediction'] == 'fake') > sum(1 for conf in all_confidences if conf['prediction'] == 'real') else 'real'
105
 
106
- return final_prediction, frames, all_confidences
107
 
108
  # Gradio Interface
109
- def show_detected_faces(video):
110
- prediction, frames, confidences = predict_video(video.name)
111
- return prediction, frames, confidences
112
-
113
- gr.Interface(
114
- fn=show_detected_faces,
115
  inputs=[
116
- gr.inputs.Video(label="Input Video", type="file")
117
  ],
118
  outputs=[
119
- gr.outputs.Label(label="Class"),
120
- gr.outputs.Image(label="Detected Faces", type="numpy", multiple=True),
121
- gr.outputs.Label(label="Confidences", type="json")
122
  ],
123
- title="Deep Fake Video Detection",
124
- description="Detect whether the Video is fake or real and visualize the detected faces with confidence scores."
125
- ).launch()
 
 
 
19
  select_largest=False,
20
  post_process=False,
21
  device=DEVICE
22
+ ).to(DEVICE).eval()
23
  model = InceptionResnetV1(
24
  pretrained="vggface2",
25
  classify=True,
 
38
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
  frame_pil = Image.fromarray(frame)
40
 
41
+ face = mtcnn(frame_pil)
42
+ if face is None:
43
+ return None, None # No face detected
44
+
45
+ # Preprocess the face
46
+ face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
47
+ face = face.to(DEVICE, dtype=torch.float32) / 255.0
48
+
49
+ # Predict
50
+ with torch.no_grad():
51
+ output = torch.sigmoid(model(face).squeeze(0))
52
+ prediction = "real" if output.item() < 0.5 else "fake"
53
+
54
+ # Confidence scores
55
+ real_prediction = 1 - output.item()
56
+ fake_prediction = output.item()
57
+
58
+ confidences = {
59
+ 'real': real_prediction,
60
+ 'fake': fake_prediction
61
+ }
62
+
63
+ # Visualize
64
+ target_layers = [model.block8.branch1[-1]]
65
+ use_cuda = True if torch.cuda.is_available() else False
66
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
67
+ targets = [ClassifierOutputTarget(0)]
68
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
69
+ grayscale_cam = grayscale_cam[0, :]
70
+ face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy()
71
+ visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)
72
+ face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0)
73
+
74
+ return prediction, face_with_mask
 
 
 
 
 
 
 
 
75
 
76
  def predict_video(input_video):
77
  cap = cv2.VideoCapture(input_video)
78
 
79
  frames = []
80
+ confidences = []
81
+ frame_count = 0
82
+ skip_frames = 20
83
 
84
  while True:
85
  ret, frame = cap.read()
86
  if not ret:
87
  break
88
+ frame_count+=1
89
+ if frame_count % skip_frames != 0: # Skip frames if not divisible by skip_frames
90
+ continue
91
 
92
+ prediction, frame_with_mask = predict_frame(frame)
93
 
94
+ frames.append(frame_with_mask)
95
+ confidences.append(prediction)
 
96
 
97
  cap.release()
98
 
99
  # Determine the final prediction based on the maximum occurrence of predictions
100
+ final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real'
101
 
102
+ return final_prediction
103
 
104
  # Gradio Interface
105
+ interface = gr.Interface(
106
+ fn=predict_video,
 
 
 
 
107
  inputs=[
108
+ gr.Video(label="Input Video")
109
  ],
110
  outputs=[
111
+ gr.Label(label="Class"),
112
+
 
113
  ],
114
+ title="Deep fake video Detection",
115
+ description="Detect whether the Video is fake or real"
116
+ )
117
+
118
+ interface.launch()