Update README.md
Browse files
README.md
CHANGED
@@ -4,7 +4,9 @@ license: mit
|
|
4 |
|
5 |
Model convert from [https://github.com/KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)
|
6 |
|
7 |
-
Usage:
|
|
|
|
|
8 |
|
9 |
```python
|
10 |
import cv2
|
@@ -41,3 +43,107 @@ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
41 |
tags = tagger_predict(img, 0.5)
|
42 |
print(tags)
|
43 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
Model convert from [https://github.com/KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)
|
6 |
|
7 |
+
## Usage:
|
8 |
+
|
9 |
+
### Basic use
|
10 |
|
11 |
```python
|
12 |
import cv2
|
|
|
43 |
tags = tagger_predict(img, 0.5)
|
44 |
print(tags)
|
45 |
```
|
46 |
+
|
47 |
+
### Multi-gpu batch process
|
48 |
+
|
49 |
+
|
50 |
+
```python
|
51 |
+
import cv2
|
52 |
+
import torch
|
53 |
+
import os
|
54 |
+
import numpy as np
|
55 |
+
import onnxruntime as rt
|
56 |
+
from huggingface_hub import hf_hub_download
|
57 |
+
from torch.utils.data import DataLoader, Dataset
|
58 |
+
from PIL import Image
|
59 |
+
from tqdm import tqdm
|
60 |
+
from threading import Thread
|
61 |
+
|
62 |
+
|
63 |
+
class MyDataset(Dataset):
|
64 |
+
def __init__(self, image_list):
|
65 |
+
self.image_list = image_list
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
length = len(self.image_list)
|
69 |
+
return length
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
image = Image.open(self.image_list[index]).convert("RGB")
|
73 |
+
image = np.asarray(image)
|
74 |
+
s = 512
|
75 |
+
h, w = image.shape[:-1]
|
76 |
+
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
77 |
+
ph, pw = s - h, s - w
|
78 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
|
79 |
+
image = cv2.copyMakeBorder(image, ph // 2, ph - ph // 2, pw // 2, pw - pw // 2, cv2.BORDER_REPLICATE)
|
80 |
+
image = image.astype(np.float32) / 255
|
81 |
+
image = torch.from_numpy(image)
|
82 |
+
idx = torch.tensor([index], dtype=torch.int32)
|
83 |
+
return image, idx
|
84 |
+
|
85 |
+
|
86 |
+
def get_images(path):
|
87 |
+
def file_ext(fname):
|
88 |
+
return os.path.splitext(fname)[1].lower()
|
89 |
+
|
90 |
+
all_files = {
|
91 |
+
os.path.relpath(os.path.join(root, fname), path)
|
92 |
+
for root, _dirs, files in os.walk(path)
|
93 |
+
for fname in files
|
94 |
+
}
|
95 |
+
all_images = sorted(
|
96 |
+
os.path.join(path, fname) for fname in all_files if file_ext(fname) in [".png", ".jpg", ".jpeg"]
|
97 |
+
)
|
98 |
+
print(len(all_images))
|
99 |
+
return all_images
|
100 |
+
|
101 |
+
|
102 |
+
def process(all_images, batch_size=8, score_threshold=0.35):
|
103 |
+
predictions = {}
|
104 |
+
|
105 |
+
def work_fn(images, device_id):
|
106 |
+
dataset = MyDataset(images)
|
107 |
+
dataloader = DataLoader(
|
108 |
+
dataset,
|
109 |
+
batch_size=batch_size,
|
110 |
+
shuffle=False,
|
111 |
+
persistent_workers=True,
|
112 |
+
num_workers=4,
|
113 |
+
pin_memory=True,
|
114 |
+
)
|
115 |
+
for data in tqdm(dataloader):
|
116 |
+
image, idxs = data
|
117 |
+
image = image.numpy()
|
118 |
+
probs = tagger_model[device_id].run(None, {"input_1": image})[0]
|
119 |
+
probs = probs.astype(np.float32)
|
120 |
+
bs = probs.shape[0]
|
121 |
+
for i in range(bs):
|
122 |
+
tags = []
|
123 |
+
for prob, label in zip(probs[i].tolist(), tagger_tags):
|
124 |
+
if prob > score_threshold:
|
125 |
+
tags.append((label.replace("_", " "), prob))
|
126 |
+
predictions[images[idxs[i].item()]] = tags
|
127 |
+
|
128 |
+
gpu_num = len(tagger_model)
|
129 |
+
image_num = (len(all_images) // gpu_num) + 1
|
130 |
+
ts = [Thread(target=work_fn, args=(all_images[i * image_num:(i + 1) * image_num], i)) for i in range(gpu_num)]
|
131 |
+
for t in ts:
|
132 |
+
t.start()
|
133 |
+
for t in ts:
|
134 |
+
t.join()
|
135 |
+
return predictions
|
136 |
+
|
137 |
+
|
138 |
+
gpu_num = 4
|
139 |
+
batch_size = 8
|
140 |
+
tagger_model_path = hf_hub_download(repo_id="skytnt/deepdanbooru_onnx", filename="deepdanbooru.onnx")
|
141 |
+
tagger_model = [
|
142 |
+
rt.InferenceSession(tagger_model_path, providers=['CUDAExecutionProvider'], provider_options=[{'device_id': i}]) for
|
143 |
+
i in range(4)]
|
144 |
+
tagger_model_meta = tagger_model[0].get_modelmeta().custom_metadata_map
|
145 |
+
tagger_tags = eval(tagger_model_meta['tags'])
|
146 |
+
|
147 |
+
all_images = get_images("./data")
|
148 |
+
predictions = process(all_images, batch_size)
|
149 |
+
```
|