skytnt commited on
Commit
6586dd5
1 Parent(s): c583569

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -18
README.md CHANGED
@@ -19,27 +19,25 @@ tagger_model_meta = tagger_model.get_modelmeta().custom_metadata_map
19
  tagger_tags = eval(tagger_model_meta['tags'])
20
 
21
  def tagger_predict(image, score_threshold):
22
- h, w = image.shape[:2]
23
- r = min(512 / w, 512 / h)
24
- h, w = int(h * r), int(w * r)
25
- image = cv2.resize(image, (w, h))
26
- pdx = 512 - w
27
- pdy = 512 - h
28
- img_new = np.full([512, 512, 3], 1, dtype=np.float32)
29
- img_new[pdy // 2:pdy // 2 + h, pdx // 2:pdx // 2 + w] = image
30
- image = img_new[np.newaxis, :]
31
- probs = tagger_model.run(None, {"input_1": image})[0][0]
32
- probs = probs.astype(np.float32)
33
- res = []
34
- for prob, label in zip(probs.tolist(), tagger_tags):
35
- if prob < score_threshold:
36
- continue
37
- res.append(label)
38
- return res
39
 
40
  img = cv2.imread("test.jpg")
41
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
42
- img = img.astype(np.float32) / 255
43
  tags = tagger_predict(img, 0.5)
44
  print(tags)
45
  ```
 
19
  tagger_tags = eval(tagger_model_meta['tags'])
20
 
21
  def tagger_predict(image, score_threshold):
22
+ s = 512
23
+ h, w = image.shape[:-1]
24
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
25
+ ph, pw = s - h, s - w
26
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
27
+ image = cv2.copyMakeBorder(image, ph // 2, ph - ph // 2, pw // 2, pw - pw // 2, cv2.BORDER_REPLICATE)
28
+ image = image.astype(np.float32) / 255
29
+ image = img_new[np.newaxis, :]
30
+ probs = tagger_model.run(None, {"input_1": image})[0][0]
31
+ probs = probs.astype(np.float32)
32
+ res = []
33
+ for prob, label in zip(probs.tolist(), tagger_tags):
34
+ if prob < score_threshold:
35
+ continue
36
+ res.append(label)
37
+ return res
 
38
 
39
  img = cv2.imread("test.jpg")
40
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
41
  tags = tagger_predict(img, 0.5)
42
  print(tags)
43
  ```