|
import argparse |
|
import cv2 |
|
import os |
|
import time |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import time |
|
from alert import SimpleANN |
|
from datetime import datetime |
|
from dataset.transforms import BaseTransform |
|
from utils.misc import load_weight |
|
from config import build_dataset_config, build_model_config |
|
from models import build_model |
|
import pandas as pd |
|
import csv |
|
|
|
import torch.backends.cudnn as cudnn |
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.enabled = False |
|
|
|
model_alert = SimpleANN() |
|
|
|
model_alert.load_state_dict(torch.load(r'.\model_weights.pth')) |
|
|
|
model_alert.eval() |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='YOWOv2') |
|
|
|
|
|
parser.add_argument('-size', '--img_size', default=224, type=int, |
|
help='the size of input frame') |
|
parser.add_argument('--show', action='store_true', default=False, |
|
help='show the visulization results.') |
|
parser.add_argument('--cuda', action='store_true', default=True, |
|
help='use cuda.') |
|
parser.add_argument('--save_folder', default='det_results/', type=str, |
|
help='Dir to save results') |
|
parser.add_argument('-vs', '--vis_thresh', default=0.1, type=float, |
|
help='threshold for visualization') |
|
parser.add_argument('--video', default='9Y_l9NsnYE0.mp4', type=str, |
|
help='AVA video name.') |
|
parser.add_argument('-d', '--dataset', default='ava_v2.2', |
|
help='ava_v2.2') |
|
|
|
|
|
parser.add_argument('-v', '--version', default='yowo_v2_large', type=str, |
|
help='build YOWOv2') |
|
parser.add_argument('--weight', default='./backup_dir/ava_v2.2/fps32_k16_bs16_yolo_large_newdata_p2/epoch4/yowo_v2_large_epoch_4.pth', |
|
type=str, help='Trained state_dict file path to open') |
|
parser.add_argument('--topk', default=40, type=int, |
|
help='NMS threshold') |
|
parser.add_argument('--threshold', default=0.1, type=int, |
|
help='threshold') |
|
|
|
return parser.parse_args() |
|
|
|
def process_frame(frame, video_clip, num_frame, transform, list_count_fighter, model, device, class_names, args, count_n_frames) : |
|
|
|
count_n_frames += 1 |
|
|
|
|
|
fight = 0 |
|
max_score = 0 |
|
frame_pil = Image.fromarray(frame.astype(np.uint8)) |
|
|
|
|
|
if len(video_clip) <= 0: |
|
for _ in range(num_frame): |
|
video_clip.append(frame_pil) |
|
|
|
video_clip.append(frame_pil) |
|
video_clip.pop(0) |
|
|
|
|
|
orig_h, orig_w = frame.shape[:2] |
|
|
|
t_transform = time.time() |
|
x = transform(video_clip) |
|
|
|
|
|
x = torch.stack(x, dim=1) |
|
x = x.unsqueeze(0).to(device) |
|
|
|
|
|
t0 = time.time() |
|
|
|
batch_bboxes = model(x) |
|
|
|
t1 = time.time() |
|
|
|
bboxes = batch_bboxes[0] |
|
|
|
for bbox in bboxes: |
|
x1, y1, x2, y2 = bbox[:4] |
|
det_conf = bbox[4] |
|
|
|
cls_out = det_conf * bbox[5:] |
|
|
|
x1, x2 = int(x1 * orig_w), int(x2 * orig_w) |
|
y1, y2 = int(y1 * orig_h), int(y2 * orig_h) |
|
|
|
|
|
cls_scores = np.array(cls_out) |
|
|
|
|
|
|
|
|
|
if max(cls_scores) < args.threshold: |
|
continue |
|
indices = np.argmax(cls_scores) |
|
scores = cls_scores[indices] |
|
indices = [indices] |
|
scores = [scores] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(scores) > 0: |
|
blk = np.zeros(frame.shape, np.uint8) |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
coord = [] |
|
text = [] |
|
text_size = [] |
|
|
|
|
|
if indices[0]== 0: |
|
fight += 1 |
|
max_score = max(cls_scores[indices], max_score) |
|
else: |
|
fight+=0 |
|
|
|
for _, cls_ind in enumerate(indices): |
|
|
|
if class_names[cls_ind] == "bully": |
|
color = (0,0,255) |
|
else: |
|
class_name = class_names[cls_ind] |
|
if class_name == "victim": |
|
color = (255,0,0) |
|
else: |
|
color = (0,255,0) |
|
|
|
|
|
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) |
|
|
|
|
|
|
|
|
|
frame = cv2.addWeighted(frame, 1.0, blk, 0.25, 1) |
|
for t in range(len(text)): |
|
cv2.putText(frame, text[t], coord[t], font, 0.75, (0, 0, 255), 2) |
|
|
|
|
|
if fight >= 1: |
|
fight = 1 |
|
list_count_fighter.append(fight) |
|
if len(list_count_fighter) > num_frame: |
|
list_count_fighter.pop(0) |
|
|
|
|
|
return frame, list_count_fighter, fight, max_score, count_n_frames |
|
@torch.no_grad() |
|
def run(args, d_cfg, model, device, transform, class_names): |
|
csv_file = "D:/yowov2V7/YOWOv2/alert_test.csv" |
|
with open(csv_file, 'r') as f: |
|
reader = csv.reader(f) |
|
data = list(reader) |
|
df = pd.read_csv(csv_file) |
|
video_value = "test_17" |
|
|
|
path_to_video = f"D:/NO/Django_code/video_test/{video_value}.mp4" |
|
name = path_to_video.split("/")[-1] |
|
video = cv2.VideoCapture(1) |
|
save_size = (1280, 720) |
|
fps = 2 |
|
|
|
id_frame = 7 |
|
num_frame = 8 |
|
video_clip = [] |
|
list_count_fighter = [] |
|
alert = "Normal" |
|
color = (0,255,0) |
|
count_fight = 0 |
|
count_frame = 0 |
|
count_n_frames = -1 |
|
while(True): |
|
|
|
ret, frame = video.read() |
|
now = datetime.now() |
|
formatted_time = now.strftime("%Y-%m-%d %H:%M:%S") |
|
if ret: |
|
start_time = time.time() |
|
count_frame += 1 |
|
|
|
if count_frame % id_frame == 0: |
|
count_frame = 0 |
|
frame, list_count_fighter, fight, max_score, count_n_frames = process_frame (frame, video_clip, num_frame, transform,list_count_fighter, model, device, class_names, args, count_n_frames) |
|
df.loc[(df['video'] == video_value) & (df['id'] == count_n_frames), f'predict_8'] = fight |
|
df.loc[(df['video'] == video_value) & (df['id'] == count_n_frames), f'conf_score_8'] = max_score |
|
if len(list_count_fighter) == num_frame: |
|
|
|
|
|
|
|
|
|
count_fight = 0 |
|
for i in list_count_fighter: |
|
count_fight += i |
|
|
|
if count_fight >= num_frame/2: |
|
|
|
alert = "Bullying" |
|
|
|
|
|
|
|
color = (0,0,255) |
|
else: |
|
alert = "Normal" |
|
|
|
|
|
|
|
color = (0,255,0) |
|
|
|
|
|
df.to_csv(csv_file, index=False) |
|
current_time = time.time() |
|
elapsed_time = current_time - start_time |
|
|
|
fps = 1/elapsed_time |
|
|
|
|
|
|
|
|
|
cv2.putText(frame, f"Time: {str(formatted_time)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) |
|
cv2.putText(frame, f"FPS: {fps:.2f}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) |
|
cv2.putText(frame, f"Alert: {alert}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) |
|
|
|
|
|
|
|
if args.show: |
|
|
|
cv2.namedWindow('key-frame detection', cv2.WINDOW_NORMAL) |
|
|
|
|
|
cv2.resizeWindow('key-frame detection', 1280, 720) |
|
|
|
|
|
cv2.imshow('key-frame detection', frame) |
|
cv2.imshow('key-frame detection', frame) |
|
|
|
if cv2.waitKey(1) & 0xFF == ord('q'): |
|
break |
|
|
|
|
|
else: |
|
break |
|
|
|
video.release() |
|
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
|
|
if args.cuda: |
|
cudnn.benchmark = True |
|
print('use cuda') |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
d_cfg = build_dataset_config(args) |
|
m_cfg = build_model_config(args) |
|
|
|
class_names = d_cfg['label_map'] |
|
num_classes = 3 |
|
|
|
|
|
basetransform = BaseTransform( |
|
img_size=d_cfg['test_size'], |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
model = build_model( |
|
args=args, |
|
d_cfg=d_cfg, |
|
m_cfg=m_cfg, |
|
device=device, |
|
num_classes=num_classes, |
|
trainable=False |
|
) |
|
|
|
|
|
model = load_weight(model=model, path_to_ckpt=args.weight) |
|
|
|
|
|
model = model.to(device).eval() |
|
|
|
|
|
run(args=args, d_cfg=d_cfg, model=model, device=device, |
|
transform=basetransform, class_names=class_names) |
|
|