yotamsapi commited on
Commit
86de714
β€’
1 Parent(s): 2984b18

Create ops.py

Browse files
Files changed (1) hide show
  1. retinaface/ops.py +26 -0
retinaface/ops.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from retinaface.anchor import decode_tf, prior_box_tf
2
+ import tensorflow as tf
3
+
4
+
5
+ def extract_detections(bbox_regressions, landm_regressions, classifications, image_sizes, iou_th=0.4, score_th=0.02):
6
+ min_sizes = [[16, 32], [64, 128], [256, 512]]
7
+ steps = [8, 16, 32]
8
+ variances = [0.1, 0.2]
9
+ preds = tf.concat( # [bboxes, landms, landms_valid, conf]
10
+ [bbox_regressions,
11
+ landm_regressions,
12
+ tf.ones_like(classifications[:, 0][..., tf.newaxis]),
13
+ classifications[:, 1][..., tf.newaxis]], 1)
14
+ priors = prior_box_tf(image_sizes, min_sizes, steps, False)
15
+ decode_preds = decode_tf(preds, priors, variances)
16
+
17
+ selected_indices = tf.image.non_max_suppression(
18
+ boxes=decode_preds[:, :4],
19
+ scores=decode_preds[:, -1],
20
+ max_output_size=tf.shape(decode_preds)[0],
21
+ iou_threshold=iou_th,
22
+ score_threshold=score_th)
23
+
24
+ out = tf.gather(decode_preds, selected_indices)
25
+
26
+ return out