Ryan-Pham commited on
Commit
5111a32
1 Parent(s): beb7843

Upload 3 files

Browse files
Files changed (3) hide show
  1. alert.py +24 -0
  2. model_weights.pth +3 -0
  3. test_video_ava.py +320 -0
alert.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, TensorDataset
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+
6
+ class SimpleANN(nn.Module):
7
+ def __init__(self):
8
+ super(SimpleANN, self).__init__()
9
+ self.hidden = nn.Linear(16, 8)
10
+ self.output = nn.Linear(8, 1)
11
+ self.sigmoid = nn.Sigmoid()
12
+
13
+ def forward(self, x):
14
+ x = torch.relu(self.hidden(x))
15
+ x = self.sigmoid(self.output(x))
16
+ return x
17
+
18
+ # model = SimpleANN()
19
+
20
+ # model.load_state_dict(torch.load(r'D:\yowov2V7\YOWOv2\model_weights.pth'))
21
+
22
+ # model.eval()
23
+ # outputs = model(X_batch) # đưa cái chuỗi 16 cái có hay ko vô đây
24
+ # predicted = outputs.round()
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eade0256195b458a21ae291877432f295ce2a6a8febe693c6fdba0732600677a
3
+ size 2688
test_video_ava.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import os
4
+ import time
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import time
9
+ from alert import SimpleANN
10
+ from datetime import datetime
11
+ from dataset.transforms import BaseTransform
12
+ from utils.misc import load_weight
13
+ from config import build_dataset_config, build_model_config
14
+ from models import build_model
15
+ import pandas as pd
16
+ import csv
17
+
18
+ import torch.backends.cudnn as cudnn
19
+ # import torch.distributed as dist
20
+ # from torch.nn.parallel import DistributedDataParallel as DDP
21
+
22
+ # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
23
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
24
+ # torch.cuda.set_device(0)
25
+ torch.backends.cudnn.enabled = False
26
+ #dist.init_process_group(backend='nccl')
27
+ model_alert = SimpleANN()
28
+
29
+ model_alert.load_state_dict(torch.load(r'D:\yowov2V7\YOWOv2\model_weights.pth'))
30
+
31
+ model_alert.eval()
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(description='YOWOv2')
35
+
36
+ # basic
37
+ parser.add_argument('-size', '--img_size', default=224, type=int,
38
+ help='the size of input frame')
39
+ parser.add_argument('--show', action='store_true', default=False,
40
+ help='show the visulization results.')
41
+ parser.add_argument('--cuda', action='store_true', default=False,
42
+ help='use cuda.')
43
+ parser.add_argument('--save_folder', default='det_results/', type=str,
44
+ help='Dir to save results')
45
+ parser.add_argument('-vs', '--vis_thresh', default=0.1, type=float,
46
+ help='threshold for visualization')
47
+ parser.add_argument('--video', default='9Y_l9NsnYE0.mp4', type=str,
48
+ help='AVA video name.')
49
+ parser.add_argument('-d', '--dataset', default='ava_v2.2',
50
+ help='ava_v2.2')
51
+
52
+ # model
53
+ parser.add_argument('-v', '--version', default='yowo_v2_large', type=str,
54
+ help='build YOWOv2')
55
+ parser.add_argument('--weight', default='weight/',
56
+ type=str, help='Trained state_dict file path to open')
57
+ parser.add_argument('--topk', default=40, type=int,
58
+ help='NMS threshold')
59
+ parser.add_argument('--threshold', default=0.1, type=int,
60
+ help='threshold')
61
+
62
+ return parser.parse_args()
63
+
64
+ def process_frame(frame, video_clip, num_frame, transform, list_count_fighter, model, device, class_names, args, count_n_frames) :
65
+
66
+ count_n_frames += 1
67
+
68
+ # to PIL image
69
+ fight = 0
70
+ max_score = 0
71
+ frame_pil = Image.fromarray(frame.astype(np.uint8))
72
+
73
+
74
+ if len(video_clip) <= 0:
75
+ for _ in range(num_frame):
76
+ video_clip.append(frame_pil)
77
+
78
+ video_clip.append(frame_pil)
79
+ video_clip.pop(0)
80
+ #del video_clip[0]
81
+ # orig size
82
+ orig_h, orig_w = frame.shape[:2]
83
+ # transform
84
+ t_transform = time.time()
85
+ x = transform(video_clip)
86
+ # print("before transform", time.time() - t_transform, "s")
87
+ # List [T, 3, H, W] -> [3, T, H, W]
88
+ x = torch.stack(x, dim=1)
89
+ x = x.unsqueeze(0).to(device) # [B, 3, T, H, W], B=1
90
+
91
+ # print("preprocessing input", time.time() - start_time, "s")
92
+ t0 = time.time()
93
+ # inference
94
+ batch_bboxes = model(x)
95
+ # print("inference time ", time.time() - t0, "s")
96
+ t1 = time.time()
97
+ # batch size = 1
98
+ bboxes = batch_bboxes[0]
99
+ # visualize detection results
100
+ for bbox in bboxes:
101
+ x1, y1, x2, y2 = bbox[:4]
102
+ det_conf = bbox[4]
103
+ #cls_out = [det_conf * cls_conf for cls_conf in bbox[5:]]
104
+ cls_out = det_conf * bbox[5:]
105
+ # rescale bbox
106
+ x1, x2 = int(x1 * orig_w), int(x2 * orig_w)
107
+ y1, y2 = int(y1 * orig_h), int(y2 * orig_h)
108
+
109
+ # numpy
110
+ cls_scores = np.array(cls_out)
111
+ # tensor
112
+ #cls_scores = cls_out.cpu().detach().numpy()
113
+
114
+
115
+ if max(cls_scores) < args.threshold:
116
+ continue
117
+ indices = np.argmax(cls_scores)
118
+ scores = cls_scores[indices]
119
+ indices = [indices]
120
+ scores = [scores]
121
+ # max_score = max(cls_scores)
122
+ # if max_score < args.threshold:
123
+ # continue
124
+ # indices = [np.argmax(cls_scores)]
125
+ # scores = [max_score]
126
+ # indices = np.where(cls_scores > 0.0)
127
+ # scores = cls_scores[indices]
128
+ # indices = list(indices[0])
129
+ # scores = list(scores)
130
+ if len(scores) > 0:
131
+ blk = np.zeros(frame.shape, np.uint8)
132
+ font = cv2.FONT_HERSHEY_SIMPLEX
133
+ coord = []
134
+ text = []
135
+ text_size = []
136
+
137
+ #-----------------------------old---------------------------------------------#
138
+ if indices[0]== 0:
139
+ fight += 1
140
+ max_score = max(cls_scores[indices], max_score)
141
+ else:
142
+ fight+=0
143
+
144
+ for _, cls_ind in enumerate(indices):
145
+ #-----------------------------old---------------------------------------------#
146
+ if class_names[cls_ind] == "bully":
147
+ color = (0,0,255)
148
+ else:
149
+ class_name = class_names[cls_ind]
150
+ if class_name == "victim":
151
+ color = (255,0,0)
152
+ else:
153
+ color = (0,255,0)
154
+
155
+
156
+ #color = (0,255,0)
157
+ #print(class_names[cls_ind])
158
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
159
+ #text.append("[{:.2f}] ".format(scores[_]) + str(class_names[cls_ind]))
160
+ #text_size.append(cv2.getTextSize(text[-1], font, fontScale=0.75, thickness=2)[0])
161
+ #coord.append((x1+3, y1+25+10*_))
162
+ #cv2.rectangle(blk, (coord[-1][0]-1, coord[-1][1]-20), (coord[-1][0]+text_size[-1][0]+1, coord[-1][1]+text_size[-1][1]-8), (0, 255, 0), cv2.FILLED)
163
+ frame = cv2.addWeighted(frame, 1.0, blk, 0.25, 1)
164
+ for t in range(len(text)):
165
+ cv2.putText(frame, text[t], coord[t], font, 0.75, (0, 0, 255), 2)
166
+ # print("after predict time", time.time() - t1, "s")
167
+
168
+ if fight >= 1:
169
+ fight = 1
170
+ list_count_fighter.append(fight)
171
+ if len(list_count_fighter) > num_frame:
172
+ list_count_fighter.pop(0)
173
+
174
+
175
+ return frame, list_count_fighter, fight, max_score, count_n_frames
176
+ @torch.no_grad()
177
+ def run(args, d_cfg, model, device, transform, class_names):
178
+ csv_file = "D:/yowov2V7/YOWOv2/alert_test.csv"
179
+ with open(csv_file, 'r') as f:
180
+ reader = csv.reader(f)
181
+ data = list(reader)
182
+ df = pd.read_csv(csv_file)
183
+ video_value = "test_17"
184
+
185
+ path_to_video = f"D:/NO/Django_code/video_test/{video_value}.mp4"
186
+ name = path_to_video.split("/")[-1]
187
+ video = cv2.VideoCapture(1)
188
+ save_size = (1280, 720)
189
+ fps = 2
190
+ #id_frame = 30/fps
191
+ id_frame = 7
192
+ num_frame = 8
193
+ video_clip = []
194
+ list_count_fighter = []
195
+ alert = "Normal"
196
+ color = (0,255,0)
197
+ count_fight = 0
198
+ count_frame = 0
199
+ count_n_frames = -1
200
+ while(True):
201
+
202
+ ret, frame = video.read()
203
+ now = datetime.now()
204
+ formatted_time = now.strftime("%Y-%m-%d %H:%M:%S")
205
+ if ret:
206
+ start_time = time.time()
207
+ count_frame += 1
208
+ # prepare
209
+ if count_frame % id_frame == 0:
210
+ count_frame = 0
211
+ frame, list_count_fighter, fight, max_score, count_n_frames = process_frame (frame, video_clip, num_frame, transform,list_count_fighter, model, device, class_names, args, count_n_frames)
212
+ df.loc[(df['video'] == video_value) & (df['id'] == count_n_frames), f'predict_8'] = fight
213
+ df.loc[(df['video'] == video_value) & (df['id'] == count_n_frames), f'conf_score_8'] = max_score
214
+ if len(list_count_fighter) == num_frame:
215
+ # print(torch.tensor(list_count_fighter).type(torch.LongTensor))
216
+ # print(type(torch.tensor(list_count_fighter).type(torch.LongTensor)))
217
+ #outputs = model_alert(torch.tensor(list_count_fighter).float())
218
+ #predicted = outputs.round()
219
+ count_fight = 0
220
+ for i in list_count_fighter:
221
+ count_fight += i
222
+
223
+ if count_fight >= num_frame/2:
224
+ #if predicted == 1:
225
+ alert = "Bullying"
226
+
227
+ #print("Bully")
228
+ #df.loc[(df['video'] == video_value) & (df['id'] == count_n_frames), f'predict_{num_frame}'] = 1
229
+ color = (0,0,255)
230
+ else:
231
+ alert = "Normal"
232
+
233
+ #print("Normal")
234
+ #df.loc[(df['video'] == video_value) & (df['id'] == count_n_frames), f'predict_{num_frame}'] = 0
235
+ color = (0,255,0)
236
+ # frames += 1
237
+ # count_frame += 1
238
+ df.to_csv(csv_file, index=False)
239
+ current_time = time.time()
240
+ elapsed_time = current_time - start_time
241
+ # print("elapsed_time", elapsed_time)
242
+ fps = 1/elapsed_time
243
+ # if elapsed_time >= 1:
244
+ # fps = frame / elapsed_time
245
+ # start_time = current_time
246
+ # frames = 0
247
+ cv2.putText(frame, f"Time: {str(formatted_time)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
248
+ cv2.putText(frame, f"FPS: {fps:.2f}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
249
+ cv2.putText(frame, f"Alert: {alert}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
250
+ # save
251
+ # out.write(frame)
252
+
253
+ if args.show:
254
+ # show
255
+ cv2.namedWindow('key-frame detection', cv2.WINDOW_NORMAL)
256
+
257
+ # Thay đổi kích thước cửa sổ thành (width, height)
258
+ cv2.resizeWindow('key-frame detection', 1280, 720)
259
+
260
+ # Hiển thị khung hình trong cửa sổ
261
+ cv2.imshow('key-frame detection', frame)
262
+ cv2.imshow('key-frame detection', frame)
263
+ #------------------------------------original---------------------------------------------------------
264
+ if cv2.waitKey(1) & 0xFF == ord('q'):
265
+ break
266
+
267
+
268
+ else:
269
+ break
270
+
271
+ video.release()
272
+ # out.release()
273
+ cv2.destroyAllWindows()
274
+
275
+
276
+ if __name__ == '__main__':
277
+ args = parse_args()
278
+ # cuda
279
+ if args.cuda:
280
+ cudnn.benchmark = True
281
+ print('use cuda')
282
+ device = torch.device("cuda")
283
+ else:
284
+ device = torch.device("cpu")
285
+
286
+ # config
287
+ d_cfg = build_dataset_config(args)
288
+ m_cfg = build_model_config(args)
289
+
290
+ class_names = d_cfg['label_map']
291
+ num_classes = 3
292
+
293
+ # transform
294
+ basetransform = BaseTransform(
295
+ img_size=d_cfg['test_size'],
296
+ # pixel_mean=d_cfg['pixel_mean'],
297
+ # pixel_std=d_cfg['pixel_std']
298
+ # pixel_mean=0,
299
+ # pixel_std=1
300
+ )
301
+
302
+ # build model
303
+ model = build_model(
304
+ args=args,
305
+ d_cfg=d_cfg,
306
+ m_cfg=m_cfg,
307
+ device=device,
308
+ num_classes=num_classes,
309
+ trainable=False
310
+ )
311
+
312
+ # load trained weight
313
+ model = load_weight(model=model, path_to_ckpt=args.weight)
314
+
315
+ # to eval
316
+ model = model.to(device).eval()
317
+
318
+ # run
319
+ run(args=args, d_cfg=d_cfg, model=model, device=device,
320
+ transform=basetransform, class_names=class_names)