""" |
This code is used to batch detect images in a folder. |
""" |
import argparse |
import os |
import sys |
import cv2 |
from vision.ssd.config.fd_config import define_img_size |
parser = argparse.ArgumentParser(description='detect_imgs') |
parser.add_argument('--net_type', default="RFB", type=str, |
help='The network architecture ,optional: RFB (higher precision) or slim (faster)') |
parser.add_argument('--input_size', default=320, type=int, |
help='define network input size,default optional value 128/160/320/480/640/1280') |
parser.add_argument('--threshold', default=0.65, type=float, |
help='score threshold') |
parser.add_argument('--candidate_size', default=1500, type=int, |
help='nms candidate size') |
parser.add_argument('--path', default="D:/Database/face_detect/test/originalPics", type=str, |
help='imgs dir') |
parser.add_argument('--test_device', default="cpu", type=str, |
help='cuda:0 or cpu') |
args = parser.parse_args() |
define_img_size(args.input_size) |
from vision.ssd.mb_tiny_fd import create_mb_tiny_fd, create_mb_tiny_fd_predictor |
from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd, create_Mb_Tiny_RFB_fd_predictor |
result_path = "./detect_imgs_results" |
label_path = "./models/voc-model-labels.txt" |
fd_result_path = 'D:/Database/face_detect/test/rfb_fd_result.txt' |
fddb_txt_path = 'D:/Database/face_detect/test/FDDB-folds/FDDB-fold-01-10_2845.txt' |
test_device = args.test_device |
class_names = [name.strip() for name in open(label_path).readlines()] |
if args.net_type == 'slim': |
model_path = "models/pretrained/version-slim-320.pth" |
net = create_mb_tiny_fd(len(class_names), is_test=True, device=test_device) |
predictor = create_mb_tiny_fd_predictor(net, candidate_size=args.candidate_size, device=test_device) |
elif args.net_type == 'RFB': |
model_path = "models/pretrained/version-RFB-320.pth" |
net = create_Mb_Tiny_RFB_fd(len(class_names), is_test=True, device=test_device) |
predictor = create_Mb_Tiny_RFB_fd_predictor(net, candidate_size=args.candidate_size, device=test_device) |
else: |
print("The net type is wrong!") |
sys.exit(1) |
net.load(model_path) |
def get_file_names(dir_path): |
file_list = os.listdir(dir_path) |
total_file_list = list() |
for entry in file_list: |
full_path = os.path.join(dir_path, entry) |
if (os.path.isdir(full_path)): |
total_file_list = total_file_list + get_file_names(full_path) |
else: |
total_file_list.append(full_path) |
return total_file_list |
def get_file_paths(txt_path): |
path_list = list() |
with open(txt_path, "r") as txt_file: |
for line in txt_file: |
path_list.append(line.strip()) |
return path_list |
if __name__ == '__main__': |
if not os.path.exists(result_path): |
os.makedirs(result_path) |
listdir = get_file_paths(fddb_txt_path) |
total_count = 0 |
correct_count = 0 |
for file_path in listdir: |
filename = file_path |
img_path = os.path.join(args.path, filename) |
orig_image = cv2.imread(img_path + ".jpg") |
if orig_image is None: |
continue |
print("filename: ", filename) |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) |
boxes, labels, probs = predictor.predict(image, args.candidate_size / 2, args.threshold) |
with open(fd_result_path, "a") as fd_result_file: |
print(filename, file=fd_result_file) |
print(boxes.size(0), file=fd_result_file) |
for i in range(boxes.size(0)): |
box = boxes[i, :] |
score = f"{probs[i]:.3f}" |
print(f"{box[0]:.3f}", f"{box[1]:.3f}", f"{box[2] - box[0]:.3f}", f"{box[3] - box[1]:.3f}", score, file=fd_result_file) |