[update]add main
Browse files- .gitignore +10 -0
- README.md +2 -2
- data/2lnWoly.jpg +0 -0
- examples/detr_cppe5/step_2_train_model.py +271 -0
- examples/detr_cppe5/step_3_test_model.py +184 -0
- main.py +268 -0
- project_settings.py +12 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.git/
|
3 |
+
.idea/
|
4 |
+
|
5 |
+
cache/
|
6 |
+
flagged/
|
7 |
+
gradio_cached_examples/
|
8 |
+
hub_datasets/
|
9 |
+
|
10 |
+
**/__pycache__/
|
README.md
CHANGED
@@ -4,8 +4,8 @@ emoji: 🐨
|
|
4 |
colorFrom: pink
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
4 |
colorFrom: pink
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.38.0
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
data/2lnWoly.jpg
ADDED
examples/detr_cppe5/step_2_train_model.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
reference:
|
5 |
+
https://huggingface.co/docs/transformers/tasks/object_detection
|
6 |
+
|
7 |
+
pip install -q datasets transformers evaluate timm albumentations
|
8 |
+
|
9 |
+
"""
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
import os
|
12 |
+
from pathlib import Path
|
13 |
+
import re
|
14 |
+
from typing import Dict, List
|
15 |
+
|
16 |
+
# from project_settings import project_path
|
17 |
+
# project_path = os.path.abspath(os.path.dirname(__file__))
|
18 |
+
project_path = os.path.abspath("./")
|
19 |
+
project_path = Path(project_path)
|
20 |
+
|
21 |
+
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
|
22 |
+
|
23 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
|
24 |
+
|
25 |
+
import albumentations
|
26 |
+
from datasets import load_dataset
|
27 |
+
import huggingface_hub
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image, ImageDraw
|
30 |
+
import torch
|
31 |
+
import torch.distributed as dist
|
32 |
+
import torch.multiprocessing as mp
|
33 |
+
from transformers import HfArgumentParser
|
34 |
+
from transformers.models.auto.processing_auto import AutoImageProcessor
|
35 |
+
from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
|
36 |
+
from transformers import TrainingArguments
|
37 |
+
from transformers import Trainer
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ScriptArguments:
|
42 |
+
# dataset
|
43 |
+
dataset_path: str = field(default="qgyd2021/cppe-5")
|
44 |
+
dataset_name: str = field(default=None)
|
45 |
+
dataset_cache_dir: str = field(default=(project_path / "hub_datasets").as_posix())
|
46 |
+
# dataset_cache_dir: str = field(default="hub_datasets")
|
47 |
+
|
48 |
+
# model
|
49 |
+
pretrained_model_name_or_path: str = field(default="facebook/detr-resnet-50")
|
50 |
+
|
51 |
+
# training_args
|
52 |
+
output_dir: str = field(default="output_dir")
|
53 |
+
per_device_train_batch_size: int = field(default=8)
|
54 |
+
gradient_accumulation_steps: int = field(default=4)
|
55 |
+
num_train_epochs: float = field(default=20)
|
56 |
+
fp16: bool = field(default=True)
|
57 |
+
save_steps: int = field(default=200)
|
58 |
+
logging_steps: int = field(default=50)
|
59 |
+
learning_rate: float = field(default=1e-5)
|
60 |
+
weight_decay: float = field(default=1e-4)
|
61 |
+
save_total_limit: int = field(default=2)
|
62 |
+
remove_unused_columns: bool = field(default=False)
|
63 |
+
report_to: str = field(default="tensorboard")
|
64 |
+
push_to_hub: bool = field(default=True)
|
65 |
+
hub_model_id: str = field(default="detr_cppe5_object_detection")
|
66 |
+
hub_strategy: str = field(default="every_save")
|
67 |
+
|
68 |
+
# hf_token
|
69 |
+
hf_token: str = field(default="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")
|
70 |
+
|
71 |
+
|
72 |
+
def get_args():
|
73 |
+
parser = HfArgumentParser(ScriptArguments)
|
74 |
+
args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
|
75 |
+
return args
|
76 |
+
|
77 |
+
|
78 |
+
def show_first_image(example: dict, index_to_label: Dict[int, str]):
|
79 |
+
image: Image = example["image"]
|
80 |
+
annotations = example["objects"]
|
81 |
+
|
82 |
+
draw = ImageDraw.Draw(image)
|
83 |
+
|
84 |
+
for i in range(len(annotations["id"])):
|
85 |
+
box = annotations["bbox"][i - 1]
|
86 |
+
class_idx = annotations["category"][i - 1]
|
87 |
+
x, y, w, h = tuple(box)
|
88 |
+
draw.rectangle((x, y, x + w, y + h), outline="red", width=1)
|
89 |
+
draw.text((x, y), index_to_label[class_idx], fill="white")
|
90 |
+
return image
|
91 |
+
|
92 |
+
|
93 |
+
def formatted_annotations(image_id, category, area, bbox):
|
94 |
+
annotations = []
|
95 |
+
for i in range(0, len(category)):
|
96 |
+
new_ann = {
|
97 |
+
"image_id": image_id,
|
98 |
+
"category_id": category[i],
|
99 |
+
"isCrowd": 0,
|
100 |
+
"area": area[i],
|
101 |
+
"bbox": list(bbox[i]),
|
102 |
+
}
|
103 |
+
annotations.append(new_ann)
|
104 |
+
|
105 |
+
return annotations
|
106 |
+
|
107 |
+
|
108 |
+
def train_model(local_rank, world_size, args):
|
109 |
+
os.environ["RANK"] = f"{local_rank}"
|
110 |
+
os.environ["LOCAL_RANK"] = f"{local_rank}"
|
111 |
+
os.environ["WORLD_SIZE"] = f"{world_size}"
|
112 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
113 |
+
os.environ["MASTER_PORT"] = "12355"
|
114 |
+
|
115 |
+
huggingface_hub.login(token=args.hf_token)
|
116 |
+
|
117 |
+
# dataset
|
118 |
+
dataset_dict = load_dataset(
|
119 |
+
path=args.dataset_path,
|
120 |
+
cache_dir=args.dataset_cache_dir
|
121 |
+
)
|
122 |
+
train_dataset = dataset_dict["train"]
|
123 |
+
|
124 |
+
remove_idx = [590, 821, 822, 875, 876, 878, 879]
|
125 |
+
keep = [i for i in range(len(train_dataset)) if i not in remove_idx]
|
126 |
+
train_dataset = train_dataset.select(keep)
|
127 |
+
|
128 |
+
categories = ["Coverall", "Face_Shield", "Gloves", "Goggles", "Mask"]
|
129 |
+
index_to_label = {index: x for index, x in enumerate(categories, start=0)}
|
130 |
+
label_to_index = {v: k for k, v in index_to_label.items()}
|
131 |
+
|
132 |
+
# first_example = train_dataset[0]
|
133 |
+
# image: Image = show_first_image(example=first_example, index_to_label=index_to_label)
|
134 |
+
# image.show()
|
135 |
+
|
136 |
+
image_processor = AutoImageProcessor.from_pretrained(args.pretrained_model_name_or_path)
|
137 |
+
|
138 |
+
transform = albumentations.Compose(
|
139 |
+
[
|
140 |
+
albumentations.Resize(480, 480),
|
141 |
+
albumentations.HorizontalFlip(p=1.0),
|
142 |
+
albumentations.RandomBrightnessContrast(p=1.0),
|
143 |
+
],
|
144 |
+
bbox_params=albumentations.BboxParams(format="coco", label_fields=["category"]),
|
145 |
+
)
|
146 |
+
|
147 |
+
# transforming a batch
|
148 |
+
def transform_aug_annotation(examples):
|
149 |
+
image_ids = examples["image_id"]
|
150 |
+
images, bboxes, area, categories = [], [], [], []
|
151 |
+
for image, objects in zip(examples["image"], examples["objects"]):
|
152 |
+
image = np.array(image.convert("RGB"))[:, :, ::-1]
|
153 |
+
out = transform.__call__(image=image, bboxes=objects["bbox"], category=objects["category"])
|
154 |
+
|
155 |
+
area.append(objects["area"])
|
156 |
+
images.append(out["image"])
|
157 |
+
bboxes.append(out["bboxes"])
|
158 |
+
categories.append(out["category"])
|
159 |
+
|
160 |
+
targets = [
|
161 |
+
{"image_id": id_, "annotations": formatted_annotations(id_, cat_, ar_, box_)}
|
162 |
+
for id_, cat_, ar_, box_ in zip(image_ids, categories, area, bboxes)
|
163 |
+
]
|
164 |
+
|
165 |
+
return image_processor.__call__(images=images, annotations=targets, return_tensors="pt")
|
166 |
+
|
167 |
+
train_dataset = train_dataset.with_transform(transform_aug_annotation)
|
168 |
+
|
169 |
+
def collate_fn(batch):
|
170 |
+
pixel_values = [item["pixel_values"] for item in batch]
|
171 |
+
encoding = image_processor.pad(pixel_values, return_tensors="pt")
|
172 |
+
labels = [item["labels"] for item in batch]
|
173 |
+
batch = {
|
174 |
+
"pixel_values": encoding["pixel_values"],
|
175 |
+
"pixel_mask": encoding["pixel_mask"],
|
176 |
+
"labels": labels
|
177 |
+
}
|
178 |
+
return batch
|
179 |
+
|
180 |
+
model = AutoModelForObjectDetection.from_pretrained(
|
181 |
+
args.pretrained_model_name_or_path,
|
182 |
+
id2label=index_to_label,
|
183 |
+
label2id=label_to_index,
|
184 |
+
ignore_mismatched_sizes=True,
|
185 |
+
)
|
186 |
+
|
187 |
+
training_args = TrainingArguments(
|
188 |
+
output_dir=args.output_dir,
|
189 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
190 |
+
num_train_epochs=args.num_train_epochs,
|
191 |
+
fp16=args.fp16,
|
192 |
+
save_steps=args.save_steps,
|
193 |
+
logging_steps=args.logging_steps,
|
194 |
+
learning_rate=args.learning_rate,
|
195 |
+
weight_decay=args.weight_decay,
|
196 |
+
save_total_limit=args.save_total_limit,
|
197 |
+
remove_unused_columns=args.remove_unused_columns,
|
198 |
+
report_to=args.report_to,
|
199 |
+
push_to_hub=args.push_to_hub,
|
200 |
+
hub_model_id=args.hub_model_id,
|
201 |
+
hub_strategy=args.hub_strategy,
|
202 |
+
local_rank=local_rank,
|
203 |
+
ddp_backend="nccl",
|
204 |
+
# fsdp="auto_wrap",
|
205 |
+
)
|
206 |
+
print(training_args)
|
207 |
+
|
208 |
+
partial_state_str = f"""
|
209 |
+
distributed_type: {training_args.distributed_state.distributed_type}
|
210 |
+
local_process_index: {training_args.distributed_state.local_process_index}
|
211 |
+
num_processes: {training_args.distributed_state.num_processes}
|
212 |
+
process_index: {training_args.distributed_state.process_index}
|
213 |
+
device: {training_args.distributed_state.device}
|
214 |
+
"""
|
215 |
+
partial_state_str = re.sub(r"[\u0020]{4,}", "", partial_state_str)
|
216 |
+
print(partial_state_str)
|
217 |
+
|
218 |
+
environ = f"""
|
219 |
+
RANK: {os.environ.get("RANK", -1)}
|
220 |
+
WORLD_SIZE: {os.environ.get("WORLD_SIZE", -1)}
|
221 |
+
LOCAL_RANK: {os.environ.get("LOCAL_RANK", -1)}
|
222 |
+
"""
|
223 |
+
environ = re.sub(r"[\u0020]{4,}", "", environ)
|
224 |
+
print(environ)
|
225 |
+
|
226 |
+
trainer = Trainer(
|
227 |
+
model=model,
|
228 |
+
args=training_args,
|
229 |
+
data_collator=collate_fn,
|
230 |
+
train_dataset=train_dataset,
|
231 |
+
tokenizer=image_processor,
|
232 |
+
)
|
233 |
+
trainer.train()
|
234 |
+
trainer.push_to_hub()
|
235 |
+
return
|
236 |
+
|
237 |
+
|
238 |
+
def single_gpu_train():
|
239 |
+
args = get_args()
|
240 |
+
|
241 |
+
train_model(0, 1, args)
|
242 |
+
|
243 |
+
return
|
244 |
+
|
245 |
+
|
246 |
+
def train_on_kaggle_notebook():
|
247 |
+
"""
|
248 |
+
train on kaggle notebook with GPU T4 x2
|
249 |
+
|
250 |
+
from shutil import copyfile
|
251 |
+
copyfile(src = "../input/tempdataset/step_2_train_model.py", dst = "../working/step_2_train_model.py")
|
252 |
+
|
253 |
+
import step_2_train_model
|
254 |
+
step_2_train_model.train_on_kaggle_notebook()
|
255 |
+
|
256 |
+
"""
|
257 |
+
args = get_args()
|
258 |
+
|
259 |
+
world_size = torch.cuda.device_count()
|
260 |
+
print("world_size: {}".format(world_size))
|
261 |
+
|
262 |
+
mp.spawn(train_model,
|
263 |
+
args=(world_size, args),
|
264 |
+
nprocs=world_size,
|
265 |
+
join=True)
|
266 |
+
|
267 |
+
return
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == '__main__':
|
271 |
+
single_gpu_train()
|
examples/detr_cppe5/step_3_test_model.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
reference:
|
5 |
+
https://huggingface.co/spaces/nickmuchi/license-plate-detection-with-YOLOS
|
6 |
+
https://huggingface.co/docs/transformers/tasks/object_detection
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
import io
|
10 |
+
import os
|
11 |
+
from typing import Dict
|
12 |
+
|
13 |
+
from project_settings import project_path
|
14 |
+
|
15 |
+
hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
|
16 |
+
|
17 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
|
18 |
+
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import numpy as np
|
21 |
+
from PIL import Image
|
22 |
+
import requests
|
23 |
+
import torch
|
24 |
+
from transformers.models.auto.processing_auto import AutoImageProcessor
|
25 |
+
from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
|
26 |
+
from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
|
27 |
+
import validators
|
28 |
+
|
29 |
+
|
30 |
+
def get_args():
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument(
|
33 |
+
"--pretrained_model_name_or_path",
|
34 |
+
default="qgyd2021/detr_cppe5_object_detection",
|
35 |
+
# default=(project_path / "trained_models/detr_cppe5_object_detection").as_posix(),
|
36 |
+
type=str
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--image_url_or_path",
|
40 |
+
default="https://i.imgur.com/2lnWoly.jpg",
|
41 |
+
type=str
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--threshold",
|
45 |
+
default=0.24,
|
46 |
+
type=float
|
47 |
+
)
|
48 |
+
# 0.5, 0.6, 0.7
|
49 |
+
parser.add_argument("--iou_threshold", default=0.6, type=float)
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
# colors for visualization
|
55 |
+
COLORS = [
|
56 |
+
[0.000, 0.447, 0.741],
|
57 |
+
[0.850, 0.325, 0.098],
|
58 |
+
[0.929, 0.694, 0.125],
|
59 |
+
[0.494, 0.184, 0.556],
|
60 |
+
[0.466, 0.674, 0.188],
|
61 |
+
[0.301, 0.745, 0.933]
|
62 |
+
]
|
63 |
+
|
64 |
+
|
65 |
+
def get_original_image(url_input):
|
66 |
+
if validators.url(url_input):
|
67 |
+
image = Image.open(requests.get(url_input, stream=True).raw)
|
68 |
+
return image
|
69 |
+
|
70 |
+
|
71 |
+
def figure2image(fig):
|
72 |
+
buf = io.BytesIO()
|
73 |
+
fig.savefig(buf)
|
74 |
+
buf.seek(0)
|
75 |
+
pil_image = Image.open(buf)
|
76 |
+
base_width = 750
|
77 |
+
width_percent = base_width / float(pil_image.size[0])
|
78 |
+
height_size = (float(pil_image.size[1]) * float(width_percent))
|
79 |
+
height_size = int(height_size)
|
80 |
+
pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS)
|
81 |
+
return pil_image
|
82 |
+
|
83 |
+
|
84 |
+
def non_max_suppression(boxes, scores, threshold):
|
85 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
86 |
+
overlapping bounding boxes for a given object.
|
87 |
+
Args:
|
88 |
+
boxes: array of [xmin, ymin, xmax, ymax]
|
89 |
+
scores: array of scores associated with each box.
|
90 |
+
threshold: IoU threshold
|
91 |
+
Return:
|
92 |
+
keep: indices of the boxes to keep
|
93 |
+
"""
|
94 |
+
x1 = boxes[:, 0]
|
95 |
+
y1 = boxes[:, 1]
|
96 |
+
x2 = boxes[:, 2]
|
97 |
+
y2 = boxes[:, 3]
|
98 |
+
|
99 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
100 |
+
order = scores.argsort()[::-1] # get boxes with more confidence first
|
101 |
+
|
102 |
+
keep = []
|
103 |
+
while order.size > 0:
|
104 |
+
i = order[0] # pick max confidence box
|
105 |
+
keep.append(i)
|
106 |
+
|
107 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
108 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
109 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
110 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
111 |
+
|
112 |
+
w = np.maximum(0.0, xx2 - xx1 + 1) # maximum width
|
113 |
+
h = np.maximum(0.0, yy2 - yy1 + 1) # maximum height
|
114 |
+
inter = w * h
|
115 |
+
|
116 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
117 |
+
inds = np.where(ovr <= threshold)[0]
|
118 |
+
order = order[inds + 1]
|
119 |
+
|
120 |
+
return keep
|
121 |
+
|
122 |
+
|
123 |
+
def draw_boxes(image, boxes, scores, labels, threshold: float, idx_to_label: Dict[int, str] = None):
|
124 |
+
plt.figure(figsize=(50, 50))
|
125 |
+
plt.imshow(image)
|
126 |
+
|
127 |
+
if idx_to_label is not None:
|
128 |
+
labels = [idx_to_label[x] for x in labels]
|
129 |
+
|
130 |
+
axis = plt.gca()
|
131 |
+
colors = COLORS * len(boxes)
|
132 |
+
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
|
133 |
+
if score < threshold:
|
134 |
+
continue
|
135 |
+
axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10))
|
136 |
+
axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8))
|
137 |
+
plt.axis("off")
|
138 |
+
|
139 |
+
return figure2image(plt.gcf())
|
140 |
+
|
141 |
+
|
142 |
+
def main():
|
143 |
+
args = get_args()
|
144 |
+
|
145 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(args.pretrained_model_name_or_path)
|
146 |
+
model = AutoModelForObjectDetection.from_pretrained(args.pretrained_model_name_or_path)
|
147 |
+
image_processor = AutoImageProcessor.from_pretrained(args.pretrained_model_name_or_path)
|
148 |
+
|
149 |
+
# image
|
150 |
+
image = get_original_image(args.image_url_or_path)
|
151 |
+
image_size = torch.tensor([tuple(reversed(image.size))])
|
152 |
+
|
153 |
+
# infer
|
154 |
+
# inputs = feature_extractor(images=image, return_tensors="pt")
|
155 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
156 |
+
outputs = model.forward(**inputs)
|
157 |
+
|
158 |
+
processed_outputs = image_processor.post_process_object_detection(
|
159 |
+
outputs, threshold=args.threshold, target_sizes=image_size)
|
160 |
+
# processed_outputs = feature_extractor.post_process(outputs, target_sizes=image_size)
|
161 |
+
processed_outputs = processed_outputs[0]
|
162 |
+
|
163 |
+
# draw box
|
164 |
+
boxes = processed_outputs["boxes"].detach().numpy()
|
165 |
+
scores = processed_outputs["scores"].detach().numpy()
|
166 |
+
labels = processed_outputs["labels"].detach().numpy()
|
167 |
+
|
168 |
+
keep = non_max_suppression(boxes, scores, threshold=args.iou_threshold)
|
169 |
+
boxes = boxes[keep]
|
170 |
+
scores = scores[keep]
|
171 |
+
labels = labels[keep]
|
172 |
+
|
173 |
+
viz_image: Image = draw_boxes(
|
174 |
+
image, boxes, scores, labels,
|
175 |
+
threshold=args.threshold,
|
176 |
+
idx_to_label=model.config.id2label
|
177 |
+
)
|
178 |
+
viz_image.show()
|
179 |
+
|
180 |
+
return
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
main()
|
main.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
from typing import Dict, List
|
9 |
+
|
10 |
+
from project_settings import project_path
|
11 |
+
|
12 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
+
import requests
|
19 |
+
import torch
|
20 |
+
from transformers.models.auto.processing_auto import AutoImageProcessor
|
21 |
+
from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
|
22 |
+
from transformers.models.auto.modeling_auto import AutoModelForObjectDetection
|
23 |
+
import validators
|
24 |
+
|
25 |
+
from project_settings import project_path
|
26 |
+
|
27 |
+
|
28 |
+
# colors for visualization
|
29 |
+
COLORS = [
|
30 |
+
[0.000, 0.447, 0.741],
|
31 |
+
[0.850, 0.325, 0.098],
|
32 |
+
[0.929, 0.694, 0.125],
|
33 |
+
[0.494, 0.184, 0.556],
|
34 |
+
[0.466, 0.674, 0.188],
|
35 |
+
[0.301, 0.745, 0.933]
|
36 |
+
]
|
37 |
+
|
38 |
+
|
39 |
+
def get_original_image(url_input):
|
40 |
+
if validators.url(url_input):
|
41 |
+
image = Image.open(requests.get(url_input, stream=True).raw)
|
42 |
+
return image
|
43 |
+
|
44 |
+
|
45 |
+
def figure2image(fig):
|
46 |
+
buf = io.BytesIO()
|
47 |
+
fig.savefig(buf)
|
48 |
+
buf.seek(0)
|
49 |
+
pil_image = Image.open(buf)
|
50 |
+
base_width = 750
|
51 |
+
width_percent = base_width / float(pil_image.size[0])
|
52 |
+
height_size = (float(pil_image.size[1]) * float(width_percent))
|
53 |
+
height_size = int(height_size)
|
54 |
+
pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS)
|
55 |
+
return pil_image
|
56 |
+
|
57 |
+
|
58 |
+
def non_max_suppression(boxes, scores, threshold):
|
59 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
60 |
+
overlapping bounding boxes for a given object.
|
61 |
+
Args:
|
62 |
+
boxes: array of [xmin, ymin, xmax, ymax]
|
63 |
+
scores: array of scores associated with each box.
|
64 |
+
threshold: IoU threshold
|
65 |
+
Return:
|
66 |
+
keep: indices of the boxes to keep
|
67 |
+
"""
|
68 |
+
x1 = boxes[:, 0]
|
69 |
+
y1 = boxes[:, 1]
|
70 |
+
x2 = boxes[:, 2]
|
71 |
+
y2 = boxes[:, 3]
|
72 |
+
|
73 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
74 |
+
order = scores.argsort()[::-1] # get boxes with more confidence first
|
75 |
+
|
76 |
+
keep = []
|
77 |
+
while order.size > 0:
|
78 |
+
i = order[0] # pick max confidence box
|
79 |
+
keep.append(i)
|
80 |
+
|
81 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
82 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
83 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
84 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
85 |
+
|
86 |
+
w = np.maximum(0.0, xx2 - xx1 + 1) # maximum width
|
87 |
+
h = np.maximum(0.0, yy2 - yy1 + 1) # maximum height
|
88 |
+
inter = w * h
|
89 |
+
|
90 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
91 |
+
inds = np.where(ovr <= threshold)[0]
|
92 |
+
order = order[inds + 1]
|
93 |
+
|
94 |
+
return keep
|
95 |
+
|
96 |
+
|
97 |
+
def draw_boxes(image, boxes, scores, labels, threshold: float,
|
98 |
+
idx_to_label: Dict[int, str] = None, labels_to_show: str = None):
|
99 |
+
if isinstance(labels_to_show, str):
|
100 |
+
if len(labels_to_show.strip()) == 0:
|
101 |
+
labels_to_show = None
|
102 |
+
else:
|
103 |
+
labels_to_show = labels_to_show.split(",")
|
104 |
+
labels_to_show = [label.strip().lower() for label in labels_to_show]
|
105 |
+
labels_to_show = None if len(labels_to_show) == 0 else labels_to_show
|
106 |
+
|
107 |
+
plt.figure(figsize=(50, 50))
|
108 |
+
plt.imshow(image)
|
109 |
+
|
110 |
+
if idx_to_label is not None:
|
111 |
+
labels = [idx_to_label[x] for x in labels]
|
112 |
+
|
113 |
+
axis = plt.gca()
|
114 |
+
colors = COLORS * len(boxes)
|
115 |
+
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
|
116 |
+
if labels_to_show is not None and label.lower() not in labels_to_show:
|
117 |
+
continue
|
118 |
+
if score < threshold:
|
119 |
+
continue
|
120 |
+
axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10))
|
121 |
+
axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8))
|
122 |
+
plt.axis("off")
|
123 |
+
|
124 |
+
return figure2image(plt.gcf())
|
125 |
+
|
126 |
+
|
127 |
+
def detr_object_detection(url_input: str,
|
128 |
+
image_input: Image,
|
129 |
+
pretrained_model_name_or_path: str = "qgyd2021/detr_cppe5_object_detection",
|
130 |
+
threshold: float = 0.5,
|
131 |
+
iou_threshold: float = 0.5,
|
132 |
+
labels_to_show: str = None,
|
133 |
+
):
|
134 |
+
# feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
|
135 |
+
model = AutoModelForObjectDetection.from_pretrained(pretrained_model_name_or_path)
|
136 |
+
image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)
|
137 |
+
|
138 |
+
# image
|
139 |
+
if validators.url(url_input):
|
140 |
+
image = get_original_image(url_input)
|
141 |
+
elif image_input:
|
142 |
+
image = image_input
|
143 |
+
else:
|
144 |
+
raise AssertionError("at least one `url_input` and `image_input`")
|
145 |
+
image_size = torch.tensor([tuple(reversed(image.size))])
|
146 |
+
|
147 |
+
# infer
|
148 |
+
# inputs = feature_extractor(images=image, return_tensors="pt")
|
149 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
150 |
+
outputs = model.forward(**inputs)
|
151 |
+
|
152 |
+
processed_outputs = image_processor.post_process_object_detection(
|
153 |
+
outputs, threshold=threshold, target_sizes=image_size)
|
154 |
+
# processed_outputs = feature_extractor.post_process(outputs, target_sizes=image_size)
|
155 |
+
processed_outputs = processed_outputs[0]
|
156 |
+
|
157 |
+
# draw box
|
158 |
+
boxes = processed_outputs["boxes"].detach().numpy()
|
159 |
+
scores = processed_outputs["scores"].detach().numpy()
|
160 |
+
labels = processed_outputs["labels"].detach().numpy()
|
161 |
+
|
162 |
+
keep = non_max_suppression(boxes, scores, threshold=iou_threshold)
|
163 |
+
boxes = boxes[keep]
|
164 |
+
scores = scores[keep]
|
165 |
+
labels = labels[keep]
|
166 |
+
|
167 |
+
viz_image: Image = draw_boxes(
|
168 |
+
image, boxes, scores, labels,
|
169 |
+
threshold=threshold,
|
170 |
+
idx_to_label=model.config.id2label,
|
171 |
+
labels_to_show=labels_to_show
|
172 |
+
)
|
173 |
+
return viz_image
|
174 |
+
|
175 |
+
|
176 |
+
def main():
|
177 |
+
|
178 |
+
title = "## Detr Cppe5 Object Detection"
|
179 |
+
|
180 |
+
description = """
|
181 |
+
reference:
|
182 |
+
https://huggingface.co/docs/transformers/tasks/object_detection
|
183 |
+
|
184 |
+
"""
|
185 |
+
|
186 |
+
example_urls = [
|
187 |
+
*[
|
188 |
+
[
|
189 |
+
"https://huggingface.co/datasets/qgyd2021/cppe-5/resolve/main/data/images/{}.png".format(idx),
|
190 |
+
"qgyd2021/detr_cppe5_object_detection",
|
191 |
+
0.25, 0.6, None
|
192 |
+
] for idx in range(1001, 1030)
|
193 |
+
]
|
194 |
+
]
|
195 |
+
|
196 |
+
example_images = [
|
197 |
+
[
|
198 |
+
"data/2lnWoly.jpg",
|
199 |
+
"qgyd2021/detr_cppe5_object_detection",
|
200 |
+
0.25, 0.6, None
|
201 |
+
]
|
202 |
+
]
|
203 |
+
|
204 |
+
with gr.Blocks() as blocks:
|
205 |
+
gr.Markdown(value=title)
|
206 |
+
gr.Markdown(value=description)
|
207 |
+
|
208 |
+
model_name = gr.components.Dropdown(
|
209 |
+
choices=[
|
210 |
+
"qgyd2021/detr_cppe5_object_detection",
|
211 |
+
],
|
212 |
+
value="qgyd2021/detr_cppe5_object_detection",
|
213 |
+
label="model_name",
|
214 |
+
)
|
215 |
+
threshold_slider = gr.components.Slider(
|
216 |
+
minimum=0, maximum=1.0,
|
217 |
+
step=0.01, value=0.5,
|
218 |
+
label="Threshold"
|
219 |
+
)
|
220 |
+
iou_threshold_slider = gr.components.Slider(
|
221 |
+
minimum=0, maximum=1.0,
|
222 |
+
step=0.1, value=0.5,
|
223 |
+
label="IOU Threshold"
|
224 |
+
)
|
225 |
+
classes_to_detect = gr.Textbox(placeholder="e.g. person, truck (split by , comma).",
|
226 |
+
label="labels to show")
|
227 |
+
|
228 |
+
with gr.Tabs():
|
229 |
+
with gr.TabItem("Image URL"):
|
230 |
+
with gr.Row():
|
231 |
+
with gr.Column():
|
232 |
+
url_input = gr.Textbox(lines=1, label="Enter valid image URL here..")
|
233 |
+
original_image = gr.Image()
|
234 |
+
url_input.change(get_original_image, url_input, original_image)
|
235 |
+
with gr.Column():
|
236 |
+
img_output_from_url = gr.Image()
|
237 |
+
|
238 |
+
with gr.Row():
|
239 |
+
gr.Examples(examples=example_urls,
|
240 |
+
inputs=[url_input, model_name, threshold_slider, iou_threshold_slider],
|
241 |
+
examples_per_page=5,
|
242 |
+
)
|
243 |
+
|
244 |
+
url_but = gr.Button("Detect")
|
245 |
+
|
246 |
+
with gr.TabItem("Image Upload"):
|
247 |
+
with gr.Row():
|
248 |
+
img_input = gr.Image(type="pil")
|
249 |
+
img_output_from_upload = gr.Image()
|
250 |
+
|
251 |
+
with gr.Row():
|
252 |
+
gr.Examples(examples=example_images,
|
253 |
+
inputs=[img_input, model_name, threshold_slider, iou_threshold_slider],
|
254 |
+
examples_per_page=5,
|
255 |
+
)
|
256 |
+
|
257 |
+
img_but = gr.Button("Detect")
|
258 |
+
|
259 |
+
inputs = [url_input, img_input, model_name, threshold_slider, iou_threshold_slider, classes_to_detect]
|
260 |
+
url_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_url], queue=True)
|
261 |
+
img_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_upload], queue=True)
|
262 |
+
|
263 |
+
blocks.queue().launch()
|
264 |
+
return
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == '__main__':
|
268 |
+
main()
|
project_settings.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
project_path = os.path.abspath(os.path.dirname(__file__))
|
8 |
+
project_path = Path(project_path)
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.38.0
|
2 |
+
transformers==4.30.2
|
3 |
+
torch==1.13.1
|
4 |
+
validators==0.22.0
|
5 |
+
albumentations==1.3.1
|
6 |
+
timm==0.9.7
|